这是indexloc提供的服务,不要输入任何密码
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/peft/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import importlib.metadata
import inspect
import json
import os
import warnings
from dataclasses import asdict, dataclass, field
from typing import Optional, Union

import packaging.version
from huggingface_hub import hf_hub_download
from transformers.utils import PushToHubMixin, http_user_agent

from peft import __version__

from .utils import CONFIG_NAME, PeftType, TaskType


Expand All @@ -43,6 +49,30 @@ def _check_and_remove_unused_kwargs(cls, kwargs):
return kwargs, unexpected_kwargs


def _is_dev_version(version: str) -> bool:
# check if the given version is a dev version
return packaging.version.Version(version).dev is not None


def _get_commit_hash(pkg_name: str) -> str | None:
# If PEFT was installed from a specific commit hash, try to get it. This works e.g. when installing PEFT with `pip
# install git+https://github.com/huggingface/peft.git@<HASH>`. This works not for other means, like editable
# installs.
try:
dist = importlib.metadata.distribution(pkg_name)
except importlib.metadata.PackageNotFoundError:
return None

# See: https://packaging.python.org/en/latest/specifications/direct-url/
for path in dist.files or []:
if path.name == "direct_url.json":
direct_url = json.loads((dist.locate_file(path)).read_text())
vcs_info = direct_url.get("vcs_info")
if vcs_info and "commit_id" in vcs_info:
return vcs_info["commit_id"]
return None


@dataclass
class PeftConfigMixin(PushToHubMixin):
r"""
Expand All @@ -60,13 +90,38 @@ class PeftConfigMixin(PushToHubMixin):
auto_mapping: Optional[dict] = field(
default=None, metadata={"help": "An auto mapping dict to help retrieve the base model class if needed."}
)
peft_version: Optional[str] = field(default=None, metadata={"help": "PEFT version, leave empty to auto-fill."})

def __post_init__(self):
# check for invalid task type
if (self.task_type is not None) and (self.task_type not in list(TaskType)):
raise ValueError(
f"Invalid task type: '{self.task_type}'. Must be one of the following task types: {', '.join(TaskType)}."
)
if self.peft_version is None:
self.peft_version = self._get_peft_version()

@staticmethod
def _get_peft_version() -> str:
# gets the current peft version; if it's a dev version, try to get the commit hash too, as the dev version is
# ambiguous
version = __version__
if not _is_dev_version(version):
return version

try:
git_hash = _get_commit_hash("peft")
if git_hash is None:
git_hash = "UNKNOWN"
except Exception:
# Broad exception: We never want to break user code just because the git_hash could not be determined
warnings.warn(
"A dev version of PEFT is used but there was an error while trying to determine the commit hash. "
"Please open an issue: https://github.com/huggingface/peft/issues"
)
Comment on lines +116 to +121
Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a bare except Exception: is overly broad and can mask unexpected errors. Consider catching more specific exceptions like ImportError, FileNotFoundError, or JSONDecodeError that are likely to occur in the _get_commit_hash function.

Copilot uses AI. Check for mistakes.
git_hash = "UNKNOWN"
version = version + f"@{git_hash}"
return version

def to_dict(self) -> dict:
r"""
Expand Down
93 changes: 93 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,96 @@ def test_lora_config_layers_to_transform_validation(self):
)
assert config.layers_to_transform is None
assert config.layers_pattern is None

@pytest.mark.parametrize("version", ["0.10", "0.17.0", "1"])
@pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES)
def test_peft_version_is_stored(self, version, config_class, mandatory_kwargs, monkeypatch, tmp_path):
# Check that the PEFT version is automatically stored in/restored from the config file.
from peft import config

monkeypatch.setattr(config, "__version__", version)

peft_config = config_class(**mandatory_kwargs)
assert peft_config.peft_version == version

peft_config.save_pretrained(tmp_path)
with open(tmp_path / "adapter_config.json") as f:
config_dict = json.load(f)
assert config_dict["peft_version"] == version

# ensure that the version from the config is being loaded, not just the current version
monkeypatch.setattr(config, "__version__", "0.1.another-version")

# load from config
config_loaded = PeftConfig.from_pretrained(tmp_path)
assert config_loaded.peft_version == version

# load from json
config_path = tmp_path / "adapter_config.json"
config_json = PeftConfig.from_json_file(str(config_path))
assert config_json["peft_version"] == version

@pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES)
def test_peft_version_is_dev_version(self, config_class, mandatory_kwargs, monkeypatch, tmp_path):
# When a dev version of PEFT is installed, the actual state of PEFT is ambiguous. Therefore, try to determine
# the commit hash too and store it as part of the version string.
from peft import config

version = "0.15.0.dev7"
monkeypatch.setattr(config, "__version__", version)

def fake_commit_hash(pkg_name):
return "abcdef012345"

monkeypatch.setattr(config, "_get_commit_hash", fake_commit_hash)

peft_config = config_class(**mandatory_kwargs)
expected_version = f"{version}@{fake_commit_hash('peft')}"
assert peft_config.peft_version == expected_version

peft_config.save_pretrained(tmp_path)
config_loaded = PeftConfig.from_pretrained(tmp_path)
assert config_loaded.peft_version == expected_version

@pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES)
def test_peft_version_is_dev_version_but_commit_hash_cannot_be_determined(
self, config_class, mandatory_kwargs, monkeypatch, tmp_path
):
# There can be cases where PEFT is using a dev version but the commit hash cannot be determined. In this case,
# just store the dev version string.
from peft import config

version = "0.15.0.dev7"
monkeypatch.setattr(config, "__version__", version)

def fake_commit_hash(pkg_name):
return None

monkeypatch.setattr(config, "_get_commit_hash", fake_commit_hash)

peft_config = config_class(**mandatory_kwargs)
assert peft_config.peft_version == version + "@UNKNOWN"

peft_config.save_pretrained(tmp_path)
config_loaded = PeftConfig.from_pretrained(tmp_path)
assert config_loaded.peft_version == version + "@UNKNOWN"

@pytest.mark.parametrize("config_class, mandatory_kwargs", ALL_CONFIG_CLASSES)
def test_peft_version_warn_when_commit_hash_errors(self, config_class, mandatory_kwargs, monkeypatch, tmp_path):
# We try to get the PEFT commit hash if a dev version is installed. But in case there is any kind of error
# there, we don't want user code to break. Instead, the code should run and a version without commit hash should
# be recorded. In addition, there should be a warning.
from peft import config

version = "0.15.0.dev7"
monkeypatch.setattr(config, "__version__", version)

def fake_commit_hash_raises(pkg_name):
raise Exception("Error for testing purpose")

monkeypatch.setattr(config, "_get_commit_hash", fake_commit_hash_raises)

msg = "A dev version of PEFT is used but there was an error while trying to determine the commit hash"
with pytest.warns(UserWarning, match=msg):
peft_config = config_class(**mandatory_kwargs)
assert peft_config.peft_version == version + "@UNKNOWN"
Loading