Spaces:
Running
on
Zero
Running
on
Zero
Update analyzer modules and tokenizer
Browse files
app.py
CHANGED
|
@@ -13,7 +13,7 @@ MODEL = None
|
|
| 13 |
|
| 14 |
LANGUAGE_CONFIG = {
|
| 15 |
"ar": {
|
| 16 |
-
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/
|
| 17 |
"text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."
|
| 18 |
},
|
| 19 |
"da": {
|
|
@@ -57,7 +57,7 @@ LANGUAGE_CONFIG = {
|
|
| 57 |
"text": "Il mese scorso abbiamo raggiunto un nuovo traguardo: due miliardi di visualizzazioni sul nostro canale YouTube."
|
| 58 |
},
|
| 59 |
"ja": {
|
| 60 |
-
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/
|
| 61 |
"text": "先月、私たちのYouTubeチャンネルで二十億回の再生回数という新たなマイルストーンに到達しました。"
|
| 62 |
},
|
| 63 |
"ko": {
|
|
@@ -101,8 +101,8 @@ LANGUAGE_CONFIG = {
|
|
| 101 |
"text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık."
|
| 102 |
},
|
| 103 |
"zh": {
|
| 104 |
-
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/
|
| 105 |
-
"text": "
|
| 106 |
},
|
| 107 |
}
|
| 108 |
|
|
|
|
| 13 |
|
| 14 |
LANGUAGE_CONFIG = {
|
| 15 |
"ar": {
|
| 16 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
|
| 17 |
"text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."
|
| 18 |
},
|
| 19 |
"da": {
|
|
|
|
| 57 |
"text": "Il mese scorso abbiamo raggiunto un nuovo traguardo: due miliardi di visualizzazioni sul nostro canale YouTube."
|
| 58 |
},
|
| 59 |
"ja": {
|
| 60 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ja/ja_prompts1.flac",
|
| 61 |
"text": "先月、私たちのYouTubeチャンネルで二十億回の再生回数という新たなマイルストーンに到達しました。"
|
| 62 |
},
|
| 63 |
"ko": {
|
|
|
|
| 101 |
"text": "Geçen ay YouTube kanalımızda iki milyar görüntüleme ile yeni bir dönüm noktasına ulaştık."
|
| 102 |
},
|
| 103 |
"zh": {
|
| 104 |
+
"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/zh_f2.flac",
|
| 105 |
+
"text": "上个月,我们达到了一个新的里程碑. 我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。"
|
| 106 |
},
|
| 107 |
}
|
| 108 |
|
src/chatterbox/models/t3/inference/alignment_stream_analyzer.py
CHANGED
|
@@ -155,12 +155,12 @@ class AlignmentStreamAnalyzer:
|
|
| 155 |
token_repetition = (
|
| 156 |
# self.complete and
|
| 157 |
len(self.generated_tokens) >= 3 and
|
| 158 |
-
len(set(self.generated_tokens[-
|
| 159 |
)
|
| 160 |
|
| 161 |
if token_repetition:
|
| 162 |
repeated_token = self.generated_tokens[-1]
|
| 163 |
-
logger.warning(f"🚨 Detected
|
| 164 |
|
| 165 |
# Suppress EoS to prevent early termination
|
| 166 |
if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
|
|
|
|
| 155 |
token_repetition = (
|
| 156 |
# self.complete and
|
| 157 |
len(self.generated_tokens) >= 3 and
|
| 158 |
+
len(set(self.generated_tokens[-2:])) == 1
|
| 159 |
)
|
| 160 |
|
| 161 |
if token_repetition:
|
| 162 |
repeated_token = self.generated_tokens[-1]
|
| 163 |
+
logger.warning(f"🚨 Detected 2x repetition of token {repeated_token}")
|
| 164 |
|
| 165 |
# Suppress EoS to prevent early termination
|
| 166 |
if cur_text_posn < S - 3 and S > 5: # Only suppress if text is longer than 5 tokens
|
src/chatterbox/models/t3/modules/t3_config.py
CHANGED
|
@@ -25,6 +25,10 @@ class T3Config:
|
|
| 25 |
@property
|
| 26 |
def n_channels(self):
|
| 27 |
return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
@classmethod
|
| 30 |
def english_only(cls):
|
|
|
|
| 25 |
@property
|
| 26 |
def n_channels(self):
|
| 27 |
return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"]
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def is_multilingual(self):
|
| 31 |
+
return self.text_tokens_dict_size == 2352
|
| 32 |
|
| 33 |
@classmethod
|
| 34 |
def english_only(cls):
|
src/chatterbox/models/t3/t3.py
CHANGED
|
@@ -257,14 +257,17 @@ class T3(nn.Module):
|
|
| 257 |
# TODO? synchronize the expensive compile function
|
| 258 |
# with self.compile_lock:
|
| 259 |
if not self.compiled:
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
patched_model = T3HuggingfaceBackend(
|
| 270 |
config=self.cfg,
|
|
|
|
| 257 |
# TODO? synchronize the expensive compile function
|
| 258 |
# with self.compile_lock:
|
| 259 |
if not self.compiled:
|
| 260 |
+
# Default to None for English models, only create for multilingual
|
| 261 |
+
alignment_stream_analyzer = None
|
| 262 |
+
if self.hp.is_multilingual:
|
| 263 |
+
alignment_stream_analyzer = AlignmentStreamAnalyzer(
|
| 264 |
+
self.tfmr,
|
| 265 |
+
None,
|
| 266 |
+
text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)),
|
| 267 |
+
alignment_layer_idx=9, # TODO: hparam or something?
|
| 268 |
+
eos_idx=self.hp.stop_speech_token,
|
| 269 |
+
)
|
| 270 |
+
assert alignment_stream_analyzer.eos_idx == self.hp.stop_speech_token
|
| 271 |
|
| 272 |
patched_model = T3HuggingfaceBackend(
|
| 273 |
config=self.cfg,
|
src/chatterbox/models/tokenizers/tokenizer.py
CHANGED
|
@@ -151,9 +151,7 @@ def korean_normalize(text: str) -> str:
|
|
| 151 |
return initial + medial + final
|
| 152 |
|
| 153 |
# Decompose syllables and normalize punctuation
|
| 154 |
-
result = ''.join(decompose_hangul(char) for char in text)
|
| 155 |
-
result = re.sub(r'[…~?!,:;()「」『』]', '.', result) # Korean punctuation
|
| 156 |
-
|
| 157 |
return result.strip()
|
| 158 |
|
| 159 |
|
|
@@ -201,81 +199,39 @@ class ChineseCangjieConverter:
|
|
| 201 |
|
| 202 |
def _cangjie_encode(self, glyph: str):
|
| 203 |
"""Encode a single Chinese glyph to Cangjie code."""
|
| 204 |
-
|
| 205 |
-
|
|
|
|
| 206 |
return None
|
| 207 |
-
|
| 208 |
-
index =
|
| 209 |
-
|
| 210 |
-
return code + index_suffix
|
| 211 |
|
| 212 |
-
|
| 213 |
-
"""Convert Arabic numerals (1-99) to Chinese characters."""
|
| 214 |
-
digit_map = {'0': '零', '1': '一', '2': '二', '3': '三', '4': '四',
|
| 215 |
-
'5': '五', '6': '六', '7': '七', '8': '八', '9': '九'}
|
| 216 |
-
|
| 217 |
-
pattern = re.compile(r'(?<!\d)(\d{1,2})(?!\d)')
|
| 218 |
-
|
| 219 |
-
def convert_number(match):
|
| 220 |
-
num = int(match.group(1))
|
| 221 |
-
|
| 222 |
-
if num == 0:
|
| 223 |
-
return '零'
|
| 224 |
-
elif 1 <= num <= 9:
|
| 225 |
-
return digit_map[str(num)]
|
| 226 |
-
elif num == 10:
|
| 227 |
-
return '十'
|
| 228 |
-
elif 11 <= num <= 19:
|
| 229 |
-
return '十' + digit_map[str(num % 10)]
|
| 230 |
-
elif 20 <= num <= 99:
|
| 231 |
-
tens, ones = divmod(num, 10)
|
| 232 |
-
if ones == 0:
|
| 233 |
-
return digit_map[str(tens)] + '十'
|
| 234 |
-
else:
|
| 235 |
-
return digit_map[str(tens)] + '十' + digit_map[str(ones)]
|
| 236 |
-
else:
|
| 237 |
-
return match.group(1)
|
| 238 |
-
|
| 239 |
-
return pattern.sub(convert_number, text)
|
| 240 |
|
| 241 |
-
def
|
| 242 |
"""Convert Chinese characters in text to Cangjie tokens."""
|
| 243 |
-
|
| 244 |
-
text = re.sub('(。|…)', '.', text)
|
| 245 |
-
text = self._normalize_numbers(text)
|
| 246 |
-
|
| 247 |
-
# Skip segmentation for simple sequences (numbers, punctuation, short phrases)
|
| 248 |
if self.segmenter is not None:
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
# Only segment complex Chinese text (longer sentences without enumeration patterns)
|
| 256 |
-
if not is_simple_sequence and len(text) > 10:
|
| 257 |
-
chinese_chars = sum(1 for c in text if category(c) == "Lo")
|
| 258 |
-
total_chars = len([c for c in text if c.strip()])
|
| 259 |
-
|
| 260 |
-
if chinese_chars > 5 and chinese_chars / total_chars > 0.7:
|
| 261 |
-
segmented_words = self.segmenter.cut(text)
|
| 262 |
-
text = " ".join(segmented_words)
|
| 263 |
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
cangjie = self._cangjie_encode(char)
|
| 268 |
if cangjie is None:
|
| 269 |
-
output.append(
|
| 270 |
continue
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
|
|
|
| 276 |
else:
|
| 277 |
-
output.append(
|
| 278 |
-
|
| 279 |
return "".join(output)
|
| 280 |
|
| 281 |
|
|
@@ -299,7 +255,7 @@ class MTLTokenizer:
|
|
| 299 |
def encode(self, txt: str, language_id: str = None):
|
| 300 |
# Language-specific text processing
|
| 301 |
if language_id == 'zh':
|
| 302 |
-
txt = self.cangjie_converter
|
| 303 |
elif language_id == 'ja':
|
| 304 |
txt = hiragana_normalize(txt)
|
| 305 |
elif language_id == 'he':
|
|
|
|
| 151 |
return initial + medial + final
|
| 152 |
|
| 153 |
# Decompose syllables and normalize punctuation
|
| 154 |
+
result = ''.join(decompose_hangul(char) for char in text)
|
|
|
|
|
|
|
| 155 |
return result.strip()
|
| 156 |
|
| 157 |
|
|
|
|
| 199 |
|
| 200 |
def _cangjie_encode(self, glyph: str):
|
| 201 |
"""Encode a single Chinese glyph to Cangjie code."""
|
| 202 |
+
normed_glyph = glyph
|
| 203 |
+
code = self.word2cj.get(normed_glyph, None)
|
| 204 |
+
if code is None: # e.g. Japanese hiragana
|
| 205 |
return None
|
| 206 |
+
index = self.cj2word[code].index(normed_glyph)
|
| 207 |
+
index = str(index) if index > 0 else ""
|
| 208 |
+
return code + str(index)
|
|
|
|
| 209 |
|
| 210 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
+
def __call__(self, text):
|
| 213 |
"""Convert Chinese characters in text to Cangjie tokens."""
|
| 214 |
+
output = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
if self.segmenter is not None:
|
| 216 |
+
segmented_words = self.segmenter.cut(text)
|
| 217 |
+
full_text = " ".join(segmented_words)
|
| 218 |
+
else:
|
| 219 |
+
full_text = text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
+
for t in full_text:
|
| 222 |
+
if category(t) == "Lo":
|
| 223 |
+
cangjie = self._cangjie_encode(t)
|
|
|
|
| 224 |
if cangjie is None:
|
| 225 |
+
output.append(t)
|
| 226 |
continue
|
| 227 |
+
code = []
|
| 228 |
+
for c in cangjie:
|
| 229 |
+
code.append(f"[cj_{c}]")
|
| 230 |
+
code.append("[cj_.]")
|
| 231 |
+
code = "".join(code)
|
| 232 |
+
output.append(code)
|
| 233 |
else:
|
| 234 |
+
output.append(t)
|
|
|
|
| 235 |
return "".join(output)
|
| 236 |
|
| 237 |
|
|
|
|
| 255 |
def encode(self, txt: str, language_id: str = None):
|
| 256 |
# Language-specific text processing
|
| 257 |
if language_id == 'zh':
|
| 258 |
+
txt = self.cangjie_converter(txt)
|
| 259 |
elif language_id == 'ja':
|
| 260 |
txt = hiragana_normalize(txt)
|
| 261 |
elif language_id == 'he':
|
src/chatterbox/mtl_tts.py
CHANGED
|
@@ -83,7 +83,7 @@ def punc_norm(text: str) -> str:
|
|
| 83 |
|
| 84 |
# Add full stop if no ending punc
|
| 85 |
text = text.rstrip(" ")
|
| 86 |
-
sentence_enders = {".", "!", "?", "-", ","}
|
| 87 |
if not any(text.endswith(p) for p in sentence_enders):
|
| 88 |
text += "."
|
| 89 |
|
|
|
|
| 83 |
|
| 84 |
# Add full stop if no ending punc
|
| 85 |
text = text.rstrip(" ")
|
| 86 |
+
sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"}
|
| 87 |
if not any(text.endswith(p) for p in sentence_enders):
|
| 88 |
text += "."
|
| 89 |
|