FYP_ASR_Service / analyzer /ASR_jp_jp.py
HK0712's picture
CHANGE: keep load in ram
a6526f0
# =======================================================================
# 1. 匯入區 (Imports)
# - 新增了 pyopenjtalk 和 MeCab
# =======================================================================
import torch
import soundfile as sf
import librosa
from transformers import Wav2Vec2Processor, HubertForCTC
import os
import pyopenjtalk
import MeCab
import numpy as np
from datetime import datetime, timezone
import re
# =======================================================================
# 2. 全域變數與配置區 (Global Variables & Config)
# 【已修改】移除了全域的 processor 和 model 變數。
# =======================================================================
# 自動檢測可用設備
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"INFO: ASR_jp_jp.py is configured to use device: {DEVICE}")
# 設定為日語 ASR 模型
MODEL_NAME = "prj-beatrice/japanese-hubert-base-phoneme-ctc-v3"
# 初始化 MeCab 分詞器
# -Owakati 選項能直接輸出以空格分隔的單詞,非常方便
try:
mecab_tagger = MeCab.Tagger("-Owakati")
except RuntimeError:
print("ERROR: MeCab Tagger 初始化失敗。請確保 mecab 和 mecab-ipadic-utf8 已正確安裝。")
mecab_tagger = None
# =======================================================================
# 3. 核心業務邏輯區 (Core Business Logic)
# =======================================================================
# -----------------------------------------------------------------------
# 3.1. 模型載入函數
# 【已刪除】舊的 load_model() 函數已被移除。
# -----------------------------------------------------------------------
# -----------------------------------------------------------------------
# 3.2. 日語 G2P 輔助函數 (此檔案最核心的修改)
# 【保持不變】
# -----------------------------------------------------------------------
def _get_target_phonemes_by_word(text: str) -> tuple[list[str], list[list[str]]]:
if not mecab_tagger:
raise RuntimeError("MeCab Tagger 未初始化,無法處理日語文本。")
words = mecab_tagger.parse(text).strip().split()
target_words_original = []
target_ipa_by_word = []
for word in words:
if not word:
continue
phonemes_str = pyopenjtalk.g2p(word, kana=False)
cleaned_phonemes = re.sub(r'\s+', ' ', phonemes_str).strip()
phoneme_list = cleaned_phonemes.split()
if word and phoneme_list:
target_words_original.append(word)
target_ipa_by_word.append(phoneme_list)
return target_words_original, target_ipa_by_word
# -----------------------------------------------------------------------
# 3.3. 音素切分函數 (用於處理 ASR 的輸出)
# 【保持不變】
# -----------------------------------------------------------------------
def _tokenize_asr_output(phoneme_string: str) -> list:
"""
將 ASR 模型輸出的音素字串切分為列表。
此模型的輸出是單字元音素,以空格分隔。
"""
return phoneme_string.split()
# -----------------------------------------------------------------------
# 3.4. 核心分析函數 (主入口)
# 【已修改】將模型載入和快取邏輯合併至此。
# -----------------------------------------------------------------------
def analyze(audio_file_path: str, target_sentence: str, cache: dict = {}) -> dict:
"""
接收音訊檔案路徑和目標日語句子,回傳詳細的發音分析字典。
模型會被載入並儲存在此函數獨立的 'cache' 中,實現狀態隔離。
"""
# 檢查快取中是否已有模型,如果沒有則載入
if "model" not in cache:
print(f"快取未命中 (ASR_jp_jp)。正在載入模型 '{MODEL_NAME}'...")
try:
# 載入模型並存入此函數的快取字典
cache["processor"] = Wav2Vec2Processor.from_pretrained(MODEL_NAME)
cache["model"] = HubertForCTC.from_pretrained(MODEL_NAME) # <-- 使用 HubertForCTC
cache["model"].to(DEVICE)
print(f"模型 '{MODEL_NAME}' 已載入並快取。")
except Exception as e:
print(f"處理或載入模型 '{MODEL_NAME}' 時發生錯誤: {e}")
raise RuntimeError(f"Failed to load model '{MODEL_NAME}': {e}")
# 從此函數的獨立快取中獲取模型和處理器
processor = cache["processor"]
model = cache["model"]
# --- 以下為原始分析邏輯,保持不變 ---
# 【關鍵步驟 1: G2P】
target_words_original, target_ipa_by_word = _get_target_phonemes_by_word(target_sentence)
if not target_words_original:
print("警告: G2P 處理後目標句子為空。")
return _format_to_json_structure([], target_sentence, [])
# 【關鍵步驟 2: ASR】
try:
speech, sample_rate = sf.read(audio_file_path)
if len(speech) == 0:
print("警告: 音訊檔案為空。")
user_ipa_full = ""
else:
if sample_rate != 16000:
speech = librosa.resample(y=speech, orig_sr=sample_rate, target_sr=16000)
input_values = processor(speech, sampling_rate=16000, return_tensors="pt").input_values
input_values = input_values.to(DEVICE)
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
user_ipa_full = processor.decode(predicted_ids[0])
except Exception as e:
raise IOError(f"讀取或處理音訊時發生錯誤: {e}")
# 【關鍵步驟 3: 對齊】
word_alignments = _get_phoneme_alignments_by_word(user_ipa_full, target_ipa_by_word)
# 【關鍵步驟 4: 格式化】
return _format_to_json_structure(word_alignments, target_sentence, target_words_original)
# =======================================================================
# 4. 對齊與格式化函數區 (Alignment & Formatting)
# 【保持不變】
# =======================================================================
# -----------------------------------------------------------------------
# 4.1. 對齊函數 (語言無關)
# -----------------------------------------------------------------------
def _get_phoneme_alignments_by_word(user_phoneme_str, target_words_ipa_tokenized):
"""
使用動態規劃執行音素對齊。此函數是語言無關的。
"""
user_phonemes = [char for word in user_phoneme_str.split() for char in word]
target_phonemes_flat = []
word_boundaries_indices = []
current_idx = 0
for word_ipa_tokens in target_words_ipa_tokenized:
flat_tokens = [char for word in word_ipa_tokens for char in word]
target_phonemes_flat.extend(flat_tokens)
current_idx += len(flat_tokens)
word_boundaries_indices.append(current_idx - 1)
if not target_phonemes_flat:
return []
dp = np.zeros((len(user_phonemes) + 1, len(target_phonemes_flat) + 1))
for i in range(1, len(user_phonemes) + 1): dp[i][0] = i
for j in range(1, len(target_phonemes_flat) + 1): dp[0][j] = j
for i in range(1, len(user_phonemes) + 1):
for j in range(1, len(target_phonemes_flat) + 1):
cost = 0 if user_phonemes[i-1] == target_phonemes_flat[j-1] else 1
dp[i][j] = min(dp[i-1][j] + 1, dp[i][j-1] + 1, dp[i-1][j-1] + cost)
i, j = len(user_phonemes), len(target_phonemes_flat)
user_path, target_path = [], []
while i > 0 or j > 0:
cost = float('inf')
if i > 0 and j > 0:
cost = 0 if user_phonemes[i-1] == target_phonemes_flat[j-1] else 1
if i > 0 and j > 0 and dp[i][j] == dp[i-1][j-1] + cost:
user_path.insert(0, user_phonemes[i-1]); target_path.insert(0, target_phonemes_flat[j-1]); i -= 1; j -= 1
elif i > 0 and (j == 0 or dp[i][j] == dp[i-1][j] + 1):
user_path.insert(0, user_phonemes[i-1]); target_path.insert(0, '-'); i -= 1
elif j > 0 and (i == 0 or dp[i][j] == dp[i][j-1] + 1):
user_path.insert(0, '-'); target_path.insert(0, target_phonemes_flat[j-1]); j -= 1
else:
break
alignments_by_word = []
word_start_idx_in_path = 0
target_phoneme_counter_in_path = 0
word_boundary_iter = iter(word_boundaries_indices)
current_word_boundary = next(word_boundary_iter, -1)
for path_idx, p in enumerate(target_path):
if p != '-':
if target_phoneme_counter_in_path == current_word_boundary:
target_alignment = target_path[word_start_idx_in_path : path_idx + 1]
user_alignment = user_path[word_start_idx_in_path : path_idx + 1]
alignments_by_word.append({
"target": target_alignment,
"user": user_alignment
})
word_start_idx_in_path = path_idx + 1
current_word_boundary = next(word_boundary_iter, -1)
target_phoneme_counter_in_path += 1
return alignments_by_word
# -----------------------------------------------------------------------
# 4.2. 格式化函數 (語言無關)
# -----------------------------------------------------------------------
def _format_to_json_structure(alignments, sentence, original_words) -> dict:
"""
將對齊結果格式化為最終的 JSON 結構。此函數是語言無關的。
"""
total_phonemes = 0
total_errors = 0
correct_words_count = 0
words_data = []
num_words_to_process = min(len(alignments), len(original_words))
for i in range(num_words_to_process):
alignment = alignments[i]
word_is_correct = True
phonemes_data = []
min_len = min(len(alignment['target']), len(alignment['user']))
for j in range(min_len):
target_phoneme = alignment['target'][j]
user_phoneme = alignment['user'][j]
is_match = (user_phoneme == target_phoneme)
phonemes_data.append({
"target": target_phoneme,
"user": user_phoneme,
"isMatch": is_match
})
if not is_match:
word_is_correct = False
if not (user_phoneme == '-' and target_phoneme == '-'):
total_errors += 1
if word_is_correct:
correct_words_count += 1
words_data.append({
"word": original_words[i],
"isCorrect": word_is_correct,
"phonemes": phonemes_data
})
total_phonemes += sum(1 for p in alignment['target'] if p != '-')
if len(alignments) < len(original_words):
for i in range(len(alignments), len(original_words)):
_, missed_word_ipa_list = _get_target_phonemes_by_word(original_words[i])
phonemes_data = []
if missed_word_ipa_list:
for p_ipa in missed_word_ipa_list[0]:
phonemes_data.append({"target": p_ipa, "user": "-", "isMatch": False})
total_errors += 1
total_phonemes += 1
words_data.append({
"word": original_words[i],
"isCorrect": False,
"phonemes": phonemes_data
})
total_words = len(original_words)
overall_score = (correct_words_count / total_words) * 100 if total_words > 0 else 0
phoneme_error_rate = (total_errors / total_phonemes) * 100 if total_phonemes > 0 else 0
final_result = {
"sentence": sentence,
"analysisTimestampUTC": datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S (UTC)'),
"summary": {
"overallScore": round(overall_score, 1),
"totalWords": total_words,
"correctWords": correct_words_count,
"phonemeErrorRate": round(phoneme_error_rate, 2),
"total_errors": total_errors,
"total_target_phonemes": total_phonemes
},
"words": words_data
}
return final_result