datbkpro commited on
Commit
2f406aa
·
verified ·
1 Parent(s): cc8629b

Update services/streaming_voice_service.py

Browse files
Files changed (1) hide show
  1. services/streaming_voice_service.py +110 -97
services/streaming_voice_service.py CHANGED
@@ -8,7 +8,6 @@ from typing import Optional, Dict, Any
8
  from config.settings import settings
9
  from core.rag_system import EnhancedRAGSystem
10
  from core.tts_service import EnhancedTTSService
11
- from core.speechbrain_vad import SpeechBrainVAD # THÊM IMPORT
12
 
13
 
14
  class StreamingVoiceService:
@@ -17,85 +16,12 @@ class StreamingVoiceService:
17
  self.rag_system = rag_system
18
  self.tts_service = tts_service
19
 
20
- # Khởi tạo VAD
21
- self.vad_processor = SpeechBrainVAD()
22
-
23
  # Conversation context
24
  self.conversation_history = []
25
  self.current_transcription = ""
26
- self.is_listening = False
27
-
28
- def start_listening(self) -> bool:
29
- """Bắt đầu lắng nghe với VAD"""
30
- if self.is_listening:
31
- return False
32
-
33
- success = self.vad_processor.start_stream(self._on_speech_detected)
34
- if success:
35
- self.is_listening = True
36
- print("🎙️ Đã bắt đầu lắng nghe với VAD")
37
- return success
38
-
39
- def stop_listening(self):
40
- """Dừng lắng nghe"""
41
- self.vad_processor.stop_stream()
42
- self.is_listening = False
43
- print("🛑 Đã dừng lắng nghe")
44
-
45
- def process_audio_chunk(self, audio_data: tuple) -> Dict[str, Any]:
46
- """Xử lý audio chunk với VAD (dùng cho real-time streaming)"""
47
- if not audio_data or not self.is_listening:
48
- return {
49
- 'transcription': "",
50
- 'response': "",
51
- 'tts_audio': None
52
- }
53
-
54
- try:
55
- sample_rate, audio_array = audio_data
56
-
57
- # Xử lý với VAD
58
- self.vad_processor.process_stream(audio_array, sample_rate)
59
-
60
- return {
61
- 'transcription': "Đang lắng nghe...",
62
- 'response': "",
63
- 'tts_audio': None
64
- }
65
-
66
- except Exception as e:
67
- print(f"❌ Lỗi xử lý audio chunk: {e}")
68
- return {
69
- 'transcription': "",
70
- 'response': "",
71
- 'tts_audio': None
72
- }
73
-
74
- def _on_speech_detected(self, speech_audio: np.ndarray, sample_rate: int):
75
- """Callback khi VAD phát hiện speech"""
76
- print(f"🎯 VAD phát hiện speech segment: {len(speech_audio)/sample_rate:.2f}s")
77
-
78
- # Chuyển đổi speech thành text
79
- transcription = self._transcribe_audio(speech_audio, sample_rate)
80
-
81
- if not transcription or len(transcription.strip()) < 2:
82
- print("⚠️ Transcription quá ngắn hoặc trống")
83
- return
84
-
85
- print(f"📝 VAD Transcription: {transcription}")
86
- self.current_transcription = transcription
87
-
88
- # Tạo phản hồi AI
89
- response = self._generate_ai_response(transcription)
90
-
91
- # Tạo TTS
92
- tts_audio_path = self._text_to_speech(response)
93
-
94
- # Có thể gửi kết quả đến UI thông qua callback
95
- # (cần tích hợp với Gradio events)
96
 
97
  def process_streaming_audio(self, audio_data: tuple) -> Dict[str, Any]:
98
- """Xử lý audio streaming (phương thức cho compatibility)"""
99
  if not audio_data:
100
  return {
101
  'transcription': "❌ Không có dữ liệu âm thanh",
@@ -104,29 +30,58 @@ class StreamingVoiceService:
104
  }
105
 
106
  try:
 
107
  sample_rate, audio_array = audio_data
108
 
