import torch
import torch.nn as nn
from absl import logging
import numpy as np
import math
from tqdm import tqdm
def get_sde(name, **kwargs):
if name == 'vpsde':
return VPSDE(**kwargs)
elif name == 'vpsde_cosine':
return VPSDECosine(**kwargs)
else:
raise NotImplementedError
def stp(s, ts: torch.Tensor):
if isinstance(s, np.ndarray):
s = torch.from_numpy(s).type_as(ts)
extra_dims = (1,) * (ts.dim() - 1)
return s.view(-1, *extra_dims) * ts
def mos(a, start_dim=1):
return a.pow(2).flatten(start_dim=start_dim).mean(dim=-1)
def duplicate(tensor, *size):
return tensor.unsqueeze(dim=0).expand(*size, *tensor.shape)
class SDE(object):
r"""
dx = f(x, t)dt + g(t) dw with 0 <= t <= 1
f(x, t) is the drift
g(t) is the diffusion
"""
def drift(self, x, t):
raise NotImplementedError
def diffusion(self, t):
raise NotImplementedError
def cum_beta(self, t):
raise NotImplementedError
def cum_alpha(self, t):
raise NotImplementedError
def snr(self, t):
raise NotImplementedError
def nsr(self, t):
raise NotImplementedError
def marginal_prob(self, x0, t):
alpha = self.cum_alpha(t)
beta = self.cum_beta(t)
mean = stp(alpha ** 0.5, x0)
std = beta ** 0.5
return mean, std
def sample(self, x0, t_init=0):
t = torch.rand(x0.shape[0], device=x0.device) * (1. - t_init) + t_init
mean, std = self.marginal_prob(x0, t)
eps = torch.randn_like(x0)
xt = mean + stp(std, eps)
return t, eps, xt
class VPSDE(SDE):
def __init__(self, beta_min=0.1, beta_max=20):
self.beta_0 = beta_min
self.beta_1 = beta_max
def drift(self, x, t):
return -0.5 * stp(self.squared_diffusion(t), x)
def diffusion(self, t):
return self.squared_diffusion(t) ** 0.5
def squared_diffusion(self, t):
return self.beta_0 + t * (self.beta_1 - self.beta_0)
def squared_diffusion_integral(self, s, t):
return self.beta_0 * (t - s) + (self.beta_1 - self.beta_0) * (t ** 2 - s ** 2) * 0.5
def skip_beta(self, s, t):
return 1. - self.skip_alpha(s, t)
def skip_alpha(self, s, t):
x = -self.squared_diffusion_integral(s, t)
return x.exp()
def cum_beta(self, t):
return self.skip_beta(0, t)
def cum_alpha(self, t):
return self.skip_alpha(0, t)
def nsr(self, t):
return self.squared_diffusion_integral(0, t).expm1()
def snr(self, t):
return 1. / self.nsr(t)
def __str__(self):
return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}'
def __repr__(self):
return f'vpsde beta_0={self.beta_0} beta_1={self.beta_1}'
class VPSDECosine(SDE):
r"""
dx = f(x, t)dt + g(t) dw with 0 <= t <= 1
f(x, t) is the drift
g(t) is the diffusion
"""
def __init__(self, s=0.008):
self.s = s
self.F = lambda t: torch.cos((t + s) / (1 + s) * math.pi / 2) ** 2
self.F0 = math.cos(s / (1 + s) * math.pi / 2) ** 2
def drift(self, x, t):
ft = - torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi / 2
return stp(ft, x)
def diffusion(self, t):
return (torch.tan((t + self.s) / (1 + self.s) * math.pi / 2) / (1 + self.s) * math.pi) ** 0.5
def cum_beta(self, t):
return 1 - self.cum_alpha(t)
def cum_alpha(self, t):
return self.F(t) / self.F0
def snr(self, t):
Ft = self.F(t)
return Ft / (self.F0 - Ft)
def nsr(self, t):
Ft = self.F(t)
return self.F0 / Ft - 1
def __str__(self):
return 'vpsde_cosine'
def __repr__(self):
return 'vpsde_cosine'
class ScoreModel(object):
r"""
The forward process is q(x_[0,T])
"""
def __init__(self, nnet: nn.Module, pred: str, sde: SDE, T=1):
assert T == 1
self.nnet = nnet
self.pred = pred
self.sde = sde
self.T = T
print(f'ScoreModel with pred={pred}, sde={sde}, T={T}')
def predict(self, xt, t, **kwargs):
if not isinstance(t, torch.Tensor):
t = torch.tensor(t)
t = t.to(xt.device)
if t.dim() == 0:
t = duplicate(t, xt.size(0))
return self.nnet(xt, t * 999, **kwargs)
def noise_pred(self, xt, t, **kwargs):
pred = self.predict(xt, t, **kwargs)
if self.pred == 'noise_pred':
noise_pred = pred
elif self.pred == 'x0_pred':
noise_pred = - stp(self.sde.snr(t).sqrt(), pred) + stp(self.sde.cum_beta(t).rsqrt(), xt)
else:
raise NotImplementedError
return noise_pred
def x0_pred(self, xt, t, **kwargs):
pred = self.predict(xt, t, **kwargs)
if self.pred == 'noise_pred':
x0_pred = stp(self.sde.cum_alpha(t).rsqrt(), xt) - stp(self.sde.nsr(t).sqrt(), pred)
elif self.pred == 'x0_pred':
x0_pred = pred
else:
raise NotImplementedError
return x0_pred
def score(self, xt, t, **kwargs):
cum_beta = self.sde.cum_beta(t)
noise_pred = self.noise_pred(xt, t, **kwargs)
return stp(-cum_beta.rsqrt(), noise_pred)
class ReverseSDE(object):
r"""
dx = [f(x, t) - g(t)^2 s(x, t)] dt + g(t) dw
"""
def __init__(self, score_model):
self.sde = score_model.sde
self.score_model = score_model
def drift(self, x, t, **kwargs):
drift = self.sde.drift(x, t)
diffusion = self.sde.diffusion(t)
score = self.score_model.score(x, t, **kwargs)
return drift - stp(diffusion ** 2, score)
def diffusion(self, t):
return self.sde.diffusion(t)
class ODE(object):
r"""
dx = [f(x, t) - g(t)^2 s(x, t)] dt
"""
def __init__(self, score_model):
self.sde = score_model.sde
self.score_model = score_model
def drift(self, x, t, **kwargs):
drift = self.sde.drift(x, t)
diffusion = self.sde.diffusion(t)
score = self.score_model.score(x, t, **kwargs)
return drift - 0.5 * stp(diffusion ** 2, score)
def diffusion(self, t):
return 0
def dct2str(dct):
return str({k: f'{v:.6g}' for k, v in dct.items()})
@ torch.no_grad()
def euler_maruyama(rsde, x_init, sample_steps, eps=1e-3, T=1, trace=None, verbose=False, **kwargs):
r"""
The Euler Maruyama sampler for reverse SDE / ODE
See `Score-Based Generative Modeling through Stochastic Differential Equations`
"""
assert isinstance(rsde, ReverseSDE) or isinstance(rsde, ODE)
print(f"euler_maruyama with sample_steps={sample_steps}")
timesteps = np.append(0., np.linspace(eps, T, sample_steps))
timesteps = torch.tensor(timesteps).to(x_init)
x = x_init
if trace is not None:
trace.append(x)
for s, t in tqdm(list(zip(timesteps, timesteps[1:]))[::-1], disable=not verbose, desc='euler_maruyama'):
drift = rsde.drift(x, t, **kwargs)
diffusion = rsde.diffusion(t)
dt = s - t
mean = x + drift * dt
sigma = diffusion * (-dt).sqrt()
x = mean + stp(sigma, torch.randn_like(x)) if s != 0 else mean
if trace is not None:
trace.append(x)
statistics = dict(s=s, t=t, sigma=sigma.item())
logging.debug(dct2str(statistics))
return x
def LSimple(score_model: ScoreModel, x0, pred='noise_pred', **kwargs):
t, noise, xt = score_model.sde.sample(x0)
if pred == 'noise_pred':
noise_pred = score_model.noise_pred(xt, t, **kwargs)
return mos(noise - noise_pred)
elif pred == 'x0_pred':
x0_pred = score_model.x0_pred(xt, t, **kwargs)
return mos(x0 - x0_pred)
else:
raise NotImplementedError(pred)