wav2vec2 / app.py
ThanhNguyen1811's picture
Upload app.py
b7f8cd0 verified
import gradio as gr
import torch
import torchaudio
import pandas as pd
import os
import torch.nn as nn
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoModel, AutoTokenizer
# Import các class mô hình từ file models.py
from models import MultimodalClassifier, TextClassifier
# --- 1. Thiết lập và Tải Mô hình (Tải một lần khi app khởi động) ---
print("Đang thiết lập thiết bị...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Sử dụng thiết bị: {device}")
# Định nghĩa nhãn
LABELS_A = {0: "Tức giận", 1: "Bình thường", 2: "Vui vẻ"}
LABELS_B = {0: "Đe dọa", 1: "Tức giận", 2: "Tiêu cực thông thường", 3: "Trung tính", 4: "Tích cực", 5: "Vui vẻ", 6: "Châm Biếm"}
# Đường dẫn (Tương đối với thư mục gốc của Space)
MODEL_A_PATH = "saved_models/best_model_A.pth"
MODEL_B_PATH = "saved_models/best_model_B.pth"
FUZZY_RULES_PATH = "data/datafuzzy29d.csv" # Đảm bảo tên file này chính xác
# Tải các mô hình nền (từ Hugging Face Hub)
print("Đang tải các mô hình nền (STT, PhoBERT)...")
audio_processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h")
stt_model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h").to(device)
text_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
text_feature_extractor = AutoModel.from_pretrained("vinai/phobert-base").to(device)
# Tải các mô hình đã huấn luyện (từ file .pth)
print("Đang tải các mô hình đã huấn luyện (A & B)...")
model_A = MultimodalClassifier(num_classes=len(LABELS_A)).to(device)
model_A.load_state_dict(torch.load(MODEL_A_PATH, map_location=device))
model_A.eval()
model_B = TextClassifier(n_classes=len(LABELS_B)).to(device)
model_B.load_state_dict(torch.load(MODEL_B_PATH, map_location=device))
model_B.eval()
# Đặt các mô hình nền sang chế độ eval
stt_model.eval()
text_feature_extractor.eval()
# Tải luật fuzzy
print("Đang tải luật fuzzy...")
try:
fuzzy_rules_df = pd.read_csv(FUZZY_RULES_PATH, sep=';')
fuzzy_rules = {}
for _, row in fuzzy_rules_df.iterrows():
# Đảm bảo tên cột khớp với file CSV của bạn
fuzzy_rules[(row['model_a_label'], row['model_b_label'])] = row['final_label']
print(f"Đã tải {len(fuzzy_rules)} luật fuzzy.")
except Exception as e:
print(f"Lỗi khi tải luật fuzzy: {e}. Sử dụng luật dự phòng.")
fuzzy_rules = {("Bình thường", "Tiêu cực thông thường"): "Nguy cơ thấp (Dự phòng)"}
print("Tất cả mô hình đã sẵn sàng.")
# --- 2. Định nghĩa Hàm Dự đoán ---
# Hàm này sẽ được Gradio gọi mỗi khi người dùng nhấn "Submit"
def predict_sentiment(audio_input):
if audio_input is None:
return "[Chưa có âm thanh]", "N/A", "N/A", "N/A"
sample_rate, waveform_numpy = audio_input
# Đảm bảo waveform là tensor float
waveform = torch.from_numpy(waveform_numpy).float()
# Đảm bảo là 1D (mono) hoặc lấy kênh đầu tiên nếu là stereo
if waveform.ndim > 1:
waveform = waveform[0]
# Thêm chiều batch (1,)
waveform = waveform.unsqueeze(0)
# --- Bước 1 & 2 (Gộp): STT và Đặc trưng Audio ---
try:
# 1a. Resample
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
# 1b. Chuẩn bị input audio
input_values = audio_processor(waveform.squeeze(), return_tensors="pt", sampling_rate=16000).input_values.to(device)
with torch.no_grad():
audio_outputs = stt_model(input_values, output_hidden_states=True)
# 2a. Trích xuất Văn bản (STT)
logits = audio_outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)
transcribed_text = audio_processor.batch_decode(predicted_ids)[0].lower()
if not transcribed_text:
transcribed_text = "[Không nhận diện được giọng nói]"
# 2b. Trích xuất Đặc trưng Audio (cho Model A)
audio_feat_A = torch.mean(audio_outputs.hidden_states[-1], dim=1)
except Exception as e:
return f"[Lỗi xử lý audio: {e}]", "Lỗi Audio", "Lỗi Audio", "Lỗi Audio"
# --- Bước 3: Đặc trưng Text và Dự đoán Model B ---
try:
inputs_text = text_tokenizer(
transcribed_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256
).to(device)
with torch.no_grad():
# 3a. Đặc trưng Text (cho Model A)
text_outputs = text_feature_extractor(**inputs_text)
text_feat_A = text_outputs.pooler_output
# 3b. Dự đoán Model B
output_B = model_B(inputs_text['input_ids'], inputs_text['attention_mask'])
pred_idx_B = torch.argmax(output_B, dim=1).item()
pred_label_B = LABELS_B.get(pred_idx_B, f"Lỗi Nhãn B ({pred_idx_B})")
except Exception as e:
return f"[Lỗi xử lý text: {e}]", "Lỗi Text", "Lỗi Text", "Lỗi Text"
# --- Bước 4: Dự đoán Model A ---
try:
with torch.no_grad():
output_A = model_A(text_feat_A, audio_feat_A)
pred_idx_A = torch.argmax(output_A, dim=1).item()
pred_label_A = LABELS_A.get(pred_idx_A, f"Lỗi Nhãn A ({pred_idx_A})")
except Exception as e:
return transcribed_text, "Lỗi Model A", pred_label_B, f"[Lỗi Model A: {e}]"
# --- Bước 5: Kết hợp Fuzzy Logic ---
final_prediction = fuzzy_rules.get((pred_label_A, pred_label_B), "Không có luật")
# Trả về các giá trị cho các ô output của Gradio
return transcribed_text, pred_label_A, pred_label_B, final_prediction
# --- 3. Xây dựng Giao diện Gradio ---
print("Đang xây dựng giao diện Gradio...")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Ứng dụng Phân tích Cảm xúc Đa phương tiện")
gr.Markdown("Tải lên một tệp âm thanh (.wav, .mp3, v.v.) **hoặc ghi âm trực tiếp** để dự đoán cảm xúc.")
with gr.Row():
with gr.Column(scale=2):
# === BỔ SUNG TÍNH NĂNG ===
# Thêm "microphone" vào sources để cho phép ghi âm
audio_in = gr.Audio(
sources=["upload", "microphone"], # Cho phép cả tải lên và ghi âm
type="numpy",
label="Tải lên tệp âm thanh hoặc Ghi âm"
)
submit_btn = gr.Button("Phân tích", variant="primary")
with gr.Column(scale=3):
gr.Markdown("### Kết quả Phân tích")
# Các ô output
text_out = gr.Textbox(label="Văn bản được nhận diện (STT)")
final_pred_out = gr.Label(label="Kết quả cuối cùng (Nguy cơ)")
with gr.Accordion("Xem chi tiết dự đoán của từng mô hình", open=False):
pred_A_out = gr.Textbox(label="Dự đoán Model A (Đa phương tiện)")
pred_B_out = gr.Textbox(label="Dự đoán Model B (Chỉ văn bản)")
# Liên kết nút bấm với hàm dự đoán
submit_btn.click(
fn=predict_sentiment,
inputs=audio_in,
outputs=[text_out, pred_A_out, pred_B_out, final_pred_out]
)
gr.Markdown("Lưu ý: Mô hình STT được tối ưu cho tiếng Việt.")
print("Đang khởi chạy demo...")
demo.launch() # Không cần (share=True) khi chạy trên Spaces