109
  print(f"🎯 Nhận audio: {len(audio_array)} samples, SR: {sample_rate}")
110
 
111
- # Sử dụng VAD để kiểm tra speech
112
- if not self.vad_processor.is_speech(audio_array, sample_rate):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  return {
114
- 'transcription': "❌ Không phát hiện giọng nói",
115
- 'response': "Vui lòng nói hơn",
116
  'tts_audio': None
117
  }
118
 
119
  # Chuyển đổi thành văn bản
120
  transcription = self._transcribe_audio(audio_array, sample_rate)
121
 
122
- if not transcription or len(transcription.strip()) < 2:
123
  return {
124
  'transcription': "❌ Không nghe rõ",
125
  'response': "Xin vui lòng nói lại rõ hơn",
126
  'tts_audio': None
127
  }
128
 
 
 
 
 
 
 
 
 
129
  print(f"📝 Đã chuyển đổi: {transcription}")
 
 
130
  self.current_transcription = transcription
131
 
132
  # Tạo phản hồi AI
@@ -143,38 +98,57 @@ class StreamingVoiceService:
143
 
144
  except Exception as e:
145
  print(f"❌ Lỗi xử lý streaming audio: {e}")
 
146
  return {
147
  'transcription': f"❌ Lỗi: {str(e)}",
148
- 'response': "Xin lỗi, có lỗi xảy ra",
149
  'tts_audio': None
150
  }
151
 
152
  def _transcribe_audio(self, audio_data: np.ndarray, sample_rate: int) -> Optional[str]:
153
- """Chuyển audio -> text (giữ nguyên)"""
154
- # ... giữ nguyên code cũ ...
155
  try:
 
 
 
 
 
 
 
 
156
  if audio_data.ndim > 1:
157
- audio_data = np.mean(audio_data, axis=1)
158
 
159
- audio_max = np.max(np.abs(audio_data))
160
- if audio_max > 0.1:
161
- audio_data = audio_data / audio_max * 0.9
 
 
 
162
 
163
- max_duration = 15
 
164
  max_samples = sample_rate * max_duration
165
  if len(audio_data) > max_samples:
166
  audio_data = audio_data[:max_samples]
 
167
 
168
- min_duration = 1.0
169
- min_samples = sample_rate * min_duration
 
170
  if len(audio_data) < min_samples:
171
- padding = np.zeros(min_samples - len(audio_data))
 
172
  audio_data = np.concatenate([audio_data, padding])
 
 
 
173
 
174
  buffer = io.BytesIO()
175
  sf.write(buffer, audio_data, sample_rate, format='wav', subtype='PCM_16')
176
  buffer.seek(0)
177
 
 
178
  transcription = self.client.audio.transcriptions.create(
179
  model=settings.WHISPER_MODEL,
180
  file=("speech.wav", buffer.read(), "audio/wav"),
@@ -183,6 +157,7 @@ class StreamingVoiceService:
183
  temperature=0.0,
184
  )
185
 
 
186
  if hasattr(transcription, 'text'):
187
  result = transcription.text.strip()
188
  elif isinstance(transcription, str):
@@ -190,28 +165,65 @@ class StreamingVoiceService:
190
  else:
191
  result = str(transcription).strip()
192
 
 
193
  return result
194
 
195
  except Exception as e:
196
  print(f"❌ Lỗi transcription: {e}")
 
197
  return None
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  def _generate_ai_response(self, user_input: str) -> str:
200
- """Sinh phản hồi AI (giữ nguyên)"""
201
- # ... giữ nguyên code cũ ...
202
  try:
 
203
  self.conversation_history.append({"role": "user", "content": user_input})
204
 
 
205
  rag_results = self.rag_system.semantic_search(user_input, top_k=2)
206
  context_text = "\n".join([f"- {result.get('text', str(result))}" for result in rag_results]) if rag_results else ""
207
 
208
  system_prompt = f"""Bạn là trợ lý AI thông minh chuyên về tiếng Việt.
209
- Hãy trả lời ngắn gọn, tự nhiên và hữu ích.
210
  Thông tin tham khảo:
211
  {context_text}
212
  """
