Xxx / vincie_service.py
XCarleX's picture
Upload 6 files
c8a837a verified
#!/usr/bin/env python3
import os
import sys
import json
import subprocess
from pathlib import Path
from typing import List, Optional
from huggingface_hub import hf_hub_download
class VincieService:
"""
Serviço que:
- garante que o repo VINCIE está presente
- baixa dit.pth e vae.pth via hf_hub_download (local_dir)
- cria symlink /app/VINCIE/ckpt/VINCIE-3B -> /app/ckpt/VINCIE-3B
- executa main.py com overrides Hydra/YACS (multi-turn e multi-concept)
- fornece fallback (shim) para apex.normalization se Apex não existir
"""
def __init__(
self,
repo_dir: str = "/app/VINCIE",
ckpt_dir: str = "/app/ckpt/VINCIE-3B",
python_bin: str = "python",
repo_id: str = "ByteDance-Seed/VINCIE-3B",
):
self.repo_dir = Path(repo_dir)
self.ckpt_dir = Path(ckpt_dir)
self.python = python_bin
self.repo_id = repo_id
self.generate_yaml = self.repo_dir / "configs" / "generate.yaml"
self.assets_dir = self.repo_dir / "assets"
self.output_root = Path("/app/outputs")
self.output_root.mkdir(parents=True, exist_ok=True)
(self.repo_dir / "ckpt").mkdir(parents=True, exist_ok=True)
# ---------- Setup ----------
def ensure_repo(self, git_url: str = "https://github.com/ByteDance-Seed/VINCIE") -> None:
"""Clona o repositório VINCIE se ainda não existir."""
if not self.repo_dir.exists():
subprocess.run(["git", "clone", git_url, str(self.repo_dir)], check=True)
def ensure_model(self, hf_token: Optional[str] = None) -> None:
"""
Baixa apenas os arquivos necessários do repo ByteDance-Seed/VINCIE-3B:
- dit.pth
- vae.pth
Usa hf_hub_download com local_dir e cria symlink de compatibilidade.
"""
self.ckpt_dir.mkdir(parents=True, exist_ok=True)
token = hf_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
def _need(p: Path) -> bool:
try:
return not (p.exists() and p.stat().st_size > 1_000_000)
except FileNotFoundError:
return True
for fname in ["dit.pth", "vae.pth"]:
dst = self.ckpt_dir / fname
if _need(dst):
print(f"Baixando {fname} de {self.repo_id} ...")
hf_hub_download(
repo_id=self.repo_id,
filename=fname,
local_dir=str(self.ckpt_dir),
token=token,
force_download=False,
local_files_only=False,
)
# Symlink de compatibilidade para caminhos relativos do repo
link = self.repo_dir / "ckpt" / "VINCIE-3B"
try:
if link.is_symlink() or link.exists():
try:
link.unlink()
except IsADirectoryError:
pass
if not link.exists():
link.symlink_to(self.ckpt_dir, target_is_directory=True)
except Exception as e:
print("Aviso: falha ao criar symlink de ckpt:", e)
def ensure_apex(self, enable_shim: bool = True) -> None:
"""
Se Apex não estiver presente, injeta um shim mínimo para FusedRMSNorm/FusedLayerNorm
usando torch.nn, evitando falhas de import nos caminhos que dependem de apex.normalization.
"""
try:
import importlib
importlib.import_module("apex.normalization")
return
except Exception:
if not enable_shim:
return
shim_root = Path("/app/shims")
apex_pkg = shim_root / "apex"
apex_pkg.mkdir(parents=True, exist_ok=True)
(apex_pkg / "__init__.py").write_text("from .normalization import *\n")
(apex_pkg / "normalization.py").write_text(
"import torch\n"
"import torch.nn as nn\n"
"\n"
"class FusedRMSNorm(nn.Module):\n"
" def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):\n"
" super().__init__()\n"
" self.mod = nn.RMSNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)\n"
" def forward(self, x):\n"
" return self.mod(x)\n"
"\n"
"class FusedLayerNorm(nn.Module):\n"
" def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):\n"
" super().__init__()\n"
" self.mod = nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)\n"
" def forward(self, x):\n"
" return self.mod(x)\n"
)
# Tornar o shim visível neste processo e nos subprocessos
sys.path.insert(0, str(shim_root))
os.environ["PYTHONPATH"] = f"{str(shim_root)}:{os.environ.get('PYTHONPATH','')}"
def ready(self) -> bool:
"""Verifica se repo/config e checkpoints obrigatórios existem."""
have_repo = self.repo_dir.exists() and self.generate_yaml.exists()
dit_ok = (self.ckpt_dir / "dit.pth").exists()
vae_ok = (self.ckpt_dir / "vae.pth").exists()
return bool(have_repo and dit_ok and vae_ok)
# ---------- Core runner ----------
def _run_vincie(self, overrides: List[str], work_output: Path) -> None:
"""Executa main.py com overrides Hydra/YACS do VINCIE dentro do diretório do repo."""
work_output.mkdir(parents=True, exist_ok=True)
cmd = [
self.python,
"main.py",
str(self.generate_yaml),
*overrides,
f"generation.output.dir={str(work_output)}",
]
env = os.environ.copy()
subprocess.run(cmd, cwd=self.repo_dir, check=True, env=env)
# ---------- Multi-turn editing ----------
def multi_turn_edit(
self,
input_image: str,
turns: List[str],
out_dir_name: Optional[str] = None,
) -> Path:
"""
Equivalente ao exemplo oficial:
generation.positive_prompt.image_path=[...]
generation.positive_prompt.prompts=[...]
"""
out_dir = self.output_root / (out_dir_name or f"multi_turn_{self._slug(input_image)}")
image_json = json.dumps([str(input_image)])
prompts_json = json.dumps(turns)
overrides = [
f"generation.positive_prompt.image_path={image_json}",
f"generation.positive_prompt.prompts={prompts_json}",
f"ckpt.path={str(self.ckpt_dir)}",
]
self._run_vincie(overrides, out_dir)
return out_dir
# ---------- Multi-concept composition ----------
def multi_concept_compose(
self,
concept_images: List[str],
concept_prompts: List[str],
final_prompt: str,
out_dir_name: Optional[str] = None,
) -> Path:
"""
Usa image_path como lista de imagens de conceito e prompts = [p1, p2, ..., final],
mantendo compatibilidade com o pipeline do VINCIE.
"""
out_dir = self.output_root / (out_dir_name or "multi_concept")
imgs_json = json.dumps([str(p) for p in concept_images])
prompts_all = concept_prompts + [final_prompt]
prompts_json = json.dumps(prompts_all)
overrides = [
f"generation.positive_prompt.image_path={imgs_json}",
f"generation.positive_prompt.prompts={prompts_json}",
f"generation.pad_img_placehoder=False",
f"ckpt.path={str(self.ckpt_dir)}",
]
self._run_vincie(overrides, out_dir)
return out_dir
# ---------- Helpers ----------
@staticmethod
def _slug(path_or_text: str) -> str:
p = Path(path_or_text)
base = p.stem if p.exists() else str(path_or_text)
keep = "".join(c if c.isalnum() or c in "-_." else "_" for c in str(base))
return keep[:64]