Spaces:
Running
Running
| import os | |
| import io | |
| import base64 | |
| import numpy as np | |
| import torch | |
| from torchvision import transforms | |
| from diffusers import StableDiffusionPipeline | |
| from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL | |
| from insightface.app import FaceAnalysis | |
| from PIL import Image | |
| class IPFacePlusV2Pipeline: | |
| def __init__(self, model_path): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.base_model_path = "runwayml/stable-diffusion-v1-5" | |
| self.ip_adapter_ckpt = model_path | |
| # 環境変数で cache_dir を事前に指定している前提 | |
| self.cache_dir = os.environ.get("HF_HUB_CACHE", "/tmp/hf/hub") | |
| # Load base pipeline | |
| self.pipe = StableDiffusionPipeline.from_pretrained( | |
| self.base_model_path, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| cache_dir=self.cache_dir | |
| ).to(self.device) | |
| self.pipe.to(self.device) | |
| self.pipe.to(dtype=torch.float32) # ← この行を追加 | |
| # Load face embedding extractor | |
| self.face_app = FaceAnalysis( | |
| name="buffalo_l", | |
| root=self.cache_dir, | |
| providers=["CUDAExecutionProvider" if self.device == "cuda" else "CPUExecutionProvider"] | |
| ) | |
| self.face_app.prepare(ctx_id=0 if self.device == "cuda" else -1) | |
| # Load IP-Adapter | |
| self.ip_adapter = IPAdapterFaceIDXL( | |
| self.pipe, | |
| self.ip_adapter_ckpt, | |
| device=self.device | |
| ) | |
| # Image transform (for completeness if needed) | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| def generate_image(self, face_image_base64: str, prompt: str, scale: float = 1.0): | |
| print(f"🚀 Torch device: {self.device}, CUDA available: {torch.cuda.is_available()}") | |
| print(f"🚀 CUDA device name: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU only'}") | |
| print("🧪 [pipeline] Decoding image...") | |
| # Convert base64 to PIL image | |
| image_data = base64.b64decode(face_image_base64.split(",")[-1]) | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| print("🎨 [pipeline] 顔認識") | |
| # Extract face embedding | |
| faces = self.face_app.get(np.array(image)) | |
| if not faces: | |
| raise RuntimeError("No face detected in the image.") | |
| print("🎨 [pipeline] tensorタイプ変更") | |
| # Convert face embedding to torch tensor and match dtype | |
| face_embedding = torch.tensor(faces[0].normed_embedding).to(self.device) | |
| print("🎨 [pipeline] dtype チェック") | |
| # 安全な dtype チェック(pipe 自体に dtype が無ければ unet から取得) | |
| pipe_dtype = getattr(self.pipe, "dtype", getattr(self.pipe, "unet", None).dtype) | |
| if pipe_dtype == torch.float16: | |
| face_embedding = face_embedding.to(dtype=torch.float16) | |
| else: | |
| face_embedding = face_embedding.to(dtype=torch.float32) | |
| print("🎨 [pipeline] Generate image") | |
| # Generate image | |
| image = self.ip_adapter.generate( | |
| prompt=prompt, | |
| scale=scale, | |
| faceid_embeds=face_embedding.unsqueeze(0), # バッチ次元を追加 | |
| height=1024, # ← 512から変更 | |
| width=1024 # ← 512から変更 | |
| )[0] # 最初の1枚を取得 | |
| print("🎨 [pipeline] Convert PIL image to base64") | |
| # Convert PIL image to base64 | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="PNG") | |
| encoded_img = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| print("📸 [pipeline] Encoding result to base64...") | |
| return f"data:image/png;base64,{encoded_img}" |