213
 
214
  messages = [{"role": "system", "content": system_prompt}]
 
215
  messages.extend(self.conversation_history[-4:])
216
 
217
  completion = self.client.chat.completions.create(
@@ -224,16 +236,17 @@ Thông tin tham khảo:
224
  response = completion.choices[0].message.content
225
  self.conversation_history.append({"role": "assistant", "content": response})
226
 
 
227
  if len(self.conversation_history) > 8:
228
  self.conversation_history = self.conversation_history[-8:]
229
 
230
  return response
231
 
232
  except Exception as e:
233
- return f"Xin lỗi, tôi gặp lỗi: {str(e)}"
234
 
235
  def _text_to_speech(self, text: str) -> Optional[str]:
236
- """Chuyển văn bản thành giọng nói (giữ nguyên)"""
237
  try:
238
  if not text or text.startswith("❌") or text.startswith("Xin lỗi"):
239
  return None
@@ -241,6 +254,7 @@ Thông tin tham khảo:
241
  tts_bytes = self.tts_service.text_to_speech(text, 'vi')
242
  if tts_bytes:
243
  audio_path = self.tts_service.save_audio_to_file(tts_bytes)
 
244
  return audio_path
245
  except Exception as e:
246
  print(f"❌ Lỗi TTS: {e}")
@@ -255,7 +269,6 @@ Thông tin tham khảo:
255
  def get_conversation_state(self) -> dict:
256
  """Lấy trạng thái hội thoại"""
257
  return {
258
- 'is_listening': self.is_listening,
259
  'history_length': len(self.conversation_history),
260
  'current_transcription': self.current_transcription,
261
  'last_update': time.strftime("%H:%M:%S")
 
8
  from config.settings import settings
9
  from core.rag_system import EnhancedRAGSystem
10
  from core.tts_service import EnhancedTTSService
 
11
 
12
 
13
  class StreamingVoiceService:
 
16
  self.rag_system = rag_system
17
  self.tts_service = tts_service
18
 
 
 
 
19
  # Conversation context
20
  self.conversation_history = []
21
  self.current_transcription = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def process_streaming_audio(self, audio_data: tuple) -> Dict[str, Any]:
24
+ """Xử lý audio streaming từ Gradio microphone component"""
25
  if not audio_data:
26
  return {
27
  'transcription': "❌ Không có dữ liệu âm thanh",
 
30
  }
31
 
32
  try:
33
+ # Lấy dữ liệu audio từ Gradio
34
  sample_rate, audio_array = audio_data
35
 
36
  print(f"🎯 Nhận audio: {len(audio_array)} samples, SR: {sample_rate}")
37
 
38
+ # Kiểm tra kiểu dữ liệu chuyển đổi nếu cần
39
+ if isinstance(audio_array, np.ndarray):
40
+ if audio_array.dtype == np.float32 or audio_array.dtype == np.float64:
41
+ # Chuyển từ float sang int16
42
+ audio_array = (audio_array * 32767).astype(np.int16)
43
+
44
+ # Kiểm tra audio có dữ liệu không
45
+ if len(audio_array) == 0:
46
+ return {
47
+ 'transcription': "❌ Âm thanh trống",
48
+ 'response': "Vui lòng nói lại",
49
+ 'tts_audio': None
50
+ }
51
+
52
+ # Tính toán âm lượng
53
+ audio_abs = np.abs(audio_array.astype(np.float32))
54
+ audio_rms = np.sqrt(np.mean(audio_abs**2)) / 32767.0
55
+ print(f"📊 Âm lượng RMS: {audio_rms:.4f}")
56
+
57
+ if audio_rms < 0.005:
58
  return {
59
+ 'transcription': "❌ Âm thanh quá yếu",
60
+ 'response': "Xin vui lòng nói to hơn",
61
  'tts_audio': None
62
  }
63
 
64
  # Chuyển đổi thành văn bản
65
  transcription = self._transcribe_audio(audio_array, sample_rate)
66
 
67
+ if not transcription or len(transcription.strip()) == 0:
68
  return {
69
  'transcription': "❌ Không nghe rõ",
70
  'response': "Xin vui lòng nói lại rõ hơn",
71
  'tts_audio': None
72
  }
73
 
74
+ # Kiểm tra nếu transcription quá ngắn
75
+ if len(transcription.strip()) < 2:
76
+ return {
77
+ 'transcription': "❌ Câu nói quá ngắn",
78
+ 'response': "Xin vui lòng nói câu dài hơn",
79
+ 'tts_audio': None
80
+ }
81
+
82
  print(f"📝 Đã chuyển đổi: {transcription}")
83
+
84
+ # Cập nhật transcription hiện tại
85
  self.current_transcription = transcription
86
 
87
  # Tạo phản hồi AI
 
98
 
99
  except Exception as e:
100
  print(f"❌ Lỗi xử lý streaming audio: {e}")
101
+ print(f"Chi tiết lỗi: {traceback.format_exc()}")
102
  return {
103
  'transcription': f"❌ Lỗi: {str(e)}",
104
+ 'response': "Xin lỗi, có lỗi xảy ra trong quá trình xử lý",
105
  'tts_audio': None
106
  }
107
 
108
  def _transcribe_audio(self, audio_data: np.ndarray, sample_rate: int) -> Optional[str]:
109
+ """Chuyển audio -> text với xử lý sample rate"""
 
110
  try:
111
+ # Đảm bảo kiểu dữ liệu là int16
112
+ if audio_data.dtype != np.int16:
113
+ if audio_data.dtype in [np.float32, np.float64]:
114
+ audio_data = (audio_data * 32767).astype(np.int16)
115
+ else:
116
+ audio_data = audio_data.astype(np.int16)
117
+
118
+ # Chuẩn hóa audio data
119
  if audio_data.ndim > 1:
120
+ audio_data = np.mean(audio_data, axis=1).astype(np.int16) # Chuyển sang mono
121
 
122
+ # Resample nếu sample rate không phải 16000Hz (Whisper yêu cầu)
123
+ target_sample_rate = 16000
124
+ if sample_rate != target_sample_rate:
125
+ audio_data = self._resample_audio(audio_data, sample_rate, target_sample_rate)
126
+ sample_rate = target_sample_rate
127
+ print(f"🔄 Đã resample từ {sample_rate}Hz xuống {target_sample_rate}Hz")
128
 
129
+ # Giới hạn độ dài audio
130
+ max_duration = 10 # giây
131
  max_samples = sample_rate * max_duration
132
  if len(audio_data) > max_samples:
133
  audio_data = audio_data[:max_samples]
134
+ print(f"⚠️ Cắt audio xuống còn {max_duration} giây")
135
 
136
+ # Đảm bảo audio đủ dài
137
+ min_duration = 0.5 # giây
138
+ min_samples = int(sample_rate * min_duration)
139
  if len(audio_data) < min_samples:
140
+ # Pad audio nếu quá ngắn
141
+ padding = np.zeros(min_samples - len(audio_data), dtype=np.int16)
142
  audio_data = np.concatenate([audio_data, padding])
143
+ print(f"⚠️ Đã pad audio lên {min_duration} giây")
144
+
145
+ print(f"🔊 Gửi audio đến Whisper: {len(audio_data)} samples, {sample_rate}Hz")
146
 
147
  buffer = io.BytesIO()
148
  sf.write(buffer, audio_data, sample_rate, format='wav', subtype='PCM_16')
149
  buffer.seek(0)
150
 
151
+ # Gọi API Whisper
152
  transcription = self.client.audio.transcriptions.create(
153
  model=settings.WHISPER_MODEL,
154
  file=("speech.wav", buffer.read(), "audio/wav"),
 
157
  temperature=0.0,
158
  )
159
 
160
+ # Xử lý response
161
  if hasattr(transcription, 'text'):
162
  result = transcription.text.strip()
163
  elif isinstance(transcription, str):
 
165
  else:
166
  result = str(transcription).strip()
167
 
168
+ print(f"✅ Transcription thành công: '{result}'")
169
  return result
170
 
171
  except Exception as e:
172
  print(f"❌ Lỗi transcription: {e}")
173
+ print(f"Audio details: dtype={audio_data.dtype}, shape={audio_data.shape}, sr={sample_rate}")
174
  return None
175
 
176
+ def _resample_audio(self, audio_data: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
177
+ """Resample audio sử dụng scipy"""
178
+ try:
179
+ from scipy import signal
180
+
181
+ # Tính số samples mới
182
+ duration = len(audio_data) / orig_sr
183
+ new_length = int(duration * target_sr)
184
+
185
+ # Resample sử dụng scipy.signal.resample
186
+ resampled_audio = signal.resample(audio_data, new_length)
187
+
188
+ # Chuyển lại về int16
189
+ resampled_audio = resampled_audio.astype(np.int16)
190
+
191
+ return resampled_audio
192
+
193
+ except ImportError:
194
+ print("⚠️ Không có scipy, sử dụng simple resampling")
195
+ # Simple resampling bằng interpolation
196
+ orig_length = len(audio_data)
197
+ new_length = int(orig_length * target_sr / orig_sr)
198
+
199
+ # Linear interpolation
200
+ x_old = np.linspace(0, 1, orig_length)
201
+ x_new = np.linspace(0, 1, new_length)
202
+ resampled_audio = np.interp(x_new, x_old, audio_data).astype(np.int16)
203
+
204
+ return resampled_audio
205
+ except Exception as e:
206
+ print(f"❌ Lỗi resample: {e}")
207
+ return audio_data
208
+
209
  def _generate_ai_response(self, user_input: str) -> str:
210
+ """Sinh phản hồi AI"""
 
211
  try:
212
+ # Thêm vào lịch sử
213
  self.conversation_history.append({"role": "user", "content": user_input})
214
 
215
+ # Tìm kiếm RAG
216
  rag_results = self.rag_system.semantic_search(user_input, top_k=2)
217
  context_text = "\n".join([f"- {result.get('text', str(result))}" for result in rag_results]) if rag_results else ""
218
 
219
  system_prompt = f"""Bạn là trợ lý AI thông minh chuyên về tiếng Việt.
220
+ Hãy trả lời ngắn gọn, tự nhiên và hữu ích (dưới 100 từ).
221
  Thông tin tham khảo:
222
  {context_text}
223
  """
224
 
225
  messages = [{"role": "system", "content": system_prompt}]
226
+ # Giữ lại 4 tin nhắn gần nhất
227
  messages.extend(self.conversation_history[-4:])
228
 
229
  completion = self.client.chat.completions.create(
 
236
  response = completion.choices[0].message.content
237
  self.conversation_history.append({"role": "assistant", "content": response})
238
 
239
+ # Giới hạn lịch sử
240
  if len(self.conversation_history) > 8:
241
  self.conversation_history = self.conversation_history[-8:]
242
 
243
  return response
244
 
245
  except Exception as e:
246
+ return f"Xin lỗi, tôi gặp lỗi khi tạo phản hồi: {str(e)}"
247
 
248
  def _text_to_speech(self, text: str) -> Optional[str]:
249
+ """Chuyển văn bản thành giọng nói"""
250
  try:
251
  if not text or text.startswith("❌") or text.startswith("Xin lỗi"):
252
  return None
 
254
  tts_bytes = self.tts_service.text_to_speech(text, 'vi')
255
  if tts_bytes:
256
  audio_path = self.tts_service.save_audio_to_file(tts_bytes)
257
+ print(f"✅ Đã tạo TTS: {audio_path}")
258
  return audio_path
259
  except Exception as e:
260
  print(f"❌ Lỗi TTS: {e}")
 
269
  def get_conversation_state(self) -> dict:
270
  """Lấy trạng thái hội thoại"""
271
  return {
 
272
  'history_length': len(self.conversation_history),
273
  'current_transcription': self.current_transcription,
274
  'last_update': time.strftime("%H:%M:%S")