import gradio as gr import torch import torchaudio import pandas as pd import os import torch.nn as nn from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, AutoModel, AutoTokenizer # Import các class mô hình từ file models.py from models import MultimodalClassifier, TextClassifier # --- 1. Thiết lập và Tải Mô hình (Tải một lần khi app khởi động) --- print("Đang thiết lập thiết bị...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Sử dụng thiết bị: {device}") # Định nghĩa nhãn LABELS_A = {0: "Tức giận", 1: "Bình thường", 2: "Vui vẻ"} LABELS_B = {0: "Đe dọa", 1: "Tức giận", 2: "Tiêu cực thông thường", 3: "Trung tính", 4: "Tích cực", 5: "Vui vẻ", 6: "Châm Biếm"} # Đường dẫn (Tương đối với thư mục gốc của Space) MODEL_A_PATH = "saved_models/best_model_A.pth" MODEL_B_PATH = "saved_models/best_model_B.pth" FUZZY_RULES_PATH = "data/datafuzzy29d.csv" # Đảm bảo tên file này chính xác # Tải các mô hình nền (từ Hugging Face Hub) print("Đang tải các mô hình nền (STT, PhoBERT)...") audio_processor = Wav2Vec2Processor.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h") stt_model = Wav2Vec2ForCTC.from_pretrained("nguyenvulebinh/wav2vec2-base-vietnamese-250h").to(device) text_tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base") text_feature_extractor = AutoModel.from_pretrained("vinai/phobert-base").to(device) # Tải các mô hình đã huấn luyện (từ file .pth) print("Đang tải các mô hình đã huấn luyện (A & B)...") model_A = MultimodalClassifier(num_classes=len(LABELS_A)).to(device) model_A.load_state_dict(torch.load(MODEL_A_PATH, map_location=device)) model_A.eval() model_B = TextClassifier(n_classes=len(LABELS_B)).to(device) model_B.load_state_dict(torch.load(MODEL_B_PATH, map_location=device)) model_B.eval() # Đặt các mô hình nền sang chế độ eval stt_model.eval() text_feature_extractor.eval() # Tải luật fuzzy print("Đang tải luật fuzzy...") try: fuzzy_rules_df = pd.read_csv(FUZZY_RULES_PATH, sep=';') fuzzy_rules = {} for _, row in fuzzy_rules_df.iterrows(): # Đảm bảo tên cột khớp với file CSV của bạn fuzzy_rules[(row['model_a_label'], row['model_b_label'])] = row['final_label'] print(f"Đã tải {len(fuzzy_rules)} luật fuzzy.") except Exception as e: print(f"Lỗi khi tải luật fuzzy: {e}. Sử dụng luật dự phòng.") fuzzy_rules = {("Bình thường", "Tiêu cực thông thường"): "Nguy cơ thấp (Dự phòng)"} print("Tất cả mô hình đã sẵn sàng.") # --- 2. Định nghĩa Hàm Dự đoán --- # Hàm này sẽ được Gradio gọi mỗi khi người dùng nhấn "Submit" def predict_sentiment(audio_input): if audio_input is None: return "[Chưa có âm thanh]", "N/A", "N/A", "N/A" sample_rate, waveform_numpy = audio_input # Đảm bảo waveform là tensor float waveform = torch.from_numpy(waveform_numpy).float() # Đảm bảo là 1D (mono) hoặc lấy kênh đầu tiên nếu là stereo if waveform.ndim > 1: waveform = waveform[0] # Thêm chiều batch (1,) waveform = waveform.unsqueeze(0) # --- Bước 1 & 2 (Gộp): STT và Đặc trưng Audio --- try: # 1a. Resample if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) # 1b. Chuẩn bị input audio input_values = audio_processor(waveform.squeeze(), return_tensors="pt", sampling_rate=16000).input_values.to(device) with torch.no_grad(): audio_outputs = stt_model(input_values, output_hidden_states=True) # 2a. Trích xuất Văn bản (STT) logits = audio_outputs.logits predicted_ids = torch.argmax(logits, dim=-1) transcribed_text = audio_processor.batch_decode(predicted_ids)[0].lower() if not transcribed_text: transcribed_text = "[Không nhận diện được giọng nói]" # 2b. Trích xuất Đặc trưng Audio (cho Model A) audio_feat_A = torch.mean(audio_outputs.hidden_states[-1], dim=1) except Exception as e: return f"[Lỗi xử lý audio: {e}]", "Lỗi Audio", "Lỗi Audio", "Lỗi Audio" # --- Bước 3: Đặc trưng Text và Dự đoán Model B --- try: inputs_text = text_tokenizer( transcribed_text, return_tensors="pt", padding=True, truncation=True, max_length=256 ).to(device) with torch.no_grad(): # 3a. Đặc trưng Text (cho Model A) text_outputs = text_feature_extractor(**inputs_text) text_feat_A = text_outputs.pooler_output # 3b. Dự đoán Model B output_B = model_B(inputs_text['input_ids'], inputs_text['attention_mask']) pred_idx_B = torch.argmax(output_B, dim=1).item() pred_label_B = LABELS_B.get(pred_idx_B, f"Lỗi Nhãn B ({pred_idx_B})") except Exception as e: return f"[Lỗi xử lý text: {e}]", "Lỗi Text", "Lỗi Text", "Lỗi Text" # --- Bước 4: Dự đoán Model A --- try: with torch.no_grad(): output_A = model_A(text_feat_A, audio_feat_A) pred_idx_A = torch.argmax(output_A, dim=1).item() pred_label_A = LABELS_A.get(pred_idx_A, f"Lỗi Nhãn A ({pred_idx_A})") except Exception as e: return transcribed_text, "Lỗi Model A", pred_label_B, f"[Lỗi Model A: {e}]" # --- Bước 5: Kết hợp Fuzzy Logic --- final_prediction = fuzzy_rules.get((pred_label_A, pred_label_B), "Không có luật") # Trả về các giá trị cho các ô output của Gradio return transcribed_text, pred_label_A, pred_label_B, final_prediction # --- 3. Xây dựng Giao diện Gradio --- print("Đang xây dựng giao diện Gradio...") with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# Ứng dụng Phân tích Cảm xúc Đa phương tiện") gr.Markdown("Tải lên một tệp âm thanh (.wav, .mp3, v.v.) **hoặc ghi âm trực tiếp** để dự đoán cảm xúc.") with gr.Row(): with gr.Column(scale=2): # === BỔ SUNG TÍNH NĂNG === # Thêm "microphone" vào sources để cho phép ghi âm audio_in = gr.Audio( sources=["upload", "microphone"], # Cho phép cả tải lên và ghi âm type="numpy", label="Tải lên tệp âm thanh hoặc Ghi âm" ) submit_btn = gr.Button("Phân tích", variant="primary") with gr.Column(scale=3): gr.Markdown("### Kết quả Phân tích") # Các ô output text_out = gr.Textbox(label="Văn bản được nhận diện (STT)") final_pred_out = gr.Label(label="Kết quả cuối cùng (Nguy cơ)") with gr.Accordion("Xem chi tiết dự đoán của từng mô hình", open=False): pred_A_out = gr.Textbox(label="Dự đoán Model A (Đa phương tiện)") pred_B_out = gr.Textbox(label="Dự đoán Model B (Chỉ văn bản)") # Liên kết nút bấm với hàm dự đoán submit_btn.click( fn=predict_sentiment, inputs=audio_in, outputs=[text_out, pred_A_out, pred_B_out, final_pred_out] ) gr.Markdown("Lưu ý: Mô hình STT được tối ưu cho tiếng Việt.") print("Đang khởi chạy demo...") demo.launch() # Không cần (share=True) khi chạy trên Spaces