PandaArtStation commited on
Commit
8a2d3f3
·
verified ·
1 Parent(s): b8ce5c8

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +135 -64
models.py CHANGED
@@ -1,10 +1,9 @@
1
  import torch
2
  from diffusers import (
3
- StableDiffusionImg2ImgPipeline,
4
  StableDiffusionInpaintPipeline,
5
- DDIMScheduler,
6
- PNDMScheduler,
7
- EulerDiscreteScheduler
8
  )
9
  from PIL import Image, ImageFilter, ImageEnhance
10
  import numpy as np
@@ -14,24 +13,27 @@ class InteriorDesignerPro:
14
  def __init__(self):
15
  self.device = torch.device("cuda")
16
  self.model_name = "RealVisXL V4.0"
17
-
18
  # Проверка GPU
19
  gpu_name = torch.cuda.get_device_name(0)
20
- self.is_powerful_gpu = any(gpu in gpu_name for gpu in ['A100', 'H100', 'RTX 4090', 'RTX 3090'])
21
-
22
- # Основная модель - RealVis V4 для фотореалистичных интерьеров
23
  print(f"Loading {self.model_name} on {gpu_name}...")
24
- self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
25
  "SG161222/RealVisXL_V4.0",
26
  torch_dtype=torch.float16,
27
- safety_checker=None,
28
- requires_safety_checker=False,
29
- local_files_only=False
30
  ).to(self.device)
31
-
 
 
 
 
32
  # Настройка scheduler для лучшего качества
33
  self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
34
-
35
  # Inpainting модель для удаления объектов
36
  try:
37
  self.inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
@@ -47,118 +49,130 @@ class InteriorDesignerPro:
47
  print(f"Warning: Could not load inpainting model: {e}")
48
  print("Using img2img as fallback for object removal")
49
  self.inpaint_pipe = None
50
-
51
  def apply_style_pro(self, image, style_name, room_type, strength=0.75, quality="balanced"):
52
  """Применение стиля к изображению с учетом качества"""
53
  from design_styles import DESIGN_STYLES
54
-
55
  style = DESIGN_STYLES.get(style_name, DESIGN_STYLES["Современный минимализм"])
56
-
57
  # Настройки качества
58
  quality_settings = {
59
  "fast": {"steps": 20, "guidance": 7.5},
60
  "balanced": {"steps": 35, "guidance": 8.5},
61
  "ultra": {"steps": 50, "guidance": 10}
62
  }
63
-
64
  settings = quality_settings.get(quality, quality_settings["balanced"])
65
-
66
  # Генерация промпта с учетом комнаты
67
  room_specific = style.get("room_specific", {}).get(room_type, "")
68
- full_prompt = f"{style['prompt']}, {room_specific}, {room_type} interior design, professional photo, high quality, 8k"
69
-
70
- # Генерация
71
  result = self.pipe(
72
  prompt=full_prompt,
73
- negative_prompt=style.get("negative", "low quality, blurry"),
 
 
74
  image=image,
75
  strength=strength,
76
  num_inference_steps=settings["steps"],
77
- guidance_scale=settings["guidance"]
 
 
 
78
  ).images[0]
79
-
80
  return result
81
-
82
  def create_variations(self, image, num_variations=4):
83
  """Создание вариаций дизайна"""
84
  variations = []
85
  base_seed = torch.randint(0, 1000000, (1,)).item()
86
-
87
  for i in range(num_variations):
88
  torch.manual_seed(base_seed + i)
89
-
90
  var = self.pipe(
91
- prompt="interior design variation, same style, different details",
 
92
  image=image,
93
  strength=0.4 + (i * 0.05),
94
  num_inference_steps=30,
95
- guidance_scale=7.5
 
 
96
  ).images[0]
97
-
98
  variations.append(var)
99
-
100
  return variations
101
-
102
  def create_hdr_lighting(self, image, intensity=0.3):
103
  """Улучшение освещения в стиле HDR"""
104
  # Конвертируем в numpy
105
  img_array = np.array(image)
106
-
107
  # Применяем CLAHE для улучшения контраста
108
  lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
109
  l, a, b = cv2.split(lab)
110
-
111
  clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
112
  l_clahe = clahe.apply(l)
113
-
114
  enhanced_lab = cv2.merge([l_clahe, a, b])
115
  enhanced_rgb = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2RGB)
116
-
117
  # Смешиваем с оригиналом
118
  result = cv2.addWeighted(img_array, 1-intensity, enhanced_rgb, intensity, 0)
119
-
120
  return Image.fromarray(result)
121
-
122
  def enhance_details(self, image):
123
  """Улучшение деталей изображения"""
124
  # Увеличиваем резкость
