-
Notifications
You must be signed in to change notification settings - Fork 2.1k
FEAT: Support torchao #2062
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT: Support torchao #2062
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
With huggingface/transformers#33361 being merged (which marks torchao as traininable), once the next transformers version is released (>4.44.2), the GPU tests on this PR should pass (I tested locally). This PR should not be merged before that. |
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making torchao compatible @BenjaminBossan ! LGTM ! Just a few nits.
cc @msaroufim
src/peft/tuners/lora/torchao.py
Outdated
| # TODO | ||
| rep = super().__repr__() | ||
| return rep.replace("lora.Linear", f"lora.{self.__class__.__name__}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO left
| raise ValueError(f"{type(self).__name__} only supports int8 weights for now.") | ||
|
|
||
| def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: | ||
| from torchao import quantize_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quantize_ is only available from torchao 0.4.0. Maybe we should modify a bit is_torchao_available to take that into account ?
- min torchao version - remove TODO
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Supports torch AO quantization. Currently supported: - int8_weight_only - int8_dynamic_activation_int8_weight --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Supports torch AO quantization. Currently supported: - int8_weight_only - int8_dynamic_activation_int8_weight --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
|
Hi @BenjaminBossan , is there any plan for support |
AFAICT, torchao NF4 is not supported in transformers (which may change in the future). Therefore, I don't have plans to support it in PEFT. However, if you already have a working implementation, feel free to create a (draft) PR with your layers and I can take a look. |
I think maybe Back to I implement custom layer for class TorchAOLoraNF4Linear(Linear):
def __init__(self, target:nn.Module, adapter_name:str, nf4_config: NF4Config, **kwargs):
super().__init__(target, adapter_name, **kwargs)
self.config = nf4_config
def _get_base_layer_and_weight_with_checking(self)-> tuple[nn.Linear, NF4Tensor]:
base_layer= self.get_base_layer()
assert isinstance(base_layer, nn.Linear)
nf4_weight = base_layer.weight
assert isinstance(nf4_weight, NF4Tensor)
return base_layer, nf4_weight
def _accumulate_adapter_weights(
self, base_weight: torch.Tensor, adapter_names:list[str],
*,
merge:bool = True,
safe_merge:bool=False
)->torch.Tensor:
for active_adapter in adapter_names:
if merge:
base_weight += self.get_delta_weight(active_adapter)
if safe_merge and not torch.isfinite(base_weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
else:
# unmerge
if active_adapter not in self.lora_A.keys():
continue
base_weight -= self.get_delta_weight(active_adapter)
return base_weight
def _make_nf4_weight_param(self, weight_tensor: torch.Tensor)->nn.Parameter:
nf4_tensor = to_nf4(weight_tensor,self.config.block_size,self.config.scaler_block_size)
return nn.Parameter(nf4_tensor)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
if not (adapter_names:= check_adapters_to_merge(self, adapter_names)):
return
# I actually why the `TorchaoLoraLinear` make this inside a loop, why not update every thing
# and merge one for all ?
base_layer, nf4_weight = self._get_base_layer_and_weight_with_checking()
weight = self._accumulate_adapter_weights(nf4_weight.get_original_weight(), adapter_names,merge=True,safe_merge=safe_merge)
base_layer.weight = self._make_nf4_weight_param(weight)
self.merged_adapters.extend(adapter_names)
def unmerge(self) -> None:
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
base_layer, nf4_weight = self._get_base_layer_and_weight_with_checking()
weight = self._accumulate_adapter_weights(nf4_weight.get_original_weight(), self.merged_adapters, merge=False)
base_layer.weight = self._make_nf4_weight_param(weight)
self.merged_adapters.clear()
def dispatch_torchao_linear(target:nn.Module|BaseTunerLayer, adapter_name:str, aobase_config: AOBaseConfig|None=None, **kwargs):
if isinstance(target, BaseTunerLayer):
target = target.get_base_layer()
assert isinstance(target, nn.Module)
# torchao only support Linear operation for now afaik.
# If there's quantized weight support conv module, let use define dispatcher.
if not isinstance(target, nn.Linear):
return None
if not is_torchao_available():
return None
if isinstance(target.weight, NF4Tensor):
if not isinstance(aobase_config, NF4Config):
raise ValueError("Weight is quantized by NF4Tensor need NF4Config.")
nf4config = cast(NF4Config, aobase_config)
return TorchAOLoraNF4Linear(target, adapter_name, nf4config, **kwargs)
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization import LinearActivationQuantizedTensor
if isinstance(target.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor)):
return TorchAOLoraAQLinear(target, adapter_name, aobase_config = aobase_config,**kwargs)
return NoneAs you see, I want to do something like checking module before add layer such as Ofcourse I can custom my Model and Config, but I feel the register process is too complex and error prone. |
I looked this up but didn't see any config for NF4. Could you please paste a snippet that illustrates how to load a transformers model with torchao NF4?
I'm not quite sure what you mean by "not 100% self-contained", but yes, we would need to go with a custom layer.
We can think about that, when I worked on this, I wanted to keep it simple, as I was not sure if anyone would use it at all.
Once we have a small example, we can do some testing, it's tough to say in the abstract. Once we have that, we can proceed with a draft PR with your implementation and I can guide you through the missing steps.
LMK what exactly you're missing there. |
Here is the from transformers import Qwen2ForCausalLM, TorchAoConfig
from torchao.dtypes import NF4Tensor, to_nf4
from torchao.quantization import register_quantize_module_handler, Float8WeightOnlyConfig, ModuleFqnToConfig
from dataclasses import dataclass
from torchao.core.config import AOBaseConfig
import torch
from torch import nn
import types
from torchao.utils import get_model_size_in_bytes
@dataclass
class NF4Config(AOBaseConfig):
block_size: int = 64
scaler_block_size: int = 256
def linear_module_repr(module: nn.Linear):
return f"in_features={module.weight.shape[1]}, out_features={module.weight.shape[0]}, weight={module.weight}, dtype={module.weight.dtype}"
@register_quantize_module_handler(NF4Config)
def _nf4_weight_only_transform(
module: torch.nn.Module,
config: NF4Config,
) -> torch.nn.Module:
new_weight = to_nf4(module.weight, config.block_size, config.scaler_block_size)
module.weight = nn.Parameter(new_weight, requires_grad=False) # Freeze
module.extra_repr = types.MethodType(
linear_module_repr,
module
)
return module
config = TorchAoConfig(NF4Config())
model = quantized_model = Qwen2ForCausalLM.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct",
)
quantized_model = Qwen2ForCausalLM.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct",
quantization_config = config
)
print(get_model_size_in_bytes(model)) # 2520669824
print(get_model_size_in_bytes(quantized_model)) # 1273966688
No, that's your words, I just copy it :v The main problem with torchao is that it not using the "swap module", but quantize the weight directly instead, and also allow to custom any quantized tensor. So that i think it's good for not just check for module only, like the current custom map, but have the custom dispatch function to check the weight. Or just have some way to integrate it.
Yeah I used to, but after review all source code carefully, I think I have a way to use it confidently. Anyway, my code now just work (QLora with |
I see, thanks. It's not really straightforward to use, I hope this will be simplified in the future.
At a first glance, this doesn't look bad, I think we could work based on this implementation.
Haha, okay, but my quote is from a different context, namely adding completely new PEFT methods. Here, the job is much easier, just adding a new layer for an existing PEFT method.
Yes, for that we need to make changes directly in PEFT, the dynamic dispatch can't handle that.
No worries, but if you have some time in the future, I'd be happy to see a PR. Don't worry about making it perfect on the first try, we can iterate on it. |
Add support for torchao.
The current status is:
int8_weight_onlyworks fullyint8_dynamic_activation_int8_weightonly works partly (asdequantizeis not supported, merging and DoRA won't work)int4_weight_onlynot supported as some ops for forward call are missingnf4not supported on transformers side