Spaces:
Configuration error
Configuration error
Commit
·
c6dd45b
1
Parent(s):
c655f3a
MPS support
Browse files
app.py
CHANGED
|
@@ -17,18 +17,27 @@ from share_btn import community_icon_html, loading_icon_html, share_js
|
|
| 17 |
# load pipelines
|
| 18 |
# sd_model_id = "runwayml/stable-diffusion-v1-5"
|
| 19 |
sd_model_id = "stabilityai/stable-diffusion-2-1-base"
|
| 20 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
pipe.scheduler = DPMSolverMultistepSchedulerInject.from_pretrained(sd_model_id, subfolder="scheduler"
|
| 24 |
, algorithm_type="sde-dpmsolver++", solver_order=2)
|
| 25 |
|
| 26 |
blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
| 27 |
-
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=
|
| 28 |
|
| 29 |
## IMAGE CPATIONING ##
|
| 30 |
def caption_image(input_image):
|
| 31 |
-
inputs = blip_processor(images=input_image, return_tensors="pt").to(device,
|
| 32 |
pixel_values = inputs.pixel_values
|
| 33 |
|
| 34 |
generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
|
|
@@ -228,7 +237,7 @@ def randomize_seed_fn(seed, is_random):
|
|
| 228 |
|
| 229 |
def seed_everything(seed):
|
| 230 |
torch.manual_seed(seed)
|
| 231 |
-
torch.cuda.manual_seed(seed)
|
| 232 |
random.seed(seed)
|
| 233 |
np.random.seed(seed)
|
| 234 |
|
|
@@ -902,4 +911,4 @@ with gr.Blocks(css="style.css") as demo:
|
|
| 902 |
)
|
| 903 |
|
| 904 |
demo.queue(default_concurrency_limit=1)
|
| 905 |
-
demo.launch()
|
|
|
|
| 17 |
# load pipelines
|
| 18 |
# sd_model_id = "runwayml/stable-diffusion-v1-5"
|
| 19 |
sd_model_id = "stabilityai/stable-diffusion-2-1-base"
|
| 20 |
+
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
+
|
| 22 |
+
torch_dtype = torch.float16
|
| 23 |
+
if torch.cuda.is_available():
|
| 24 |
+
device = "cuda"
|
| 25 |
+
elif torch.backends.mps.is_available():
|
| 26 |
+
device = "mps"
|
| 27 |
+
torch_dtype = torch.float32
|
| 28 |
+
else:
|
| 29 |
+
device = "cpu"
|
| 30 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch_dtype)
|
| 31 |
+
pipe = SemanticStableDiffusionImg2ImgPipeline_DPMSolver.from_pretrained(sd_model_id,vae=vae,torch_dtype=torch_dtype, safety_checker=None, requires_safety_checker=False).to(device)
|
| 32 |
pipe.scheduler = DPMSolverMultistepSchedulerInject.from_pretrained(sd_model_id, subfolder="scheduler"
|
| 33 |
, algorithm_type="sde-dpmsolver++", solver_order=2)
|
| 34 |
|
| 35 |
blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
| 36 |
+
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch_dtype).to(device)
|
| 37 |
|
| 38 |
## IMAGE CPATIONING ##
|
| 39 |
def caption_image(input_image):
|
| 40 |
+
inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch_dtype)
|
| 41 |
pixel_values = inputs.pixel_values
|
| 42 |
|
| 43 |
generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
|
|
|
|
| 237 |
|
| 238 |
def seed_everything(seed):
|
| 239 |
torch.manual_seed(seed)
|
| 240 |
+
# torch.cuda.manual_seed(seed)
|
| 241 |
random.seed(seed)
|
| 242 |
np.random.seed(seed)
|
| 243 |
|
|
|
|
| 911 |
)
|
| 912 |
|
| 913 |
demo.queue(default_concurrency_limit=1)
|
| 914 |
+
demo.launch()
|