seanpedrickcase commited on
Commit
cb7a4c9
·
1 Parent(s): 38198b1

Removed explicit references to cuda in functions where spaces GPU are loaded

Browse files
Files changed (2) hide show
  1. funcs/embeddings.py +12 -12
  2. funcs/representation_model.py +22 -7
funcs/embeddings.py CHANGED
@@ -35,18 +35,18 @@ def make_or_load_embeddings(docs: list, file_list: list, embeddings_out: np.ndar
35
  """
36
 
37
  # Check for torch cuda
38
- from torch import cuda, backends, version
39
-
40
- print("Is CUDA enabled? ", cuda.is_available())
41
- print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
42
- if cuda.is_available():
43
- torch_device = "gpu"
44
- print("Cuda version installed is: ", version.cuda)
45
- high_quality_mode = "Yes"
46
- os.system("nvidia-smi")
47
- else:
48
- torch_device = "cpu"
49
- high_quality_mode = "No"
50
 
51
  if high_quality_mode_opt == "Yes":
52
  # Define a list of possible local locations to search for the model
 
35
  """
36
 
37
  # Check for torch cuda
38
+ # from torch import cuda, backends, version
39
+
40
+ # print("Is CUDA enabled? ", cuda.is_available())
41
+ # print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
42
+ # if cuda.is_available():
43
+ # torch_device = "gpu"
44
+ # print("Cuda version installed is: ", version.cuda)
45
+ # high_quality_mode = "Yes"
46
+ # os.system("nvidia-smi")
47
+ # else:
48
+ # torch_device = "cpu"
49
+ # high_quality_mode = "No"
50
 
51
  if high_quality_mode_opt == "Yes":
52
  # Define a list of possible local locations to search for the model
funcs/representation_model.py CHANGED
@@ -19,16 +19,30 @@ random_seed = 42
19
  RUNNING_ON_AWS = get_or_create_env_var('RUNNING_ON_AWS', '0')
20
  print(f'The value of RUNNING_ON_AWS is {RUNNING_ON_AWS}')
21
 
22
- from torch import cuda, backends, version, get_num_threads
23
-
24
- print("Is CUDA enabled? ", cuda.is_available())
25
- print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
26
- if cuda.is_available():
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  torch_device = "gpu"
28
  print("Cuda version installed is: ", version.cuda)
29
  high_quality_mode = "Yes"
30
  os.system("nvidia-smi")
31
- else:
 
32
  torch_device = "cpu"
33
  high_quality_mode = "No"
34
 
@@ -42,6 +56,7 @@ else: # torch_device = "cpu"
42
  n_gpu_layers = 0
43
 
44
  #print("Running on device:", torch_device)
 
45
  n_threads = get_num_threads()
46
  print("CPU n_threads:", n_threads)
47
 
@@ -56,7 +71,7 @@ seed: int = random_seed
56
  reset: bool = True
57
  stream: bool = False
58
  n_threads: int = n_threads
59
- n_batch:int = 256
60
  n_ctx:int = 8192 #4096. # Set to 8192 just to avoid any exceeded context window issues
61
  sample:bool = True
62
  trust_remote_code:bool =True
 
19
  RUNNING_ON_AWS = get_or_create_env_var('RUNNING_ON_AWS', '0')
20
  print(f'The value of RUNNING_ON_AWS is {RUNNING_ON_AWS}')
21
 
22
+ USE_GPU = get_or_create_env_var('USE_GPU', '0')
23
+ print(f'The value of USE_GPU is {USE_GPU}')
24
+
25
+ # from torch import cuda, backends, version, get_num_threads
26
+
27
+ # print("Is CUDA enabled? ", cuda.is_available())
28
+ # print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
29
+ # if cuda.is_available():
30
+ # torch_device = "gpu"
31
+ # print("Cuda version installed is: ", version.cuda)
32
+ # high_quality_mode = "Yes"
33
+ # os.system("nvidia-smi")
34
+ # else:
35
+ # torch_device = "cpu"
36
+ # high_quality_mode = "No"
37
+
38
+ if USE_GPU == "1":
39
+ print("Using GPU for representation functions")
40
  torch_device = "gpu"
41
  print("Cuda version installed is: ", version.cuda)
42
  high_quality_mode = "Yes"
43
  os.system("nvidia-smi")
44
+ else:
45
+ print("Using CPU for representation functions")
46
  torch_device = "cpu"
47
  high_quality_mode = "No"
48
 
 
56
  n_gpu_layers = 0
57
 
58
  #print("Running on device:", torch_device)
59
+ from torch import get_num_threads
60
  n_threads = get_num_threads()
61
  print("CPU n_threads:", n_threads)
62
 
 
71
  reset: bool = True
72
  stream: bool = False
73
  n_threads: int = n_threads
74
+ n_batch:int = 512
75
  n_ctx:int = 8192 #4096. # Set to 8192 just to avoid any exceeded context window issues
76
  sample:bool = True
77
  trust_remote_code:bool =True