Update orpheus-tts/engine_class.py
Browse files
orpheus-tts/engine_class.py
CHANGED
|
@@ -5,7 +5,7 @@ from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
|
|
| 5 |
from transformers import AutoTokenizer
|
| 6 |
import threading
|
| 7 |
import queue
|
| 8 |
-
from
|
| 9 |
|
| 10 |
class OrpheusModel:
|
| 11 |
def __init__(self, model_name, dtype=torch.bfloat16, tokenizer=None, **engine_kwargs):
|
|
@@ -63,7 +63,7 @@ class OrpheusModel:
|
|
| 63 |
if (model_name in unsupported_models):
|
| 64 |
raise ValueError(f"Model {model_name} is not supported. Only medium-3b is supported, small, micro and nano models will be released very soon")
|
| 65 |
elif model_name in model_map:
|
| 66 |
-
return
|
| 67 |
else:
|
| 68 |
return model_name
|
| 69 |
|
|
|
|
| 5 |
from transformers import AutoTokenizer
|
| 6 |
import threading
|
| 7 |
import queue
|
| 8 |
+
from decoder import tokens_decoder_sync
|
| 9 |
|
| 10 |
class OrpheusModel:
|
| 11 |
def __init__(self, model_name, dtype=torch.bfloat16, tokenizer=None, **engine_kwargs):
|
|
|
|
| 63 |
if (model_name in unsupported_models):
|
| 64 |
raise ValueError(f"Model {model_name} is not supported. Only medium-3b is supported, small, micro and nano models will be released very soon")
|
| 65 |
elif model_name in model_map:
|
| 66 |
+
return model_map[model_name]["repo_id"]
|
| 67 |
else:
|
| 68 |
return model_name
|
| 69 |
|