Spaces:
Paused
Paused
| import platform | |
| import sys | |
| import torch | |
| def get_cuda_version(): | |
| if torch.cuda.is_available(): | |
| cuda_version = torch.version.cuda | |
| return f"cu{cuda_version.replace('.', '')[:2]}" # 例如:cu121 | |
| return "cpu" | |
| def get_torch_version(): | |
| return f"torch{torch.__version__.split('+')[0]}"[:-2] # 例如:torch2.2 | |
| def get_python_version(): | |
| version = sys.version_info | |
| return f"cp{version.major}{version.minor}" # 例如:cp310 | |
| def get_abi_flag(): | |
| return "abiTRUE" if torch._C._GLIBCXX_USE_CXX11_ABI else "abiFALSE" | |
| def get_platform(): | |
| system = platform.system().lower() | |
| machine = platform.machine().lower() | |
| if system == "linux" and machine == "x86_64": | |
| return "linux_x86_64" | |
| elif system == "windows" and machine == "amd64": | |
| return "win_amd64" | |
| elif system == "darwin" and machine == "x86_64": | |
| return "macosx_x86_64" | |
| else: | |
| raise ValueError(f"Unsupported platform: {system}_{machine}") | |
| def generate_flash_attn_filename(flash_attn_version="2.7.2.post1"): | |
| cuda_version = get_cuda_version() | |
| torch_version = get_torch_version() | |
| python_version = get_python_version() | |
| abi_flag = get_abi_flag() | |
| platform_tag = get_platform() | |
| filename = ( | |
| f"flash_attn-{flash_attn_version}+{cuda_version}{torch_version}cxx11{abi_flag}-" | |
| f"{python_version}-{python_version}-{platform_tag}.whl" | |
| ) | |
| return filename | |
| if __name__ == "__main__": | |
| try: | |
| filename = generate_flash_attn_filename() | |
| print(f"{filename}") | |
| except Exception as e: | |
| print("Error generating filename:", e) |