+
Skip to content

v0.14.0

Latest
Compare
Choose a tag to compare
@jainapurva jainapurva released this 13 Oct 21:34
· 46 commits to main since this release

Highlights

We are excited to announce the 0.14.0 release of torchao! This release adds support for MoE training on Backwell GPUs and NVFP4 QAT!

(Prototype) MoE training on Blackwell GPUs

We’ve added a quantized building block for speeding up MoE training on Blackwell GPUs: torchao’s `_scaled_grouped_mm`! It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. This results in significant speedups (see benchmarks below)!

import torch
from torch.nn import functional as F
from torchao.prototype.moe_training import (
    _scaled_grouped_mm as torchao_scaled_grouped_mm
)
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
from torchao.prototype.moe_training.utils import generate_jagged_offs

num_groups, total_M, N, K = 8, 131072, 8192, 5120

# A = input actvations, B = expert weights
A = torch.randn(total_M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
B = torch.randn(num_groups, N, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)

# Token group offsets computed by router in actual MoE layer
offs = generate_jagged_offs(num_groups, total_M, device="cuda")

# Forward and backward example
out = torchao_scaled_grouped_mm(
        A,
        B.transpose(-2, -1),
        offs=offs,
        scaling_type=MoEScalingType.MXFP8,
)
labels = torch.ones_like(out)
loss = F.mse_loss(out, labels)
loss.backward()

Microbenchmarks (see README for commands to reproduce benchmarks):

  • Forward + backward pass vs torch._grouped_mm:
    • ~1.4-1.8x faster for Llama4 17bx16e shapes
    • ~1.2-1.4x faster for DeepSeekV3 671b shapes
  • Full MoE layer forward + backward pass:
    • ~1.4x faster (Llama4 17bx16e shapes, batch_size=8, seq_len=16384)
    • ~1.2x faster (DeepSeekV3 671b shapes, batch_size=8, seq_len=16384).

It’s also already integrated into TorchTitan for E2E training with DeepSeekV3 and Llama4! Just use the command line flag: `--model.converters=”quantize.grouped_mm.mx”, which will convert all `torch._grouped_mm` ops to torchao _scaled_grouped_mm ops under the hood:

Torchtitan e2e training benchmarks (see README for commands to reproduce benchmarks):

  • ~1.4x e2e speedup of 2 layer Llama4 16e model with dp2ep (without TP) parallelism on 4 B200 GPUs connected via NVLink

(Prototype) NVFP4 QAT (#3050)

We added quantization-aware training (QAT) support for NVFP4 as a prototype feature! This feature is currently only available on blackwell GPUs:

from torchao.quantization import quantize_
from torchao.prototype.mx_formats import NVFP4InferenceConfig

# NVFP4 activation and weight quantization
base_config = NVFP4InferenceConfig()
quantize_(model, QATConfig(base_config, step="prepare"))
train(model)
quantize_(model, QATConfig(base_config, step="convert"))

Our NVFP4 QAT support is also integrated into axolotl:

axolotl train examples/llama-3/3b-qat-fsdp2-nvfp4.yaml

Initial experimentation demonstrated that QAT recovered up to 41% mmlu_pro accuracy degradation and 33% bbh accuracy degradation from NVFP4 quantization, for Qwen3-8B and Gemma3-12B-it respectively:

# Qwen3-8B
bbh: 0.779 (baseline) -> 0.7254 (quant) -> 0.7368 (qat), recovered 21.269%
mmlu_pro: 0.4969 (baseline) -> 0.4521 (quant) -> 0.4707 (qat), recovered 41.518%

# Gemma3-12B-it accuracy
bbh: 0.7527 (baseline) -> 0.7068 (quant) -> 0.7222 (qat), recovered 33.551%
mmlu_pro: 0.4074 (baseline) -> 0.3621 (quant) -> 0.3702 (qat), recovered 17.881%

BC Breaking

Rename Int4WeightPreshuffledFakeQuantizeConfig (#3005)

from torchao.quantization import quantize_
from torchao.quantization.qat import (
    QATConfig,
    Int4WeightPreshuffledFakeQuantizeConfig,
)

# Before
fq_config = Int4WeightPreshuffledFakeQuantizeConfig()
quantize_(m, QATConfig(weight_config=fq_config)

# After
fq_config = Int4WeightFakeQuantizeConfig()
quantize_(m, QATConfig(weight_config=fq_config)

Deprecations

Bump Int4WeightOnlyConfig version from 1 to 2 (#2949)

We updated the implementation for int4 Tensor, so bumps the default version from 1 to 2 for these two configs.

from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev"
quantized_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="bfloat16",
    device_map="cuda",
)

/data/users/jerryzh/ao/torchao/core/config.py:250: UserWarning: Stored version is not the same as current default version of the config: stored_version=1, current_default_version=2, please check the deprecation warning
  warnings.warn(
/data/users/jerryzh/ao/torchao/dtypes/uintx/tensor_core_tiled_layout.py:241: UserWarning: Models quantized with version 1 of Int4WeightOnlyConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2948 for more details
  warnings.warn(

Suggestion: upgrade torchao to 0.14 and later and generate the checkpoint again:

quantize_(model, Int4WeightOnlyConfig(group_size=128))

Or download the checkpoint again (please let us know if the checkpoint is not updated)

Please see #2948 for more details around the deprecation.

Deprecate config functions like int4_weight_only (#2994)

The int4_weight_only functions have been superseded by AOBaseConfig objects like Int4WeightOnlyConfig(...), which have been in-use since several previous releases

from torchao.quantization import (
    Int4WeightOnlyConfig,
    int4_weight_only,
    quantize_,
)

# Before
quantize_(m, int4_weight_only())

# After
quantize_(m, Int4WeightOnlyConfig())

# Full list of deprecated functions
float8_dynamic_activation_float8_weight
float8_static_activation_float8_weight
float8_weight_only
fpx_weight_only
gemlite_uintx_weight_only
int4_dynamic_activation_int4_weight
int4_weight_only
int8_dynamic_activation_int4_weight
int8_dynamic_activation_int8_weight
int8_weight_only
uintx_weight_only

New Features

  • CPU
    • Introduce Int4OpaqueTensor to replace Int4CPULayout in AQT (#2798)
    • [float8] Add scaled_embedding_bag kernel (#2686)
  • Safetensor Support for TorchAO Configs:
    • Add int4tensor support for safetensors (#3056)
    • Add int4tilepackedto4dtensor subclass to safetensors (#3064)
    • Add IntxUnpackedToInt8Tensor to safetensors (#3065)
  • Add Int4PlainInt32Tensor (#2845)
  • Add from_int4_tensor in Int4PreshuffledTensor (#2978)
  • Add hqq support for Int4TilePackedTo4dTensor (#2912)
  • Add torchao_convert to PARQ's QuantOptimizer (#2947)
  • Add scale-only version of the HQQ algorithm for IntxWeightOnlyConfig/Int8DynamicActivationIntxWeightConfig (#3110)

Improvement

  • QAT
    • Support QAT int4 v1 path for BC (#2888)
    • Improve QAT fp8-int4 numerics (#2937)
    • Improve QAT int4 weight-only numerics (#2986)
    • QAT configs (#3001)
    • Pass QAT learned qparams in convert (#3022)
  • (prototype) MXTensor
    • Port metadata from the linear node onto the reference custom op for int4 (#2860)
    • [mxfp8 moe training] remove mxfp8_gemms.py (#3033)
    • Make mxtensor printing nicer (#3068)
    • Mxtensor: support clone (#3070)
    • Mxtensor: delete to_copy override (#3072)
    • Mxtensor: add serialization support (#3078)
    • Mxtensor: add index select support (#3079)
    • Mxtensor: switch to AOBaseTensor dispatch (#3080)
    • Enroll mxtensor in vllm integration tests (#3081)
  • (prototype) NVFP4
    • Support NVFP4 dynamic per tensor scale (#3049)
    • Nvfp4tensor: improve printing (#3086)
    • Improve QAT nvfp4 numerics (#3050)
    • Enable 3d weights for NVFP4Tensor (#3109)
    • Enable select for NVFP4Tensor (#3117)
    • Make scale shape 2d and match qdata shape in NVFP4Tensor (#3108)
  • Float8
    • Fix Float8Tensor quantize op kernrel preference dispatch (#2883)
    • Remove unused attributes in Float8Tensor (#2935)
    • Skip expanding scales for rowwise fp8 quantize (#2950)
    • Add missing Float8Tensor op support (unsqueeze, 3dslice) for 3d weights (#3035)
    • Float8Tensor per row quantization pass bias to fbgemm kernel (#2884)
    • [sparse] Add in missing op support for FP8 Sparse (#3014)
    • [Inductor][float8] Support qlinear for float8 in inductor (#2565)
  • (prototype) SpinQuant
    • SpinQuant rotate bias (#2913)
    • Added SpinQuant rotation unit test (#2925)
  • Remove unused cpp variable, breaking style checks (#2909)
  • [pt2e] Avoid getting model device once per node (#2695)
  • [pt2e] Make prepare and convert faster by caching (#2983)
  • [safetensors enablement] refactoring for huggingface integration (#2936)
  • Move packing format used by int4 to int4_packing_format.py (#2946)
  • Move packing format to intx folder (#2910)
  • [Intel GPU] Enable llama generate.py + add unit test for quantization (#2929)
  • [Intel GPU] Support PLAIN_INT32 for AWQ on Intel GPU (#3019)
  • Make SmoothQuant more General (#2728)
  • Support Int4OpaqueTensor for HQQ (#3028)
  • Add Sparsify overhead benchmark (#3021)
  • Disable the use of argument use_cache for lm_eval by default (#3073)
  • Adding support for gemma3 TransformerEvalWrapper (#3074)
  • Add awq support for Int4TilePackedTo4dTensor (#3071)
  • Add module fqn regex support for ModuleFqnToConfig (#3084)

Bug Fixes

  • Exclude libcudart.so.13 from auditwheel repair to fix CUDA 13.0 wheel build (#2892)
  • Torchao init: do not load .so files for known incompatible torch version (#2908)
  • Fix torchao version check on torch version (#2918)
  • Change missing ops printout back to debug (#2921)
  • Another fix for torch version (#2922)
  • Exclude libcuda.so from auditwheel replair (#2927)
  • Better check for mxfp8 cuda kernel presence (#2933)
  • Fix xnnpack export (#2941)
  • Add nvcc flags to explicitly build mxfp8 dim1 cast kernel for sm100a (#2979)
  • Fix torchao_convert, remove StretchedAffineQuantizedTensor (#3015)
  • Avoid normalization layers in HF's quantization_config (#3030)
  • Minor fix on TAO op to support lowering (#3031)
  • Support select.int for Float8Tensor (#3053)
  • Fix torchAO shape check on fp8 tensors (#3057)
  • Fix: avoid removing from tuple in _get_to_kwargs (#3018)
  • Prevent union-finding cycles for shared qspecs (#3011)
  • [moe training] update llama4 bench script to handle torchtitan new log format (#3107)
  • Generalize torch compatibility check (#3042)
  • Replace param_group quantizer instance with QuantOptimizer attribute (#3104)

Documentation

  • [Intel GPU][doc] Change x86 quantizer to xpu quantizer in doc (#2916)
  • Update README.md with link to version compatibility matrix (#2920)
  • Update README.md for mx_formats build from source (#2934)
  • Docs: fix link in quantization overview documentation (#2962)
  • Hf integration doc page (#2899)

New Contributors

Full Changelog: v0.13.0...v0.14.0-rc1

点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载