#!/usr/bin/env python3 import os import sys import json import subprocess from pathlib import Path from typing import List, Optional from huggingface_hub import hf_hub_download class VincieService: """ Serviço que: - garante que o repo VINCIE está presente - baixa dit.pth e vae.pth via hf_hub_download (local_dir) - cria symlink /app/VINCIE/ckpt/VINCIE-3B -> /app/ckpt/VINCIE-3B - executa main.py com overrides Hydra/YACS (multi-turn e multi-concept) - fornece fallback (shim) para apex.normalization se Apex não existir """ def __init__( self, repo_dir: str = "/app/VINCIE", ckpt_dir: str = "/app/ckpt/VINCIE-3B", python_bin: str = "python", repo_id: str = "ByteDance-Seed/VINCIE-3B", ): self.repo_dir = Path(repo_dir) self.ckpt_dir = Path(ckpt_dir) self.python = python_bin self.repo_id = repo_id self.generate_yaml = self.repo_dir / "configs" / "generate.yaml" self.assets_dir = self.repo_dir / "assets" self.output_root = Path("/app/outputs") self.output_root.mkdir(parents=True, exist_ok=True) (self.repo_dir / "ckpt").mkdir(parents=True, exist_ok=True) # ---------- Setup ---------- def ensure_repo(self, git_url: str = "https://github.com/ByteDance-Seed/VINCIE") -> None: """Clona o repositório VINCIE se ainda não existir.""" if not self.repo_dir.exists(): subprocess.run(["git", "clone", git_url, str(self.repo_dir)], check=True) def ensure_model(self, hf_token: Optional[str] = None) -> None: """ Baixa apenas os arquivos necessários do repo ByteDance-Seed/VINCIE-3B: - dit.pth - vae.pth Usa hf_hub_download com local_dir e cria symlink de compatibilidade. """ self.ckpt_dir.mkdir(parents=True, exist_ok=True) token = hf_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") def _need(p: Path) -> bool: try: return not (p.exists() and p.stat().st_size > 1_000_000) except FileNotFoundError: return True for fname in ["dit.pth", "vae.pth"]: dst = self.ckpt_dir / fname if _need(dst): print(f"Baixando {fname} de {self.repo_id} ...") hf_hub_download( repo_id=self.repo_id, filename=fname, local_dir=str(self.ckpt_dir), token=token, force_download=False, local_files_only=False, ) # Symlink de compatibilidade para caminhos relativos do repo link = self.repo_dir / "ckpt" / "VINCIE-3B" try: if link.is_symlink() or link.exists(): try: link.unlink() except IsADirectoryError: pass if not link.exists(): link.symlink_to(self.ckpt_dir, target_is_directory=True) except Exception as e: print("Aviso: falha ao criar symlink de ckpt:", e) def ensure_apex(self, enable_shim: bool = True) -> None: """ Se Apex não estiver presente, injeta um shim mínimo para FusedRMSNorm/FusedLayerNorm usando torch.nn, evitando falhas de import nos caminhos que dependem de apex.normalization. """ try: import importlib importlib.import_module("apex.normalization") return except Exception: if not enable_shim: return shim_root = Path("/app/shims") apex_pkg = shim_root / "apex" apex_pkg.mkdir(parents=True, exist_ok=True) (apex_pkg / "__init__.py").write_text("from .normalization import *\n") (apex_pkg / "normalization.py").write_text( "import torch\n" "import torch.nn as nn\n" "\n" "class FusedRMSNorm(nn.Module):\n" " def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):\n" " super().__init__()\n" " self.mod = nn.RMSNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)\n" " def forward(self, x):\n" " return self.mod(x)\n" "\n" "class FusedLayerNorm(nn.Module):\n" " def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):\n" " super().__init__()\n" " self.mod = nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)\n" " def forward(self, x):\n" " return self.mod(x)\n" ) # Tornar o shim visível neste processo e nos subprocessos sys.path.insert(0, str(shim_root)) os.environ["PYTHONPATH"] = f"{str(shim_root)}:{os.environ.get('PYTHONPATH','')}" def ready(self) -> bool: """Verifica se repo/config e checkpoints obrigatórios existem.""" have_repo = self.repo_dir.exists() and self.generate_yaml.exists() dit_ok = (self.ckpt_dir / "dit.pth").exists() vae_ok = (self.ckpt_dir / "vae.pth").exists() return bool(have_repo and dit_ok and vae_ok) # ---------- Core runner ---------- def _run_vincie(self, overrides: List[str], work_output: Path) -> None: """Executa main.py com overrides Hydra/YACS do VINCIE dentro do diretório do repo.""" work_output.mkdir(parents=True, exist_ok=True) cmd = [ self.python, "main.py", str(self.generate_yaml), *overrides, f"generation.output.dir={str(work_output)}", ] env = os.environ.copy() subprocess.run(cmd, cwd=self.repo_dir, check=True, env=env) # ---------- Multi-turn editing ---------- def multi_turn_edit( self, input_image: str, turns: List[str], out_dir_name: Optional[str] = None, ) -> Path: """ Equivalente ao exemplo oficial: generation.positive_prompt.image_path=[...] generation.positive_prompt.prompts=[...] """ out_dir = self.output_root / (out_dir_name or f"multi_turn_{self._slug(input_image)}") image_json = json.dumps([str(input_image)]) prompts_json = json.dumps(turns) overrides = [ f"generation.positive_prompt.image_path={image_json}", f"generation.positive_prompt.prompts={prompts_json}", f"ckpt.path={str(self.ckpt_dir)}", ] self._run_vincie(overrides, out_dir) return out_dir # ---------- Multi-concept composition ---------- def multi_concept_compose( self, concept_images: List[str], concept_prompts: List[str], final_prompt: str, out_dir_name: Optional[str] = None, ) -> Path: """ Usa image_path como lista de imagens de conceito e prompts = [p1, p2, ..., final], mantendo compatibilidade com o pipeline do VINCIE. """ out_dir = self.output_root / (out_dir_name or "multi_concept") imgs_json = json.dumps([str(p) for p in concept_images]) prompts_all = concept_prompts + [final_prompt] prompts_json = json.dumps(prompts_all) overrides = [ f"generation.positive_prompt.image_path={imgs_json}", f"generation.positive_prompt.prompts={prompts_json}", f"generation.pad_img_placehoder=False", f"ckpt.path={str(self.ckpt_dir)}", ] self._run_vincie(overrides, out_dir) return out_dir # ---------- Helpers ---------- @staticmethod def _slug(path_or_text: str) -> str: p = Path(path_or_text) base = p.stem if p.exists() else str(path_or_text) keep = "".join(c if c.isalnum() or c in "-_." else "_" for c in str(base)) return keep[:64]