ip-adapter-faceid / pipeline.py
revi13's picture
Update pipeline.py
e3513cb verified
import os
import io
import base64
import numpy as np
import torch
from torchvision import transforms
from diffusers import StableDiffusionPipeline
from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL
from insightface.app import FaceAnalysis
from PIL import Image
class IPFacePlusV2Pipeline:
def __init__(self, model_path):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.base_model_path = "runwayml/stable-diffusion-v1-5"
self.ip_adapter_ckpt = model_path
# 環境変数で cache_dir を事前に指定している前提
self.cache_dir = os.environ.get("HF_HUB_CACHE", "/tmp/hf/hub")
# Load base pipeline
self.pipe = StableDiffusionPipeline.from_pretrained(
self.base_model_path,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
cache_dir=self.cache_dir
).to(self.device)
self.pipe.to(self.device)
self.pipe.to(dtype=torch.float32) # ← この行を追加
# Load face embedding extractor
self.face_app = FaceAnalysis(
name="buffalo_l",
root=self.cache_dir,
providers=["CUDAExecutionProvider" if self.device == "cuda" else "CPUExecutionProvider"]
)
self.face_app.prepare(ctx_id=0 if self.device == "cuda" else -1)
# Load IP-Adapter
self.ip_adapter = IPAdapterFaceIDXL(
self.pipe,
self.ip_adapter_ckpt,
device=self.device
)
# Image transform (for completeness if needed)
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def generate_image(self, face_image_base64: str, prompt: str, scale: float = 1.0):
print(f"🚀 Torch device: {self.device}, CUDA available: {torch.cuda.is_available()}")
print(f"🚀 CUDA device name: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU only'}")
print("🧪 [pipeline] Decoding image...")
# Convert base64 to PIL image
image_data = base64.b64decode(face_image_base64.split(",")[-1])
image = Image.open(io.BytesIO(image_data)).convert("RGB")
print("🎨 [pipeline] 顔認識")
# Extract face embedding
faces = self.face_app.get(np.array(image))
if not faces:
raise RuntimeError("No face detected in the image.")
print("🎨 [pipeline] tensorタイプ変更")
# Convert face embedding to torch tensor and match dtype
face_embedding = torch.tensor(faces[0].normed_embedding).to(self.device)
print("🎨 [pipeline] dtype チェック")
# 安全な dtype チェック(pipe 自体に dtype が無ければ unet から取得)
pipe_dtype = getattr(self.pipe, "dtype", getattr(self.pipe, "unet", None).dtype)
if pipe_dtype == torch.float16:
face_embedding = face_embedding.to(dtype=torch.float16)
else:
face_embedding = face_embedding.to(dtype=torch.float32)
print("🎨 [pipeline] Generate image")
# Generate image
image = self.ip_adapter.generate(
prompt=prompt,
scale=scale,
faceid_embeds=face_embedding.unsqueeze(0), # バッチ次元を追加
height=1024, # ← 512から変更
width=1024 # ← 512から変更
)[0] # 最初の1枚を取得
print("🎨 [pipeline] Convert PIL image to base64")
# Convert PIL image to base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
encoded_img = base64.b64encode(buffered.getvalue()).decode("utf-8")
print("📸 [pipeline] Encoding result to base64...")
return f"data:image/png;base64,{encoded_img}"