+
Skip to content

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Sep 4, 2025

Summary: This commit improved the prepare vs convert SQNR of fp8-int4 QAT from 12 to 22. This is achieved by mimicking the numerics of the target FBGEMM fp8-int4 kernel more closely. In particular, FBGEMM first quantizes the weights to fp8, and then uses max abs values to compute the scale, which is significantly different from what torchao's quant primitives do.

Unit tests:

python test/quantization/test_qat.py -k test_fbgemm_fp8_primitives
python test/quantization/test_qat.py -k test_fbgemm_int4_primitives
python test/quantization/test_qat.py -k test_quantize_api_fp8_int4

End-to-end tests:

Fine-tuning Llama3.1-8B with and without this PR in unsloth:

  • fine-tune for 1 epoch on yahma/alpaca-cleaned
  • batch size 16, learning rate 4e-5, no gradient accumulation

Wikitext:

  • QAT int4 quantized model (with this PR) achieved 19.2% lower perplexity than the int4 baseline
  • QAT int4 quantized model (with this PR) outperformed even the bf16 baseline
  • QAT int4 quantized model without this PR did not converge
# Baseline
==> unsloth_model_full_baseline_output/lm_eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.8396|±  |   N/A|

# Baseline (quantized)
==> unsloth_model_full_baseline_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |11.6904|±  |   N/A|

# QAT without this PR (quantized)
==> unsloth_model_full_qat_fp8-int4_output/eval_wikitext_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |1598419.1369|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_full_qat_fp8-int4_output/eval_wikitext_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |9.2336|±  |   N/A|

Fibonacci test:

  • QAT int4 quantized model (with this PR) produced the correct answer (13)
  • QAT int4 quantized model without this PR produced gibberish
### Instruction:
Continue the fibonnaci sequence.

### Input:
1, 1, 2, 3, 5, 8

# Baseline
==> unsloth_model_full_baseline_output/lm_eval_float.log <==
### Response:
The next number in the Fibonacci sequence is 13.<|end_of_text|>

# Baseline (quantized)
==> unsloth_model_full_baseline_output/lm_eval_quantized.log <==
### Response:
The next number in the Fibonacci sequence is 13.<|end_of_text|>

# QAT without this PR (quantized)
==> unsloth_model_full_qat_fp8-int4_output/eval_wikitext_quantized.log <==
### Response:
 The. The. I. A. The A. The. A. The. The. I. A. The. I. A. I. The. I. The. The. The. I. The. The. The. The. The. The. The. The. The. The. I. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The. The

# QAT with this PR (quantized)
==> unsloth_model_full_qat_fp8-int4_output/eval_wikitext_quantized.log <==
### Response:
13<|end_of_text|>

Copy link

pytorch-bot bot commented Sep 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2937

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 7f0830c with merge base 4700fe8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 4, 2025
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Sep 4, 2025


