这是indexloc提供的服务,不要输入任何密码
Skip to content

RuntimeError: Tensor on device cuda:0 is not on the expected device meta! #737

@LucasMagnana

Description

@LucasMagnana

🐛 Bug

Hello, I am trying to train LLMs on a language modelling task with differential privacy using opacus. While my code is working using gpt2, it is throwing RuntimeError: Tensor on device cuda:0 is not on the expected device meta! when using bert-base-cased.

To Reproduce

The code I use is the following, the model is a AutoModelForLanguageModelling from the transformers library :

def train(self, model, lr, train_dataset, eval_dataset, num_epochs):
        train_dataloader = DataLoader(
            train_dataset,
            shuffle=True,
            batch_size=self.config.train_batch_size,
            collate_fn=self.data_collator,
        )
        model = model.to(self.device)
        # Set the model to train mode (HuggingFace models load in eval mode)
        model = model.train()
        # Define optimizer
        optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

        DELTA = 1 / len(train_dataloader)

        privacy_engine = PrivacyEngine()

        model, optimizer, train_dataloader = privacy_engine.make_private_with_epsilon(
            module=model,
            optimizer=optimizer,
            data_loader=train_dataloader,
            target_delta=DELTA,
            target_epsilon=7.5,
            epochs=num_epochs,
            max_grad_norm=0.1,
        )

        for epoch in range(1, num_epochs+1):
            losses = []

            with BatchMemoryManager(
                data_loader=train_dataloader,
                max_physical_batch_size=4,
                optimizer=optimizer
            ) as memory_safe_data_loader:
                for step, batch in enumerate(tqdm(memory_safe_data_loader)):
                    optimizer.zero_grad()

                    inputs = {k: batch[k].to(self.device) for k in batch if k != "labels"}

                    outputs = model(**inputs) # output = loss, logits, hidden_states, attentions

                    loss = outputs[0].mean()
                    loss.backward()
                    losses.append(loss.item())

                    optimizer.step()

                    if step > 0 and step % 5000 == 0:
                        train_loss = np.mean(losses)
                        eps = privacy_engine.get_epsilon(DELTA)

                        print(
                        f"Epoch: {epoch} | "
                        f"Step: {step} | "
                        f"Train loss: {train_loss:.3f} | "
                        f"ɛ: {eps:.2f}"
                        )

The full error :

Traceback (most recent call last):
  File "/home/lmagnana/nlp-attacks/examples/special_finetunings/n2c2_ner_mlm_finetuning.py", line 72, in <module>
    models, metrics = finetuner.run(dataset, test_size, epochs, pathlib.Path(output_dir), output_name=ouput_name)
  File "/home/lmagnana/nlp-attacks/nlp_attacks/finetuners/Finetuner.py", line 248, in run
    model = self.train(model, self.config.learning_rate, ds["train"], ds["test"], epochs) 
  File "/home/lmagnana/nlp-attacks/nlp_attacks/finetuners/PrivacyPreservingLanguageModelling.py", line 141, in train
    loss.backward()
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 72, in __call__
    return self.hook(module, *args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/grad_sample_module.py", line 340, in capture_backprops_hook
    grad_samples = grad_sampler_fn(module, activations, backprops)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 108, in ft_compute_per_sample_gradient
    per_sample_grads = layer.ft_compute_sample_grad(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 281, in vmap_impl
    return _flat_vmap(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 47, in fn
    return f(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 403, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/apis.py", line 363, in wrapper
    return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1285, in grad_impl
    results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 47, in fn
    return f(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 1249, in grad_and_value_impl
    output = func(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 85, in compute_loss_stateless_model
    output = flayer(params, batched_activations)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/opacus/grad_sample/functorch.py", line 50, in fmodel
    return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_functorch/functional_call.py", line 143, in functional_call
    return nn.utils.stateless._functional_call(
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/utils/stateless.py", line 263, in _functional_call
    return module(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 787, in forward
    hidden_states = self.decoder(hidden_states)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/wrappers.py", line 252, in _fn
    result = fn(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/wrappers.py", line 137, in _fn
    result = fn(**bound.arguments)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_refs/__init__.py", line 1091, in add
    output = prims.add(a, b)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_ops.py", line 594, in __call__
    return self_._op(*args, **kwargs)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims/__init__.py", line 359, in _prim_elementwise_meta
    utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
  File "/home/lmagnana/anaconda3/envs/3.9/lib/python3.9/site-packages/torch/_prims_common/__init__.py", line 740, in check_same_device
    raise RuntimeError(msg)
RuntimeError: Tensor on device cuda:0 is not on the expected device meta!

Expected behavior

The code should work with both a gpt2 and a bert-base-cased model.

Environment

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==8.9.2.26
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.5.82
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] torch==2.3.1
[pip3] opacus==1.5.3
[pip3] triton==2.3.1
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.5.82                  pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
[conda] torch                     2.3.1                    pypi_0    pypi
[conda] triton                    2.3.1                    pypi_0    pypi

Thanks in advance for your replies.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions