From 621a1e45d361ed66cd40517c5541f6a4ebad30d8 Mon Sep 17 00:00:00 2001 From: Aayushongit <141538111+Aayushongit@users.noreply.github.com> Date: Sun, 20 Apr 2025 23:29:26 +0530 Subject: [PATCH] Update diffusion_utils.py --- diffusion_tf/diffusion_utils.py | 58 ++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/diffusion_tf/diffusion_utils.py b/diffusion_tf/diffusion_utils.py index f41111e..2eddd67 100644 --- a/diffusion_tf/diffusion_utils.py +++ b/diffusion_tf/diffusion_utils.py @@ -32,6 +32,15 @@ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_time betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) elif beta_schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 betas = 1. / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == 'cosine': + s = 0.008 + steps = num_diffusion_timesteps + 1 + x = np.linspace(0, num_diffusion_timesteps, steps, dtype=np.float64) + alphas_cumprod = np.cos(((x / num_diffusion_timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + betas = np.clip(betas, 0, 0.999) + else: raise NotImplementedError(beta_schedule) assert betas.shape == (num_diffusion_timesteps,) @@ -84,7 +93,25 @@ def __init__(self, *, betas, loss_type, tf_dtype=tf.float32): betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod), dtype=tf_dtype) self.posterior_mean_coef2 = tf.constant( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod), dtype=tf_dtype) - + + self.snr = tf.constant(alphas_cumprod / (1.0 - alphas_cumprod), dtype=tf_dtype) + + + def get_timestep_embedding(self, timesteps, embedding_dim, max_positions=10000): + half_dim = embedding_dim // 2 + emb = tf.math.log(tf.cast(max_positions, tf.float32)) / (half_dim - 1) + emb = tf.exp(tf.range(half_dim, dtype=tf.float32) * -emb) + emb = tf.cast(timesteps, tf.float32)[:, None] * emb[None, :] + emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=1) + if embedding_dim % 2 == 1: + emb = tf.pad(emb, [[0, 0], [0, 1]]) + return emb + +def _scale_timesteps(self, t): + if self.rescale_timesteps: + return tf.cast(t, tf.float32) * (1000.0 / self.num_timesteps) + return tf.cast(t, tf.float32) + @staticmethod def _extract(a, t, x_shape): """ @@ -156,6 +183,35 @@ def p_losses(self, denoise_fn, x_start, t, noise=None): # predict the noise instead of x_start. seems to be weighted naturally like SNR assert x_recon.shape == x_start.shape losses = nn.meanflat(tf.squared_difference(noise, x_recon)) + elif self.loss_type == 'mse': + # predict x_start directly + if model_output.shape.as_list()[-1] == x_start.shape.as_list()[-1] * 2 and self.learned_variance: + model_output, _ = tf.split(model_output, 2, axis=-1) + assert model_output.shape == x_start.shape + x_recon = model_output + if clip_denoised: + x_recon = tf.clip_by_value(x_recon, -1., 1.) + losses = nn.meanflat(tf.squared_difference(x_start, x_recon)) + elif self.loss_type == 'hybrid': + # Combination of noise prediction and data prediction + if model_output.shape.as_list()[-1] == x_start.shape.as_list()[-1] * 2 and self.learned_variance: + model_output, _ = tf.split(model_output, 2, axis=-1) + assert model_output.shape == x_start.shape + + # Reconstruct x_start from predicted noise + x_recon = self.predict_start_from_noise(x_noisy, t=t, noise=model_output) + if clip_denoised: + x_recon = tf.clip_by_value(x_recon, -1., 1.) + + # Compute both losses + noise_loss = nn.meanflat(tf.squared_difference(noise, model_output)) + data_loss = nn.meanflat(tf.squared_difference(x_start, x_recon)) + + # Weight losses by SNR at timestep t + snr = self._extract(self.snr, t, x_start.shape) + weights = 1.0 / (1.0 + snr) + losses = weights * noise_loss + (1.0 - weights) * data_loss + else: raise NotImplementedError(self.loss_type)