+
Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/genmo/mochi_preview/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP:
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=param_dtype,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
auto_wrap_policy=auto_wrap_policy,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
Expand Down Expand Up @@ -115,7 +115,7 @@ def get_model(self, *, local_rank, device_id, world_size):
model = setup_fsdp_sync(
model,
device_id=device_id,
param_dtype=torch.float32,
param_dtype=torch.bfloat16,
auto_wrap_policy=partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
Expand Down Expand Up @@ -236,7 +236,7 @@ def get_model(
model = setup_fsdp_sync(
model,
device_id=device_id,
param_dtype=torch.float32,
param_dtype=torch.bfloat16,
auto_wrap_policy=partial(
lambda_auto_wrap_policy,
lambda_fn=lambda m: m in model.blocks,
Expand Down Expand Up @@ -360,7 +360,7 @@ def get_conditioning_for_prompts(tokenizer, encoder, device, prompts: List[str])
# Sometimes returns a tensor, othertimes a tuple, not sure why
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
assert y_feat[-1].dtype == torch.float32
assert y_feat[-1].dtype == torch.bfloat16

return dict(y_mask=y_mask, y_feat=y_feat)

Expand Down Expand Up @@ -436,7 +436,7 @@ def sample_model(device, dit, conditioning, **args):
z = torch.randn(
(B, IN_CHANNELS, latent_t, latent_h, latent_w),
device=device,
dtype=torch.float32,
dtype=torch.bfloat16,
)

num_latents = latent_t * latent_h * latent_w
Expand Down Expand Up @@ -477,7 +477,7 @@ def model_fn(*, z, sigma, cfg_scale):
sigma=torch.full([B] if cond_text else [B * 2], sigma, device=z.device),
cfg_scale=cfg_schedule[i],
)
assert pred.dtype == torch.float32
assert pred.dtype == torch.bfloat16
z = z + dsigma * pred

z = z[:B] if cond_batched else z
Expand All @@ -496,6 +496,7 @@ def move_to_device(model: nn.Module, target_device, *, enabled=True):
else:
print(f"moving model from {og_device} -> {target_device}")

model.to(torch.bfloat16)
model.to(target_device)
yield
if og_device != target_device:
Expand Down
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载