minhvtt commited on
Commit
a2f6bc2
·
verified ·
1 Parent(s): 6597779

Update agent_chat_stream.py

Browse files
Files changed (1) hide show
  1. agent_chat_stream.py +100 -99
agent_chat_stream.py CHANGED
@@ -1,99 +1,100 @@
1
- """
2
- Agent Chat Streaming Endpoint
3
- SSE-based real-time streaming for Sales & Feedback agents
4
- """
5
- from typing import AsyncGenerator
6
- from stream_utils import format_sse, EVENT_STATUS, EVENT_TOKEN, EVENT_DONE, EVENT_ERROR, EVENT_METADATA
7
- from datetime import datetime
8
-
9
-
10
- async def agent_chat_stream(
11
- request,
12
- agent_service,
13
- conversation_service
14
- ) -> AsyncGenerator[str, None]:
15
- """
16
- Stream agent responses in real-time (SSE format)
17
-
18
- Args:
19
- request: ChatRequest with message, session_id, mode, user_id
20
- agent_service: AgentService instance
21
- conversation_service: ConversationService instance
22
-
23
- Yields SSE events:
24
- - status: Processing updates
25
- - token: Text chunks
26
- - metadata: Session info
27
- - done: Completion signal
28
- - error: Error messages
29
- """
30
- try:
31
- # === SESSION MANAGEMENT ===
32
- session_id = request.session_id
33
- if not session_id:
34
- session_id = conversation_service.create_session(
35
- metadata={"user_agent": "api", "created_via": "agent_stream"},
36
- user_id=request.user_id
37
- )
38
- yield format_sse(EVENT_METADATA, {"session_id": session_id})
39
-
40
- # Get conversation history
41
- history = conversation_service.get_history(session_id)
42
-
43
- # Convert to messages format
44
- messages = []
45
- for h in history:
46
- messages.append({"role": h["role"], "content": h["content"]})
47
-
48
- # Determine mode
49
- mode = getattr(request, 'mode', 'sales') # Default to sales
50
-
51
- # === STATUS UPDATE ===
52
- if mode == 'feedback':
53
- yield format_sse(EVENT_STATUS, "Đang kiểm tra lịch sử sự kiện của bạn...")
54
- else:
55
- yield format_sse(EVENT_STATUS, "Đang tư vấn...")
56
-
57
- # === CALL AGENT ===
58
- result = await agent_service.chat(
59
- user_message=request.message,
60
- conversation_history=messages,
61
- mode=mode,
62
- user_id=request.user_id
63
- )
64
-
65
- agent_response = result["message"]
66
-
67
- # === STREAM RESPONSE TOKEN BY TOKEN ===
68
- # Simple character-by-character streaming
69
- chunk_size = 5 # Characters per chunk
70
- for i in range(0, len(agent_response), chunk_size):
71
- chunk = agent_response[i:i+chunk_size]
72
- yield format_sse(EVENT_TOKEN, chunk)
73
- # Small delay for smoother streaming
74
- import asyncio
75
- await asyncio.sleep(0.02)
76
-
77
- # === SAVE HISTORY ===
78
- conversation_service.add_message(
79
- session_id=session_id,
80
- role="user",
81
- content=request.message
82
- )
83
- conversation_service.add_message(
84
- session_id=session_id,
85
- role="assistant",
86
- content=agent_response
87
- )
88
-
89
- # === DONE ===
90
- yield format_sse(EVENT_DONE, {
91
- "session_id": session_id,
92
- "timestamp": datetime.utcnow().isoformat(),
93
- "mode": mode,
94
- "tool_calls": len(result.get("tool_calls", []))
95
- })
96
-
97
- except Exception as e:
98
- print(f"⚠️ Agent Stream Error: {e}")
99
- yield format_sse(EVENT_ERROR, str(e))
 
 
1
+ """
2
+ Agent Chat Streaming Endpoint
3
+ SSE-based real-time streaming for Sales & Feedback agents
4
+ """
5
+ from typing import AsyncGenerator
6
+ from stream_utils import format_sse, EVENT_STATUS, EVENT_TOKEN, EVENT_DONE, EVENT_ERROR, EVENT_METADATA
7
+ from datetime import datetime
8
+
9
+
10
+ async def agent_chat_stream(
11
+ request,
12
+ agent_service,
13
+ conversation_service
14
+ ) -> AsyncGenerator[str, None]:
15
+ """
16
+ Stream agent responses in real-time (SSE format)
17
+
18
+ Args:
19
+ request: ChatRequest with message, session_id, mode, user_id
20
+ agent_service: AgentService instance
21
+ conversation_service: ConversationService instance
22
+
23
+ Yields SSE events:
24
+ - status: Processing updates
25
+ - token: Text chunks
26
+ - metadata: Session info
27
+ - done: Completion signal
28
+ - error: Error messages
29
+ """
30
+ try:
31
+ # === SESSION MANAGEMENT ===
32
+ session_id = request.session_id
33
+ if not session_id:
34
+ session_id = conversation_service.create_session(
35
+ metadata={"user_agent": "api", "created_via": "agent_stream"},
36
+ user_id=request.user_id
37
+ )
38
+ yield format_sse(EVENT_METADATA, {"session_id": session_id})
39
+
40
+ # Get conversation history
41
+ history = conversation_service.get_conversation_history(session_id)
42
+
43
+ # Convert to messages format
44
+ messages = []
45
+ for h in history:
46
+ messages.append({"role": h["role"], "content": h["content"]})
47
+
48
+
49
+ # Determine mode
50
+ mode = getattr(request, 'mode', 'sales') # Default to sales
51
+
52
+ # === STATUS UPDATE ===
53
+ if mode == 'feedback':
54
+ yield format_sse(EVENT_STATUS, "Đang kiểm tra lịch sử sự kiện của bạn...")
55
+ else:
56
+ yield format_sse(EVENT_STATUS, "Đang tư vấn...")
57
+
58
+ # === CALL AGENT ===
59
+ result = await agent_service.chat(
60
+ user_message=request.message,
61
+ conversation_history=messages,
62
+ mode=mode,
63
+ user_id=request.user_id
64
+ )
65
+
66
+ agent_response = result["message"]
67
+
68
+ # === STREAM RESPONSE TOKEN BY TOKEN ===
69
+ # Simple character-by-character streaming
70
+ chunk_size = 5 # Characters per chunk
71
+ for i in range(0, len(agent_response), chunk_size):
72
+ chunk = agent_response[i:i+chunk_size]
73
+ yield format_sse(EVENT_TOKEN, chunk)
74
+ # Small delay for smoother streaming
75
+ import asyncio
76
+ await asyncio.sleep(0.02)
77
+
78
+ # === SAVE HISTORY ===
79
+ conversation_service.add_message(
80
+ session_id=session_id,
81
+ role="user",
82
+ content=request.message
83
+ )
84
+ conversation_service.add_message(
85
+ session_id=session_id,
86
+ role="assistant",
87
+ content=agent_response
88
+ )
89
+
90
+ # === DONE ===
91
+ yield format_sse(EVENT_DONE, {
92
+ "session_id": session_id,
93
+ "timestamp": datetime.utcnow().isoformat(),
94
+ "mode": mode,
95
+ "tool_calls": len(result.get("tool_calls", []))
96
+ })
97
+
98
+ except Exception as e:
99
+ print(f"⚠️ Agent Stream Error: {e}")
100
+ yield format_sse(EVENT_ERROR, str(e))