Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| import shlex | |
| import subprocess | |
| import tempfile | |
| import warnings | |
| from pathlib import Path | |
| from typing import Optional | |
| import fastapi | |
| import fastapi.middleware.cors | |
| import tyro | |
| import uvicorn | |
| from attr import dataclass | |
| from fastapi import Request | |
| from fastapi.responses import Response | |
| from fam.llm.fast_inference import TTS | |
| from fam.llm.utils import check_audio_file | |
| logger = logging.getLogger(__name__) | |
| ## Setup FastAPI server. | |
| app = fastapi.FastAPI() | |
| class ServingConfig: | |
| huggingface_repo_id: str = "metavoiceio/metavoice-1B-v0.1" | |
| """Absolute path to the model directory.""" | |
| temperature: float = 1.0 | |
| """Temperature for sampling applied to both models.""" | |
| seed: int = 1337 | |
| """Random seed for sampling.""" | |
| port: int = 58003 | |
| # Singleton | |
| class _GlobalState: | |
| config: ServingConfig | |
| tts: TTS | |
| GlobalState = _GlobalState() | |
| class TTSRequest: | |
| text: str | |
| speaker_ref_path: Optional[str] = None | |
| guidance: float = 3.0 | |
| top_p: float = 0.95 | |
| top_k: Optional[int] = None | |
| async def health_check(): | |
| return {"status": "ok"} | |
| async def text_to_speech(req: Request): | |
| audiodata = await req.body() | |
| payload = None | |
| wav_out_path = None | |
| try: | |
| headers = req.headers | |
| payload = headers["X-Payload"] | |
| payload = json.loads(payload) | |
| tts_req = TTSRequest(**payload) | |
| with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp: | |
| if tts_req.speaker_ref_path is None: | |
| wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp) | |
| check_audio_file(wav_path) | |
| else: | |
| # TODO: fix | |
| wav_path = tts_req.speaker_ref_path | |
| if wav_path is None: | |
| warnings.warn("Running without speaker reference") | |
| assert tts_req.guidance is None | |
| wav_out_path = GlobalState.tts.synthesise( | |
| text=tts_req.text, | |
| spk_ref_path=wav_path, | |
| top_p=tts_req.top_p, | |
| guidance_scale=tts_req.guidance, | |
| ) | |
| with open(wav_out_path, "rb") as f: | |
| return Response(content=f.read(), media_type="audio/wav") | |
| except Exception as e: | |
| # traceback_str = "".join(traceback.format_tb(e.__traceback__)) | |
| logger.exception(f"Error processing request {payload}") | |
| return Response( | |
| content="Something went wrong. Please try again in a few mins or contact us on Discord", | |
| status_code=500, | |
| ) | |
| finally: | |
| if wav_out_path is not None: | |
| Path(wav_out_path).unlink(missing_ok=True) | |
| def _convert_audiodata_to_wav_path(audiodata, wav_tmp): | |
| with tempfile.NamedTemporaryFile() as unknown_format_tmp: | |
| if unknown_format_tmp.write(audiodata) == 0: | |
| return None | |
| unknown_format_tmp.flush() | |
| subprocess.check_output( | |
| # arbitrary 2 minute cutoff | |
| shlex.split(f"ffmpeg -t 120 -y -i {unknown_format_tmp.name} -f wav {wav_tmp.name}") | |
| ) | |
| return wav_tmp.name | |
| if __name__ == "__main__": | |
| for name in logging.root.manager.loggerDict: | |
| logger = logging.getLogger(name) | |
| logger.setLevel(logging.INFO) | |
| logging.root.setLevel(logging.INFO) | |
| GlobalState.config = tyro.cli(ServingConfig) | |
| GlobalState.tts = TTS(seed=GlobalState.config.seed) | |
| app.add_middleware( | |
| fastapi.middleware.cors.CORSMiddleware, | |
| allow_origins=["*", f"http://localhost:{GlobalState.config.port}", "http://localhost:3000"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=GlobalState.config.port, | |
| log_level="info", | |
| ) | |