Spaces:
Runtime error
Runtime error
Ensure model is downloaded before spawning sub-processes
Browse files- src/vadParallel.py +4 -0
- src/whisperContainer.py +15 -0
src/vadParallel.py
CHANGED
|
@@ -96,6 +96,10 @@ class ParallelTranscription(AbstractTranscription):
|
|
| 96 |
timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
|
| 97 |
merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration)
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
# Split into a list for each device
|
| 100 |
# TODO: Split by time instead of by number of chunks
|
| 101 |
merged_split = list(self._split(merged, len(gpu_devices)))
|
|
|
|
| 96 |
timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
|
| 97 |
merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration)
|
| 98 |
|
| 99 |
+
# We must make sure the whisper model is downloaded
|
| 100 |
+
if (len(gpu_devices) > 1):
|
| 101 |
+
whisperCallable.model_container.ensure_downloaded()
|
| 102 |
+
|
| 103 |
# Split into a list for each device
|
| 104 |
# TODO: Split by time instead of by number of chunks
|
| 105 |
merged_split = list(self._split(merged, len(gpu_devices)))
|
src/whisperContainer.py
CHANGED
|
@@ -23,6 +23,21 @@ class WhisperContainer:
|
|
| 23 |
self.model = self.cache.get(model_key, self._create_model)
|
| 24 |
return self.model
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
def _create_model(self):
|
| 27 |
print("Loading whisper model " + self.model_name)
|
| 28 |
return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
|
|
|
|
| 23 |
self.model = self.cache.get(model_key, self._create_model)
|
| 24 |
return self.model
|
| 25 |
|
| 26 |
+
def ensure_downloaded(self):
|
| 27 |
+
"""
|
| 28 |
+
Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
|
| 29 |
+
passing the container to a subprocess.
|
| 30 |
+
"""
|
| 31 |
+
# Warning: Using private API here
|
| 32 |
+
try:
|
| 33 |
+
if self.model_name in whisper._MODELS:
|
| 34 |
+
whisper._download(whisper._MODELS[self.model_name], self.download_root, False)
|
| 35 |
+
return True
|
| 36 |
+
except Exception as e:
|
| 37 |
+
# Given that the API is private, it could change at any time. We don't want to crash the program
|
| 38 |
+
print("Error pre-downloading model: " + str(e))
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
def _create_model(self):
|
| 42 |
print("Loading whisper model " + self.model_name)
|
| 43 |
return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
|