diff --git a/src/peft/config.py b/src/peft/config.py index 094ee70940..60a5c20c74 100644 --- a/src/peft/config.py +++ b/src/peft/config.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import importlib.metadata import inspect import json import os @@ -18,9 +21,12 @@ from dataclasses import asdict, dataclass, field from typing import Optional, Union +import packaging.version from huggingface_hub import hf_hub_download from transformers.utils import PushToHubMixin, http_user_agent +from peft import __version__ + from .utils import CONFIG_NAME, PeftType, TaskType @@ -43,6 +49,30 @@ def _check_and_remove_unused_kwargs(cls, kwargs): return kwargs, unexpected_kwargs +def _is_dev_version(version: str) -> bool: + # check if the given version is a dev version + return packaging.version.Version(version).dev is not None + + +def _get_commit_hash(pkg_name: str) -> str | None: + # If PEFT was installed from a specific commit hash, try to get it. This works e.g. when installing PEFT with `pip + # install git+https://github.com/huggingface/peft.git@`. This works not for other means, like editable + # installs. + try: + dist = importlib.metadata.distribution(pkg_name) + except importlib.metadata.PackageNotFoundError: + return None + + # See: https://packaging.python.org/en/latest/specifications/direct-url/ + for path in dist.files or []: + if path.name == "direct_url.json": + direct_url = json.loads((dist.locate_file(path)).read_text()) + vcs_info = direct_url.get("vcs_info") + if vcs_info and "commit_id" in vcs_info: + return vcs_info["commit_id"] + return None + + @dataclass class PeftConfigMixin(PushToHubMixin): r""" @@ -60,6 +90,7 @@ class PeftConfigMixin(PushToHubMixin): auto_mapping: Optional[dict] = field( default=None, metadata={"help": "An auto mapping dict to help retrieve the base model class if needed."} ) + peft_version: Optional[str] = field(default=None, metadata={"help": "PEFT version, leave empty to auto-fill."}) def __post_init__(self): # check for invalid task type @@ -67,6 +98,30 @@ def __post_init__(self): raise ValueError( f"Invalid task type: '{self.task_type}'. Must be one of the following task types: {', '.join(TaskType)}." ) + if self.peft_version is None: + self.peft_version = self._get_peft_version() + + @staticmethod + def _get_peft_version() -> str: + # gets the current peft version; if it's a dev version, try to get the commit hash too, as the dev version is + # ambiguous + version = __version__ + if not _is_dev_version(version): + return version + + try: + git_hash = _get_commit_hash("peft") + if git_hash is None: + git_hash = "UNKNOWN" + except Exception: + # Broad exception: We never want to break user code just because the git_hash could not be determined + warnings.warn( + "A dev version of PEFT is used but there was an error while trying to determine the commit hash. " + "Please open an issue: https://github.com/huggingface/peft/issues" + ) + git_hash = "UNKNOWN" + version = version + f"@{git_hash}" + return version def to_dict(self) -> dict: r""" diff --git a/tests/test_config.py b/tests/test_config.py index eddeb46244..8252a9681f 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -475,3 +475,96 @@ def test_lora_config_layers_to_transform_validation(self): ) assert config.layers_to_transform is None assert config.layers_pattern is None + + @pytest.mark.parametrize("version", ["0.10", "0.17.0", "1"]) + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_peft_version_is_stored(self, version, config_class, mandatory_kwargs, monkeypatch, tmp_path): + # Check that the PEFT version is automatically stored in/restored from the config file. + from peft import config + + monkeypatch.setattr(config, "__version__", version) + + peft_config = config_class(**mandatory_kwargs) + assert peft_config.peft_version == version + + peft_config.save_pretrained(tmp_path) + with open(tmp_path / "adapter_config.json") as f: + config_dict = json.load(f) + assert config_dict["peft_version"] == version + + # ensure that the version from the config is being loaded, not just the current version + monkeypatch.setattr(config, "__version__", "0.1.another-version") + + # load from config + config_loaded = PeftConfig.from_pretrained(tmp_path) + assert config_loaded.peft_version == version + + # load from json + config_path = tmp_path / "adapter_config.json" + config_json = PeftConfig.from_json_file(str(config_path)) + assert config_json["peft_version"] == version + + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_peft_version_is_dev_version(self, config_class, mandatory_kwargs, monkeypatch, tmp_path): + # When a dev version of PEFT is installed, the actual state of PEFT is ambiguous. Therefore, try to determine + # the commit hash too and store it as part of the version string. + from peft import config + + version = "0.15.0.dev7" + monkeypatch.setattr(config, "__version__", version) + + def fake_commit_hash(pkg_name): + return "abcdef012345" + + monkeypatch.setattr(config, "_get_commit_hash", fake_commit_hash) + + peft_config = config_class(**mandatory_kwargs) + expected_version = f"{version}@{fake_commit_hash('peft')}" + assert peft_config.peft_version == expected_version + + peft_config.save_pretrained(tmp_path) + config_loaded = PeftConfig.from_pretrained(tmp_path) + assert config_loaded.peft_version == expected_version + + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_peft_version_is_dev_version_but_commit_hash_cannot_be_determined( + self, config_class, mandatory_kwargs, monkeypatch, tmp_path + ): + # There can be cases where PEFT is using a dev version but the commit hash cannot be determined. In this case, + # just store the dev version string. + from peft import config + + version = "0.15.0.dev7" + monkeypatch.setattr(config, "__version__", version) + + def fake_commit_hash(pkg_name): + return None + + monkeypatch.setattr(config, "_get_commit_hash", fake_commit_hash) + + peft_config = config_class(**mandatory_kwargs) + assert peft_config.peft_version == version + "@UNKNOWN" + + peft_config.save_pretrained(tmp_path) + config_loaded = PeftConfig.from_pretrained(tmp_path) + assert config_loaded.peft_version == version + "@UNKNOWN" + + @pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES) + def test_peft_version_warn_when_commit_hash_errors(self, config_class, mandatory_kwargs, monkeypatch, tmp_path): + # We try to get the PEFT commit hash if a dev version is installed. But in case there is any kind of error + # there, we don't want user code to break. Instead, the code should run and a version without commit hash should + # be recorded. In addition, there should be a warning. + from peft import config + + version = "0.15.0.dev7" + monkeypatch.setattr(config, "__version__", version) + + def fake_commit_hash_raises(pkg_name): + raise Exception("Error for testing purpose") + + monkeypatch.setattr(config, "_get_commit_hash", fake_commit_hash_raises) + + msg = "A dev version of PEFT is used but there was an error while trying to determine the commit hash" + with pytest.warns(UserWarning, match=msg): + peft_config = config_class(**mandatory_kwargs) + assert peft_config.peft_version == version + "@UNKNOWN"