File size: 4,945 Bytes
c760a78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from abc import ABC
from io import BytesIO
from typing import Literal

import numpy as np
from pydantic import BaseModel, ConfigDict


class AbortController(ABC):
    def is_alive(self) -> bool:
        raise NotImplementedError


class NeverAbortedController(AbortController):
    def is_alive(self) -> bool:
        return True


def is_none_or_alive(abort_controller: AbortController | None) -> bool:
    return abort_controller is None or abort_controller.is_alive()


class ModelNameResponse(BaseModel):
    model_name: str


class TokenizedMessage(BaseModel):
    role: Literal["user", "assistant"]
    content: list[list[int]]
    """[audio_channels+1, time_steps]"""

    def time_steps(self) -> int:
        return len(self.content[0])

    def append(self, chunk: list[list[int]]):
        assert len(chunk) == len(self.content), "Incompatible chunk length"
        assert all(len(c) == len(chunk[0]) for c in chunk), "Incompatible chunk shape"
        for content_channel, chunk_channel in zip(self.content, chunk):
            content_channel.extend(chunk_channel)


class TokenizedConversation(BaseModel):
    messages: list[TokenizedMessage]

    def time_steps(self) -> int:
        return sum(msg.time_steps() for msg in self.messages)

    def latest_messages(self, max_time_steps: int) -> "list[TokenizedMessage]":
        sum_time_steps = 0
        selected_messages: list[TokenizedMessage] = []

        for msg in reversed(self.messages):
            cur_time_steps = msg.time_steps()
            if sum_time_steps + cur_time_steps > max_time_steps:
                break
            sum_time_steps += cur_time_steps
            selected_messages.append(msg)

        return list(reversed(selected_messages))


class ChatAudioBytes(BaseModel):
    model_config = ConfigDict(ser_json_bytes="base64", val_json_bytes="base64")
    sample_rate: int
    audio_data: bytes
    """
    shape = (channels, samples) or (samples,);
    dtype = int16 or float32
    """

    @classmethod
    def from_audio(cls, audio: tuple[int, np.ndarray]) -> "ChatAudioBytes":
        buf = BytesIO()
        np.save(buf, audio[1])
        return ChatAudioBytes(sample_rate=audio[0], audio_data=buf.getvalue())

    def to_audio(self) -> tuple[int, np.ndarray]:
        buf = BytesIO(self.audio_data)
        audio_np = np.load(buf)
        return self.sample_rate, audio_np


class ChatResponseItem(BaseModel):
    tokenized_input: TokenizedMessage | None = None
    token_chunk: list[list[int]] | None = None
    """[audio_channels+1, time_steps]"""
    text_chunk: str | None = None
    audio_chunk: ChatAudioBytes | None = None
    end_of_stream: bool | None = None
    """Represent Special token <|eostm|>"""
    end_of_transcription: bool | None = None
    """Represent Special token <|eot|> (not <|endoftext|>)"""
    stop_reason: str | None = None
    """The reason why the generation is stopped, e.g., max_new_tokens, max_length, stop_token, aborted"""


class AssistantStyle(BaseModel):
    preset_character: str | None = None
    custom_character_prompt: str | None = None

    preset_voice: str | None = None
    custom_voice: ChatAudioBytes | None = None


class SamplerConfig(BaseModel):
    """
    Sampling configuration for text/audio generation.

    - If some fields are not set, their effects are disabled.
    - If the entire config is not set (e.g., `global_sampler_config=None`), all fields are automatically determined.
    - Use `temperature=0.0`/`top_k=1`/`top_p=0.0` instead of `do_sample=False` to disable sampling.
    """

    temperature: float | None = None
    top_k: int | None = None
    top_p: float | None = None

    def normalized(self) -> tuple[float, int, float]:
        """
        Returns:
            A tuple (temperature, top_k, top_p) with normalized values.
        """
        if (
            (self.temperature is not None and self.temperature <= 0.0)
            or (self.top_k is not None and self.top_k <= 1)
            or (self.top_p is not None and self.top_p <= 0.0)
        ):
            return (1.0, 1, 1.0)

        def default_clip[T: int | float](
            value: T | None, default_value: T, min_value: T, max_value: T
        ) -> T:
            if value is None:
                return default_value
            return max(min(value, max_value), min_value)

        temperature = default_clip(self.temperature, 1.0, 0.01, 2.0)
        top_k = default_clip(self.top_k, 1_000_000, 1, 1_000_000)
        top_p = default_clip(self.top_p, 1.0, 0.01, 1.0)

        return (temperature, top_k, top_p)


class ChatRequestBody(BaseModel):
    conversation: TokenizedConversation | None = None
    input_text: str | None = None
    input_audio: ChatAudioBytes | None = None
    assistant_style: AssistantStyle | None = None
    global_sampler_config: SamplerConfig | None = None
    local_sampler_config: SamplerConfig | None = None


class PresetOptions(BaseModel):
    options: list[str]