|
|
import asyncio |
|
|
import json |
|
|
import os |
|
|
import uuid |
|
|
from typing import AsyncIterator, Dict, Any |
|
|
import aiohttp |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class SSEClient: |
|
|
"""Async SSE client for streaming chat API requests""" |
|
|
|
|
|
def __init__(self): |
|
|
self.url = os.getenv("API_ENDPOINT") |
|
|
self.headers = { |
|
|
'Content-Type': 'application/json', |
|
|
'User-Agent': 'HuggingFace-Gradio-Demo' |
|
|
} |
|
|
|
|
|
async def stream_chat(self, query: str, |
|
|
deep_thinking_mode: bool = False, |
|
|
search_before_planning: bool = False, |
|
|
debug: bool = False, |
|
|
chat_id: str = None) -> AsyncIterator[Dict[str, Any]]: |
|
|
""" |
|
|
Async request to SSE interface and return streaming data with event parsing |
|
|
|
|
|
Args: |
|
|
query: User query content |
|
|
deep_thinking_mode: Whether to enable deep thinking mode, default False |
|
|
search_before_planning: Whether to search before planning, default False |
|
|
debug: Whether to enable debug mode, default False |
|
|
chat_id: Chat ID, will be auto-generated if not provided |
|
|
|
|
|
Yields: |
|
|
Dict[str, Any]: SSE event data with 'event' and 'data' fields |
|
|
""" |
|
|
if chat_id is None: |
|
|
chat_id = self._generate_chat_id() |
|
|
|
|
|
|
|
|
data = { |
|
|
"messages": [{ |
|
|
"id": chat_id, |
|
|
"role": "user", |
|
|
"type": "text", |
|
|
"content": query |
|
|
}], |
|
|
"deep_thinking_mode": deep_thinking_mode, |
|
|
"search_before_planning": search_before_planning, |
|
|
"debug": debug, |
|
|
"chatId": chat_id |
|
|
} |
|
|
|
|
|
async with aiohttp.ClientSession( |
|
|
timeout=aiohttp.ClientTimeout(total=None) |
|
|
) as session: |
|
|
try: |
|
|
async with session.post( |
|
|
self.url, |
|
|
headers=self.headers, |
|
|
json=data |
|
|
) as response: |
|
|
if response.status != 200: |
|
|
raise Exception(f"Request failed with status code: {response.status}") |
|
|
|
|
|
|
|
|
current_event = None |
|
|
|
|
|
async for line in response.content: |
|
|
line = line.decode('utf-8').strip() |
|
|
if line: |
|
|
if line.startswith('event: '): |
|
|
|
|
|
current_event = line[7:] |
|
|
elif line.startswith('data: '): |
|
|
|
|
|
data_content = line[6:] |
|
|
if data_content and data_content != '[DONE]': |
|
|
|
|
|
yield { |
|
|
'event': current_event or 'message', |
|
|
'data': data_content |
|
|
} |
|
|
|
|
|
current_event = None |
|
|
elif line == '': |
|
|
|
|
|
current_event = None |
|
|
else: |
|
|
|
|
|
yield { |
|
|
'event': current_event or 'data', |
|
|
'data': line |
|
|
} |
|
|
current_event = None |
|
|
|
|
|
except asyncio.CancelledError: |
|
|
|
|
|
raise |
|
|
except Exception as e: |
|
|
raise Exception(f"SSE request error: {str(e)}") |
|
|
|
|
|
def _generate_chat_id(self) -> str: |
|
|
"""Generate chat ID""" |
|
|
return str(uuid.uuid4()).replace('-', '')[:21] |
|
|
|
|
|
async def stream_chat_parsed(self, query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]: |
|
|
""" |
|
|
Async request to SSE interface and return parsed JSON data with event structure |
|
|
|
|
|
Args: |
|
|
query: User query content |
|
|
**kwargs: Other parameters passed to stream_chat |
|
|
|
|
|
Yields: |
|
|
Dict[str, Any]: Event data with 'event' and 'data' fields, where 'data' contains parsed JSON |
|
|
""" |
|
|
async for event_data in self.stream_chat(query, **kwargs): |
|
|
try: |
|
|
|
|
|
parsed_data = json.loads(event_data['data']) |
|
|
yield { |
|
|
'event': event_data['event'], |
|
|
'data': parsed_data |
|
|
} |
|
|
except json.JSONDecodeError: |
|
|
|
|
|
yield event_data |
|
|
except (KeyError, TypeError): |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
async def request_sse_stream(query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]: |
|
|
""" |
|
|
Convenience function: Async request to SSE interface and return raw event data |
|
|
|
|
|
Args: |
|
|
query: User query content |
|
|
**kwargs: Other parameters |
|
|
|
|
|
Yields: |
|
|
Dict[str, Any]: Raw event data with 'event' and 'data' fields (data as string) |
|
|
""" |
|
|
client = SSEClient() |
|
|
async for event_data in client.stream_chat(query, **kwargs): |
|
|
yield event_data |
|
|
|
|
|
|
|
|
async def request_sse_stream_parsed(query: str, **kwargs) -> AsyncIterator[Dict[str, Any]]: |
|
|
""" |
|
|
Convenience function: Async request to SSE interface and return structured event data |
|
|
|
|
|
Args: |
|
|
query: User query content |
|
|
**kwargs: Other parameters |
|
|
|
|
|
Yields: |
|
|
Dict[str, Any]: Event data with 'event' and 'data' fields |
|
|
""" |
|
|
client = SSEClient() |
|
|
async for event_data in client.stream_chat_parsed(query, **kwargs): |
|
|
yield event_data |
|
|
|
|
|
|
|
|
async def stop_chat(chat_id: str): |
|
|
url = f"{os.getenv('STOP_CHAT_API_ENDPOINT')}" |
|
|
async with aiohttp.ClientSession() as session: |
|
|
async with session.post(url, json={"chatId": chat_id}) as response: |
|
|
if response.status != 200: |
|
|
logger.error(f"Request failed with status code: {response.status}") |
|
|
raise Exception(f"Request failed with status code: {response.status}") |
|
|
return await response.json() |
|
|
|
|
|
|
|
|
async def main(): |
|
|
"""Example usage method""" |
|
|
query = "Hello" |
|
|
|
|
|
print("=== SSE Event Stream ===") |
|
|
async for event_data in request_sse_stream_parsed(query): |
|
|
event_type = event_data.get('event', 'unknown') |
|
|
data_content = event_data.get('data', {}) |
|
|
print(f"Event: {event_type}") |
|
|
print(f"Data: {data_content}") |
|
|
print("-" * 40) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|