File size: 3,885 Bytes
8bc2033
 
 
 
52b171e
 
 
 
 
 
 
 
9b6d335
52b171e
 
9b6d335
52b171e
8bc2033
 
 
52b171e
 
 
8bc2033
 
52b171e
 
1d5bcb3
 
 
52b171e
8bc2033
 
 
 
 
52b171e
 
 
 
 
 
 
 
 
8bc2033
52b171e
 
 
 
9509867
afb5de2
3f1043d
31967f1
 
3f1043d
afb5de2
 
 
 
3f1043d
afb5de2
 
 
 
 
3f1043d
afb5de2
 
74415da
3f1043d
74415da
 
 
 
 
 
 
afb5de2
3f1043d
afb5de2
 
 
 
1f8023d
e3513cb
 
74415da
afb5de2
3f1043d
afb5de2
 
 
 
3f1043d
74415da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import io
import base64
import numpy as np
import torch
from torchvision import transforms
from diffusers import StableDiffusionPipeline
from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDXL
from insightface.app import FaceAnalysis
from PIL import Image

class IPFacePlusV2Pipeline:
    def __init__(self, model_path):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.base_model_path = "runwayml/stable-diffusion-v1-5"
        self.ip_adapter_ckpt = model_path

        # 環境変数で cache_dir を事前に指定している前提
        self.cache_dir = os.environ.get("HF_HUB_CACHE", "/tmp/hf/hub")

        # Load base pipeline
        self.pipe = StableDiffusionPipeline.from_pretrained(
            self.base_model_path,
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
            cache_dir=self.cache_dir
        ).to(self.device)

        self.pipe.to(self.device)
        self.pipe.to(dtype=torch.float32)  # ← この行を追加

        # Load face embedding extractor
        self.face_app = FaceAnalysis(
            name="buffalo_l",
            root=self.cache_dir,
            providers=["CUDAExecutionProvider" if self.device == "cuda" else "CPUExecutionProvider"]
        )
        self.face_app.prepare(ctx_id=0 if self.device == "cuda" else -1)

        # Load IP-Adapter
        self.ip_adapter = IPAdapterFaceIDXL(
            self.pipe,
            self.ip_adapter_ckpt,
            device=self.device
        )

        # Image transform (for completeness if needed)
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        
    def generate_image(self, face_image_base64: str, prompt: str, scale: float = 1.0):
        
        print(f"🚀 Torch device: {self.device}, CUDA available: {torch.cuda.is_available()}")
        print(f"🚀 CUDA device name: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU only'}")
        print("🧪 [pipeline] Decoding image...")
        # Convert base64 to PIL image
        image_data = base64.b64decode(face_image_base64.split(",")[-1])
        image = Image.open(io.BytesIO(image_data)).convert("RGB")
    
        print("🎨 [pipeline] 顔認識")
        # Extract face embedding
        faces = self.face_app.get(np.array(image))
        if not faces:
            raise RuntimeError("No face detected in the image.")
    
        print("🎨 [pipeline] tensorタイプ変更")
        # Convert face embedding to torch tensor and match dtype
        face_embedding = torch.tensor(faces[0].normed_embedding).to(self.device)
    
        print("🎨 [pipeline] dtype チェック")
        # 安全な dtype チェック(pipe 自体に dtype が無ければ unet から取得)
        pipe_dtype = getattr(self.pipe, "dtype", getattr(self.pipe, "unet", None).dtype)
    
        if pipe_dtype == torch.float16:
            face_embedding = face_embedding.to(dtype=torch.float16)
        else:
            face_embedding = face_embedding.to(dtype=torch.float32)
    
        print("🎨 [pipeline] Generate image")
        # Generate image
        image = self.ip_adapter.generate(
            prompt=prompt,
            scale=scale,
            faceid_embeds=face_embedding.unsqueeze(0),  # バッチ次元を追加
            height=1024,  # ← 512から変更
            width=1024    # ← 512から変更
        )[0]  # 最初の1枚を取得
    
        print("🎨 [pipeline] Convert PIL image to base64")
        # Convert PIL image to base64
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        encoded_img = base64.b64encode(buffered.getvalue()).decode("utf-8")
        print("📸 [pipeline] Encoding result to base64...")
        return f"data:image/png;base64,{encoded_img}"