primerz commited on
Commit
8967a30
·
verified ·
1 Parent(s): 667cce7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -22
app.py CHANGED
@@ -13,13 +13,16 @@ from PIL import Image
13
  import numpy as np
14
  import cv2
15
  from transformers import pipeline as transformers_pipeline
 
16
  import os
17
 
18
- # Device configuration
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  dtype = torch.float16 if device == "cuda" else torch.float32
21
 
22
  print(f"Using device: {device}")
 
23
 
24
  class RetroArtConverter:
25
  def __init__(self):
@@ -42,16 +45,22 @@ class RetroArtConverter:
42
  torch_dtype=self.dtype
43
  ).to(self.device)
44
 
45
- # Load custom VAE
46
- print("Loading custom VAE (pixelate)...")
47
- vae_path = "./models/vae/pixelate.safetensors"
48
- if os.path.exists(vae_path):
 
 
 
 
49
  self.vae = AutoencoderKL.from_single_file(
50
  vae_path,
51
  torch_dtype=self.dtype
52
  ).to(self.device)
53
- else:
54
- print("Warning: Custom VAE not found, using default SDXL VAE")
 
 
55
  self.vae = AutoencoderKL.from_pretrained(
56
  "madebyollin/sdxl-vae-fp16-fix",
57
  torch_dtype=self.dtype
@@ -64,11 +73,14 @@ class RetroArtConverter:
64
  model="Intel/dpt-hybrid-midas"
65
  )
66
 
67
- # Load SDXL base model with custom checkpoint
68
- print("Loading SDXL model (horizon)...")
69
- model_path = "./models/checkpoints/horizon.safetensors"
70
-
71
- if os.path.exists(model_path):
 
 
 
72
  self.pipe = StableDiffusionXLControlNetPipeline.from_single_file(
73
  model_path,
74
  controlnet=self.controlnet_depth,
@@ -76,8 +88,10 @@ class RetroArtConverter:
76
  torch_dtype=self.dtype,
77
  use_safetensors=True
78
  ).to(self.device)
79
- else:
80
- print("Warning: Custom checkpoint not found, using default SDXL")
 
 
81
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
82
  "stabilityai/stable-diffusion-xl-base-1.0",
83
  controlnet=self.controlnet_depth,
@@ -86,14 +100,19 @@ class RetroArtConverter:
86
  use_safetensors=True
87
  ).to(self.device)
88
 
89
- # Load custom LORA
90
- print("Loading LORA (retroart)...")
91
- lora_path = "./models/lora/retroart.safetensors"
92
- if os.path.exists(lora_path):
 
 
 
 
93
  self.pipe.load_lora_weights(lora_path)
94
- print("LORA loaded successfully")
95
- else:
96
- print("Warning: Custom LORA not found")
 
97
 
98
  # Optimize pipeline
99
  self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
@@ -331,4 +350,4 @@ if __name__ == "__main__":
331
  server_port=7860,
332
  share=False,
333
  show_api=True # Enable API
334
- )
 
13
  import numpy as np
14
  import cv2
15
  from transformers import pipeline as transformers_pipeline
16
+ from huggingface_hub import hf_hub_download
17
  import os
18
 
19
+ # Configuration
20
+ MODEL_REPO = "primerz/pixagram"
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  dtype = torch.float16 if device == "cuda" else torch.float32
23
 
24
  print(f"Using device: {device}")
25
+ print(f"Loading models from: {MODEL_REPO}")
26
 
27
  class RetroArtConverter:
28
  def __init__(self):
 
45
  torch_dtype=self.dtype
46
  ).to(self.device)
47
 
48
+ # Load custom VAE from HuggingFace Hub
49
+ print("Loading custom VAE (pixelate) from HuggingFace Hub...")
50
+ try:
51
+ vae_path = hf_hub_download(
52
+ repo_id=MODEL_REPO,
53
+ filename="pixelate.safetensors",
54
+ repo_type="model"
55
+ )
56
  self.vae = AutoencoderKL.from_single_file(
57
  vae_path,
58
  torch_dtype=self.dtype
59
  ).to(self.device)
60
+ print("✓ Custom VAE loaded successfully")
61
+ except Exception as e:
62
+ print(f"Warning: Could not load custom VAE: {e}")
63
+ print("Using default SDXL VAE")
64
  self.vae = AutoencoderKL.from_pretrained(
65
  "madebyollin/sdxl-vae-fp16-fix",
66
  torch_dtype=self.dtype
 
73
  model="Intel/dpt-hybrid-midas"
74
  )
75
 
76
+ # Load SDXL checkpoint from HuggingFace Hub
77
+ print("Loading SDXL checkpoint (horizon) from HuggingFace Hub...")
78
+ try:
79
+ model_path = hf_hub_download(
80
+ repo_id=MODEL_REPO,
81
+ filename="horizon.safetensors",
82
+ repo_type="model"
83
+ )
84
  self.pipe = StableDiffusionXLControlNetPipeline.from_single_file(
85
  model_path,
86
  controlnet=self.controlnet_depth,
 
88
  torch_dtype=self.dtype,
89
  use_safetensors=True
90
  ).to(self.device)
91
+ print("✓ Custom checkpoint loaded successfully")
92
+ except Exception as e:
93
+ print(f"Warning: Could not load custom checkpoint: {e}")
94
+ print("Using default SDXL")
95
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
96
  "stabilityai/stable-diffusion-xl-base-1.0",
97
  controlnet=self.controlnet_depth,
 
100
  use_safetensors=True
101
  ).to(self.device)
102
 
103
+ # Load LORA from HuggingFace Hub
104
+ print("Loading LORA (retroart) from HuggingFace Hub...")
105
+ try:
106
+ lora_path = hf_hub_download(
107
+ repo_id=MODEL_REPO,
108
+ filename="retroart.safetensors",
109
+ repo_type="model"
110
+ )
111
  self.pipe.load_lora_weights(lora_path)
112
+ print("LORA loaded successfully")
113
+ except Exception as e:
114
+ print(f"Warning: Could not load LORA: {e}")
115
+ print("Running without LORA")
116
 
117
  # Optimize pipeline
118
  self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
 
350
  server_port=7860,
351
  share=False,
352
  show_api=True # Enable API
353
+ )