diff --git a/examples/corda_finetuning/README.md b/examples/corda_finetuning/README.md index 9f5d0a8ff3..d39d2cc09e 100644 --- a/examples/corda_finetuning/README.md +++ b/examples/corda_finetuning/README.md @@ -150,7 +150,10 @@ corda_config = CordaConfig( #### Knowledge-preserved adaptation mode ```bash -CUDA_VISIBLE_DEVICES=0 python -u preprocess.py --model_id="meta-llama/Llama-2-7b-hf" \ +export CUDA_VISIBLE_DEVICES=0 # force to use device 0 of CUDA GPU +export ZE_AFFINITY_MASK=0 # force to use device 0 of Intel XPU + +python -u preprocess.py --model_id="meta-llama/Llama-2-7b-hf" \ --r 128 --seed 233 \ --save_model --save_path {path_to_residual_model} \ --calib_dataset "nqopen" @@ -165,7 +168,10 @@ Arguments: #### Instruction-previewed adaptation mode ```bash -CUDA_VISIBLE_DEVICES=0 python -u preprocess.py --model_id="meta-llama/Llama-2-7b-hf" \ +export CUDA_VISIBLE_DEVICES=0 # force to use device 0 of CUDA GPU +export ZE_AFFINITY_MASK=0 # force to use device 0 of Intel XPU + +python -u preprocess.py --model_id="meta-llama/Llama-2-7b-hf" \ --r 128 --seed 233 \ --save_model --save_path {path_to_residual_model} \ --first_eigen --calib_dataset "MetaMATH" @@ -248,4 +254,4 @@ Note that this conversion is not supported if `rslora` is used in combination wi booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, year={2024}, } -``` \ No newline at end of file +``` diff --git a/examples/corda_finetuning/preprocess.py b/examples/corda_finetuning/preprocess.py index 15bb18cb6b..765242f15e 100644 --- a/examples/corda_finetuning/preprocess.py +++ b/examples/corda_finetuning/preprocess.py @@ -38,8 +38,11 @@ def main(args): # Setting random seed of numpy and torch np.random.seed(args.seed) torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - torch.backends.cudnn.deterministic = True + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + elif torch.xpu.is_available(): + torch.xpu.manual_seed_all(args.seed) + torch.use_deterministic_algorithms(True) # Load model model_id = args.model_id