Spaces:
Build error
Build error
Fix threads argument
Browse files
app.py
CHANGED
|
@@ -29,7 +29,7 @@ import ffmpeg
|
|
| 29 |
import gradio as gr
|
| 30 |
|
| 31 |
from src.download import ExceededMaximumDuration, download_url
|
| 32 |
-
from src.utils import slugify, write_srt, write_vtt
|
| 33 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
| 34 |
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
| 35 |
from src.whisper.whisperFactory import create_whisper_container
|
|
@@ -596,9 +596,14 @@ if __name__ == '__main__':
|
|
| 596 |
help="the Whisper implementation to use")
|
| 597 |
parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
|
| 598 |
help="the compute type to use for inference")
|
|
|
|
|
|
|
| 599 |
|
| 600 |
args = parser.parse_args().__dict__
|
| 601 |
|
| 602 |
updated_config = default_app_config.update(**args)
|
| 603 |
|
|
|
|
|
|
|
|
|
|
| 604 |
create_ui(app_config=updated_config)
|
|
|
|
| 29 |
import gradio as gr
|
| 30 |
|
| 31 |
from src.download import ExceededMaximumDuration, download_url
|
| 32 |
+
from src.utils import optional_int, slugify, write_srt, write_vtt
|
| 33 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
| 34 |
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
| 35 |
from src.whisper.whisperFactory import create_whisper_container
|
|
|
|
| 596 |
help="the Whisper implementation to use")
|
| 597 |
parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
|
| 598 |
help="the compute type to use for inference")
|
| 599 |
+
parser.add_argument("--threads", type=optional_int, default=0,
|
| 600 |
+
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
| 601 |
|
| 602 |
args = parser.parse_args().__dict__
|
| 603 |
|
| 604 |
updated_config = default_app_config.update(**args)
|
| 605 |
|
| 606 |
+
if (threads := args.pop("threads")) > 0:
|
| 607 |
+
torch.set_num_threads(threads)
|
| 608 |
+
|
| 609 |
create_ui(app_config=updated_config)
|
cli.py
CHANGED
|
@@ -113,6 +113,9 @@ def cli():
|
|
| 113 |
device: str = args.pop("device")
|
| 114 |
os.makedirs(output_dir, exist_ok=True)
|
| 115 |
|
|
|
|
|
|
|
|
|
|
| 116 |
whisper_implementation = args.pop("whisper_implementation")
|
| 117 |
print(f"Using {whisper_implementation} for Whisper")
|
| 118 |
|
|
|
|
| 113 |
device: str = args.pop("device")
|
| 114 |
os.makedirs(output_dir, exist_ok=True)
|
| 115 |
|
| 116 |
+
if (threads := args.pop("threads")) > 0:
|
| 117 |
+
torch.set_num_threads(threads)
|
| 118 |
+
|
| 119 |
whisper_implementation = args.pop("whisper_implementation")
|
| 120 |
print(f"Using {whisper_implementation} for Whisper")
|
| 121 |
|