|
|
from typing import List, Dict, Any, Optional
|
|
|
import asyncio
|
|
|
import json
|
|
|
from fastapi import FastAPI, WebSocket
|
|
|
from pydantic import BaseModel
|
|
|
import torch
|
|
|
|
|
|
class PeerNetwork:
|
|
|
def __init__(self, host: str = "localhost", port: int = 8000):
|
|
|
self.app = FastAPI()
|
|
|
self.active_peers: Dict[str, WebSocket] = {}
|
|
|
self.host = host
|
|
|
self.port = port
|
|
|
|
|
|
|
|
|
@self.app.websocket("/ws/{peer_id}")
|
|
|
async def websocket_endpoint(websocket: WebSocket, peer_id: str):
|
|
|
await self.connect_peer(websocket, peer_id)
|
|
|
try:
|
|
|
while True:
|
|
|
data = await websocket.receive_text()
|
|
|
await self.broadcast(data, peer_id)
|
|
|
except Exception:
|
|
|
await self.disconnect_peer(peer_id)
|
|
|
|
|
|
async def connect_peer(self, websocket: WebSocket, peer_id: str):
|
|
|
"""Connect a new peer to the network"""
|
|
|
await websocket.accept()
|
|
|
self.active_peers[peer_id] = websocket
|
|
|
|
|
|
async def disconnect_peer(self, peer_id: str):
|
|
|
"""Remove a peer from the network"""
|
|
|
if peer_id in self.active_peers:
|
|
|
await self.active_peers[peer_id].close()
|
|
|
del self.active_peers[peer_id]
|
|
|
|
|
|
async def broadcast(self, message: str, sender_id: str):
|
|
|
"""Broadcast a message to all peers except the sender"""
|
|
|
for peer_id, websocket in self.active_peers.items():
|
|
|
if peer_id != sender_id:
|
|
|
await websocket.send_text(message)
|
|
|
|
|
|
class OpenPeerClient:
|
|
|
def __init__(self, network_url: str):
|
|
|
self.network_url = network_url
|
|
|
self.websocket: Optional[WebSocket] = None
|
|
|
self.peer_id: Optional[str] = None
|
|
|
|
|
|
async def connect(self, peer_id: str):
|
|
|
"""Connect to the peer network"""
|
|
|
self.peer_id = peer_id
|
|
|
self.websocket = await WebSocket.connect(f"{self.network_url}/ws/{peer_id}")
|
|
|
|
|
|
async def send_model_update(self, model_state: Dict[str, torch.Tensor]):
|
|
|
"""Send model state updates to the network"""
|
|
|
if not self.websocket:
|
|
|
raise RuntimeError("Not connected to network")
|
|
|
|
|
|
serialized_state = {
|
|
|
"type": "model_update",
|
|
|
"peer_id": self.peer_id,
|
|
|
"state": {k: v.cpu().numpy().tolist() for k, v in model_state.items()}
|
|
|
}
|
|
|
await self.websocket.send_text(json.dumps(serialized_state))
|
|
|
|
|
|
async def receive_updates(self):
|
|
|
"""Receive updates from the network"""
|
|
|
if not self.websocket:
|
|
|
raise RuntimeError("Not connected to network")
|
|
|
|
|
|
while True:
|
|
|
data = await self.websocket.receive_text()
|
|
|
yield json.loads(data)
|
|
|
|
|
|
def create_peer_network(host: str = "localhost", port: int = 8000) -> PeerNetwork:
|
|
|
"""Create and start a peer network server"""
|
|
|
network = PeerNetwork(host, port)
|
|
|
import uvicorn
|
|
|
uvicorn.run(network.app, host=host, port=port)
|
|
|
return network |