File size: 479 Bytes
f4d9e3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch

def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
    else:
        device = torch.device("cpu")
        print("CUDA not available. Using CPU.")
    return device

# Usage example:
if __name__ == "__main__":
    device = get_device()

    # Example: Create a tensor on the selected device
    x = torch.rand(3, 3).to(device)
    print(f"Tensor device: {x.device}")