Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| def fix_pytorch_int8(): | |
| valid_path = [p for p in sys.path if p and os.path.isdir(p)] | |
| for path in valid_path: | |
| for folder in os.listdir(path): | |
| if 'torch' in folder: | |
| packages_path = path | |
| break | |
| fix_path = f'{packages_path}/torch/nn/parameter.py' | |
| with open(fix_path, 'r') as f: | |
| text = f.read() | |
| if 'if data.dtype == torch.int8' not in text: | |
| text = text.replace( | |
| ' return torch.Tensor._make_subclass(cls, data, requires_grad)', | |
| ' if data.dtype == torch.int8:\n' \ | |
| ' requires_grad = False\n' \ | |
| ' return torch.Tensor._make_subclass(cls, data, requires_grad)' | |
| ) | |
| with open(fix_path, 'w') as f: | |
| f.write(text) | |
| return print('Fixed torch/nn/parameter.py') | |