Yinhong Liu commited on
Commit
eab8699
·
1 Parent(s): 4dbbaff
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -32,16 +32,14 @@ MODEL_OPTIONS = {
32
 
33
  def load_model(model_choice):
34
  model_repo_id = MODEL_OPTIONS[model_choice]
35
- # pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
36
- if model_choice == "Sana":
37
  pipe = SanaPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
38
- elif model_choice == "SD3":
39
- pipe = StableDiffusion3Pipeline.from_pretrained(
40
- model_repo_id, torch_dtype=torch_dtype
41
- )
42
- else:
43
  pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
44
-
 
45
  pipe = pipe.to(device)
46
  return pipe
47
 
 
32
 
33
  def load_model(model_choice):
34
  model_repo_id = MODEL_OPTIONS[model_choice]
35
+ if "Sana" in model_choice:
 
36
  pipe = SanaPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
37
+ elif "SD3" in model_choice:
38
+ pipe = StableDiffusion3Pipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
39
+ elif "Flux" in model_choice:
 
 
40
  pipe = FluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
41
+ else:
42
+ raise ValueError(f"Unknown model type for: {model_choice}")
43
  pipe = pipe.to(device)
44
  return pipe
45