diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index c981ff1802..58f3e4ed75 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -16,6 +16,7 @@ import copy import gc +import importlib import multiprocessing as mp import os import traceback @@ -536,7 +537,11 @@ def export_pytorch( # Check that inputs match, and order them properly dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes) - device = torch.device(device) + if device == "dml" and importlib.util.find_spec("torch_directml"): + torch_directml = importlib.import_module("torch_directml") + device = torch_directml.device() + else: + device = torch.device(device) def remap(value): if isinstance(value, torch.Tensor): @@ -544,7 +549,7 @@ def remap(value): return value - if device.type == "cuda" and torch.cuda.is_available(): + if device.type == "cuda" and torch.cuda.is_available() or device.type == "privateuseone": model.to(device) dummy_inputs = tree_map(remap, dummy_inputs) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 27f712608e..616847ecf8 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -2198,12 +2198,16 @@ def get_model_from_task( kwargs["torch_dtype"] = torch_dtype if isinstance(device, str): - device = torch.device(device) + if device == "dml" and importlib.util.find_spec("torch_directml"): + torch_directml = importlib.import_module("torch_directml") + device = torch_directml.device() + else: + device = torch.device(device) elif device is None: device = torch.device("cpu") # TODO : fix EulerDiscreteScheduler loading to enable for SD models - if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers": + if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers" and device.type != "privateuseone": with device: # Initialize directly in the requested device, to save allocation time. Especially useful for large # models to initialize on cuda device. diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 846ae278b5..4c7fb2c3b6 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -114,6 +114,8 @@ def _is_package_available( "onnxruntime-migraphx", "ort-migraphx-nightly", "ort-rocm-nightly", + # For DirectML + "onnxruntime-directml", ], ) _tf_available, _tf_version = _is_package_available( diff --git a/setup.py b/setup.py index f446b9fdc1..b84c841b5b 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,14 @@ "transformers>=4.36,<4.53.0", "onnxruntime-training>=1.11.0", ], + "onnxruntime-directml": [ + "torch-directml", + "onnx", + "datasets>=1.2.1", + "onnxruntime-directml>=1.11.0", + "protobuf>=3.20.1", + "transformers>=4.36,<4.53.0", + ], "exporters": [ "onnx", "onnxruntime", @@ -81,6 +89,13 @@ "protobuf>=3.20.1", "transformers>=4.36,<4.53.0", ], + "exporters-directml": [ + "torch-directml", + "onnx", + "onnxruntime-directml", + "protobuf>=3.20.1", + "transformers>=4.36,<4.53.0", + ], "exporters-tf": [ "onnx", "h5py",