+
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions test/prototype/test_spinquant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import torch
from torchao._models.llama.model import Transformer
from torchao.prototype.spinquant import apply_spinquant


def _init_model(name="7B", device="cpu", precision=torch.bfloat16):
model = Transformer.from_name(name)
model.to(device=device, dtype=precision)
return model.eval()


_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


@pytest.mark.parametrize("device", _AVAILABLE_DEVICES)
def test_spinquant_no_quantization(device):
model = _init_model(device=device)
seq_len = 16
batch_size = 1
is_training = False
input_ids = torch.randint(0, 1024, (batch_size, seq_len)).to(device)
input_pos = None if is_training else torch.arange(seq_len).to(device)
with torch.device(device):
model.setup_caches(max_batch_size=batch_size, max_seq_length=seq_len, training=is_training)

with torch.no_grad():
out = model(input_ids, input_pos)
apply_spinquant(model)
out_spinquant = model(input_ids, input_pos)

# Output should be the same without quantization (the rotations cancel out)
# TODO: not sure if these atol/rtol are excessively large (it fails for smaller values)
torch.testing.assert_close(out, out_spinquant, atol=5e-2, rtol=1e-2)


# TODO: test GPTQ compatability?
5 changes: 4 additions & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tokenizer import get_tokenizer
import time
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.prototype.spinquant import apply_spinquant

def run_evaluation(
checkpoint_path: Path,
Expand Down Expand Up @@ -69,6 +70,8 @@ def run_evaluation(
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

if quantization:
if "spinquant" in quantization:
apply_spinquant(model)
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
Expand Down Expand Up @@ -229,7 +232,7 @@ def run_evaluation(
help=(
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, "
"int4wo-<groupsize>-gptq, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, "
"uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, "
"uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
"autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>-<grad_acc_steps>-<c>, "
"float8wo, float8dq, float8saq"
),
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- `galore/docs` - implementation notes and discussion of issues faced in kernel design.
- [`quant_llm`](quant_llm) - FP16 x Floatx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112)
- [`low_bit_optim`](low_bit_optim) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers).
- [`spinquant`](spinquant) - re-implementation of [SpinQuant](https://arxiv.org/abs/2405.16406)

#### Roadmap

Expand Down
11 changes: 11 additions & 0 deletions torchao/prototype/spinquant/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# SpinQuant

Re-implementation of SpinQuant based on the official code implementation (https://github.com/facebookresearch/SpinQuant).

## Usage

Using this implementation with CUDA requires installing the Fast Hadamard Transform CUDA package, which can be done as follows:

```shell
pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git
```
1 change: 1 addition & 0 deletions torchao/prototype/spinquant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .spinquant import apply_spinquant
Loading
Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载