125
  enhancer = ImageEnhance.Sharpness(image)
126
  sharp = enhancer.enhance(1.5)
127
-
128
  # Немного увеличиваем контраст
129
  enhancer = ImageEnhance.Contrast(sharp)
130
  contrast = enhancer.enhance(1.1)
131
-
132
  return contrast
133
-
134
  def change_element(self, image, element, value, strength=0.7):
135
  """Изменение отдельного элемента интерьера"""
136
  from design_styles import ROOM_ELEMENTS
137
-
138
  element_info = ROOM_ELEMENTS.get(element, {})
139
  prompt_add = element_info.get("prompt_add", element.lower())
140
-
141
- prompt = f"interior with {value} {prompt_add}, professional photo"
142
  negative = f"old {element}, damaged, ugly"
143
-
144
  result = self.pipe(
145
  prompt=prompt,
 
146
  negative_prompt=negative,
 
147
  image=image,
148
  strength=strength,
149
  num_inference_steps=40,
150
- guidance_scale=8.0
 
 
151
  ).images[0]
152
-
153
- return result
154
 
 
 
155
  def create_style_comparison(self, image, styles, quality="fast"):
156
  """Создание сравнения стилей"""
157
  results = []
158
-
159
  # Настройки для быстрой генерации
160
  steps = 20 if quality == "fast" else 35
161
-
162
  for style in styles:
163
  styled = self.apply_style_pro(
164
  image,
@@ -168,23 +182,23 @@ class InteriorDesignerPro:
168
  quality=quality
169
  )
170
  results.append((style, styled))
171
-
172
  return results
173
 
174
 
175
  class ObjectRemover:
176
  """Класс для удаления объектов"""
177
-
178
  def __init__(self, inpaint_pipe):
179
  self.pipe = inpaint_pipe
180
  self.device = torch.device("cuda")
181
-
182
  def remove_objects(self, image, mask):
183
  """Удаление объектов с изображения"""
184
  if self.pipe is None:
185
  # Fallback на простое заполнение
186
  return self.simple_inpaint(image, mask)
187
-
188
  # Используем inpainting pipeline
189
  result = self.pipe(
190
  prompt="empty room interior, clean wall, seamless texture",
@@ -195,34 +209,91 @@ class ObjectRemover:
195
  num_inference_steps=50,
196
  guidance_scale=7.5
197
  ).images[0]
198
-
199
  return result
200
-
201
  def simple_inpaint(self, image, mask):
202
  """Простое заполнение через OpenCV"""
203
  img_array = np.array(image)
204
  mask_array = np.array(mask.convert('L'))
205
-
206
  # Инпейнтинг через OpenCV
207
  result = cv2.inpaint(img_array, mask_array, 3, cv2.INPAINT_TELEA)
208
-
209
  return Image.fromarray(result)
210
-
211
  def generate_mask_from_text(self, image, text_description, precision=0.3):
212
  """Генерация маски на основе текстового описания"""
213
  # Простая маска в центре (заглушка)
214
  # В реальности тут должен быть CLIP или SAM
215
  width, height = image.size
216
  mask = Image.new('L', (width, height), 0)
217
-
218
  # Создаем маску в центре
219
  center_x, center_y = width // 2, height // 2
220
  radius = int(min(width, height) * precision)
221
-
222
  # Рисуем круг
223
- import ImageDraw
224
  draw = ImageDraw.Draw(mask)
225
  draw.ellipse([center_x - radius, center_y - radius,
226
  center_x + radius, center_y + radius], fill=255)
227
-
228
  return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from diffusers import (
3
+ StableDiffusionXLImg2ImgPipeline,
4
  StableDiffusionInpaintPipeline,
5
+ EulerDiscreteScheduler,
6
+ DPMSolverMultistepScheduler
 
7
  )
8
  from PIL import Image, ImageFilter, ImageEnhance
9
  import numpy as np
 
13
  def __init__(self):
14
  self.device = torch.device("cuda")
15
  self.model_name = "RealVisXL V4.0"
16
+
17
  # Проверка GPU
18
  gpu_name = torch.cuda.get_device_name(0)
19
+ self.is_powerful_gpu = any(gpu in gpu_name for gpu in ['A100', 'H100', 'RTX 4090', 'RTX 3090', 'T4'])
20
+
21
+ # Основная модель - RealVisXL V4 для фотореалистичных интерьеров
22
  print(f"Loading {self.model_name} on {gpu_name}...")
23
+ self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
24
  "SG161222/RealVisXL_V4.0",
25
  torch_dtype=torch.float16,
26
+ use_safetensors=True,
27
+ variant="fp16"
 
28
  ).to(self.device)
29
+
30
+ # Включаем оптимизации для SDXL
31
+ self.pipe.enable_model_cpu_offload()
32
+ self.pipe.enable_vae_slicing()
33
+
34
  # Настройка scheduler для лучшего качества
35
  self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config)
36
+
37
  # Inpainting модель для удаления объектов
38
  try:
39
  self.inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
 
49
  print(f"Warning: Could not load inpainting model: {e}")
50
  print("Using img2img as fallback for object removal")
51
  self.inpaint_pipe = None
52
+
53
  def apply_style_pro(self, image, style_name, room_type, strength=0.75, quality="balanced"):
54
  """Применение стиля к изображению с учетом качества"""
55
  from design_styles import DESIGN_STYLES
56
+
57
  style = DESIGN_STYLES.get(style_name, DESIGN_STYLES["Современный минимализм"])
58
+
59
  # Настройки качества
60
  quality_settings = {
61
  "fast": {"steps": 20, "guidance": 7.5},
62
  "balanced": {"steps": 35, "guidance": 8.5},
63
  "ultra": {"steps": 50, "guidance": 10}
64
  }
65
+
66
  settings = quality_settings.get(quality, quality_settings["balanced"])
67
+
68
  # Генерация промпта с учетом комнаты
69
  room_specific = style.get("room_specific", {}).get(room_type, "")
70
+ full_prompt = f"{style['prompt']}, {room_specific}, {room_type} interior design, professional photo, high quality, 8k, photorealistic"
71
+
72
+ # Генерация с параметрами для SDXL
73
  result = self.pipe(
74
  prompt=full_prompt,
75
+ prompt_2=full_prompt, # SDXL требует второй промпт
76
+ negative_prompt=style.get("negative", "low quality, blurry, deformed"),
77
+ negative_prompt_2=style.get("negative", "low quality, blurry, deformed"),
78
  image=image,
79
  strength=strength,
80
  num_inference_steps=settings["steps"],
81
+ guidance_scale=settings["guidance"],
82
+ # SDXL специфичные параметры
83
+ original_size=(1024, 1024),
84
+ target_size=(1024, 1024)
85
  ).images[0]
86
+
87
  return result
88
+
89
  def create_variations(self, image, num_variations=4):
90
  """Создание вариаций дизайна"""
91
  variations = []
92
  base_seed = torch.randint(0, 1000000, (1,)).item()
93
+
94
  for i in range(num_variations):
95
  torch.manual_seed(base_seed + i)
96
+
97
  var = self.pipe(
98
+ prompt="interior design variation, same style, different details, photorealistic",
99
+ prompt_2="interior design variation, same style, different details, photorealistic",
100
  image=image,
101
  strength=0.4 + (i * 0.05),
102
  num_inference_steps=30,
103
+ guidance_scale=7.5,
104
+ original_size=(1024, 1024),
105
+ target_size=(1024, 1024)
106
  ).images[0]
107
+
108
  variations.append(var)
109
+
110
  return variations
111
+
112
  def create_hdr_lighting(self, image, intensity=0.3):
113
  """Улучшение освещения в стиле HDR"""
114
  # Конвертируем в numpy
115
  img_array = np.array(image)
116
+
117
  # Применяем CLAHE для улучшения контраста
118
  lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2LAB)
119
  l, a, b = cv2.split(lab)
120
+
121
  clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
122
  l_clahe = clahe.apply(l)
123
+
124
  enhanced_lab = cv2.merge([l_clahe, a, b])
125
  enhanced_rgb = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2RGB)
126
+
127
  # Смешиваем с оригиналом
128
  result = cv2.addWeighted(img_array, 1-intensity, enhanced_rgb, intensity, 0)
129
+
130
  return Image.fromarray(result)
131
+
132
  def enhance_details(self, image):
133
  """Улучшение деталей изображения"""
134
  # Увеличиваем резкость
135
  enhancer = ImageEnhance.Sharpness(image)
136
  sharp = enhancer.enhance(1.5)
137
+
138
  # Немного увеличиваем контраст
139
  enhancer = ImageEnhance.Contrast(sharp)
140
  contrast = enhancer.enhance(1.1)
141
+
142
  return contrast
143
+
144
  def change_element(self, image, element, value, strength=0.7):
145
  """Изменение отдельного элемента интерьера"""
146
  from design_styles import ROOM_ELEMENTS
147
+
148
  element_info = ROOM_ELEMENTS.get(element, {})
149
  prompt_add = element_info.get("prompt_add", element.lower())
150
+
151
+ prompt = f"interior with {value} {prompt_add}, professional photo, photorealistic"
152
  negative = f"old {element}, damaged, ugly"
