Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| import uuid | |
| from datetime import UTC, datetime | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from typing_extensions import TypedDict | |
| from .models import MessageType, ParticipantRole | |
| logger = logging.getLogger(__name__) | |
| class ConnectionMetadata(TypedDict): | |
| """Connection metadata with proper typing""" | |
| workspace_id: str | |
| room_id: str | |
| participant_id: str | |
| role: ParticipantRole | |
| connected_at: datetime | |
| last_activity: datetime | |
| message_count: int | |
| # ============= SIMPLIFIED ROOM SYSTEM ============= | |
| class RoboticsRoom: | |
| """Simple robotics room with producer/consumer pattern""" | |
| def __init__(self, room_id: str, workspace_id: str): | |
| self.id = room_id | |
| self.workspace_id = workspace_id | |
| # Participants | |
| self.producer: str | None = None | |
| self.consumers: list[str] = [] | |
| # State | |
| self.joints: dict[str, float] = {} | |
| class RoboticsCore: | |
| """Core robotics system - simplified and merged with workspace support""" | |
| def __init__(self): | |
| # Nested structure: workspace_id -> room_id -> RoboticsRoom | |
| self.workspaces: dict[str, dict[str, RoboticsRoom]] = {} | |
| self.connections: dict[str, WebSocket] = {} # participant_id -> websocket | |
| self.connection_metadata: dict[ | |
| str, ConnectionMetadata | |
| ] = {} # participant_id -> metadata | |
| # ============= ROOM MANAGEMENT ============= | |
| def create_room( | |
| self, workspace_id: str | None = None, room_id: str | None = None | |
| ) -> tuple[str, str]: | |
| """Create a new room and return (workspace_id, room_id)""" | |
| workspace_id = workspace_id or str(uuid.uuid4()) | |
| room_id = room_id or str(uuid.uuid4()) | |
| # Initialize workspace if it doesn't exist | |
| if workspace_id not in self.workspaces: | |
| self.workspaces[workspace_id] = {} | |
| room = RoboticsRoom(room_id, workspace_id) | |
| self.workspaces[workspace_id][room_id] = room | |
| logger.info(f"Created room {room_id} in workspace {workspace_id}") | |
| return workspace_id, room_id | |
| def list_rooms(self, workspace_id: str) -> list[dict]: | |
| """List all rooms in a specific workspace""" | |
| if workspace_id not in self.workspaces: | |
| return [] | |
| return [ | |
| { | |
| "id": room.id, | |
| "workspace_id": room.workspace_id, | |
| "participants": { | |
| "producer": room.producer, | |
| "consumers": room.consumers, | |
| "total": len(room.consumers) + (1 if room.producer else 0), | |
| }, | |
| "joints_count": len(room.joints), | |
| } | |
| for room in self.workspaces[workspace_id].values() | |
| ] | |
| def delete_room(self, workspace_id: str, room_id: str) -> bool: | |
| """Delete a room from a workspace""" | |
| if ( | |
| workspace_id not in self.workspaces | |
| or room_id not in self.workspaces[workspace_id] | |
| ): | |
| return False | |
| room = self.workspaces[workspace_id][room_id] | |
| # Disconnect all participants | |
| for consumer_id in room.consumers[:]: | |
| self.leave_room(workspace_id, room_id, consumer_id) | |
| if room.producer: | |
| self.leave_room(workspace_id, room_id, room.producer) | |
| del self.workspaces[workspace_id][room_id] | |
| logger.info(f"Deleted room {room_id} from workspace {workspace_id}") | |
| return True | |
| def get_room_state(self, workspace_id: str, room_id: str) -> dict: | |
| """Get detailed room state""" | |
| room = self._get_room(workspace_id, room_id) | |
| if not room: | |
| return {"error": "Room not found"} | |
| return { | |
| "room_id": room_id, | |
| "workspace_id": workspace_id, | |
| "joints": room.joints, | |
| "participants": { | |
| "producer": room.producer, | |
| "consumers": room.consumers, | |
| "total": len(room.consumers) + (1 if room.producer else 0), | |
| }, | |
| "timestamp": datetime.now(tz=UTC).isoformat(), | |
| } | |
| def get_room_info(self, workspace_id: str, room_id: str) -> dict: | |
| """Get basic room info""" | |
| room = self._get_room(workspace_id, room_id) | |
| if not room: | |
| return {"error": "Room not found"} | |
| return { | |
| "id": room.id, | |
| "workspace_id": room.workspace_id, | |
| "participants": { | |
| "producer": room.producer, | |
| "consumers": room.consumers, | |
| "total": len(room.consumers) + (1 if room.producer else 0), | |
| }, | |
| "joints_count": len(room.joints), | |
| "has_producer": room.producer is not None, | |
| "active_consumers": len(room.consumers), | |
| } | |
| def _get_room(self, workspace_id: str, room_id: str) -> RoboticsRoom | None: | |
| """Get room by workspace and room ID""" | |
| if workspace_id not in self.workspaces: | |
| return None | |
| return self.workspaces[workspace_id].get(room_id) | |
| # ============= PARTICIPANT MANAGEMENT ============= | |
| def join_room( | |
| self, | |
| workspace_id: str, | |
| room_id: str, | |
| participant_id: str, | |
| role: ParticipantRole, | |
| ) -> bool: | |
| """Join room as producer or consumer""" | |
| room = self._get_room(workspace_id, room_id) | |
| if not room: | |
| return False | |
| if role == ParticipantRole.PRODUCER: | |
| if room.producer is None: | |
| room.producer = participant_id | |
| logger.info( | |
| f"Producer {participant_id} joined room {room_id} in workspace {workspace_id}" | |
| ) | |
| return True | |
| # Room already has a producer, reject the new one | |
| logger.warning( | |
| f"Producer {participant_id} failed to join room {room_id} - room already has producer {room.producer}" | |
| ) | |
| return False | |
| if role == ParticipantRole.CONSUMER: | |
| if participant_id not in room.consumers: | |
| room.consumers.append(participant_id) | |
| logger.info( | |
| f"Consumer {participant_id} joined room {room_id} in workspace {workspace_id}" | |
| ) | |
| return True | |
| return False | |
| return False | |
| def leave_room(self, workspace_id: str, room_id: str, participant_id: str): | |
| """Leave room""" | |
| room = self._get_room(workspace_id, room_id) | |
| if not room: | |
| return | |
| if room.producer == participant_id: | |
| room.producer = None | |
| logger.info( | |
| f"Producer {participant_id} left room {room_id} in workspace {workspace_id}" | |
| ) | |
| elif participant_id in room.consumers: | |
| room.consumers.remove(participant_id) | |
| logger.info( | |
| f"Consumer {participant_id} left room {room_id} in workspace {workspace_id}" | |
| ) | |
| # ============= JOINT CONTROL ============= | |
| def update_joints( | |
| self, workspace_id: str, room_id: str, joint_updates: list[dict] | |
| ) -> list[dict]: | |
| room = self._get_room(workspace_id, room_id) | |
| if not room: | |
| msg = f"Room {room_id} not found in workspace {workspace_id}" | |
| raise ValueError(msg) | |
| changed_joints = [] | |
| for joint in joint_updates: | |
| name = joint["name"] | |
| value = joint["value"] | |
| # Only track actual changes | |
| if room.joints.get(name) != value: | |
| room.joints[name] = value | |
| changed_joints.append(joint) | |
| return changed_joints | |
| # ============= WEBSOCKET HANDLING ============= | |
| async def handle_websocket( | |
| self, websocket: WebSocket, workspace_id: str, room_id: str | |
| ): | |
| """Handle WebSocket connection""" | |
| await websocket.accept() | |
| participant_id: str | None = None | |
| role: ParticipantRole | None = None | |
| try: | |
| # Get join message | |
| data = await websocket.receive_text() | |
| join_msg = json.loads(data) | |
| participant_id = join_msg["participant_id"] | |
| role = ParticipantRole(join_msg["role"]) | |
| # Join room | |
| if not self.join_room(workspace_id, room_id, participant_id, role): | |
| await websocket.send_text( | |
| json.dumps({ | |
| "type": MessageType.ERROR.value, | |
| "message": "Cannot join room", | |
| }) | |
| ) | |
| await websocket.close() | |
| return | |
| self.connections[participant_id] = websocket | |
| # Store connection metadata | |
| self.connection_metadata[participant_id] = ConnectionMetadata( | |
| workspace_id=workspace_id, | |
| room_id=room_id, | |
| participant_id=participant_id, | |
| role=role, | |
| connected_at=datetime.now(tz=UTC), | |
| last_activity=datetime.now(tz=UTC), | |
| message_count=0, | |
| ) | |
| # Send current state to consumer | |
| if role == ParticipantRole.CONSUMER: | |
| room = self._get_room(workspace_id, room_id) | |
| if room: | |
| await websocket.send_text( | |
| json.dumps({ | |
| "type": MessageType.STATE_SYNC.value, | |
| "data": room.joints, | |
| "timestamp": datetime.now(tz=UTC).isoformat(), | |
| }) | |
| ) | |
| # Send join confirmation | |
| await websocket.send_text( | |
| json.dumps({ | |
| "type": MessageType.JOINED.value, | |
| "room_id": room_id, | |
| "workspace_id": workspace_id, | |
| "role": role.value, | |
| "timestamp": datetime.now(tz=UTC).isoformat(), | |
| }) | |
| ) | |
| # Handle messages | |
| async for message in websocket.iter_text(): | |
| try: | |
| msg = json.loads(message) | |
| await self._handle_message( | |
| workspace_id, room_id, participant_id, role, msg | |
| ) | |
| except json.JSONDecodeError: | |
| logger.exception(f"Invalid JSON from {participant_id}") | |
| except Exception: | |
| logger.exception("Message error") | |
| except WebSocketDisconnect: | |
| logger.info(f"WebSocket disconnected: {participant_id}") | |
| except Exception: | |
| logger.exception("WebSocket error") | |
| finally: | |
| # Cleanup | |
| if participant_id: | |
| metadata = self.connection_metadata.get(participant_id) | |
| if metadata: | |
| self.leave_room( | |
| metadata["workspace_id"], metadata["room_id"], participant_id | |
| ) | |
| if participant_id in self.connections: | |
| del self.connections[participant_id] | |
| if participant_id in self.connection_metadata: | |
| del self.connection_metadata[participant_id] | |
| async def _handle_message( | |
| self, | |
| workspace_id: str, | |
| room_id: str, | |
| participant_id: str, | |
| role: ParticipantRole, | |
| message: dict, | |
| ): | |
| """Handle incoming WebSocket message with structured handlers""" | |
| # Update activity tracking | |
| if participant_id in self.connection_metadata: | |
| self.connection_metadata[participant_id]["last_activity"] = datetime.now( | |
| tz=UTC | |
| ) | |
| self.connection_metadata[participant_id]["message_count"] += 1 | |
| try: | |
| msg_type = MessageType(message.get("type")) | |
| except ValueError: | |
| logger.warning( | |
| f"Unknown message type from {participant_id}: {message.get('type')}" | |
| ) | |
| await self._handle_error( | |
| participant_id, f"Unknown message type: {message.get('type')}" | |
| ) | |
| return | |
| # Dispatch to specific handlers | |
| if msg_type == MessageType.JOINT_UPDATE: | |
| await self._handle_joint_update( | |
| workspace_id, room_id, participant_id, role, message | |
| ) | |
| elif msg_type == MessageType.HEARTBEAT: | |
| await self._handle_heartbeat(participant_id) | |
| elif msg_type == MessageType.EMERGENCY_STOP: | |
| await self._handle_emergency_stop( | |
| workspace_id, room_id, participant_id, message | |
| ) | |
| else: | |
| logger.warning(f"Unhandled message type {msg_type} from {participant_id}") | |
| # ============= STRUCTURED MESSAGE HANDLERS ============= | |
| async def _handle_joint_update( | |
| self, | |
| workspace_id: str, | |
| room_id: str, | |
| participant_id: str, | |
| role: ParticipantRole, | |
| message: dict, | |
| ): | |
| """Handle joint update commands from producers""" | |
| if role != ParticipantRole.PRODUCER: | |
| logger.warning( | |
| f"Non-producer {participant_id} attempted to send joint update" | |
| ) | |
| return | |
| joints = message.get("data", []) | |
| if not joints: | |
| logger.warning(f"Empty joint data from producer {participant_id}") | |
| return | |
| try: | |
| changed_joints = self.update_joints(workspace_id, room_id, joints) | |
| if changed_joints: | |
| await self._broadcast_to_consumers( | |
| workspace_id, | |
| room_id, | |
| { | |
| "type": MessageType.JOINT_UPDATE.value, | |
| "data": changed_joints, | |
| "timestamp": datetime.now(tz=UTC).isoformat(), | |
| "source": participant_id, | |
| }, | |
| ) | |
| logger.debug( | |
| f"Producer {participant_id} sent {len(changed_joints)} joint updates" | |
| ) | |
| except Exception: | |
| logger.exception(f"Error processing joint update from {participant_id}") | |
| await self._handle_error(participant_id, "Failed to process joint update") | |
| async def _handle_heartbeat(self, participant_id: str): | |
| """Handle heartbeat messages""" | |
| try: | |
| await self._send_to_participant( | |
| participant_id, | |
| { | |
| "type": MessageType.HEARTBEAT_ACK.value, | |
| "timestamp": datetime.now(tz=UTC).isoformat(), | |
| }, | |
| ) | |
| logger.debug(f"Heartbeat acknowledged for {participant_id}") | |
| except Exception: | |
| logger.exception(f"Error handling heartbeat from {participant_id}") | |
| async def _handle_emergency_stop( | |
| self, workspace_id: str, room_id: str, participant_id: str, message: dict | |
| ): | |
| """Handle emergency stop messages""" | |
| try: | |
| reason = message.get("reason", f"Emergency stop from {participant_id}") | |
| emergency_message = { | |
| "type": MessageType.EMERGENCY_STOP.value, | |
| "timestamp": datetime.now(tz=UTC).isoformat(), | |
| "reason": reason, | |
| "source": participant_id, | |
| } | |
| # Broadcast to all participants in room | |
| await self._broadcast_to_all_participants( | |
| workspace_id, room_id, emergency_message | |
| ) | |
| logger.warning( | |
| f"🚨 Emergency stop triggered by {participant_id} in room {room_id} (workspace {workspace_id})" | |
| ) | |
| except Exception: | |
| logger.exception(f"Error handling emergency stop from {participant_id}") | |
| async def _handle_error(self, participant_id: str, error_message: str): | |
| """Send error message to participant""" | |
| try: | |
| await self._send_to_participant( | |
| participant_id, | |
| { | |
| "type": MessageType.ERROR.value, | |
| "message": error_message, | |
| "timestamp": datetime.now(tz=UTC).isoformat(), | |
| }, | |
| ) | |
| except Exception: | |
| logger.exception(f"Error sending error message to {participant_id}") | |
| async def _broadcast_to_consumers( | |
| self, workspace_id: str, room_id: str, message: dict | |
| ): | |
| """Send message to all consumers in room""" | |
| room = self._get_room(workspace_id, room_id) | |
| if not room: | |
| return | |
| message_text = json.dumps(message) | |
| failed = [] | |
| for consumer_id in room.consumers: | |
| if consumer_id in self.connections: | |
| try: | |
| await self.connections[consumer_id].send_text(message_text) | |
| except Exception: | |
| logger.exception(f"Error sending message to {consumer_id}") | |
| failed.append(consumer_id) | |
| # Cleanup failed connections | |
| for consumer_id in failed: | |
| room.consumers.remove(consumer_id) | |
| if consumer_id in self.connections: | |
| del self.connections[consumer_id] | |
| if consumer_id in self.connection_metadata: | |
| del self.connection_metadata[consumer_id] | |
| async def _broadcast_to_all_participants( | |
| self, workspace_id: str, room_id: str, message: dict | |
| ): | |
| """Send message to all participants (producer + consumers) in room""" | |
| room = self._get_room(workspace_id, room_id) | |
| if not room: | |
| return | |
| message_text = json.dumps(message) | |
| participants = [] | |
| # Add producer if exists | |
| if room.producer: | |
| participants.append(room.producer) | |
| # Add all consumers | |
| participants.extend(room.consumers) | |
| failed = [] | |
| sent_count = 0 | |
| for participant_id in participants: | |
| if participant_id in self.connections: | |
| try: | |
| await self.connections[participant_id].send_text(message_text) | |
| sent_count += 1 | |
| except Exception: | |
| logger.exception(f"Error sending message to {participant_id}") | |
| failed.append(participant_id) | |
| # Cleanup failed connections | |
| for participant_id in failed: | |
| metadata = self.connection_metadata.get(participant_id) | |
| if metadata: | |
| self.leave_room( | |
| metadata["workspace_id"], metadata["room_id"], participant_id | |
| ) | |
| if participant_id in self.connections: | |
| del self.connections[participant_id] | |
| if participant_id in self.connection_metadata: | |
| del self.connection_metadata[participant_id] | |
| logger.debug( | |
| f"Broadcast message to {sent_count}/{len(participants)} participants in room {room_id}" | |
| ) | |
| async def _send_to_participant(self, participant_id: str, message: dict): | |
| """Send message to specific participant""" | |
| if participant_id in self.connections: | |
| try: | |
| await self.connections[participant_id].send_text(json.dumps(message)) | |
| except Exception: | |
| logger.exception(f"Error sending message to {participant_id}") | |
| if participant_id in self.connections: | |
| del self.connections[participant_id] | |
| # ============= CONNECTION MONITORING ============= | |
| def get_connection_stats(self) -> dict: | |
| """Get connection statistics and metadata""" | |
| stats = { | |
| "total_connections": len(self.connections), | |
| "total_workspaces": len(self.workspaces), | |
| "total_rooms": sum(len(rooms) for rooms in self.workspaces.values()), | |
| "connections_by_role": {"producer": 0, "consumer": 0}, | |
| "connections_by_workspace": {}, | |
| "active_connections": [], | |
| } | |
| # Count by role and collect active connections | |
| for participant_id, metadata in self.connection_metadata.items(): | |
| role = metadata["role"] | |
| workspace_id = metadata["workspace_id"] | |
| room_id = metadata["room_id"] | |
| stats["connections_by_role"][role.value] += 1 | |
| if workspace_id not in stats["connections_by_workspace"]: | |
| stats["connections_by_workspace"][workspace_id] = { | |
| "producer": 0, | |
| "consumer": 0, | |
| "rooms": 0, | |
| } | |
| stats["connections_by_workspace"][workspace_id][role.value] += 1 | |
| # Count unique rooms per workspace | |
| if workspace_id in self.workspaces: | |
| stats["connections_by_workspace"][workspace_id]["rooms"] = len( | |
| self.workspaces[workspace_id] | |
| ) | |
| stats["active_connections"].append({ | |
| "participant_id": participant_id, | |
| "workspace_id": workspace_id, | |
| "room_id": room_id, | |
| "role": role.value, | |
| "connected_at": metadata["connected_at"].isoformat(), | |
| "last_activity": metadata["last_activity"].isoformat(), | |
| "message_count": metadata["message_count"], | |
| }) | |
| return stats | |
| # ============= EXTERNAL API METHODS ============= | |
| async def send_command_to_room( | |
| self, workspace_id: str, room_id: str, joints: list[dict] | |
| ): | |
| changed_joints = self.update_joints(workspace_id, room_id, joints) | |
| if changed_joints: | |
| await self._broadcast_to_consumers( | |
| workspace_id, | |
| room_id, | |
| { | |
| "type": MessageType.JOINT_UPDATE.value, | |
| "data": changed_joints, | |
| "timestamp": datetime.now(tz=UTC).isoformat(), | |
| "source": "api", | |
| }, | |
| ) | |
| return len(changed_joints) | |