mereith's picture
feat: stop chat
a1207af
raw
history blame
7.34 kB
import asyncio
import json
import os
import uuid
from typing import AsyncIterator, Dict, Any
import aiohttp
import logging
logger = logging.getLogger(__name__)
class SSEClient:
"""Async SSE client for streaming chat API requests"""
def __init__(self):
self.url = os.getenv("API_ENDPOINT")
self.headers = {
'Content-Type': 'application/json',
'User-Agent': 'HuggingFace-Gradio-Demo'
}
async def stream_chat(self, query: str,
deep_thinking_mode: bool = False,
search_before_planning: bool = False,
debug: bool = False,
chat_id: str = None) -> AsyncIterator[Dict[str, Any]]:
"""
Async request to SSE interface and return streaming data with event parsing
Args:
query: User query content
deep_thinking_mode: Whether to enable deep thinking mode, default False
search_before_planning: Whether to search before planning, default False
debug: Whether to enable debug mode, default False
chat_id: Chat ID, will be auto-generated if not provided
Yields:
Dict[str, Any]: SSE event data with 'event' and 'data' fields
"""
if chat_id is None:
chat_id = self._generate_chat_id()
# Build request data
data = {
"messages": [{
"id": chat_id,
"role": "user",
"type": "text",
"content": query
}],
"deep_thinking_mode": deep_thinking_mode,
"search_before_planning": search_before_planning,
"debug": debug,
"chatId": chat_id
}
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=None) # No timeout limit
) as session:
try:
async with session.post(
self.url,
headers=self.headers,
json=data
) as response:
if response.status != 200:
raise Exception(f"Request failed with status code: {response.status}")
# Read SSE stream and parse events
current_event = None
async for line in response.content:
line = line.decode('utf-8').strip()
if line:
if line.startswith('event: '):
# Parse event type
current_event = line[7:] # Remove "event: " prefix
elif line.startswith('data: '):
# Parse data content
data_content = line[6:] # Remove "data: " prefix
if data_content and data_content != '[DONE]':
# Yield structured event data
yield {
'event': current_event or 'message',
'data': data_content
}
# Reset event for next message
current_event = None
elif line == '':
# Empty line indicates end of event, reset current_event
current_event = None
else:
# Handle other formats or raw data
yield {
'event': current_event or 'data',
'data': line
}
current_event = None
except asyncio.CancelledError:
# Handle cancellation
raise
except Exception as e:
raise Exception(f"SSE request error: {str(e)}")
def _generate_chat_id(self) -> str:
"""Generate chat ID"""
return str(uuid.uuid4()).replace('-', '')[:21]
async def stream_chat_parsed(self, query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]:
"""
Async request to SSE interface and return parsed JSON data with event structure
Args:
query: User query content
**kwargs: Other parameters passed to stream_chat
Yields:
Dict[str, Any]: Event data with 'event' and 'data' fields, where 'data' contains parsed JSON
"""
async for event_data in self.stream_chat(query, **kwargs):
try:
# Try to parse the data field as JSON
parsed_data = json.loads(event_data['data'])
yield {
'event': event_data['event'],
'data': parsed_data
}
except json.JSONDecodeError:
# If data is not valid JSON, keep original data
yield event_data
except (KeyError, TypeError):
# If event_data doesn't have expected structure, skip
continue
# Convenience functions
async def request_sse_stream(query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]:
"""
Convenience function: Async request to SSE interface and return raw event data
Args:
query: User query content
**kwargs: Other parameters
Yields:
Dict[str, Any]: Raw event data with 'event' and 'data' fields (data as string)
"""
client = SSEClient()
async for event_data in client.stream_chat(query, **kwargs):
yield event_data
async def request_sse_stream_parsed(query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]:
"""
Convenience function: Async request to SSE interface and return structured event data
Args:
query: User query content
**kwargs: Other parameters
Yields:
Dict[str, Any]: Event data with 'event' and 'data' fields
"""
client = SSEClient()
async for event_data in client.stream_chat_parsed(query, **kwargs):
yield event_data
async def stop_chat(chat_id: str):
url = f"{os.getenv('STOP_CHAT_API_ENDPOINT')}"
async with aiohttp.ClientSession() as session:
async with session.post(url, json={"chatId": chat_id}) as response:
if response.status != 200:
logger.error(f"Request failed with status code: {response.status}")
raise Exception(f"Request failed with status code: {response.status}")
return await response.json()
# Example usage
async def main():
"""Example usage method"""
query = "Hello"
print("=== SSE Event Stream ===")
async for event_data in request_sse_stream_parsed(query):
event_type = event_data.get('event', 'unknown')
data_content = event_data.get('data', {})
print(f"Event: {event_type}")
print(f"Data: {data_content}")
print("-" * 40)
if __name__ == "__main__":
asyncio.run(main())