-
Notifications
You must be signed in to change notification settings - Fork 348
Improve QAT fp8-int4 numerics #2937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2937
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7f0830c with merge base 4700fe8 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
|
||
@dataclass | ||
class Int4WeightFBGEMMFakeQuantizeConfig(FakeQuantizeConfigBase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe not mention FBGEMM, but Preshuffled
instead to allign with the tensor name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, should we use the following naming scheme?
Int4PreshuffledTensor -> Int4WeightPreshuffledFakeQuantizeConfig
Int4Tensor -> Int4WeightFakeQuantizeConfig (future PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, if the fake quant path uses the quant ops specific to Int4PreshuffledTensor (in from_hp), then I think we can just call it Int4WeightPreshuffledFakeQuantizeConfig
can you explain what exactly is different in default path vs this path? |
e21e5bf
to
274ed29
Compare
Sure, there are a few main differences. The new path mimics the fbgemm numerics by:
|
ecd3a49
to
74ed2db
Compare
""" | ||
self._test_quantize_api_against_ptq( | ||
Int8DynamicActivationInt4WeightConfig(group_size=32), | ||
target_prepare_sqnr=30, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trying to understand why this is not inf, is this because fake quant does not do dtype conversion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure, but this is enough to recover significant accuracy degradation in most cases. I did verify the qparams dtypes are also matching but haven't investigated further, may continue this later separately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this test is not related to the changes, since it's Int8DynamicActivationInt4WeightConfig
(q2, _) = quantize_fp8_row(x2) | ||
(q2, scale2) = int4_row_quantize(q2, group_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this what happens in quantize_int4_preshuffle
? why do we first quantize to fp8 then to int4 instead of just quantizing to int4?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah it happens here, not sure the reason: https://github.com/pytorch/FBGEMM/blob/3ca2859adc0ae24b1214ccacedff24ea5fce9be5/fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py#L138
if cast_to_float8_dtype: | ||
tensor_clamped = tensor_clamped.to(float8_dtype) | ||
return tensor_clamped | ||
return _RoundToFloat8.apply(tensor_clamped, float8_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so fake quant also do casting? I thought we don't do casting during fake quant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think fp8 numerics is close enough with bf16/fp32 that we can do this during QAT. This helps mimic the convert numerics a bit closer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp8-int4 prepare vs convert sqnr drops from 22.375 to 19.25 without this
(q2, _) = quantize_fp8_row(x2) | ||
(q2, scale2) = int4_row_quantize(q2, group_size) | ||
|
||
# (3) Reference implementation for QAT without the dequantize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for this, should we just initialize Int4WeightPreshuffledFakeQuantizer
can call it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried doing that but it's tricky because we don't want the dequantize steps here (but they're needed in the fake quantizer)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
def forward(self, w: torch.Tensor) -> torch.Tensor: | ||
return self._forward(w)[0] | ||
|
||
def _forward(self, w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't see this function being used elsewhere except for in L121, why do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh sorry, we don't need this. This was leftover when I tried to write a test using this class
**Summary:** This commit improved the prepare vs convert SQNR of fp8-int4 QAT from 12 to 22. This is achieved by mimicking the numerics of the target FBGEMM fp8-int4 kernel more closely. In particular, FBGEMM first quantizes the weights to fp8, and then uses max abs values to compute the scale, which is significantly different from what torchao's quant primitives do. **Test Plan:** ``` python test/quantization/test_qat.py -k test_fbgemm_fp8_primitives python test/quantization/test_qat.py -k test_fbgemm_int4_primitives python test/quantization/test_qat.py -k test_quantize_api_fp8_int4 ```
74ed2db
to
7f0830c
Compare
**Summary:** Similar to #2937, this commit improves the prepare vs convert SQNR of int4 weight-only QAT from 12 to 45. This is achieved by mimicking the numerics of the target FBGEMM bf16-int4 kernel more closely. **Test Plan:** ``` python test/quantization/test_qat.py -k test_quantize_api_int4 ```
**Summary:** Similar to #2937, this commit improves the prepare vs convert SQNR of int4 weight-only QAT from 12 to 45. This is achieved by mimicking the numerics of the target FBGEMM bf16-int4 kernel more closely. In particular, the FBGEMM kernel: 1. Performs asymmetric [0, 15] quant first then recenters to 8 2. Uses smaller scale eps of 1e-6 instead of bf16's eps (0.0078125) 3. Quantizes the weights using min val instead of zero points **Test Plan:** ``` python test/quantization/test_qat.py -k test_quantize_api_int4 python test/quantization/test_qat.py -k test_fbgemm_int4_weight_only_primitives ``` End-to-end tests TBD
**Summary:** Similar to #2937, this commit improves the prepare vs convert SQNR of int4 weight-only QAT from 12 to 45. This is achieved by mimicking the numerics of the target FBGEMM bf16-int4 kernel more closely. In particular, the FBGEMM kernel: 1. Performs asymmetric [0, 15] quant first then recenters to 8 2. Uses smaller scale eps of 1e-6 instead of bf16's eps (0.0078125) 3. Quantizes the weights using min val instead of zero points **Test Plan:** ``` python test/quantization/test_qat.py -k test_quantize_api_int4 python test/quantization/test_qat.py -k test_fbgemm_int4_weight_only_primitives ``` End-to-end tests TBD
**Summary:** Similar to #2937, this commit improves the prepare vs convert SQNR of int4 weight-only QAT from 12 to 45. This is achieved by mimicking the numerics of the target FBGEMM bf16-int4 kernel more closely. In particular, the FBGEMM kernel: 1. Performs asymmetric [0, 15] quant first then recenters to 8 2. Uses smaller scale eps of 1e-6 instead of bf16's eps (0.0078125) 3. Quantizes the weights using min val instead of zero points **Unit tests:** ``` python test/quantization/test_qat.py -k test_quantize_api_int4 python test/quantization/test_qat.py -k test_fbgemm_int4_weight_only_primitives ``` **End-to-end tests:** Fine-tuning Llama3.1-8B with and without this PR in unsloth: - fine-tune for 1 epoch on yahma/alpaca-cleaned with LoRA - batch size 8, learning rate 2e-4, no gradient accumulation Wikitext: - QAT int4 quantized model (with this PR) achieved 33% lower perplexity than the int4 baseline - QAT int4 quantized model without this PR was worse ``` ==> unsloth_model_lora_baseline_output/lm_eval_float.log <== | | |none | 0|word_perplexity|↓ |7.5551|± | N/A| ==> unsloth_model_lora_baseline_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |8.7655|± | N/A| # QAT without this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |8.3548|± | N/A| # QAT with this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.0683|± | N/A| ```
**Summary:** Similar to #2937, this commit improves the prepare vs convert SQNR of int4 weight-only QAT from 12 to 45. This is achieved by mimicking the numerics of the target FBGEMM bf16-int4 kernel more closely. In particular, the FBGEMM kernel: 1. Performs asymmetric [0, 15] quant first then recenters to 8 2. Uses smaller scale eps of 1e-6 instead of bf16's eps (0.0078125) 3. Quantizes the weights using min val instead of zero points **Unit tests:** ``` python test/quantization/test_qat.py -k test_quantize_api_int4 python test/quantization/test_qat.py -k test_fbgemm_int4_weight_only_primitives ``` **End-to-end tests:** Fine-tuning Llama3.1-8B with and without this PR in unsloth: - fine-tune for 1 epoch on yahma/alpaca-cleaned with LoRA - batch size 8, learning rate 2e-4, no gradient accumulation Wikitext: - QAT int4 quantized model (with this PR) achieved 33% lower perplexity than the int4 baseline - QAT int4 quantized model without this PR was worse ``` ==> unsloth_model_lora_baseline_output/lm_eval_float.log <== | | |none | 0|word_perplexity|↓ |7.5551|± | N/A| ==> unsloth_model_lora_baseline_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |8.7655|± | N/A| # QAT without this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |8.3548|± | N/A| # QAT with this PR (quantized) ==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <== | | |none | 0|word_perplexity|↓ |10.0683|± | N/A| ```
Summary: This commit improved the prepare vs convert SQNR of fp8-int4 QAT from 12 to 22. This is achieved by mimicking the numerics of the target FBGEMM fp8-int4 kernel more closely. In particular, FBGEMM first quantizes the weights to fp8, and then uses max abs values to compute the scale, which is significantly different from what torchao's quant primitives do.
Unit tests:
End-to-end tests:
Fine-tuning Llama3.1-8B with and without this PR in unsloth:
yahma/alpaca-cleaned
Wikitext:
Fibonacci test: