primerz commited on
Commit
aa69808
·
verified ·
1 Parent(s): 22858c3

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +71 -109
generator.py CHANGED
@@ -1,5 +1,9 @@
1
  """
2
  Generation logic for Pixagram AI Pixel Art Generator
 
 
 
 
3
  """
4
  import gc
5
  import torch
@@ -19,10 +23,10 @@ from utils import (
19
  draw_kps, get_demographic_description, calculate_optimal_size, enhance_face_crop
20
  )
21
  from models import (
22
- load_face_analysis, load_depth_detector, load_controlnets, load_image_encoder,
23
  load_sdxl_pipeline, load_loras, setup_ip_adapter,
24
- # --- START FIX: Import setup_cappella ---
25
- setup_cappella,
26
  # --- END FIX ---
27
  setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip,
28
  load_openpose_detector, load_mediapipe_face_detector
@@ -71,11 +75,9 @@ class RetroArtConverter:
71
  self.instantid_enabled = instantid_success
72
  self.models_loaded['instantid'] = instantid_success
73
 
74
- # Load image encoder
75
- if self.instantid_enabled:
76
- self.image_encoder = load_image_encoder()
77
- else:
78
- self.image_encoder = None
79
 
80
  # --- FIX START: Robust ControlNet Loading ---
81
  # Determine which controlnets to use
@@ -122,16 +124,18 @@ class RetroArtConverter:
122
  self.models_loaded['lora'] = lora_success
123
 
124
  # Setup IP-Adapter
125
- if self.instantid_active and self.image_encoder is not None: # <-- Check instantid_active
126
- self.image_proj_model, ip_adapter_success = setup_ip_adapter(self.pipe, self.image_encoder)
 
127
  self.models_loaded['ip_adapter'] = ip_adapter_success
 
128
  else:
129
- print("[INFO] Face preservation: IP-Adapter disabled (InstantID model failed or encoder failed)")
130
  self.models_loaded['ip_adapter'] = False
131
  self.image_proj_model = None
132
 
133
- # --- START FIX: Setup Cappella ---
134
- self.cappella, self.use_cappella = setup_cappella(self.pipe)
135
  # --- END FIX ---
136
 
137
  # Setup LCM scheduler
@@ -182,24 +186,21 @@ class RetroArtConverter:
182
 
183
  print("=== UPGRADE VERIFICATION ===")
184
  try:
185
- # --- FIX: Corrected import paths and class names ---
186
- from resampler import Resampler
187
- from attention_processor import IPAttnProcessor2_0
 
188
 
189
- resampler_check = isinstance(self.image_proj_model, Resampler) if hasattr(self, 'image_proj_model') and self.image_proj_model is not None else False
190
- custom_attn_check = any(isinstance(p, IPAttnProcessor2_0) for p in self.pipe.unet.attn_processors.values()) if hasattr(self, 'pipe') else False
191
- # --- END FIX ---
 
 
192
 
193
- print(f"Enhanced Perceiver Resampler: {'[OK] ACTIVE' if resampler_check else '[INFO] Not active'}")
194
- print(f"Enhanced IP-Adapter Attention: {'[OK] ACTIVE' if custom_attn_check else '[INFO] Not active'}")
 
195
 
196
- if resampler_check and custom_attn_check:
197
- print("[SUCCESS] Face preservation upgrade fully active")
198
- print(" Expected improvement: +10-15% face similarity")
199
- elif resampler_check or custom_attn_check:
200
- print("[PARTIAL] Some upgrades active")
201
- else:
202
- print("[INFO] Using standard components")
203
  except Exception as e:
204
  print(f"[INFO] Verification skipped: {e}")
205
  print("============================\n")
@@ -641,33 +642,16 @@ class RetroArtConverter:
641
  guidance_scale = adaptive_params['guidance_scale']
642
  lora_scale = adaptive_params['lora_scale']
643
 
644
- # Extract face embeddings
645
- face_embeddings_base = face.normed_embedding
 
 
646
 
647
  # Extract face crop
648
  bbox = face.bbox.astype(int)
649
  x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
650
  face_bbox_original = [x1, y1, x2, y2]
651
 
652
- # Add padding
653
- face_width = x2 - x1
654
- face_height = y2 - y1
655
- padding_x = int(face_width * 0.3)
656
- padding_y = int(face_height * 0.3)
657
- x1 = max(0, x1 - padding_x)
658
- y1 = max(0, y1 - padding_y)
659
- x2 = min(resized_image.width, x2 + padding_x)
660
- y2 = min(resized_image.height, y2 + padding_y)
661
-
662
- # Crop face region
663
- face_crop = resized_image.crop((x1, y1, x2, y2))
664
-
665
- # MULTI-SCALE PROCESSING
666
- face_embeddings = self.extract_multi_scale_face(face_crop, face)
667
-
668
- # Enhance face crop
669
- face_crop_enhanced = enhance_face_crop(face_crop)
670
-
671
  # Draw keypoints
672
  face_kps = face.kps
673
  face_kps_image = draw_kps(resized_image, face_kps)
@@ -677,7 +661,7 @@ class RetroArtConverter:
677
  facial_attrs = get_facial_attributes(face)
678
 
679
  # Update prompt with detected attributes
680
- prompt = build_enhanced_prompt(prompt, facial_attrs, TRIGGER_WORD[lora_choice])
681
 
682
  # Legacy output for compatibility
683
  age = facial_attrs['age']
@@ -686,7 +670,7 @@ class RetroArtConverter:
686
 
687
  gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A')
688
  print(f"Face info: bbox={face.bbox}, age={age if age else 'N/A'}, gender={gender_str}")
689
- print(f"Face crop size: {face_crop.size}, enhanced: {face_crop_enhanced.size if face_crop_enhanced else 'N/A'}")
690
  else:
691
  print("✗ InsightFace found no faces")
692
 
@@ -745,15 +729,20 @@ class RetroArtConverter:
745
  if adapter_name != "none" and self.loaded_loras.get(adapter_name, False):
746
  try:
747
  self.pipe.set_adapters([adapter_name], adapter_weights=[lora_scale])
748
- print(f"LORA: Set adapter '{adapter_name}' with scale: {lora_scale}")
 
 
749
  except Exception as e:
750
- print(f"Could not set LORA adapter '{adapter_name}': {e}")
 
751
  self.pipe.set_adapters([]) # Disable LORAs if setting failed
752
  else:
753
  if adapter_name == "none":
754
  print("LORAs disabled by user choice.")
755
  else:
756
  print(f"LORA '{adapter_name}' not loaded or available, disabling LORAs.")
 
 
757
  self.pipe.set_adapters([]) # Disable all LORAs
758
 
759
 
@@ -777,28 +766,33 @@ class RetroArtConverter:
777
 
778
  pipe_kwargs["generator"] = generator
779
 
780
- # --- START FIX: Use our new Cappella module ---
781
- if self.use_cappella and self.cappella is not None:
782
  try:
783
- print("Encoding prompts with Cappella...")
784
 
785
- # Call Cappella once. It handles truncation and padding.
786
- conditioning = self.cappella(prompt, negative_prompt)
 
 
787
 
788
- # Unpack the results
789
- pipe_kwargs["prompt_embeds"] = conditioning.embeds
790
- pipe_kwargs["pooled_prompt_embeds"] = conditioning.pooled_embeds
791
- pipe_kwargs["negative_prompt_embeds"] = conditioning.negative_embeds
792
- pipe_kwargs["negative_pooled_prompt_embeds"] = conditioning.negative_pooled_embeds
793
 
794
- print(f"[OK] Cappella encoded - Prompt: {pipe_kwargs['prompt_embeds'].shape}, Negative: {pipe_kwargs['negative_prompt_embeds'].shape}")
 
 
 
 
 
795
  except Exception as e:
796
- print(f"Cappella encoding failed, using standard prompts: {e}")
797
  traceback.print_exc()
798
  pipe_kwargs["prompt"] = prompt
799
  pipe_kwargs["negative_prompt"] = negative_prompt
800
  else:
801
- print("[WARNING] Cappella not found, using standard prompt encoding.")
802
  pipe_kwargs["prompt"] = prompt
803
  pipe_kwargs["negative_prompt"] = negative_prompt
804
  # --- END FIX ---
@@ -831,53 +825,21 @@ class RetroArtConverter:
831
  conditioning_scales.append(identity_control_scale)
832
  scale_debug_str.append(f"Identity: {identity_control_scale:.2f}")
833
 
834
- # Add face embeddings for IP-Adapter if available
835
- if face_embeddings is not None and self.models_loaded.get('ip_adapter', False) and face_crop_enhanced is not None:
836
- print(f"Processing InstantID face embeddings with Resampler...")
837
 
838
- with torch.no_grad():
839
- face_emb_tensor = torch.from_numpy(face_embeddings).to(device=self.device, dtype=self.dtype)
840
- face_emb_tensor = face_emb_tensor.reshape(1, -1, 512)
841
- face_proj_embeds = self.image_proj_model(face_emb_tensor)
842
-
843
- boosted_scale = identity_preservation * IDENTITY_BOOST_MULTIPLIER
844
- face_proj_embeds = face_proj_embeds * boosted_scale
845
-
846
- print(f" - Face embedding: {face_proj_embeds.shape}, Scale: {boosted_scale:.2f}")
847
-
848
- # --- START FIX: Your padding solution ---
849
- # This fixes the "109 vs 77" error
850
- if 'prompt_embeds' in pipe_kwargs:
851
- original_embeds = pipe_kwargs['prompt_embeds']
852
-
853
- # Concatenate face embeddings to POSITIVE prompt
854
- combined_embeds = torch.cat([original_embeds, face_proj_embeds], dim=1)
855
- pipe_kwargs['prompt_embeds'] = combined_embeds
856
-
857
- # CRITICAL: Pad negative_prompt_embeds by the same amount
858
- if 'negative_prompt_embeds' in pipe_kwargs:
859
- negative_embeds = pipe_kwargs['negative_prompt_embeds']
860
- # Create zero padding [1, 16, 2048]
861
- neg_padding = torch.zeros(
862
- (
863
- negative_embeds.shape[0], # 1
864
- face_proj_embeds.shape[1], # 16
865
- negative_embeds.shape[2], # 2048
866
- ),
867
- device=negative_embeds.device,
868
- dtype=negative_embeds.dtype
869
- )
870
- # Concatenate zero padding to NEGATIVE prompt
871
- pipe_kwargs['negative_prompt_embeds'] = torch.cat([negative_embeds, neg_padding], dim=1)
872
- print(f" [OK] Negative prompt padded to match: {pipe_kwargs['negative_prompt_embeds'].shape}")
873
-
874
- print(f" [OK] Face embeddings concatenated successfully! Prompt: {combined_embeds.shape}")
875
- else:
876
- print(f" [WARNING] Can't concatenate - no prompt_embeds (use Cappella)")
877
- # --- END FIX 2 ---
878
 
879
  elif has_detected_faces:
880
  print(" Face detected but IP-Adapter/embeddings unavailable, using keypoints only")
 
881
 
882
  else:
883
  # No face detected - blank map needed to maintain ControlNet list order
 
1
  """
2
  Generation logic for Pixagram AI Pixel Art Generator
3
+ --- UPGRADED VERSION ---
4
+ - Uses StableDiffusionXLInstantIDImg2ImgPipeline for native InstantID support.
5
+ - Replaces broken 'cappella' encoder with 'Compel' for robust prompt chunking.
6
+ - Fixes LoRA style conflicts by using the correct pipeline architecture.
7
  """
8
  import gc
9
  import torch
 
23
  draw_kps, get_demographic_description, calculate_optimal_size, enhance_face_crop
24
  )
