Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,945 Bytes
c760a78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from abc import ABC
from io import BytesIO
from typing import Literal
import numpy as np
from pydantic import BaseModel, ConfigDict
class AbortController(ABC):
def is_alive(self) -> bool:
raise NotImplementedError
class NeverAbortedController(AbortController):
def is_alive(self) -> bool:
return True
def is_none_or_alive(abort_controller: AbortController | None) -> bool:
return abort_controller is None or abort_controller.is_alive()
class ModelNameResponse(BaseModel):
model_name: str
class TokenizedMessage(BaseModel):
role: Literal["user", "assistant"]
content: list[list[int]]
"""[audio_channels+1, time_steps]"""
def time_steps(self) -> int:
return len(self.content[0])
def append(self, chunk: list[list[int]]):
assert len(chunk) == len(self.content), "Incompatible chunk length"
assert all(len(c) == len(chunk[0]) for c in chunk), "Incompatible chunk shape"
for content_channel, chunk_channel in zip(self.content, chunk):
content_channel.extend(chunk_channel)
class TokenizedConversation(BaseModel):
messages: list[TokenizedMessage]
def time_steps(self) -> int:
return sum(msg.time_steps() for msg in self.messages)
def latest_messages(self, max_time_steps: int) -> "list[TokenizedMessage]":
sum_time_steps = 0
selected_messages: list[TokenizedMessage] = []
for msg in reversed(self.messages):
cur_time_steps = msg.time_steps()
if sum_time_steps + cur_time_steps > max_time_steps:
break
sum_time_steps += cur_time_steps
selected_messages.append(msg)
return list(reversed(selected_messages))
class ChatAudioBytes(BaseModel):
model_config = ConfigDict(ser_json_bytes="base64", val_json_bytes="base64")
sample_rate: int
audio_data: bytes
"""
shape = (channels, samples) or (samples,);
dtype = int16 or float32
"""
@classmethod
def from_audio(cls, audio: tuple[int, np.ndarray]) -> "ChatAudioBytes":
buf = BytesIO()
np.save(buf, audio[1])
return ChatAudioBytes(sample_rate=audio[0], audio_data=buf.getvalue())
def to_audio(self) -> tuple[int, np.ndarray]:
buf = BytesIO(self.audio_data)
audio_np = np.load(buf)
return self.sample_rate, audio_np
class ChatResponseItem(BaseModel):
tokenized_input: TokenizedMessage | None = None
token_chunk: list[list[int]] | None = None
"""[audio_channels+1, time_steps]"""
text_chunk: str | None = None
audio_chunk: ChatAudioBytes | None = None
end_of_stream: bool | None = None
"""Represent Special token <|eostm|>"""
end_of_transcription: bool | None = None
"""Represent Special token <|eot|> (not <|endoftext|>)"""
stop_reason: str | None = None
"""The reason why the generation is stopped, e.g., max_new_tokens, max_length, stop_token, aborted"""
class AssistantStyle(BaseModel):
preset_character: str | None = None
custom_character_prompt: str | None = None
preset_voice: str | None = None
custom_voice: ChatAudioBytes | None = None
class SamplerConfig(BaseModel):
"""
Sampling configuration for text/audio generation.
- If some fields are not set, their effects are disabled.
- If the entire config is not set (e.g., `global_sampler_config=None`), all fields are automatically determined.
- Use `temperature=0.0`/`top_k=1`/`top_p=0.0` instead of `do_sample=False` to disable sampling.
"""
temperature: float | None = None
top_k: int | None = None
top_p: float | None = None
def normalized(self) -> tuple[float, int, float]:
"""
Returns:
A tuple (temperature, top_k, top_p) with normalized values.
"""
if (
(self.temperature is not None and self.temperature <= 0.0)
or (self.top_k is not None and self.top_k <= 1)
or (self.top_p is not None and self.top_p <= 0.0)
):
return (1.0, 1, 1.0)
def default_clip[T: int | float](
value: T | None, default_value: T, min_value: T, max_value: T
) -> T:
if value is None:
return default_value
return max(min(value, max_value), min_value)
temperature = default_clip(self.temperature, 1.0, 0.01, 2.0)
top_k = default_clip(self.top_k, 1_000_000, 1, 1_000_000)
top_p = default_clip(self.top_p, 1.0, 0.01, 1.0)
return (temperature, top_k, top_p)
class ChatRequestBody(BaseModel):
conversation: TokenizedConversation | None = None
input_text: str | None = None
input_audio: ChatAudioBytes | None = None
assistant_style: AssistantStyle | None = None
global_sampler_config: SamplerConfig | None = None
local_sampler_config: SamplerConfig | None = None
class PresetOptions(BaseModel):
options: list[str]
|