Spaces:
Runtime error
Runtime error
| import os | |
| import tempfile | |
| import unittest | |
| import torch | |
| from diffusers.loaders.lora_base import LoraBaseMixin | |
| class UtilityMethodDeprecationTests(unittest.TestCase): | |
| def test_fetch_state_dict_cls_method_raises_warning(self): | |
| state_dict = torch.nn.Linear(3, 3).state_dict() | |
| with self.assertWarns(FutureWarning) as warning: | |
| _ = LoraBaseMixin._fetch_state_dict( | |
| state_dict, | |
| weight_name=None, | |
| use_safetensors=False, | |
| local_files_only=True, | |
| cache_dir=None, | |
| force_download=False, | |
| proxies=None, | |
| token=None, | |
| revision=None, | |
| subfolder=None, | |
| user_agent=None, | |
| allow_pickle=None, | |
| ) | |
| warning_message = str(warning.warnings[0].message) | |
| assert "Using the `_fetch_state_dict()` method from" in warning_message | |
| def test_best_guess_weight_name_cls_method_raises_warning(self): | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| state_dict = torch.nn.Linear(3, 3).state_dict() | |
| torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin")) | |
| with self.assertWarns(FutureWarning) as warning: | |
| _ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir) | |
| warning_message = str(warning.warnings[0].message) | |
| assert "Using the `_best_guess_weight_name()` method from" in warning_message | |