Update orpheus-tts/engine_class.py
Browse files- orpheus-tts/engine_class.py +13 -5
orpheus-tts/engine_class.py
CHANGED
|
@@ -112,23 +112,31 @@ class OrpheusModel:
|
|
| 112 |
|
| 113 |
|
| 114 |
|
| 115 |
-
def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids = [
|
| 116 |
prompt_string = self._format_prompt(prompt, voice)
|
| 117 |
-
print(prompt)
|
|
|
|
|
|
|
| 118 |
sampling_params = SamplingParams(
|
| 119 |
temperature=temperature,
|
| 120 |
top_p=top_p,
|
| 121 |
max_tokens=max_tokens, # Adjust max_tokens as needed.
|
| 122 |
-
stop_token_ids = stop_token_ids,
|
| 123 |
-
repetition_penalty=repetition_penalty,
|
| 124 |
)
|
| 125 |
|
| 126 |
token_queue = queue.Queue()
|
|
|
|
| 127 |
|
| 128 |
async def async_producer():
|
| 129 |
async for result in self.engine.generate(prompt=prompt_string, sampling_params=sampling_params, request_id=request_id):
|
| 130 |
# Place each token text into the queue.
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
token_queue.put(None) # Sentinel to indicate completion.
|
| 133 |
|
| 134 |
def run_async():
|
|
|
|
| 112 |
|
| 113 |
|
| 114 |
|
| 115 |
+
def generate_tokens_sync(self, prompt, voice=None, request_id="req-001", temperature=0.6, top_p=0.8, max_tokens=1200, stop_token_ids = [128258], repetition_penalty=1.3):
|
| 116 |
prompt_string = self._format_prompt(prompt, voice)
|
| 117 |
+
print(f"DEBUG: Original prompt: {prompt}")
|
| 118 |
+
print(f"DEBUG: Formatted prompt: {prompt_string}")
|
| 119 |
+
|
| 120 |
sampling_params = SamplingParams(
|
| 121 |
temperature=temperature,
|
| 122 |
top_p=top_p,
|
| 123 |
max_tokens=max_tokens, # Adjust max_tokens as needed.
|
| 124 |
+
stop_token_ids = stop_token_ids,
|
| 125 |
+
repetition_penalty=repetition_penalty,
|
| 126 |
)
|
| 127 |
|
| 128 |
token_queue = queue.Queue()
|
| 129 |
+
token_count = 0
|
| 130 |
|
| 131 |
async def async_producer():
|
| 132 |
async for result in self.engine.generate(prompt=prompt_string, sampling_params=sampling_params, request_id=request_id):
|
| 133 |
# Place each token text into the queue.
|
| 134 |
+
token_text = result.outputs[0].text
|
| 135 |
+
print(f"DEBUG: Generated token {token_count}: {repr(token_text)}")
|
| 136 |
+
token_queue.put(token_text)
|
| 137 |
+
nonlocal token_count
|
| 138 |
+
token_count += 1
|
| 139 |
+
print(f"DEBUG: Generation completed. Total tokens: {token_count}")
|
| 140 |
token_queue.put(None) # Sentinel to indicate completion.
|
| 141 |
|
| 142 |
def run_async():
|