|
|
"""
|
|
|
CouchDB client for distributed coordination.
|
|
|
"""
|
|
|
import couchdb
|
|
|
import uuid
|
|
|
from datetime import datetime
|
|
|
from typing import Dict, List, Optional, Any
|
|
|
from .config import settings
|
|
|
|
|
|
class CouchDBClient:
|
|
|
"""Client for interacting with CouchDB for distributed coordination."""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.server = couchdb.Server(settings.COUCHDB_URL)
|
|
|
self.server.resource.credentials = (
|
|
|
settings.COUCHDB_USER,
|
|
|
settings.COUCHDB_PASSWORD
|
|
|
)
|
|
|
self._ensure_databases()
|
|
|
|
|
|
def _ensure_databases(self):
|
|
|
"""Ensure required databases exist."""
|
|
|
required_dbs = ['agents', 'jobs', 'gradients', 'model_state']
|
|
|
for db_name in required_dbs:
|
|
|
if db_name not in self.server:
|
|
|
self.server.create(db_name)
|
|
|
|
|
|
def register_agent(self, agent_id: str, capabilities: Dict[str, Any]) -> bool:
|
|
|
"""Register an agent in the cluster."""
|
|
|
db = self.server['agents']
|
|
|
doc = {
|
|
|
'_id': agent_id,
|
|
|
'status': 'active',
|
|
|
'capabilities': capabilities,
|
|
|
'last_heartbeat': datetime.utcnow().isoformat(),
|
|
|
'current_job': None
|
|
|
}
|
|
|
try:
|
|
|
db.save(doc)
|
|
|
return True
|
|
|
except couchdb.http.ResourceConflict:
|
|
|
return False
|
|
|
|
|
|
def update_heartbeat(self, agent_id: str) -> bool:
|
|
|
"""Update agent heartbeat."""
|
|
|
db = self.server['agents']
|
|
|
try:
|
|
|
doc = db[agent_id]
|
|
|
doc['last_heartbeat'] = datetime.utcnow().isoformat()
|
|
|
db.save(doc)
|
|
|
return True
|
|
|
except couchdb.http.ResourceNotFound:
|
|
|
return False
|
|
|
|
|
|
def create_job(self, job_type: str, params: Dict[str, Any]) -> str:
|
|
|
"""Create a new job in the job queue."""
|
|
|
db = self.server['jobs']
|
|
|
job_id = str(uuid.uuid4())
|
|
|
doc = {
|
|
|
'_id': job_id,
|
|
|
'type': job_type,
|
|
|
'params': params,
|
|
|
'status': 'pending',
|
|
|
'created_at': datetime.utcnow().isoformat(),
|
|
|
'assigned_to': None
|
|
|
}
|
|
|
db.save(doc)
|
|
|
return job_id
|
|
|
|
|
|
def claim_job(self, agent_id: str) -> Optional[Dict[str, Any]]:
|
|
|
"""Attempt to claim a pending job."""
|
|
|
db = self.server['jobs']
|
|
|
for row in db.view('_all_docs', include_docs=True):
|
|
|
doc = row.doc
|
|
|
if doc.get('status') == 'pending':
|
|
|
try:
|
|
|
doc['status'] = 'in_progress'
|
|
|
doc['assigned_to'] = agent_id
|
|
|
doc['claimed_at'] = datetime.utcnow().isoformat()
|
|
|
db.save(doc)
|
|
|
return doc
|
|
|
except couchdb.http.ResourceConflict:
|
|
|
continue
|
|
|
return None
|
|
|
|
|
|
def update_job_status(self, job_id: str, status: str, result: Optional[Dict[str, Any]] = None) -> bool:
|
|
|
"""Update job status and optionally store results."""
|
|
|
db = self.server['jobs']
|
|
|
try:
|
|
|
doc = db[job_id]
|
|
|
doc['status'] = status
|
|
|
if result:
|
|
|
doc['result'] = result
|
|
|
doc['updated_at'] = datetime.utcnow().isoformat()
|
|
|
db.save(doc)
|
|
|
return True
|
|
|
except couchdb.http.ResourceNotFound:
|
|
|
return False
|
|
|
|
|
|
def store_gradients(self, job_id: str, gradients: Dict[str, Any]) -> str:
|
|
|
"""Store computed gradients."""
|
|
|
db = self.server['gradients']
|
|
|
gradient_id = str(uuid.uuid4())
|
|
|
doc = {
|
|
|
'_id': gradient_id,
|
|
|
'job_id': job_id,
|
|
|
'gradients': gradients,
|
|
|
'timestamp': datetime.utcnow().isoformat()
|
|
|
}
|
|
|
db.save(doc)
|
|
|
return gradient_id
|
|
|
|
|
|
def get_active_agents(self) -> List[Dict[str, Any]]:
|
|
|
"""Get list of currently active agents."""
|
|
|
db = self.server['agents']
|
|
|
active_agents = []
|
|
|
for row in db.view('_all_docs', include_docs=True):
|
|
|
doc = row.doc
|
|
|
if doc.get('status') == 'active':
|
|
|
active_agents.append(doc)
|
|
|
return active_agents
|
|
|
|
|
|
def store_model_state(self, state: Dict[str, Any]) -> str:
|
|
|
"""Store current model state."""
|
|
|
db = self.server['model_state']
|
|
|
state_id = str(uuid.uuid4())
|
|
|
doc = {
|
|
|
'_id': state_id,
|
|
|
'state': state,
|
|
|
'timestamp': datetime.utcnow().isoformat()
|
|
|
}
|
|
|
db.save(doc)
|
|
|
return state_id
|
|
|
|
|
|
def get_latest_model_state(self) -> Optional[Dict[str, Any]]:
|
|
|
"""Retrieve the latest model state."""
|
|
|
db = self.server['model_state']
|
|
|
|
|
|
for row in db.view('_all_docs', include_docs=True, descending=True, limit=1):
|
|
|
return row.doc
|
|
|
return None |