From 8e5533452c3c183bd76d6f3848fd67d473742318 Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Sun, 18 Feb 2024 21:07:05 +0900 Subject: [PATCH 01/10] Enable DirectML acceleration and improve device handling This commit introduces two improvements: 1. DirectML acceleration: - Added support for running optimum commands on DirectML hardware (Windows only) using the --device dml flag. - Automatically sets the device to torch_directml.device() when the flag is specified. 2. Improved device handling: - Ensures the model is directly initialized in the device only when applicable. --- optimum/commands/optimum_cli.py | 4 ++++ optimum/exporters/tasks.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/optimum/commands/optimum_cli.py b/optimum/commands/optimum_cli.py index 4bae9bb5f8..6bbfa9cee5 100644 --- a/optimum/commands/optimum_cli.py +++ b/optimum/commands/optimum_cli.py @@ -158,6 +158,10 @@ def main(): parser.print_help() exit(1) + if args.device == "dml": + import torch_directml + args.device = torch_directml.device() + # Run service = args.func(args) service.run() diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index e8e8af2bce..ae1b3fb20b 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1860,7 +1860,7 @@ def get_model_from_task( device = torch.device("cpu") # TODO : fix EulerDiscreteScheduler loading to enable for SD models - if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers": + if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers" and not device.type: with device: # Initialize directly in the requested device, to save allocation time. Especially useful for large # models to initialize on cuda device. From 26d2ed9a5aeee700b72f99c2c507fa48f0886971 Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Sun, 18 Feb 2024 22:47:45 +0900 Subject: [PATCH 02/10] Refine device handling for PyTorch 2.0+ and device type check This commit refines the device handling in optimum/exporters/tasks.py for the following improvements: - More precise device check: Instead of checking for not device.type, the condition is updated to device.type != "privateuseone". This ensures the initialization happens on the requested device only if it's not a private use device (e.g., DirectML). - Improved clarity: The code comments are updated to better explain the purpose of the device initialization and its benefits for large models. --- optimum/exporters/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index ae1b3fb20b..85932eda50 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1860,7 +1860,7 @@ def get_model_from_task( device = torch.device("cpu") # TODO : fix EulerDiscreteScheduler loading to enable for SD models - if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers" and not device.type: + if version.parse(torch.__version__) >= version.parse("2.0") and library_name != "diffusers" and device.type != "privateuseone": with device: # Initialize directly in the requested device, to save allocation time. Especially useful for large # models to initialize on cuda device. From 4502bc4110ab993e354a5808e3434b32f2161b48 Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Sun, 25 Feb 2024 10:17:28 +0900 Subject: [PATCH 03/10] Support privateuseone device for PyTorch model export - Extends device compatibility to "privateuseone" in export_pytorch for exporting models usable on specific hardware. This commit allows exporting PyTorch models compatible with the "privateuseone" device, potentially enabling inference on specialized hardware platforms. --- optimum/exporters/onnx/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 21fc927942..75e7b3d527 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -554,7 +554,7 @@ def remap(value): return value - if device.type == "cuda" and torch.cuda.is_available(): + if device.type == "cuda" and torch.cuda.is_available() or device.type == "privateuseone": model.to(device) dummy_inputs = tree_map(remap, dummy_inputs) From 99c92b62c82f27caf61804cb827896ffe983bc69 Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Sun, 25 Feb 2024 15:31:31 +0900 Subject: [PATCH 04/10] Enable DML device support for PyTorch models in Optimum This commit adds support for running PyTorch models on the DML device within the Optimum framework. - Dynamic DML device handling: Introduces dynamic import of torch_directml for improved maintainability. - Consistent device selection: Ensures consistent device selection across optimum/exporters/onnx/convert.py, optimum/exporters/tasks.py, and optimum/onnxruntime/io_binding/io_binding_helper.py. This change allows users to leverage DML capabilities for efficient PyTorch model inference with Optimum. --- optimum/exporters/onnx/convert.py | 7 ++++++- optimum/exporters/tasks.py | 6 +++++- optimum/onnxruntime/io_binding/io_binding_helper.py | 9 +++++++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 75e7b3d527..5801955a75 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -16,6 +16,7 @@ import copy import gc +import importlib import multiprocessing as mp import os import traceback @@ -546,7 +547,11 @@ def export_pytorch( # Check that inputs match, and order them properly dummy_inputs = config.generate_dummy_inputs(framework="pt", **input_shapes) - device = torch.device(device) + if device == "dml" and importlib.util.find_spec("torch_directml"): + torch_directml = importlib.import_module("torch_directml") + device = torch_directml.device() + else: + device = torch.device(device) def remap(value): if isinstance(value, torch.Tensor): diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 85932eda50..cd346ef9e4 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1855,7 +1855,11 @@ def get_model_from_task( kwargs["torch_dtype"] = torch_dtype if isinstance(device, str): - device = torch.device(device) + if device == "dml" and importlib.util.find_spec("torch_directml"): + torch_directml = importlib.import_module("torch_directml") + device = torch_directml.device() + else: + device = torch.device(device) elif device is None: device = torch.device("cpu") diff --git a/optimum/onnxruntime/io_binding/io_binding_helper.py b/optimum/onnxruntime/io_binding/io_binding_helper.py index 31da537918..6f226e9a32 100644 --- a/optimum/onnxruntime/io_binding/io_binding_helper.py +++ b/optimum/onnxruntime/io_binding/io_binding_helper.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import logging import traceback from typing import TYPE_CHECKING @@ -145,8 +146,12 @@ def to_pytorch_via_dlpack(ort_value: OrtValue) -> torch.Tensor: @staticmethod def get_device_index(device): if isinstance(device, str): - # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 - device = torch.device(device) + if device == "dml" and importlib.util.find_spec("torch_directml"): + torch_directml = importlib.import_module("torch_directml") + device = torch_directml.device() + else: + # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 + device = torch.device(device) elif isinstance(device, int): return device return 0 if device.index is None else device.index From 08253df79e430908d6b1e9da01fe401fd81e6e86 Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Sun, 25 Feb 2024 15:34:41 +0900 Subject: [PATCH 05/10] Remove redundant DML device handling in optimum_cli.py This commit removes unnecessary code for handling the DML device in optimum/commands/optimum_cli.py. - Redundant import: The code previously imported torch_directml conditionally, which is no longer needed as DML device support is handled in other parts of the codebase. This change simplifies the code and avoids potential conflicts. --- optimum/commands/optimum_cli.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/optimum/commands/optimum_cli.py b/optimum/commands/optimum_cli.py index 6bbfa9cee5..4bae9bb5f8 100644 --- a/optimum/commands/optimum_cli.py +++ b/optimum/commands/optimum_cli.py @@ -158,10 +158,6 @@ def main(): parser.print_help() exit(1) - if args.device == "dml": - import torch_directml - args.device = torch_directml.device() - # Run service = args.func(args) service.run() From 107879ef4655bc9faa4694ff5895e0d7c51d6846 Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Wed, 28 Feb 2024 22:39:05 +0900 Subject: [PATCH 06/10] Add DML-specific dependencies to `setup.py` This commit updates `setup.py` to include the following changes: - Introduces a new conditional section "exporters-directml" with dependencies required for exporting models for DML inference. - This section mirrors the existing "exporters" and "exporters-gpu" sections, adding `onnxruntime-directml` as a dependency. This update ensures users have the necessary libraries for working with DML devices when installing Optimum with DML support. --- setup.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/setup.py b/setup.py index 25a09de985..d403d62e51 100644 --- a/setup.py +++ b/setup.py @@ -61,8 +61,17 @@ "protobuf>=3.20.1", "accelerate", # ORTTrainer requires it. ], + "onnxruntime-directml": [ + "onnx", + "onnxruntime-directml>=1.11.0", + "datasets>=1.2.1", + "evaluate", + "protobuf>=3.20.1", + "accelerate", # ORTTrainer requires it. + ], "exporters": ["onnx", "onnxruntime", "timm"], "exporters-gpu": ["onnx", "onnxruntime-gpu", "timm"], + "exporters-directml": ["onnx", "onnxruntime-directml", "timm"], "exporters-tf": [ "tensorflow>=2.4,<=2.12.1", "tf2onnx", From ae49e738e584be7854e03e4a14b16c9ee453f8a7 Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Sun, 19 Jan 2025 15:44:56 +0900 Subject: [PATCH 07/10] Enhance package availability check to support multiple distribution names --- optimum/utils/import_utils.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 6c25c72475..1700868efb 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -37,15 +37,21 @@ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} -def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: +def _is_package_available(pkg_name: str, return_version: bool = False, dist_names: list[str] | None = None) -> Union[Tuple[bool, str], bool]: # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version package_exists = importlib.util.find_spec(pkg_name) is not None package_version = "N/A" if package_exists: - try: - package_version = importlib.metadata.version(pkg_name) - package_exists = True - except importlib.metadata.PackageNotFoundError: + if dist_names is None: + dist_names = [pkg_name] + for dist_name in dist_names: + try: + package_version = importlib.metadata.version(dist_name) + package_exists = True + break + except importlib.metadata.PackageNotFoundError: + pass + else: package_exists = False if return_version: return package_exists, package_version @@ -66,7 +72,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _torch_available, _torch_version = _is_package_available("torch", return_version=True) # importlib.metadata.version seem to not be robust with the ONNX Runtime extensions (`onnxruntime-gpu`, etc.) -_onnxruntime_available = _is_package_available("onnxruntime", return_version=False) +_onnxruntime_available = _is_package_available("onnxruntime", return_version=False, dist_names=["onnxruntime, onnxruntime-gpu", "onnxruntime-directml"]) # TODO : Remove torch_version = version.parse(importlib.metadata.version("torch")) if _torch_available else None From e84dadbed0bf1fe2e71ecf1b4f446b3ba386a922 Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Sun, 19 Jan 2025 16:47:30 +0900 Subject: [PATCH 08/10] Add torch-directml to exporters-directml requirements --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 1a72501dea..c7b4a22f82 100644 --- a/setup.py +++ b/setup.py @@ -83,6 +83,7 @@ "transformers>=4.36,<4.48.0", ], "exporters-directml": [ + "torch-directml", "onnx", "onnxruntime-directml", "timm", From 66ed0ffa406c18831792a5f019090f54a18254c8 Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Sun, 2 Feb 2025 17:43:23 +0900 Subject: [PATCH 09/10] Add support for onnxruntime-directml in import utilities --- optimum/utils/import_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 8da1df5fac..c3c0cc6aae 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -113,6 +113,8 @@ def _is_package_available( "onnxruntime-migraphx", "ort-migraphx-nightly", "ort-rocm-nightly", + # For DirectML + "onnxruntime-directml", ], ) _tf_available, _tf_version = _is_package_available( From 5363014123e946cb2e827dcc98b9101c8ac39fdd Mon Sep 17 00:00:00 2001 From: Kaz Nishimura Date: Sun, 29 Jun 2025 16:40:49 +0900 Subject: [PATCH 10/10] Update setup.py dependencies for onnxruntime-directml and transformers --- setup.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index a6b98f262b..b84c841b5b 100644 --- a/setup.py +++ b/setup.py @@ -70,13 +70,12 @@ "onnxruntime-training>=1.11.0", ], "onnxruntime-directml": [ + "torch-directml", "onnx", - "onnxruntime-directml>=1.11.0", "datasets>=1.2.1", - "evaluate", + "onnxruntime-directml>=1.11.0", "protobuf>=3.20.1", - "accelerate", # ORTTrainer requires it. - "transformers>=4.36,<4.48.0", + "transformers>=4.36,<4.53.0", ], "exporters": [ "onnx", @@ -94,8 +93,8 @@ "torch-directml", "onnx", "onnxruntime-directml", - "timm", - "transformers>=4.36,<4.48.0", + "protobuf>=3.20.1", + "transformers>=4.36,<4.53.0", ], "exporters-tf": [ "onnx",