这是indexloc提供的服务,不要输入任何密码
Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion diffusion_tf/diffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_time
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1. / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == 'cosine':
s = 0.008
steps = num_diffusion_timesteps + 1
x = np.linspace(0, num_diffusion_timesteps, steps, dtype=np.float64)
alphas_cumprod = np.cos(((x / num_diffusion_timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
betas = np.clip(betas, 0, 0.999)

else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
Expand Down Expand Up @@ -84,7 +93,25 @@ def __init__(self, *, betas, loss_type, tf_dtype=tf.float32):
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod), dtype=tf_dtype)
self.posterior_mean_coef2 = tf.constant(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod), dtype=tf_dtype)


self.snr = tf.constant(alphas_cumprod / (1.0 - alphas_cumprod), dtype=tf_dtype)


def get_timestep_embedding(self, timesteps, embedding_dim, max_positions=10000):
half_dim = embedding_dim // 2
emb = tf.math.log(tf.cast(max_positions, tf.float32)) / (half_dim - 1)
emb = tf.exp(tf.range(half_dim, dtype=tf.float32) * -emb)
emb = tf.cast(timesteps, tf.float32)[:, None] * emb[None, :]
emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=1)
if embedding_dim % 2 == 1:
emb = tf.pad(emb, [[0, 0], [0, 1]])
return emb

def _scale_timesteps(self, t):
if self.rescale_timesteps:
return tf.cast(t, tf.float32) * (1000.0 / self.num_timesteps)
return tf.cast(t, tf.float32)

@staticmethod
def _extract(a, t, x_shape):
"""
Expand Down Expand Up @@ -156,6 +183,35 @@ def p_losses(self, denoise_fn, x_start, t, noise=None):
# predict the noise instead of x_start. seems to be weighted naturally like SNR
assert x_recon.shape == x_start.shape
losses = nn.meanflat(tf.squared_difference(noise, x_recon))
elif self.loss_type == 'mse':
# predict x_start directly
if model_output.shape.as_list()[-1] == x_start.shape.as_list()[-1] * 2 and self.learned_variance:
model_output, _ = tf.split(model_output, 2, axis=-1)
assert model_output.shape == x_start.shape
x_recon = model_output
if clip_denoised:
x_recon = tf.clip_by_value(x_recon, -1., 1.)
losses = nn.meanflat(tf.squared_difference(x_start, x_recon))
elif self.loss_type == 'hybrid':
# Combination of noise prediction and data prediction
if model_output.shape.as_list()[-1] == x_start.shape.as_list()[-1] * 2 and self.learned_variance:
model_output, _ = tf.split(model_output, 2, axis=-1)
assert model_output.shape == x_start.shape

# Reconstruct x_start from predicted noise
x_recon = self.predict_start_from_noise(x_noisy, t=t, noise=model_output)
if clip_denoised:
x_recon = tf.clip_by_value(x_recon, -1., 1.)

# Compute both losses
noise_loss = nn.meanflat(tf.squared_difference(noise, model_output))
data_loss = nn.meanflat(tf.squared_difference(x_start, x_recon))

# Weight losses by SNR at timestep t
snr = self._extract(self.snr, t, x_start.shape)
weights = 1.0 / (1.0 + snr)
losses = weights * noise_loss + (1.0 - weights) * data_loss

else:
raise NotImplementedError(self.loss_type)

Expand Down