|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import subprocess |
|
|
from pathlib import Path |
|
|
from typing import List, Optional |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
class VincieService: |
|
|
""" |
|
|
Serviço que: |
|
|
- garante que o repo VINCIE está presente |
|
|
- baixa dit.pth e vae.pth via hf_hub_download (local_dir) |
|
|
- cria symlink /app/VINCIE/ckpt/VINCIE-3B -> /app/ckpt/VINCIE-3B |
|
|
- executa main.py com overrides Hydra/YACS (multi-turn e multi-concept) |
|
|
- fornece fallback (shim) para apex.normalization se Apex não existir |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
repo_dir: str = "/app/VINCIE", |
|
|
ckpt_dir: str = "/app/ckpt/VINCIE-3B", |
|
|
python_bin: str = "python", |
|
|
repo_id: str = "ByteDance-Seed/VINCIE-3B", |
|
|
): |
|
|
self.repo_dir = Path(repo_dir) |
|
|
self.ckpt_dir = Path(ckpt_dir) |
|
|
self.python = python_bin |
|
|
self.repo_id = repo_id |
|
|
|
|
|
self.generate_yaml = self.repo_dir / "configs" / "generate.yaml" |
|
|
self.assets_dir = self.repo_dir / "assets" |
|
|
self.output_root = Path("/app/outputs") |
|
|
self.output_root.mkdir(parents=True, exist_ok=True) |
|
|
(self.repo_dir / "ckpt").mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
def ensure_repo(self, git_url: str = "https://github.com/ByteDance-Seed/VINCIE") -> None: |
|
|
"""Clona o repositório VINCIE se ainda não existir.""" |
|
|
if not self.repo_dir.exists(): |
|
|
subprocess.run(["git", "clone", git_url, str(self.repo_dir)], check=True) |
|
|
|
|
|
def ensure_model(self, hf_token: Optional[str] = None) -> None: |
|
|
""" |
|
|
Baixa apenas os arquivos necessários do repo ByteDance-Seed/VINCIE-3B: |
|
|
- dit.pth |
|
|
- vae.pth |
|
|
Usa hf_hub_download com local_dir e cria symlink de compatibilidade. |
|
|
""" |
|
|
self.ckpt_dir.mkdir(parents=True, exist_ok=True) |
|
|
token = hf_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") |
|
|
|
|
|
def _need(p: Path) -> bool: |
|
|
try: |
|
|
return not (p.exists() and p.stat().st_size > 1_000_000) |
|
|
except FileNotFoundError: |
|
|
return True |
|
|
|
|
|
for fname in ["dit.pth", "vae.pth"]: |
|
|
dst = self.ckpt_dir / fname |
|
|
if _need(dst): |
|
|
print(f"Baixando {fname} de {self.repo_id} ...") |
|
|
hf_hub_download( |
|
|
repo_id=self.repo_id, |
|
|
filename=fname, |
|
|
local_dir=str(self.ckpt_dir), |
|
|
token=token, |
|
|
force_download=False, |
|
|
local_files_only=False, |
|
|
) |
|
|
|
|
|
|
|
|
link = self.repo_dir / "ckpt" / "VINCIE-3B" |
|
|
try: |
|
|
if link.is_symlink() or link.exists(): |
|
|
try: |
|
|
link.unlink() |
|
|
except IsADirectoryError: |
|
|
pass |
|
|
if not link.exists(): |
|
|
link.symlink_to(self.ckpt_dir, target_is_directory=True) |
|
|
except Exception as e: |
|
|
print("Aviso: falha ao criar symlink de ckpt:", e) |
|
|
|
|
|
def ensure_apex(self, enable_shim: bool = True) -> None: |
|
|
""" |
|
|
Se Apex não estiver presente, injeta um shim mínimo para FusedRMSNorm/FusedLayerNorm |
|
|
usando torch.nn, evitando falhas de import nos caminhos que dependem de apex.normalization. |
|
|
""" |
|
|
try: |
|
|
import importlib |
|
|
importlib.import_module("apex.normalization") |
|
|
return |
|
|
except Exception: |
|
|
if not enable_shim: |
|
|
return |
|
|
|
|
|
shim_root = Path("/app/shims") |
|
|
apex_pkg = shim_root / "apex" |
|
|
apex_pkg.mkdir(parents=True, exist_ok=True) |
|
|
(apex_pkg / "__init__.py").write_text("from .normalization import *\n") |
|
|
|
|
|
(apex_pkg / "normalization.py").write_text( |
|
|
"import torch\n" |
|
|
"import torch.nn as nn\n" |
|
|
"\n" |
|
|
"class FusedRMSNorm(nn.Module):\n" |
|
|
" def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):\n" |
|
|
" super().__init__()\n" |
|
|
" self.mod = nn.RMSNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)\n" |
|
|
" def forward(self, x):\n" |
|
|
" return self.mod(x)\n" |
|
|
"\n" |
|
|
"class FusedLayerNorm(nn.Module):\n" |
|
|
" def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):\n" |
|
|
" super().__init__()\n" |
|
|
" self.mod = nn.LayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)\n" |
|
|
" def forward(self, x):\n" |
|
|
" return self.mod(x)\n" |
|
|
) |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(shim_root)) |
|
|
os.environ["PYTHONPATH"] = f"{str(shim_root)}:{os.environ.get('PYTHONPATH','')}" |
|
|
|
|
|
def ready(self) -> bool: |
|
|
"""Verifica se repo/config e checkpoints obrigatórios existem.""" |
|
|
have_repo = self.repo_dir.exists() and self.generate_yaml.exists() |
|
|
dit_ok = (self.ckpt_dir / "dit.pth").exists() |
|
|
vae_ok = (self.ckpt_dir / "vae.pth").exists() |
|
|
return bool(have_repo and dit_ok and vae_ok) |
|
|
|
|
|
|
|
|
|
|
|
def _run_vincie(self, overrides: List[str], work_output: Path) -> None: |
|
|
"""Executa main.py com overrides Hydra/YACS do VINCIE dentro do diretório do repo.""" |
|
|
work_output.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
cmd = [ |
|
|
self.python, |
|
|
"main.py", |
|
|
str(self.generate_yaml), |
|
|
*overrides, |
|
|
f"generation.output.dir={str(work_output)}", |
|
|
] |
|
|
env = os.environ.copy() |
|
|
subprocess.run(cmd, cwd=self.repo_dir, check=True, env=env) |
|
|
|
|
|
|
|
|
|
|
|
def multi_turn_edit( |
|
|
self, |
|
|
input_image: str, |
|
|
turns: List[str], |
|
|
out_dir_name: Optional[str] = None, |
|
|
) -> Path: |
|
|
""" |
|
|
Equivalente ao exemplo oficial: |
|
|
generation.positive_prompt.image_path=[...] |
|
|
generation.positive_prompt.prompts=[...] |
|
|
""" |
|
|
out_dir = self.output_root / (out_dir_name or f"multi_turn_{self._slug(input_image)}") |
|
|
image_json = json.dumps([str(input_image)]) |
|
|
prompts_json = json.dumps(turns) |
|
|
overrides = [ |
|
|
f"generation.positive_prompt.image_path={image_json}", |
|
|
f"generation.positive_prompt.prompts={prompts_json}", |
|
|
f"ckpt.path={str(self.ckpt_dir)}", |
|
|
] |
|
|
self._run_vincie(overrides, out_dir) |
|
|
return out_dir |
|
|
|
|
|
|
|
|
|
|
|
def multi_concept_compose( |
|
|
self, |
|
|
concept_images: List[str], |
|
|
concept_prompts: List[str], |
|
|
final_prompt: str, |
|
|
out_dir_name: Optional[str] = None, |
|
|
) -> Path: |
|
|
""" |
|
|
Usa image_path como lista de imagens de conceito e prompts = [p1, p2, ..., final], |
|
|
mantendo compatibilidade com o pipeline do VINCIE. |
|
|
""" |
|
|
out_dir = self.output_root / (out_dir_name or "multi_concept") |
|
|
imgs_json = json.dumps([str(p) for p in concept_images]) |
|
|
prompts_all = concept_prompts + [final_prompt] |
|
|
prompts_json = json.dumps(prompts_all) |
|
|
overrides = [ |
|
|
f"generation.positive_prompt.image_path={imgs_json}", |
|
|
f"generation.positive_prompt.prompts={prompts_json}", |
|
|
f"generation.pad_img_placehoder=False", |
|
|
f"ckpt.path={str(self.ckpt_dir)}", |
|
|
] |
|
|
self._run_vincie(overrides, out_dir) |
|
|
return out_dir |
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def _slug(path_or_text: str) -> str: |
|
|
p = Path(path_or_text) |
|
|
base = p.stem if p.exists() else str(path_or_text) |
|
|
keep = "".join(c if c.isalnum() or c in "-_." else "_" for c in str(base)) |
|
|
return keep[:64] |
|
|
|