153
+
154
  result = self.pipe(
155
  prompt=prompt,
156
+ prompt_2=prompt,
157
  negative_prompt=negative,
158
+ negative_prompt_2=negative,
159
  image=image,
160
  strength=strength,
161
  num_inference_steps=40,
162
+ guidance_scale=8.0,
163
+ original_size=(1024, 1024),
164
+ target_size=(1024, 1024)
165
  ).images[0]
 
 
166
 
167
+ return result
168
+
169
  def create_style_comparison(self, image, styles, quality="fast"):
170
  """Создание сравнения стилей"""
171
  results = []
172
+
173
  # Настройки для быстрой генерации
174
  steps = 20 if quality == "fast" else 35
175
+
176
  for style in styles:
177
  styled = self.apply_style_pro(
178
  image,
 
182
  quality=quality
183
  )
184
  results.append((style, styled))
185
+
186
  return results
187
 
188
 
189
  class ObjectRemover:
190
  """Класс для удаления объектов"""
191
+
192
  def __init__(self, inpaint_pipe):
193
  self.pipe = inpaint_pipe
194
  self.device = torch.device("cuda")
195
+
196
  def remove_objects(self, image, mask):
197
  """Удаление объектов с изображения"""
198
  if self.pipe is None:
199
  # Fallback на простое заполнение
200
  return self.simple_inpaint(image, mask)
201
+
202
  # Используем inpainting pipeline
203
  result = self.pipe(
204
  prompt="empty room interior, clean wall, seamless texture",
 
209
  num_inference_steps=50,
210
  guidance_scale=7.5
211
  ).images[0]
212
+
213
  return result
214
+
215
  def simple_inpaint(self, image, mask):
216
  """Простое заполнение через OpenCV"""
217
  img_array = np.array(image)
218
  mask_array = np.array(mask.convert('L'))
219
+
220
  # Инпейнтинг через OpenCV
221
  result = cv2.inpaint(img_array, mask_array, 3, cv2.INPAINT_TELEA)
222
+
223
  return Image.fromarray(result)
224
+
225
  def generate_mask_from_text(self, image, text_description, precision=0.3):
226
  """Генерация маски на основе текстового описания"""
227
  # Простая маска в центре (заглушка)
228
  # В реальности тут должен быть CLIP или SAM
229
  width, height = image.size
230
  mask = Image.new('L', (width, height), 0)
231
+
232
  # Создаем маску в центре
233
  center_x, center_y = width // 2, height // 2
234
  radius = int(min(width, height) * precision)
235
+
236
  # Рисуем круг
237
+ from PIL import ImageDraw
238
  draw = ImageDraw.Draw(mask)
239
  draw.ellipse([center_x - radius, center_y - radius,
240
  center_x + radius, center_y + radius], fill=255)
241
+
242
  return mask
243
+
244
+
245
+ # Добавляем метод _create_comparison_grid к классу при импорте
246
+ def _create_comparison_grid(self, images_with_labels):
247
+ """Создает сетку из изображений с подписями"""
248
+ if not images_with_labels:
249
+ return None
250
+
251
+ images = [img for _, img in images_with_labels]
252
+ labels = [label for label, _ in images_with_labels]
253
+
254
+ # Определяем размер сетки
255
+ n = len(images)
256
+ cols = min(3, n) # Максимум 3 колонки
257
+ rows = (n + cols - 1) // cols
258
+
259
+ # Размер одного изображения
260
+ img_width, img_height = images[0].size
261
+ padding = 20
262
+ label_height = 40
263
+
264
+ # Создаем холст
265
+ grid_width = cols * img_width + (cols + 1) * padding
266
+ grid_height = rows * (img_height + label_height) + (rows + 1) * padding
267
+
268
+ grid = Image.new('RGB', (grid_width, grid_height), 'white')
269
+
270
+ # Добавляем изображения
271
+ from PIL import ImageDraw, ImageFont
272
+ draw = ImageDraw.Draw(grid)
273
+
274
+ try:
275
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
276
+ except:
277
+ font = None
278
+
279
+ for idx, (img, label) in enumerate(zip(images, labels)):
280
+ row = idx // cols
281
+ col = idx % cols
282
+
283
+ x = col * img_width + (col + 1) * padding
284
+ y = row * (img_height + label_height) + (row + 1) * padding
285
+
286
+ # Вставляем изображение
287
+ grid.paste(img, (x, y))
288
+
289
+ # Добавляем подпись
290
+ text_x = x + img_width // 2
291
+ text_y = y + img_height + 5
292
+
293
+ draw.text((text_x, text_y), label, fill='black', font=font, anchor='mt')
294
+
295
+ return grid
296
+
297
+ # Патчим класс
298
+ InteriorDesignerPro._create_comparison_grid = _create_comparison_grid
299
+ models.py