@dataclass
class Int4WeightFBGEMMFakeQuantizeConfig(FakeQuantizeConfigBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe not mention FBGEMM, but Preshuffled instead to allign with the tensor name

Copy link
Contributor Author

@andrewor14 andrewor14 Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, should we use the following naming scheme?

Int4PreshuffledTensor -> Int4WeightPreshuffledFakeQuantizeConfig
Int4Tensor -> Int4WeightFakeQuantizeConfig (future PR)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, if the fake quant path uses the quant ops specific to Int4PreshuffledTensor (in from_hp), then I think we can just call it Int4WeightPreshuffledFakeQuantizeConfig

@vkuzo
Copy link
Contributor

vkuzo commented Sep 4, 2025

can you explain what exactly is different in default path vs this path?

@andrewor14 andrewor14 force-pushed the fp8-int4-qat-numerics branch from e21e5bf to 274ed29 Compare September 4, 2025 18:30
@andrewor14
Copy link
Contributor Author

can you explain what exactly is different in default path vs this path?

Sure, there are a few main differences. The new path mimics the fbgemm numerics by:

  1. Quantizing the weights first to fp8 and then to int4
  2. Using max abs value to compute the scale (the old path uses max - min)
  3. Using eps=1e-12 for fp8 quant and 1e-6 for int4 quant (old path uses torch.finfo(dtype).eps)

@andrewor14 andrewor14 force-pushed the fp8-int4-qat-numerics branch 7 times, most recently from ecd3a49 to 74ed2db Compare September 4, 2025 22:23
"""
self._test_quantize_api_against_ptq(
Int8DynamicActivationInt4WeightConfig(group_size=32),
target_prepare_sqnr=30,
Copy link
Contributor

@jerryzh168 jerryzh168 Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying to understand why this is not inf, is this because fake quant does not do dtype conversion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, but this is enough to recover significant accuracy degradation in most cases. I did verify the qparams dtypes are also matching but haven't investigated further, may continue this later separately

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this test is not related to the changes, since it's Int8DynamicActivationInt4WeightConfig

Comment on lines +2120 to +2121
(q2, _) = quantize_fp8_row(x2)
(q2, scale2) = int4_row_quantize(q2, group_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this what happens in quantize_int4_preshuffle? why do we first quantize to fp8 then to int4 instead of just quantizing to int4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if cast_to_float8_dtype:
tensor_clamped = tensor_clamped.to(float8_dtype)
return tensor_clamped
return _RoundToFloat8.apply(tensor_clamped, float8_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so fake quant also do casting? I thought we don't do casting during fake quant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think fp8 numerics is close enough with bf16/fp32 that we can do this during QAT. This helps mimic the convert numerics a bit closer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8-int4 prepare vs convert sqnr drops from 22.375 to 19.25 without this

(q2, _) = quantize_fp8_row(x2)
(q2, scale2) = int4_row_quantize(q2, group_size)

# (3) Reference implementation for QAT without the dequantize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this, should we just initialize Int4WeightPreshuffledFakeQuantizer can call it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried doing that but it's tricky because we don't want the dequantize steps here (but they're needed in the fake quantizer)

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

def forward(self, w: torch.Tensor) -> torch.Tensor:
return self._forward(w)[0]

def _forward(self, w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't see this function being used elsewhere except for in L121, why do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry, we don't need this. This was leftover when I tried to write a test using this class

**Summary:** This commit improved the prepare vs convert SQNR
of fp8-int4 QAT from 12 to 22. This is achieved by mimicking the
numerics of the target FBGEMM fp8-int4 kernel more closely.
In particular, FBGEMM first quantizes the weights to fp8, and
then uses max abs values to compute the scale, which is
significantly different from what torchao's quant primitives do.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_fbgemm_fp8_primitives
python test/quantization/test_qat.py -k test_fbgemm_int4_primitives
python test/quantization/test_qat.py -k test_quantize_api_fp8_int4
```
@andrewor14 andrewor14 force-pushed the fp8-int4-qat-numerics branch from 74ed2db to 7f0830c Compare September 6, 2025 00:13
@andrewor14 andrewor14 merged commit a2206e9 into main Sep 8, 2025
18 checks passed
andrewor14 added a commit that referenced this pull request Sep 11, 2025
**Summary:** Similar to #2937, this commit improves the prepare
vs convert SQNR of int4 weight-only QAT from 12 to 45. This is
achieved by mimicking the numerics of the target FBGEMM bf16-int4
kernel more closely.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_quantize_api_int4
```
andrewor14 added a commit that referenced this pull request Sep 11, 2025
**Summary:** Similar to #2937, this commit improves the prepare
vs convert SQNR of int4 weight-only QAT from 12 to 45. This is
achieved by mimicking the numerics of the target FBGEMM bf16-int4
kernel more closely. In particular, the FBGEMM kernel:

1. Performs asymmetric [0, 15] quant first then recenters to 8
2. Uses smaller scale eps of 1e-6 instead of bf16's eps (0.0078125)
3. Quantizes the weights using min val instead of zero points

**Test Plan:**
```
python test/quantization/test_qat.py -k test_quantize_api_int4
python test/quantization/test_qat.py -k test_fbgemm_int4_weight_only_primitives
```

End-to-end tests TBD
andrewor14 added a commit that referenced this pull request Sep 11, 2025
**Summary:** Similar to #2937, this commit improves the prepare
vs convert SQNR of int4 weight-only QAT from 12 to 45. This is
achieved by mimicking the numerics of the target FBGEMM bf16-int4
kernel more closely. In particular, the FBGEMM kernel:

1. Performs asymmetric [0, 15] quant first then recenters to 8
2. Uses smaller scale eps of 1e-6 instead of bf16's eps (0.0078125)
3. Quantizes the weights using min val instead of zero points

**Test Plan:**
```
python test/quantization/test_qat.py -k test_quantize_api_int4
python test/quantization/test_qat.py -k test_fbgemm_int4_weight_only_primitives
```

End-to-end tests TBD
andrewor14 added a commit that referenced this pull request Sep 12, 2025
**Summary:** Similar to #2937, this commit improves the prepare
vs convert SQNR of int4 weight-only QAT from 12 to 45. This is
achieved by mimicking the numerics of the target FBGEMM bf16-int4
kernel more closely. In particular, the FBGEMM kernel:

1. Performs asymmetric [0, 15] quant first then recenters to 8
2. Uses smaller scale eps of 1e-6 instead of bf16's eps (0.0078125)
3. Quantizes the weights using min val instead of zero points

**Unit tests:**

```
python test/quantization/test_qat.py -k test_quantize_api_int4
python test/quantization/test_qat.py -k test_fbgemm_int4_weight_only_primitives
```

**End-to-end tests:**

Fine-tuning Llama3.1-8B with and without this PR in unsloth:

- fine-tune for 1 epoch on yahma/alpaca-cleaned with LoRA
- batch size 8, learning rate 2e-4, no gradient accumulation

Wikitext:

- QAT int4 quantized model (with this PR) achieved 33% lower
  perplexity than the int4 baseline
- QAT int4 quantized model without this PR was worse

```
==> unsloth_model_lora_baseline_output/lm_eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |7.5551|±  |   N/A|

==> unsloth_model_lora_baseline_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |8.7655|±  |   N/A|

# QAT without this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |8.3548|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.0683|±  |   N/A|
```
andrewor14 added a commit that referenced this pull request Sep 15, 2025
**Summary:** Similar to #2937, this commit improves the prepare
vs convert SQNR of int4 weight-only QAT from 12 to 45. This is
achieved by mimicking the numerics of the target FBGEMM bf16-int4
kernel more closely. In particular, the FBGEMM kernel:

1. Performs asymmetric [0, 15] quant first then recenters to 8
2. Uses smaller scale eps of 1e-6 instead of bf16's eps (0.0078125)
3. Quantizes the weights using min val instead of zero points

**Unit tests:**

```
python test/quantization/test_qat.py -k test_quantize_api_int4
python test/quantization/test_qat.py -k test_fbgemm_int4_weight_only_primitives
```

**End-to-end tests:**

Fine-tuning Llama3.1-8B with and without this PR in unsloth:

- fine-tune for 1 epoch on yahma/alpaca-cleaned with LoRA
- batch size 8, learning rate 2e-4, no gradient accumulation

Wikitext:

- QAT int4 quantized model (with this PR) achieved 33% lower
  perplexity than the int4 baseline
- QAT int4 quantized model without this PR was worse

```
==> unsloth_model_lora_baseline_output/lm_eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |7.5551|±  |   N/A|

==> unsloth_model_lora_baseline_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |8.7655|±  |   N/A|

# QAT without this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |8.3548|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.0683|±  |   N/A|
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载