Spaces:
Running
Running
| from fastapi import WebSocket | |
| from typing import Dict, List, Optional | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class ConnectionManager: | |
| """Manages WebSocket connections for masters and slaves""" | |
| def __init__(self): | |
| # robot_id -> websocket | |
| self.master_connections: Dict[str, WebSocket] = {} | |
| # robot_id -> list of websockets | |
| self.slave_connections: Dict[str, List[WebSocket]] = {} | |
| # connection_id -> (robot_id, websocket) | |
| self.connection_registry: Dict[str, tuple] = {} | |
| async def connect_master(self, connection_id: str, robot_id: str, websocket: WebSocket): | |
| """Connect a master to a robot""" | |
| # Only one master per robot | |
| if robot_id in self.master_connections: | |
| logger.warning(f"Disconnecting existing master for robot {robot_id}") | |
| await self.disconnect_master_by_robot(robot_id) | |
| self.master_connections[robot_id] = websocket | |
| self.connection_registry[connection_id] = (robot_id, websocket) | |
| logger.info(f"Master {connection_id} connected to robot {robot_id}") | |
| async def connect_slave(self, connection_id: str, robot_id: str, websocket: WebSocket): | |
| """Connect a slave to a robot""" | |
| if robot_id not in self.slave_connections: | |
| self.slave_connections[robot_id] = [] | |
| self.slave_connections[robot_id].append(websocket) | |
| self.connection_registry[connection_id] = (robot_id, websocket) | |
| logger.info(f"Slave {connection_id} connected to robot {robot_id} ({len(self.slave_connections[robot_id])} total slaves)") | |
| async def disconnect_master(self, connection_id: str): | |
| """Disconnect a master connection""" | |
| if connection_id in self.connection_registry: | |
| robot_id, websocket = self.connection_registry[connection_id] | |
| if robot_id in self.master_connections: | |
| del self.master_connections[robot_id] | |
| del self.connection_registry[connection_id] | |
| logger.info(f"Master {connection_id} disconnected from robot {robot_id}") | |
| async def disconnect_master_by_robot(self, robot_id: str): | |
| """Disconnect master by robot ID""" | |
| if robot_id in self.master_connections: | |
| websocket = self.master_connections[robot_id] | |
| # Find and remove from connection registry | |
| for conn_id, (r_id, ws) in list(self.connection_registry.items()): | |
| if r_id == robot_id and ws == websocket: | |
| del self.connection_registry[conn_id] | |
| break | |
| del self.master_connections[robot_id] | |
| try: | |
| await websocket.close() | |
| except Exception as e: | |
| logger.error(f"Error closing master websocket for robot {robot_id}: {e}") | |
| async def disconnect_slave(self, connection_id: str): | |
| """Disconnect a slave connection""" | |
| if connection_id in self.connection_registry: | |
| robot_id, websocket = self.connection_registry[connection_id] | |
| if robot_id in self.slave_connections: | |
| try: | |
| self.slave_connections[robot_id].remove(websocket) | |
| if not self.slave_connections[robot_id]: # Remove empty list | |
| del self.slave_connections[robot_id] | |
| except ValueError: | |
| logger.warning(f"Slave websocket not found in connections for robot {robot_id}") | |
| del self.connection_registry[connection_id] | |
| logger.info(f"Slave {connection_id} disconnected from robot {robot_id}") | |
| def get_master_connection(self, robot_id: str) -> Optional[WebSocket]: | |
| """Get master connection for a robot""" | |
| return self.master_connections.get(robot_id) | |
| def get_slave_connections(self, robot_id: str) -> List[WebSocket]: | |
| """Get all slave connections for a robot""" | |
| return self.slave_connections.get(robot_id, []) | |
| def get_connection_count(self) -> int: | |
| """Get total number of active connections""" | |
| master_count = len(self.master_connections) | |
| slave_count = sum(len(slaves) for slaves in self.slave_connections.values()) | |
| return master_count + slave_count | |
| def get_robot_connection_info(self, robot_id: str) -> dict: | |
| """Get connection information for a robot""" | |
| has_master = robot_id in self.master_connections | |
| slave_count = len(self.slave_connections.get(robot_id, [])) | |
| return { | |
| "robot_id": robot_id, | |
| "has_master": has_master, | |
| "slave_count": slave_count, | |
| "total_connections": (1 if has_master else 0) + slave_count | |
| } | |
| async def cleanup_robot_connections(self, robot_id: str): | |
| """Clean up all connections for a robot""" | |
| # Close master connection | |
| if robot_id in self.master_connections: | |
| try: | |
| await self.master_connections[robot_id].close() | |
| except Exception as e: | |
| logger.error(f"Error closing master connection for robot {robot_id}: {e}") | |
| del self.master_connections[robot_id] | |
| # Close slave connections | |
| if robot_id in self.slave_connections: | |
| for websocket in self.slave_connections[robot_id]: | |
| try: | |
| await websocket.close() | |
| except Exception as e: | |
| logger.error(f"Error closing slave connection for robot {robot_id}: {e}") | |
| del self.slave_connections[robot_id] | |
| # Clean up connection registry | |
| to_remove = [] | |
| for conn_id, (r_id, _) in self.connection_registry.items(): | |
| if r_id == robot_id: | |
| to_remove.append(conn_id) | |
| for conn_id in to_remove: | |
| del self.connection_registry[conn_id] | |
| logger.info(f"Cleaned up all connections for robot {robot_id}") | |
| def list_all_connections(self) -> dict: | |
| """List all active connections for debugging""" | |
| return { | |
| "masters": {robot_id: "connected" for robot_id in self.master_connections.keys()}, | |
| "slaves": {robot_id: len(slaves) for robot_id, slaves in self.slave_connections.items()}, | |
| "total_connections": self.get_connection_count() | |
| } |