Spaces:
Running
Running
| """ | |
| Style-Bert-VITS2-Editor用のサーバー。 | |
| 次のリポジトリ | |
| https://github.com/litagin02/Style-Bert-VITS2-Editor | |
| をビルドしてできあがったファイルをWebフォルダに入れて実行する。 | |
| TODO: リファクタリングやドキュメンテーションやAPI整理、辞書周りの改善などが必要。 | |
| """ | |
| import argparse | |
| import io | |
| import shutil | |
| import sys | |
| import webbrowser | |
| import zipfile | |
| from datetime import datetime | |
| from io import BytesIO | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import requests | |
| import torch | |
| import uvicorn | |
| from fastapi import APIRouter, FastAPI, HTTPException, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, Response | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from scipy.io import wavfile | |
| from config import get_path_config | |
| from style_bert_vits2.constants import ( | |
| DEFAULT_ASSIST_TEXT_WEIGHT, | |
| DEFAULT_NOISE, | |
| DEFAULT_NOISEW, | |
| DEFAULT_SDP_RATIO, | |
| DEFAULT_STYLE, | |
| DEFAULT_STYLE_WEIGHT, | |
| VERSION, | |
| Languages, | |
| ) | |
| from style_bert_vits2.logging import logger | |
| from style_bert_vits2.nlp import bert_models | |
| from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk | |
| from style_bert_vits2.nlp.japanese.g2p_utils import g2kata_tone, kata_tone2phone_tone | |
| from style_bert_vits2.nlp.japanese.normalizer import normalize_text | |
| from style_bert_vits2.nlp.japanese.user_dict import ( | |
| apply_word, | |
| delete_word, | |
| read_dict, | |
| rewrite_word, | |
| update_dict, | |
| ) | |
| from style_bert_vits2.tts_model import TTSModelHolder, TTSModelInfo | |
| # ---フロントエンド部分に関する処理--- | |
| # エディターのビルドファイルを配置するディレクトリ | |
| STATIC_DIR = Path("static") | |
| # エディターの最新のビルドファイルのダウンロード日時を記録するファイル | |
| LAST_DOWNLOAD_FILE = STATIC_DIR / "last_download.txt" | |
| def download_static_files(user, repo, asset_name): | |
| """Style-Bert-VITS2エディターの最新のビルドzipをダウンロードして展開する。""" | |
| logger.info("Checking for new release...") | |
| latest_release = get_latest_release(user, repo) | |
| if latest_release is None: | |
| logger.warning( | |
| "Failed to fetch the latest release. Proceeding without static files." | |
| ) | |
| return | |
| if not new_release_available(latest_release): | |
| logger.info("No new release available. Proceeding with existing static files.") | |
| return | |
| logger.info("New release available. Downloading static files...") | |
| asset_url = get_asset_url(latest_release, asset_name) | |
| if asset_url: | |
| if STATIC_DIR.exists(): | |
| shutil.rmtree(STATIC_DIR) | |
| STATIC_DIR.mkdir(parents=True, exist_ok=True) | |
| download_and_extract(asset_url, STATIC_DIR) | |
| save_last_download(latest_release) | |
| else: | |
| logger.warning("Asset not found. Proceeding without static files.") | |
| def get_latest_release(user, repo): | |
| url = f"https://api.github.com/repos/{user}/{repo}/releases/latest" | |
| try: | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.RequestException: | |
| return None | |
| def get_asset_url(release, asset_name): | |
| for asset in release["assets"]: | |
| if asset["name"] == asset_name: | |
| return asset["browser_download_url"] | |
| return None | |
| def download_and_extract(url, extract_to: Path): | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref: | |
| zip_ref.extractall(extract_to) | |
| # 展開先が1つのディレクトリだけの場合、その中身を直下に移動する | |
| extracted_dirs = list(extract_to.iterdir()) | |
| if len(extracted_dirs) == 1 and extracted_dirs[0].is_dir(): | |
| for file in extracted_dirs[0].iterdir(): | |
| file.rename(extract_to / file.name) | |
| extracted_dirs[0].rmdir() | |
| # index.htmlが存在するかチェック | |
| if not (extract_to / "index.html").exists(): | |
| logger.warning("index.html not found in the extracted files.") | |
| def new_release_available(latest_release): | |
| if LAST_DOWNLOAD_FILE.exists(): | |
| with open(LAST_DOWNLOAD_FILE) as file: | |
| last_download_str = file.read().strip() | |
| # 'Z'を除去して日時オブジェクトに変換 | |
| last_download_str = last_download_str.replace("Z", "+00:00") | |
| last_download = datetime.fromisoformat(last_download_str) | |
| return ( | |
| datetime.fromisoformat( | |
| latest_release["published_at"].replace("Z", "+00:00") | |
| ) | |
| > last_download | |
| ) | |
| return True | |
| def save_last_download(latest_release): | |
| with open(LAST_DOWNLOAD_FILE, "w") as file: | |
| file.write(latest_release["published_at"]) | |
| # ---フロントエンド部分に関する処理ここまで--- | |
| # 以降はAPIの設定 | |
| # pyopenjtalk_worker を起動 | |
| ## pyopenjtalk_worker は TCP ソケットサーバーのため、ここで起動する | |
| pyopenjtalk.initialize_worker() | |
| # pyopenjtalk の辞書を更新 | |
| update_dict() | |
| # 事前に BERT モデル/トークナイザーをロードしておく | |
| ## ここでロードしなくても必要になった際に自動ロードされるが、時間がかかるため事前にロードしておいた方が体験が良い | |
| ## server_editor.py は日本語にしか対応していないため、日本語の BERT モデル/トークナイザーのみロードする | |
| bert_models.load_model(Languages.JP) | |
| bert_models.load_tokenizer(Languages.JP) | |
| class AudioResponse(Response): | |
| media_type = "audio/wav" | |
| origins = [ | |
| "http://localhost:3000", | |
| "http://localhost:8000", | |
| "http://127.0.0.1:3000", | |
| "http://127.0.0.1:8000", | |
| ] | |
| path_config = get_path_config() | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model_dir", type=str, default=path_config.assets_root) | |
| parser.add_argument("--device", type=str, default="cuda") | |
| parser.add_argument("--port", type=int, default=8000) | |
| parser.add_argument("--inbrowser", action="store_true") | |
| parser.add_argument("--line_length", type=int, default=None) | |
| parser.add_argument("--line_count", type=int, default=None) | |
| # parser.add_argument("--skip_default_models", action="store_true") | |
| parser.add_argument("--skip_static_files", action="store_true") | |
| args = parser.parse_args() | |
| device = args.device | |
| if device == "cuda" and not torch.cuda.is_available(): | |
| device = "cpu" | |
| model_dir = Path(args.model_dir) | |
| port = int(args.port) | |
| # if not args.skip_default_models: | |
| # download_default_models() | |
| skip_static_files = bool(args.skip_static_files) | |
| model_holder = TTSModelHolder(model_dir, device) | |
| if len(model_holder.model_names) == 0: | |
| logger.error(f"Models not found in {model_dir}.") | |
| sys.exit(1) | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| router = APIRouter() | |
| def version() -> str: | |
| return VERSION | |
| class MoraTone(BaseModel): | |
| mora: str | |
| tone: int | |
| class TextRequest(BaseModel): | |
| text: str | |
| async def read_item(item: TextRequest): | |
| try: | |
| # 最初に正規化しないと整合性がとれない | |
| text = normalize_text(item.text) | |
| kata_tone_list = g2kata_tone(text) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Failed to convert {item.text} to katakana and tone, {e}", | |
| ) | |
| return [MoraTone(mora=kata, tone=tone) for kata, tone in kata_tone_list] | |
| async def normalize(item: TextRequest): | |
| return normalize_text(item.text) | |
| def models_info(): | |
| return model_holder.models_info | |
| class SynthesisRequest(BaseModel): | |
| model: str | |
| modelFile: str | |
| text: str | |
| moraToneList: list[MoraTone] | |
| style: str = DEFAULT_STYLE | |
| styleWeight: float = DEFAULT_STYLE_WEIGHT | |
| assistText: str = "" | |
| assistTextWeight: float = DEFAULT_ASSIST_TEXT_WEIGHT | |
| speed: float = 1.0 | |
| noise: float = DEFAULT_NOISE | |
| noisew: float = DEFAULT_NOISEW | |
| sdpRatio: float = DEFAULT_SDP_RATIO | |
| language: Languages = Languages.JP | |
| silenceAfter: float = 0.5 | |
| pitchScale: float = 1.0 | |
| intonationScale: float = 1.0 | |
| speaker: Optional[str] = None | |
| def synthesis(request: SynthesisRequest): | |
| if args.line_length is not None and len(request.text) > args.line_length: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"1行の文字数は{args.line_length}文字以下にしてください。", | |
| ) | |
| try: | |
| model = model_holder.get_model( | |
| model_name=request.model, model_path_str=request.modelFile | |
| ) | |
| except Exception as e: | |
| logger.error(e) | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to load model {request.model} from {request.modelFile}, {e}", | |
| ) | |
| text = request.text | |
| kata_tone_list = [ | |
| (mora_tone.mora, mora_tone.tone) for mora_tone in request.moraToneList | |
| ] | |
| phone_tone = kata_tone2phone_tone(kata_tone_list) | |
| tone = [t for _, t in phone_tone] | |
| try: | |
| sid = 0 if request.speaker is None else model.spk2id[request.speaker] | |
| except KeyError: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Speaker {request.speaker} not found in {model.spk2id}", | |
| ) | |
| sr, audio = model.infer( | |
| text=text, | |
| language=request.language, | |
| sdp_ratio=request.sdpRatio, | |
| noise=request.noise, | |
| noise_w=request.noisew, | |
| length=1 / request.speed, | |
| given_tone=tone, | |
| style=request.style, | |
| style_weight=request.styleWeight, | |
| assist_text=request.assistText, | |
| assist_text_weight=request.assistTextWeight, | |
| use_assist_text=bool(request.assistText), | |
| line_split=False, | |
| pitch_scale=request.pitchScale, | |
| intonation_scale=request.intonationScale, | |
| speaker_id=sid, | |
| ) | |
| with BytesIO() as wavContent: | |
| wavfile.write(wavContent, sr, audio) | |
| return Response(content=wavContent.getvalue(), media_type="audio/wav") | |
| class MultiSynthesisRequest(BaseModel): | |
| lines: list[SynthesisRequest] | |
| def multi_synthesis(request: MultiSynthesisRequest): | |
| lines = request.lines | |
| if args.line_count is not None and len(lines) > args.line_count: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"行数は{args.line_count}行以下にしてください。", | |
| ) | |
| audios = [] | |
| sr = None | |
| for i, req in enumerate(lines): | |
| if args.line_length is not None and len(req.text) > args.line_length: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"1行の文字数は{args.line_length}文字以下にしてください。", | |
| ) | |
| try: | |
| model = model_holder.get_model( | |
| model_name=req.model, model_path_str=req.modelFile | |
| ) | |
| except Exception as e: | |
| logger.error(e) | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to load model {req.model} from {req.modelFile}, {e}", | |
| ) | |
| text = req.text | |
| kata_tone_list = [ | |
| (mora_tone.mora, mora_tone.tone) for mora_tone in req.moraToneList | |
| ] | |
| phone_tone = kata_tone2phone_tone(kata_tone_list) | |
| tone = [t for _, t in phone_tone] | |
| sr, audio = model.infer( | |
| text=text, | |
| language=req.language, | |
| sdp_ratio=req.sdpRatio, | |
| noise=req.noise, | |
| noise_w=req.noisew, | |
| length=1 / req.speed, | |
| given_tone=tone, | |
| style=req.style, | |
| style_weight=req.styleWeight, | |
| assist_text=req.assistText, | |
| assist_text_weight=req.assistTextWeight, | |
| use_assist_text=bool(req.assistText), | |
| line_split=False, | |
| pitch_scale=req.pitchScale, | |
| intonation_scale=req.intonationScale, | |
| ) | |
| audios.append(audio) | |
| if i < len(lines) - 1: | |
| silence = int(sr * req.silenceAfter) | |
| audios.append(np.zeros(silence, dtype=np.int16)) | |
| audio = np.concatenate(audios) | |
| with BytesIO() as wavContent: | |
| wavfile.write(wavContent, sr, audio) | |
| return Response(content=wavContent.getvalue(), media_type="audio/wav") | |
| class UserDictWordRequest(BaseModel): | |
| surface: str | |
| pronunciation: str | |
| accent_type: int # アクセント核位置(存在しない場合は0、1文字目は1) | |
| priority: int = 5 | |
| def get_user_dict(): | |
| return read_dict() | |
| def add_user_dict_word(request: UserDictWordRequest): | |
| uuid = apply_word( | |
| surface=request.surface, | |
| pronunciation=request.pronunciation, | |
| accent_type=request.accent_type, | |
| priority=request.priority, | |
| ) | |
| update_dict() | |
| return JSONResponse( | |
| status_code=status.HTTP_201_CREATED, | |
| content={"uuid": uuid}, | |
| ) | |
| def update_user_dict_word(uuid: str, request: UserDictWordRequest): | |
| rewrite_word( | |
| word_uuid=uuid, | |
| surface=request.surface, | |
| pronunciation=request.pronunciation, | |
| accent_type=request.accent_type, | |
| priority=request.priority, | |
| ) | |
| update_dict() | |
| return JSONResponse(status_code=status.HTTP_200_OK, content={"uuid": uuid}) | |
| def delete_user_dict_word(uuid: str): | |
| delete_word(uuid) | |
| update_dict() | |
| return JSONResponse(status_code=status.HTTP_200_OK, content={"uuid": uuid}) | |
| app.include_router(router, prefix="/api") | |
| if __name__ == "__main__": | |
| if not skip_static_files: | |
| download_static_files("litagin02", "Style-Bert-VITS2-Editor", "out.zip") | |
| app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static") | |
| if args.inbrowser: | |
| webbrowser.open(f"http://localhost:{port}") | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |