wan2-2-fp8da-aoti-faster / setup_device.py
tddandroid's picture
Create setup_device.py
f4d9e3e verified
raw
history blame contribute delete
479 Bytes
import torch
def get_device():
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
print("CUDA not available. Using CPU.")
return device
# Usage example:
if __name__ == "__main__":
device = get_device()
# Example: Create a tensor on the selected device
x = torch.rand(3, 3).to(device)
print(f"Tensor device: {x.device}")