25
  from models import (
26
+ load_face_analysis, load_depth_detector, load_controlnets,
27
  load_sdxl_pipeline, load_loras, setup_ip_adapter,
28
+ # --- START FIX: Import setup_compel ---
29
+ setup_compel,
30
  # --- END FIX ---
31
  setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip,
32
  load_openpose_detector, load_mediapipe_face_detector
 
75
  self.instantid_enabled = instantid_success
76
  self.models_loaded['instantid'] = instantid_success
77
 
78
+ # --- FIX: Image encoder is loaded by pipeline ---
79
+ self.image_encoder = None
80
+ # --- END FIX ---
 
 
81
 
82
  # --- FIX START: Robust ControlNet Loading ---
83
  # Determine which controlnets to use
 
124
  self.models_loaded['lora'] = lora_success
125
 
126
  # Setup IP-Adapter
127
+ if self.instantid_active:
128
+ # The new setup_ip_adapter loads it *into* the pipe.
129
+ _ , ip_adapter_success = setup_ip_adapter(self.pipe)
130
  self.models_loaded['ip_adapter'] = ip_adapter_success
131
+ self.image_proj_model = None # No longer managed here
132
  else:
133
+ print("[INFO] Face preservation: IP-Adapter disabled (InstantID model failed)")
134
  self.models_loaded['ip_adapter'] = False
135
  self.image_proj_model = None
136
 
137
+ # --- START FIX: Setup Compel ---
138
+ self.compel, self.use_compel = setup_compel(self.pipe)
139
  # --- END FIX ---
140
 
141
  # Setup LCM scheduler
 
186
 
187
  print("=== UPGRADE VERIFICATION ===")
188
  try:
189
+ # --- FIX: Check if the correct pipeline is loaded ---
190
+ correct_pipeline = "StableDiffusionXLInstantIDImg2ImgPipeline"
191
+ pipeline_class_name = self.pipe.__class__.__name__
192
+ pipeline_check = correct_pipeline in pipeline_class_name
193
 
194
+ print(f"Pipeline Type: {pipeline_class_name}")
195
+ if pipeline_check:
196
+ print("[SUCCESS] Correct InstantID pipeline is active.")
197
+ else:
198
+ print(f"[WARNING] Incorrect pipeline active. Expected {correct_pipeline}")
199
 
200
+ compel_check = hasattr(self, 'compel') and self.compel is not None
201
+ print(f"Prompt Encoder: {'[OK] Compel' if compel_check else '[WARNING] Compel not loaded'}")
202
+ # --- END FIX ---
203
 
 
 
 
 
 
 
 
204
  except Exception as e:
205
  print(f"[INFO] Verification skipped: {e}")
206
  print("============================\n")
 
642
  guidance_scale = adaptive_params['guidance_scale']
643
  lora_scale = adaptive_params['lora_scale']
644
 
645
+ # --- FIX: Use raw embedding as required by InstantID pipeline ---
646
+ face_embeddings = face.normed_embedding
647
+ face_crop_enhanced = None # Not needed by this pipeline
648
+ # --- END FIX ---
649
 
650
  # Extract face crop
651
  bbox = face.bbox.astype(int)
652
  x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
653
  face_bbox_original = [x1, y1, x2, y2]
654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
  # Draw keypoints
656
  face_kps = face.kps
657
  face_kps_image = draw_kps(resized_image, face_kps)
 
661
  facial_attrs = get_facial_attributes(face)
662
 
663
  # Update prompt with detected attributes
664
+ prompt = build_enhanced_prompt(prompt, facial_attrs, TRIGGER_WORD.get(lora_choice, ""))
665
 
666
  # Legacy output for compatibility
667
  age = facial_attrs['age']
 
670
 
671
  gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A')
672
  print(f"Face info: bbox={face.bbox}, age={age if age else 'N/A'}, gender={gender_str}")
673
+ print(f"Face crop size: N/A, enhanced: N/A")
674
  else:
675
  print("✗ InsightFace found no faces")
676
 
 
729
  if adapter_name != "none" and self.loaded_loras.get(adapter_name, False):
730
  try:
731
  self.pipe.set_adapters([adapter_name], adapter_weights=[lora_scale])
732
+ # --- FIX: Fuse LoRA weights for correct interaction with IP-Adapter ---
733
+ self.pipe.fuse_lora(lora_scale=lora_scale, adapter_names=[adapter_name])
734
+ print(f"LORA: Fused adapter '{adapter_name}' with scale: {lora_scale}")
735
  except Exception as e:
736
+ print(f"Could not set/fuse LORA adapter '{adapter_name}': {e}")
737
+ self.pipe.unfuse_lora()
738
  self.pipe.set_adapters([]) # Disable LORAs if setting failed
739
  else:
740
  if adapter_name == "none":
741
  print("LORAs disabled by user choice.")
742
  else:
743
  print(f"LORA '{adapter_name}' not loaded or available, disabling LORAs.")
744
+ # --- FIX: Unfuse any previously fused LoRAs ---
745
+ self.pipe.unfuse_lora()
746
  self.pipe.set_adapters([]) # Disable all LORAs
747
 
748
 
 
766
 
767
  pipe_kwargs["generator"] = generator
768
 
769
+ # --- START FIX: Use Compel ---
770
+ if self.use_compel and self.compel is not None:
771
  try:
772
+ print("Encoding prompts with Compel...")
773
 
774
+ # Encode positive prompt
775
+ conditioning, pooled = self.compel(prompt)
776
+ pipe_kwargs["prompt_embeds"] = conditioning
777
+ pipe_kwargs["pooled_prompt_embeds"] = pooled
778
 
779
+ # Encode negative prompt
780
+ if not negative_prompt or not negative_prompt.strip():
781
+ negative_prompt = "" # Compel must encode something
 
 
782
 
783
+ negative_conditioning, negative_pooled = self.compel(negative_prompt)
784
+ pipe_kwargs["negative_prompt_embeds"] = negative_conditioning
785
+ pipe_kwargs["negative_pooled_prompt_embeds"] = negative_pooled
786
+
787
+ print(f"[OK] Compel encoded - Prompt: {conditioning.shape}")
788
+
789
  except Exception as e:
790
+ print(f"Compel encoding failed, using standard prompts: {e}")
791
  traceback.print_exc()
792
  pipe_kwargs["prompt"] = prompt
793
  pipe_kwargs["negative_prompt"] = negative_prompt
794
  else:
795
+ print("[WARNING] Compel not found, using standard prompt encoding.")
796
  pipe_kwargs["prompt"] = prompt
797
  pipe_kwargs["negative_prompt"] = negative_prompt
798
  # --- END FIX ---
 
825
  conditioning_scales.append(identity_control_scale)
826
  scale_debug_str.append(f"Identity: {identity_control_scale:.2f}")
827
 
828
+ # --- START FIX: Pass raw face embedding to pipeline ---
829
+ if face_embeddings is not None and self.models_loaded.get('ip_adapter', False):
830
+ print(f"Adding InstantID face embeddings (raw)...")
831
 
832
+ # The pipeline expects the raw [1, 512] embedding
833
+ face_emb_tensor = torch.from_numpy(face_embeddings).to(device=self.device, dtype=self.dtype)
834
+ pipe_kwargs["image_embeds"] = face_emb_tensor
835
+
836
+ # Set the IP-Adapter scale (face preservation)
837
+ self.pipe.set_ip_adapter_scale(identity_preservation)
838
+ print(f" - IP-Adapter scale set to: {identity_preservation:.2f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
 
840
  elif has_detected_faces:
841
  print(" Face detected but IP-Adapter/embeddings unavailable, using keypoints only")
842
+ # --- END FIX ---
843
 
844
  else:
845
  # No face detected - blank map needed to maintain ControlNet list order