File size: 3,169 Bytes
d79115c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import List, Dict, Any, Optional
import asyncio
import json
from fastapi import FastAPI, WebSocket
from pydantic import BaseModel
import torch

class PeerNetwork:
    def __init__(self, host: str = "localhost", port: int = 8000):
        self.app = FastAPI()
        self.active_peers: Dict[str, WebSocket] = {}
        self.host = host
        self.port = port
        
        # Register WebSocket endpoint
        @self.app.websocket("/ws/{peer_id}")
        async def websocket_endpoint(websocket: WebSocket, peer_id: str):
            await self.connect_peer(websocket, peer_id)
            try:
                while True:
                    data = await websocket.receive_text()
                    await self.broadcast(data, peer_id)
            except Exception:
                await self.disconnect_peer(peer_id)
                
    async def connect_peer(self, websocket: WebSocket, peer_id: str):
        """Connect a new peer to the network"""
        await websocket.accept()
        self.active_peers[peer_id] = websocket
        
    async def disconnect_peer(self, peer_id: str):
        """Remove a peer from the network"""
        if peer_id in self.active_peers:
            await self.active_peers[peer_id].close()
            del self.active_peers[peer_id]
            
    async def broadcast(self, message: str, sender_id: str):
        """Broadcast a message to all peers except the sender"""
        for peer_id, websocket in self.active_peers.items():
            if peer_id != sender_id:
                await websocket.send_text(message)
                
class OpenPeerClient:
    def __init__(self, network_url: str):
        self.network_url = network_url
        self.websocket: Optional[WebSocket] = None
        self.peer_id: Optional[str] = None
        
    async def connect(self, peer_id: str):
        """Connect to the peer network"""
        self.peer_id = peer_id
        self.websocket = await WebSocket.connect(f"{self.network_url}/ws/{peer_id}")
        
    async def send_model_update(self, model_state: Dict[str, torch.Tensor]):
        """Send model state updates to the network"""
        if not self.websocket:
            raise RuntimeError("Not connected to network")
            
        serialized_state = {
            "type": "model_update",
            "peer_id": self.peer_id,
            "state": {k: v.cpu().numpy().tolist() for k, v in model_state.items()}
        }
        await self.websocket.send_text(json.dumps(serialized_state))
        
    async def receive_updates(self):
        """Receive updates from the network"""
        if not self.websocket:
            raise RuntimeError("Not connected to network")
            
        while True:
            data = await self.websocket.receive_text()
            yield json.loads(data)
            
def create_peer_network(host: str = "localhost", port: int = 8000) -> PeerNetwork:
    """Create and start a peer network server"""
    network = PeerNetwork(host, port)
    import uvicorn
    uvicorn.run(network.app, host=host, port=port)
    return network