From 3fbda4da50bdd730d6721fdedb4717ec12da1cad Mon Sep 17 00:00:00 2001 From: Aayushongit <141538111+Aayushongit@users.noreply.github.com> Date: Sun, 20 Apr 2025 23:30:24 +0530 Subject: [PATCH] Update diffusion_utils.py --- diffusion_tf/diffusion_utils.py | 303 +++++++++++++++++++++++++++++--- 1 file changed, 282 insertions(+), 21 deletions(-) diff --git a/diffusion_tf/diffusion_utils.py b/diffusion_tf/diffusion_utils.py index f41111e..366b954 100644 --- a/diffusion_tf/diffusion_utils.py +++ b/diffusion_tf/diffusion_utils.py @@ -32,6 +32,14 @@ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_time betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) elif beta_schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 betas = 1. / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == 'cosine': + s = 0.008 + steps = num_diffusion_timesteps + 1 + x = np.linspace(0, num_diffusion_timesteps, steps, dtype=np.float64) + alphas_cumprod = np.cos(((x / num_diffusion_timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + betas = np.clip(betas, 0, 0.999) else: raise NotImplementedError(beta_schedule) assert betas.shape == (num_diffusion_timesteps,) @@ -49,8 +57,10 @@ class GaussianDiffusion: Contains utilities for the diffusion model. """ - def __init__(self, *, betas, loss_type, tf_dtype=tf.float32): + def __init__(self, *, betas, loss_type, tf_dtype=tf.float32, learned_variance=False, rescale_timesteps=False): self.loss_type = loss_type + self.learned_variance = learned_variance + self.rescale_timesteps = rescale_timesteps assert isinstance(betas, np.ndarray) self.np_betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy @@ -85,6 +95,24 @@ def __init__(self, *, betas, loss_type, tf_dtype=tf.float32): self.posterior_mean_coef2 = tf.constant( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod), dtype=tf_dtype) + # Store SNR values for diagnostics + self.snr = tf.constant(alphas_cumprod / (1.0 - alphas_cumprod), dtype=tf_dtype) + + def get_timestep_embedding(self, timesteps, embedding_dim, max_positions=10000): + half_dim = embedding_dim // 2 + emb = tf.math.log(tf.cast(max_positions, tf.float32)) / (half_dim - 1) + emb = tf.exp(tf.range(half_dim, dtype=tf.float32) * -emb) + emb = tf.cast(timesteps, tf.float32)[:, None] * emb[None, :] + emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=1) + if embedding_dim % 2 == 1: + emb = tf.pad(emb, [[0, 0], [0, 1]]) + return emb + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return tf.cast(t, tf.float32) * (1000.0 / self.num_timesteps) + return tf.cast(t, tf.float32) + @staticmethod def _extract(a, t, x_shape): """ @@ -137,7 +165,7 @@ def q_posterior(self, x_start, x_t, t): x_start.shape[0]) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_losses(self, denoise_fn, x_start, t, noise=None): + def p_losses(self, denoise_fn, x_start, t, noise=None, clip_denoised=True): """ Training loss calculation """ @@ -148,46 +176,98 @@ def p_losses(self, denoise_fn, x_start, t, noise=None): noise = tf.random_normal(shape=x_start.shape, dtype=x_start.dtype) assert noise.shape == x_start.shape and noise.dtype == x_start.dtype x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - x_recon = denoise_fn(x_noisy, t) - assert x_noisy.shape == x_start.shape - assert x_recon.shape[:3] == [B, H, W] and len(x_recon.shape) == 4 + + model_output = denoise_fn(x_noisy, self._scale_timesteps(t)) if self.loss_type == 'noisepred': # predict the noise instead of x_start. seems to be weighted naturally like SNR - assert x_recon.shape == x_start.shape - losses = nn.meanflat(tf.squared_difference(noise, x_recon)) + assert model_output.shape == x_start.shape + losses = nn.meanflat(tf.squared_difference(noise, model_output)) + elif self.loss_type == 'mse': + # predict x_start directly + if model_output.shape.as_list()[-1] == x_start.shape.as_list()[-1] * 2 and self.learned_variance: + model_output, _ = tf.split(model_output, 2, axis=-1) + assert model_output.shape == x_start.shape + x_recon = model_output + if clip_denoised: + x_recon = tf.clip_by_value(x_recon, -1., 1.) + losses = nn.meanflat(tf.squared_difference(x_start, x_recon)) + elif self.loss_type == 'hybrid': + # Combination of noise prediction and data prediction + if model_output.shape.as_list()[-1] == x_start.shape.as_list()[-1] * 2 and self.learned_variance: + model_output, _ = tf.split(model_output, 2, axis=-1) + assert model_output.shape == x_start.shape + + # Reconstruct x_start from predicted noise + x_recon = self.predict_start_from_noise(x_noisy, t=t, noise=model_output) + if clip_denoised: + x_recon = tf.clip_by_value(x_recon, -1., 1.) + + # Compute both losses + noise_loss = nn.meanflat(tf.squared_difference(noise, model_output)) + data_loss = nn.meanflat(tf.squared_difference(x_start, x_recon)) + + # Weight losses by SNR at timestep t + snr = self._extract(self.snr, t, x_start.shape) + weights = 1.0 / (1.0 + snr) + losses = weights * noise_loss + (1.0 - weights) * data_loss else: raise NotImplementedError(self.loss_type) assert losses.shape == [B] return losses - def p_mean_variance(self, denoise_fn, *, x, t, clip_denoised: bool): + def p_mean_variance(self, denoise_fn, *, x, t, clip_denoised: bool, return_pred_xstart=False, model_kwargs=None): + if model_kwargs is None: + model_kwargs = {} + + model_output = denoise_fn(x, self._scale_timesteps(t), **model_kwargs) + + if self.learned_variance and model_output.shape.as_list()[-1] == x.shape.as_list()[-1] * 2: + model_output, model_var_values = tf.split(model_output, 2, axis=-1) + # Learn the variance using the variational bound + min_log = self._extract(self.posterior_log_variance_clipped, t, x.shape) + max_log = self._extract(tf.math.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = tf.exp(model_log_variance) + else: + model_variance = self._extract(self.posterior_variance, t, x.shape) + model_log_variance = self._extract(self.posterior_log_variance_clipped, t, x.shape) + if self.loss_type == 'noisepred': - x_recon = self.predict_start_from_noise(x, t=t, noise=denoise_fn(x, t)) + x_recon = self.predict_start_from_noise(x, t=t, noise=model_output) + elif self.loss_type in ['mse', 'hybrid']: + x_recon = model_output else: raise NotImplementedError(self.loss_type) if clip_denoised: x_recon = tf.clip_by_value(x_recon, -1., 1.) - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) - assert model_mean.shape == x_recon.shape == x.shape - assert posterior_variance.shape == posterior_log_variance.shape == [x.shape[0], 1, 1, 1] - return model_mean, posterior_variance, posterior_log_variance + model_mean, _, _ = self.q_posterior(x_start=x_recon, x_t=x, t=t) + assert model_mean.shape == model_log_variance.shape == x.shape + + if return_pred_xstart: + return model_mean, model_variance, model_log_variance, x_recon + else: + return model_mean, model_variance, model_log_variance - def p_sample(self, denoise_fn, *, x, t, noise_fn, clip_denoised=True, repeat_noise=False): + def p_sample(self, denoise_fn, *, x, t, noise_fn, clip_denoised=True, repeat_noise=False, model_kwargs=None): """ Sample from the model """ - model_mean, _, model_log_variance = self.p_mean_variance(denoise_fn, x=x, t=t, clip_denoised=clip_denoised) + model_mean, _, model_log_variance = self.p_mean_variance( + denoise_fn, x=x, t=t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) noise = noise_like(x.shape, noise_fn, repeat_noise) assert noise.shape == x.shape # no noise when t == 0 nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [x.shape[0]] + [1] * (len(x.shape) - 1)) return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise - def p_sample_loop(self, denoise_fn, *, shape, noise_fn=tf.random_normal): + def p_sample_loop(self, denoise_fn, *, shape, noise_fn=tf.random_normal, model_kwargs=None): """ Generate samples """ @@ -198,7 +278,71 @@ def p_sample_loop(self, denoise_fn, *, shape, noise_fn=tf.random_normal): cond=lambda i_, _: tf.greater_equal(i_, 0), body=lambda i_, img_: [ i_ - 1, - self.p_sample(denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn) + self.p_sample(denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), + noise_fn=noise_fn, model_kwargs=model_kwargs) + ], + loop_vars=[i_0, img_0], + shape_invariants=[i_0.shape, img_0.shape], + back_prop=False + ) + assert img_final.shape == shape + return img_final + + def ddim_sample(self, denoise_fn, *, x, t, clip_denoised=True, eta=0.0, model_kwargs=None): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample. + """ + if model_kwargs is None: + model_kwargs = {} + + out = self.p_mean_variance( + denoise_fn, + x=x, + t=t, + clip_denoised=clip_denoised, + return_pred_xstart=True, + model_kwargs=model_kwargs, + ) + + model_mean, _, model_log_variance, pred_xstart = out + + # Compute eta * variance + if eta > 0: + # eta * variance + variance = self._extract(self.posterior_variance, t, x.shape) * eta + std = tf.sqrt(variance) + else: + # DDIM: no noise + std = 0.0 + + # Sample from predicted mean + scaled noise + noise = tf.random_normal(shape=x.shape) + # Signal for t=0: don't add noise + nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [x.shape[0]] + [1] * (len(x.shape) - 1)) + + sample = model_mean + nonzero_mask * std * noise + return sample + + def ddim_sample_loop(self, denoise_fn, *, shape, noise_fn=tf.random_normal, eta=0.0, model_kwargs=None): + """ + Generate samples using DDIM. + Generally faster than regular sampling for similar quality results. + """ + i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32) + assert isinstance(shape, (tuple, list)) + img_0 = noise_fn(shape=shape, dtype=tf.float32) + _, img_final = tf.while_loop( + cond=lambda i_, _: tf.greater_equal(i_, 0), + body=lambda i_, img_: [ + i_ - 1, + self.ddim_sample( + denoise_fn=denoise_fn, + x=img_, + t=tf.fill([shape[0]], i_), + eta=eta, + model_kwargs=model_kwargs + ) ], loop_vars=[i_0, img_0], shape_invariants=[i_0.shape, img_0.shape], @@ -207,13 +351,50 @@ def p_sample_loop(self, denoise_fn, *, shape, noise_fn=tf.random_normal): assert img_final.shape == shape return img_final - def p_sample_loop_trajectory(self, denoise_fn, *, shape, noise_fn=tf.random_normal, repeat_noise_steps=-1): + def ddim_sample_loop_trajectory(self, denoise_fn, *, shape, noise_fn=tf.random_normal, eta=0.0, + skip_steps=10, model_kwargs=None): + """ + Generate samples using DDIM, returning intermediate images. + skip_steps allow for skipping steps to increase efficiency. + """ + steps = self.num_timesteps // skip_steps + seq = range(0, self.num_timesteps, skip_steps) + seq_next = [-1] + list(seq[:-1]) + + i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32) + assert isinstance(shape, (tuple, list)) + img_0 = noise_fn(shape=shape, dtype=tf.float32) + + times = tf.Variable([i_0]) + imgs = tf.Variable([img_0]) + + for i, t in enumerate(seq): + t_tensor = tf.constant(t, dtype=tf.int32) + t_batch = tf.fill([shape[0]], t_tensor) + + # DDIM sampling step + img_next = self.ddim_sample( + denoise_fn=denoise_fn, + x=imgs[-1], + t=t_batch, + eta=eta, + model_kwargs=model_kwargs + ) + + # Store results + times = tf.concat([times, [t_tensor]], 0) + imgs = tf.concat([imgs, [img_next]], 0) + + assert imgs[-1].shape == shape + return times, imgs + + def p_sample_loop_trajectory(self, denoise_fn, *, shape, noise_fn=tf.random_normal, repeat_noise_steps=-1, model_kwargs=None): """ Generate samples, returning intermediate images Useful for visualizing how denoised images evolve over time Args: repeat_noise_steps (int): Number of denoising timesteps in which the same noise - is used across the batch. If >= 0, the initial noise is the same for all batch elemements. + is used across the batch. If >= 0, the initial noise is the same for all batch elements. """ i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32) assert isinstance(shape, (tuple, list)) @@ -229,7 +410,8 @@ def p_sample_loop_trajectory(self, denoise_fn, *, shape, noise_fn=tf.random_norm x=imgs_[-1], t=tf.fill([shape[0]], times_[-1]), noise_fn=noise_fn, - repeat_noise=True)]], 0) + repeat_noise=True, + model_kwargs=model_kwargs)]], 0) ], loop_vars=[times, imgs], shape_invariants=[tf.TensorShape([None, *i_0.shape]), @@ -245,7 +427,8 @@ def p_sample_loop_trajectory(self, denoise_fn, *, shape, noise_fn=tf.random_norm x=imgs_[-1], t=tf.fill([shape[0]], times_[-1]), noise_fn=noise_fn, - repeat_noise=False)]], 0) + repeat_noise=False, + model_kwargs=model_kwargs)]], 0) ], loop_vars=[times, imgs], shape_invariants=[tf.TensorShape([None, *i_0.shape]), @@ -255,6 +438,59 @@ def p_sample_loop_trajectory(self, denoise_fn, *, shape, noise_fn=tf.random_norm assert imgs[-1].shape == shape return times, imgs + def guided_p_sample(self, denoise_fn, *, x, t, cond_fn, noise_fn, clip_denoised=True, + repeat_noise=False, model_kwargs=None): + """ + Sample with classifier guidance (conditional sampling) + """ + if model_kwargs is None: + model_kwargs = {} + + model_mean, model_variance, model_log_variance = self.p_mean_variance( + denoise_fn, x=x, t=t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + + # Apply conditioning function to adjust mean + if cond_fn is not None: + gradient = cond_fn(x, t, **model_kwargs) + model_mean = model_mean + model_variance * gradient + + # Add noise + noise = noise_like(x.shape, noise_fn, repeat_noise) + nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [x.shape[0]] + [1] * (len(x.shape) - 1)) + sample = model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise + + return sample + + def guided_p_sample_loop(self, denoise_fn, *, shape, cond_fn, noise_fn=tf.random_normal, + clip_denoised=True, model_kwargs=None): + """ + Generate samples with classifier guidance + """ + i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32) + assert isinstance(shape, (tuple, list)) + img_0 = noise_fn(shape=shape, dtype=tf.float32) + _, img_final = tf.while_loop( + cond=lambda i_, _: tf.greater_equal(i_, 0), + body=lambda i_, img_: [ + i_ - 1, + self.guided_p_sample( + denoise_fn=denoise_fn, + x=img_, + t=tf.fill([shape[0]], i_), + cond_fn=cond_fn, + noise_fn=noise_fn, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs + ) + ], + loop_vars=[i_0, img_0], + shape_invariants=[i_0.shape, img_0.shape], + back_prop=False + ) + assert img_final.shape == shape + return img_final + def interpolate(self, denoise_fn, *, shape, noise_fn=tf.random_normal): """ Interpolate between images. @@ -297,3 +533,28 @@ def interpolate(self, denoise_fn, *, shape, noise_fn=tf.random_normal): assert x_interp.shape == shape return x1, x2, lam, x_interp, t + + def slerp(self, denoise_fn, *, shape, noise_fn=tf.random_normal): + """ + Spherical linear interpolation between images. + Better preserves the norm of the latents during interpolation. + """ + assert isinstance(shape, (tuple, list)) + + # Placeholders for real samples to interpolate + x1 = tf.placeholder(tf.float32, shape) + x2 = tf.placeholder(tf.float32, shape) + # lam == 0.5 averages diffused images. + lam = tf.placeholder(tf.float32, shape=()) + t = tf.placeholder(tf.int32, shape=()) + + # Add noise via forward diffusion + t_batched = tf.stack([t] * x1.shape[0]) + xt1 = self.q_sample(x1, t=t_batched) + xt2 = self.q_sample(x2, t=t_batched) + + # Normalize to unit sphere + norm1 = tf.sqrt(tf.reduce_sum(xt1**2, axis=[1, 2, 3], keepdims=True)) + norm2 = tf.sqrt(tf.reduce_sum(xt2**2, axis=[1, 2, 3], keepdims=True)) + + xt1_