-
Notifications
You must be signed in to change notification settings - Fork 2.1k
SHiRA Adapters #2584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SHiRA Adapters #2584
Conversation
|
Issue: #2585 |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR to integrate SHiRA into PEFT. It's good to see a paper that not only deals with academic improvements but also considers practical applications.
I went through the code and I think this generally already looks quite fine. Still, I found a couple of issues, please check.
Regarding testing: I saw that you added test_shira.py. This is fine for SHiRA-specific tests, but for the general PEFT integration, we need to update the existing PEFT tests to include SHiRA. As a first step, let's add one or a few examples using SHiRA to the custom model tests here:
peft/tests/test_custom_models.py
Line 68 in 759bb70
| ("Vanilla MLP 1 LoRA", "MLP", LoraConfig, {"target_modules": "lin0"}), |
After that, please run pytest tests/test_custom_models.py -k "shira" to see if these tests pass.
Ideally, before we merge, we should also:
- add docs
- add an example
- add a benchmark
Those steps are not strictly necessary, as they can be added later, but often it's a good idea to include them, as they can sometimes reveal new areas for improvement. But it's fine to focus on the proper implementation for now.
Finally, please always run style before pushing your changes to ensure that the linter is happy.
|
It seems like only the random mask is implemented, what about the SHiRA-WM, SHiRA-Grad and SHiRA-SNIP? |
Hi @BenjaminBossan, we have now included SHiRA in the test_custom_models.py and have conducted and passed all the pytests. Also, we have responded to and/or incorporated your requested changes in our new commits. Hopefully, this addresses your concerns. We have fixed the Lint issues except for UP045 which existed in many other modules throughout the PEFT repo. Please let us know if you would like to fix those for our modules. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the updates, this is shaping up quite well.
I still have a couple of comments, but those should not be too hard to address.
Moreover, let's work on the following to finalize the PR:
- Please check your ruff version. It doesn't look like
make styleworked as expected. The CI uses ruff-0.9.10. - Let's add an entry to the docs, check how the other PEFT methods do it.
- Let's add an example to the
examples/folder, what exactly you want is up to you. You can take an existing one and modify it for SHiRA. - Let's extend the test coverage by adding entries to
test_decoder_models.py,test_encoder_decoder_models.py,test_feature_extraction_models.pyandtest_seq_classifier.py. This should be straightforward, as you only need to add a config to the list of test cases.
|
Ah, forgot to reply to this:
IMO, it makes little sense to have a config parameter ( @dataclass
class ShiraConfig(PeftConfig):
...
mask_type: Litera["random", <more-options>] = field(default='random', metadata={"help": "..."})
def __post_init__(self):
if self.mask_type == "random":
self.mask_fn = random_mask
elif ...
else:
warnings.warn(f"Argument {mask_type=} not recognized, please supply your own masking function by calling `config.mask_fn = my_mask_fn`.")
self.mask_fn = NoneThen, if users want to provide a custom mask function, they can still do: shira_config = ShiraConfig(...)
shira_config.mask_function = my_custom_mask_functionThe same would be needed after loading the model, but that's necessary anyway, even with the current implementation. It is important to document this well though. Moreover, I think it would be really great if all the mask functions from the paper could be added. |
…arate call to provide mask_fn
…use it throws an error (addmm_sparse_cuda not implemented for BFloat16). Minor clean up.
… typo fix in shira_finetuning.py
Added.
We have now merged with the latest main. We also verified |
|
Hmm, there is a strange error occurring on the Windows CI. It seems that each SHiRA test that involves persistence is affected. At first, I thought it might just be flakiness but I re-ran the CI today and it happened again. The error occurs when we have code like this in the tests: with tempfile.TemporaryDirectory() as tmp_dirname:
# do stuffWhen the context manager is left, the temp directory is cleaned up, but here the cleanup fails with: We have had some occasional issues in the past with Windows not being able to clean up temp directories, to wit: Lines 397 to 415 in 35000fd
However, we never had the case that so many tests were failing and that they were all tied to a single PEFT method. This makes me think that there is something going on with SHiRA that aggravates the problem, perhaps related to the |
This is strange. I do not have a windows machine that would be able to test things out. We usually work with Linux based systems. Also, about I am also seeing some Access Denied Errors in Windows? Maybe that is a culprit? See here. Also, |
|
This is indeed a very strange situation and not something I've come across yet. Today, I re-ran the CI and still the same errors occur. Too bad that we don't have Windows machines to test on, so we need to resort to the GH CI instead. Perhaps we can try to eliminate potential sources of error one by one. The first thing I'd test is to comment out all the additions to peft/.github/workflows/tests.yml Line 45 in 35000fd
- python-version: ["3.9", "3.10", "3.11", "3.12"]
+ python-version: ["3.12"]Of course, we will revert these changes before merging. |
|
I have pushed a new commit. |
|
Interestingly, when I remove all SHiRA related lines from UPDATE: Commit ee9d1d5 fails with the same Any help would be appreciated. Thanks. |
I'm as stumped as you are, this really doesn't make a lot of sense. I wondered if you could test converting the ints to floats before saving and then back to ints when loading. This is not a final solution, just for testing if it could somehow be related to the dtype. |
…ng and torch.int when loading
|
Just pushed a commit for SHiRA indices dtype changed to torch.float32 while saving and then torch.int while loading. Let us see if this helps. Can we think of any other solutions that might help us merge this? For example, can we skip certain tests on windows OS only? This clearly seems like a Permission issue that happens only on the windows. If there was anything wrong with our code, we would see these issues happening on other OS too? |
This seems to help indeed, there is only one Windows error related to 429 from the Hub, so it's good.
While generally I agree, I think there is a possibility that something more fundamental is broken here, so before we don't fully understand what is going on, merging is risky. I can't imagine that saving integers in the |
Okay, this is good, at least. Actually, it seems that there is something going on with Integers and Safetensors that is specific only to the Windows OS. Someone reported another issue on Safetensors github: please see here. Potentially looks like some kind of an overflow problem in Windows.
Seems like the above issue was closed as it was not planned, so I don't think any fix was provided. I am okay with casting indices to float32 while saving and then casting them back to int when loading them as a solution to this (because the file size should also stay the same). Seems like it does not affect any Linux tests. |
Okay, so how about only doing that when the platform is Windows? Then we can be sure that Linux and MacOS are unaffected. For reasonably small values, this should work fine, but at a certain size, the conversion should fail. I wonder if this point is known in advance. |
|
I agree, precision issues can occur for very large integer values. So let me just limit this to windows systems. I will use something like |
… when saving. Convert to int when loading.
|
I just pushed cc47b6a. It brings back all tests (all python versions) and casts the shira_indices to float32 before saving only on windows. Hence, Mac-OS and Linux must be completely unaffected now. Thanks. Let us see if this works. BTW, I am not sure what to do about Hub errors. I thought those were not happening before. |
|
Thanks @kkb-code let's go with this approach. Could you please ensure the 120 char line limit is taken care of?
Just ignore those, they are just rate limits. |
|
Yep, thanks @BenjaminBossan. Just pushed another commit 2014c43. This should take care of the 120 char limit. I also fixed a couple minor typos in the warning texts. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking care of the Windows issue and this contribution in general. The failing MacOS test is just flakiness and can be ignored. The PR is good to be merged.
|
Thank you so much @BenjaminBossan for the guidance throughout this process. We are happy to see that our SHiRA adapter was merged into PEFT. Please let us know if we should close Issue #2585 or if you would close it at a later time. Also, I am wondering if there is any ETA on the next PEFT version release that would include the SHiRA update (e.g., v0.16.1)? Can you please let us know? That would allow users to directly use Thanks again for the help. It was great working with you on this. |
Even though we had a release just recently, we might have another one soon, but we have external dependencies for that, so I can't give a date. FYI, for new features such as SHiRA, we always do minor releases (i.e. it would be 0.17.0), patch releases are for hot fixes only. In the meantime, users can install directly from GH, e.g.
Thanks, same from my side. |
Implements: Sparse High Rank Adapters Paper: https://arxiv.org/abs/2406.13175
We would like to add code for Sparse High Rank Adapters (SHiRA) which was published at NeurIPS 2024 (PAPER LINK).
This is an alternate type of adapter and has been found to have significant advantages over the low rank adapters. Specifically, SHiRA achieves better accuracy than LoRA for a variety of vision and language tasks. It also offers simpler and higher quality multi-adapter fusion by significantly reducing concept loss, a common problem faced by low rank adapters. Please see the paper for more details.