EuuIia commited on
Commit
eee1c29
·
verified ·
1 Parent(s): 53be96b

Update api/ltx_server.py

Browse files
Files changed (1) hide show
  1. api/ltx_server.py +13 -4
api/ltx_server.py CHANGED
@@ -770,7 +770,6 @@ class VideoService:
770
  print("".join(traceback.format_exception(type(e), e, e.__traceback__)))
771
  raise
772
 
773
-
774
  # ltx_server.py
775
 
776
  def generate(
@@ -837,6 +836,17 @@ class VideoService:
837
  conditioning_items.append(ConditioningItem(end_tensor, last_frame_index, float(end_image_weight)))
838
  print(f"[DEBUG] Conditioning items: {len(conditioning_items)}")
839
 
 
 
 
 
 
 
 
 
 
 
 
840
  call_kwargs = {
841
  "prompt": prompt,
842
  "negative_prompt": negative_prompt,
@@ -857,7 +867,7 @@ class VideoService:
857
  "mixed_precision": (self.config.get("precision") == "mixed_precision"),
858
  "offload_to_cpu": False,
859
  "enhance_prompt": False,
860
- "skip_layer_strategy": SkipLayerStrategy[self.config.get("stg_mode", "AttentionValues")],
861
  }
862
  print(f"[DEBUG] output_type={call_kwargs['output_type']} skip_layer_strategy={call_kwargs['skip_layer_strategy']}")
863
 
@@ -998,7 +1008,7 @@ class VideoService:
998
  else:
999
  lat_b1, lat_b2 = None, None
1000
 
1001
- latents_parts = [p for p in [lat_a1, lat_a2, lat_b1, lat_b2] if p is not None]
1002
  if not latents_parts:
1003
  latents_parts = [latents_cpu]
1004
 
@@ -1075,6 +1085,5 @@ class VideoService:
1075
  except Exception as e:
1076
  print(f"[DEBUG] finalize() no finally falhou: {e}")
1077
 
1078
-
1079
  print("Criando instância do VideoService. O carregamento do modelo começará agora...")
1080
  video_generation_service = VideoService()
 
770
  print("".join(traceback.format_exception(type(e), e, e.__traceback__)))
771
  raise
772
 
 
773
  # ltx_server.py
774
 
775
  def generate(
 
836
  conditioning_items.append(ConditioningItem(end_tensor, last_frame_index, float(end_image_weight)))
837
  print(f"[DEBUG] Conditioning items: {len(conditioning_items)}")
838
 
839
+ # --- LÓGICA DE CONVERSÃO DO STG_MODE ---
840
+ stg_mode_str = self.config.get("stg_mode", "attention_values")
841
+ stg_mode_map = {
842
+ "attention_values": "AttentionValues",
843
+ "attention_skip": "AttentionSkip",
844
+ "residual": "Residual",
845
+ "transformer_block": "TransformerBlock"
846
+ }
847
+ stg_mode_enum_key = stg_mode_map.get(stg_mode_str.lower(), "AttentionValues")
848
+ # --- FIM DA LÓGICA DE CONVERSÃO ---
849
+
850
  call_kwargs = {
851
  "prompt": prompt,
852
  "negative_prompt": negative_prompt,
 
867
  "mixed_precision": (self.config.get("precision") == "mixed_precision"),
868
  "offload_to_cpu": False,
869
  "enhance_prompt": False,
870
+ "skip_layer_strategy": SkipLayerStrategy[stg_mode_enum_key],
871
  }
872
  print(f"[DEBUG] output_type={call_kwargs['output_type']} skip_layer_strategy={call_kwargs['skip_layer_strategy']}")
873
 
 
1008
  else:
1009
  lat_b1, lat_b2 = None, None
1010
 
1011
+ latents_parts = [p for p in [lat_a1, lat_a2, lat_b1, lat_b2] if p is not None and p.shape[2] > 1]
1012
  if not latents_parts:
1013
  latents_parts = [latents_cpu]
1014
 
 
1085
  except Exception as e:
1086
  print(f"[DEBUG] finalize() no finally falhou: {e}")
1087
 
 
1088
  print("Criando instância do VideoService. O carregamento do modelo começará agora...")
1089
  video_generation_service = VideoService()