-
Notifications
You must be signed in to change notification settings - Fork 372
Open
Description
🐛 Bug
Hello, I am trying to train LLMs on a language modelling task with differential privacy using opacus. While my code is working using gpt2, it is throwing RuntimeError: Tensor on device cuda:0 is not on the expected device meta!
when using bert-base-cased.
To Reproduce
The code I use is the following, the model is a AutoModelForLanguageModelling
from the transformers
library :
def train(self, model, lr, train_dataset, eval_dataset, num_epochs):
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=self.config.train_batch_size,
collate_fn=self.data_collator,
)
model = model.to(self.device)
# Set the model to train mode (HuggingFace models load in eval mode)
model = model.train()
# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
DELTA = 1 / len(train_dataloader)
privacy_engine = PrivacyEngine()
model, optimizer, train_dataloader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_dataloader,
target_delta=DELTA,
target_epsilon=7.5,
epochs=num_epochs,
max_grad_norm=0.1,
)
for epoch in range(1, num_epochs+1):
losses = []
with BatchMemoryManager(
data_loader=train_dataloader,
max_physical_batch_size=4,
optimizer=optimizer
) as memory_safe_data_loader:
for step, batch in enumerate(tqdm(memory_safe_data_loader)):
optimizer.zero_grad()
inputs = {k: batch[k].to(self.device) for k in batch if k != "labels"}
outputs = model(**inputs) # output = loss, logits, hidden_states, attentions
loss = outputs[0].mean()
loss.backward()
losses.append(loss.item())
optimizer.step()
if step > 0 and step % 5000 == 0:
train_loss = np.mean(losses)
eps = privacy_engine.get_epsilon(DELTA)
print(
f"Epoch: {epoch} | "
f"Step: {step} | "
f"Train loss: {train_loss:.3f} | "
f"ɛ: {eps:.2f}"
)
The full error :
Traceback (most recent call last):
File "/home/lmagnana/nlp-attacks/examples/special_finetunings/n2c2_ner_mlm_finetuning.py", line 72, in <module>
models, metrics = finetuner.run(dataset, test_size, epochs, pathlib.Path(output_dir), output_name=ouput_name)
File "/home/lmagnana/nlp-attacks/nlp_attacks/finetuners/Finetuner.py", line 248, in run
model = self.train(model, self.config.learning_rate, ds["train"], ds["test"], epochs)
File "/home/lmagnana/nlp-attacks/nlp_attacks/finetuners/PrivacyPreservingLanguageModelling.py", line 141, in train
loss.backward()
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
torch.autograd.backward(
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
_engine_run_backward(
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 72, in __call__
return self.hook(module, *args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/grad_sample_module.py", line 340, in capture_backprops_hook
grad_samples = grad_sampler_fn(module, activations, backprops)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 108, in ft_compute_per_sample_gradient
per_sample_grads = layer.ft_compute_sample_grad(
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/apis.py", line 188, in wrapped
return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 281, in vmap_impl
return _flat_vmap(
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 47, in fn
return f(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 403, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/apis.py", line 363, in wrapper
return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1285, in grad_impl
results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 47, in fn
return f(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1249, in grad_and_value_impl
output = func(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 85, in compute_loss_stateless_model
output = flayer(params, batched_activations)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 50, in fmodel
return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/functional_call.py", line 143, in functional_call
return nn.utils.stateless._functional_call(
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/utils/stateless.py", line 263, in _functional_call
return module(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 787, in forward
hidden_states = self.decoder(hidden_states)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/wrappers.py", line 252, in _fn
result = fn(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/wrappers.py", line 137, in _fn
result = fn(**bound.arguments)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_refs/__init__.py", line 1091, in add
output = prims.add(a, b)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_ops.py", line 594, in __call__
return self_._op(*args, **kwargs)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims/__init__.py", line 359, in _prim_elementwise_meta
utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/__init__.py", line 740, in check_same_device
raise RuntimeError(msg)
RuntimeError: Tensor on device cuda:0 is not on the expected device meta!
Expected behavior
The code should work with both a gpt2 and a bert-base-cased model.
Environment
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==8.9.2.26
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.5.82
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] torch==2.3.1
[pip3] opacus==1.5.3
[pip3] triton==2.3.1
[conda] numpy 1.26.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi
[conda] nvidia-cudnn-cu12 8.9.2.26 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.2.106 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.20.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.5.82 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.1.105 pypi_0 pypi
[conda] torch 2.3.1 pypi_0 pypi
[conda] triton 2.3.1 pypi_0 pypi
Thanks in advance for your replies.
Metadata
Metadata
Assignees
Labels
No labels