+
Skip to content

lumalabs/imm

Repository files navigation

Inductive Moment Matching

Official Implementation of Inductive Moment Matching

1Luma AI, 2Stanford University

Also check out our accompanying position paper that explains the motivation and ways of designing new generative paradigms.

Checklist

  • Add model weights and model definitions.
  • Add inference scripts.
  • Add evaluation scripts.
  • Add training scripts.

Dependencies

To install all packages in this codebase along with their dependencies, run

conda env create -f env.yml

For multi-node jobs, we use Slurm via submitit, which can be installed via

pip install submitit

Pre-trained models

We provide pretrained checkpoints through our repo on Hugging Face:

Datasets

Datasets are stored as uncompressed ZIP archives containing uncompressed PNG or NPY files, along with a metadata file dataset.json for labels. When using latent diffusion, it is necessary to create two different versions of a given dataset: the original RGB version, used for evaluation, and a VAE-encoded latent version, used for training.

To set up CIFAR-10:

Download the (CIFAR-10 python version)[https://www.cs.toronto.edu/~kriz/cifar.html] and convert via:

python dataset_tool.py --source=YOUR_DOWNLOADED_TARGZ_FILE \
    --dest=YOUR_PATH/datasets/cifar10-32x32.zip 

To set up ImageNet-256:

  1. Download the ILSVRC2012 data archive from Kaggle and extract it somewhere, e.g., downloads/imagenet.

  2. Crop and resize the images to create the original RGB dataset:

# Convert raw ImageNet data to a ZIP archive at 256x256 resolution
python dataset_tool.py convert --source=downloads/imagenet/ILSVRC/Data/CLS-LOC/train \
    --dest=YOUR_PATH/datasets/img256.zip --resolution=256x256 --transform=center-crop-dhariwal
  1. Run the images through a pre-trained VAE encoder to create the corresponding latent dataset:
# Convert the pixel data to VAE latents
python dataset_tool.py encode --source=YOUR_PATH/datasets/img256.zip \
    --dest=YOUR_PATH/datasets/img256-sd.zip
  1. Calculate reference statistics for the original RGB dataset, to be used with calculate_metrics.py:
  • You can find precalculated reference stats for ImageNet-256x256 here.
# Compute dataset reference statistics for calculating metrics
python calculate_metrics.py ref --data=YOUR_PATH/datasets/img256.zip \
    --dest=YOUR_PATH/dataset-refs/img256.pkl

Training

The default configs are in configs/. Before training, properly replace the logger and dataset.path fields to your own choice. You can train on a single node by

bash run_train.sh NUM_GPUS_PER_NODE CONFIG_NAME REPLACEMENT_ARGS

where NUM_GPUS_PER_NODE is number of GPUs per node, CONFIG_NAME is either cifar10.yaml or im256.yaml. If you want to replace any config args, for example, loss.a on CIFAR-10 from 1 to 2, you can run

bash run_train.sh 8 cifar10.yaml loss.a=2

To train multi-node, we use submitit and run

python launch.py --ngpus=NUM_GPUS_PER_NODE --nodes=NUM_NODES --config-name=CONFIG_NAME

And output folder will be created under ./outputs/

Sampling

The checkpoints can be tested via

python generate_images.py --config-name=CONFIG_NAME eval.resume=CKPT_PATH REPLACEMENT_ARGS

where CONFIG_NAME is im256_generate_images.yaml or cifar10_generate_images.yaml and CKPT_PATH is the path to your checkpoint. When loading imagenet256_s_a1.pkl, REPLACEMENT_ARGS needs to be network.temb_type=identity. Otherwise, REPLACEMENT_ARGS is empty.

Evaluation

The eval scripts calculate FID scores. Make sure your follow our guidelines and have reference stats files ready.

bash run_eval.sh NUM_GPUS_PER_NODE CONFIG_NAME eval.resume=YOUR_MODEL_PKL_PATH

For example, to evaluate on CIFAR-10 with your model saved at outputs/cifar/network-snapshot-latest.pkl with 8 GPUs on a single node, run

bash run_eval.sh 8 cifar10.yaml eval.resume=outputs/cifar/network-snapshot-latest.pkl

YOUR_MODEL_PKL_PATH can also be a directory containing all checkpoints of a run labeled with training iterations (i.e. your training directory). It will sort from the latest checkpoint to earliest and evaluate in that order.

Acknowledgements

Some of the utility functions are based on EDM, and thus parts of the code would apply under this license.

Citation

@inproceedings{
zhou2025inductive,
title={Inductive Moment Matching},
author={Linqi Zhou and Stefano Ermon and Jiaming Song},
booktitle={Forty-second International Conference on Machine Learning},
year={2025}
}

About

Official implementation of Inductive Moment Matching

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载