Spaces:
Runtime error
Runtime error
| from aura_sr import AuraSR | |
| import gradio as gr | |
| import spaces | |
| class ZeroGPUAuraSR(AuraSR): | |
| def from_pretrained(cls, model_id: str = "fal-ai/AuraSR", use_safetensors: bool = True): | |
| import json | |
| import torch | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| # Check if model_id is a local file | |
| if Path(model_id).is_file(): | |
| local_file = Path(model_id) | |
| if local_file.suffix == '.safetensors': | |
| use_safetensors = True | |
| elif local_file.suffix == '.ckpt': | |
| use_safetensors = False | |
| else: | |
| raise ValueError(f"Unsupported file format: {local_file.suffix}. Please use .safetensors or .ckpt files.") | |
| # For local files, we need to provide the config separately | |
| config_path = local_file.with_name('config.json') | |
| if not config_path.exists(): | |
| raise FileNotFoundError( | |
| f"Config file not found: {config_path}. " | |
| f"When loading from a local file, ensure that 'config.json' " | |
| f"is present in the same directory as '{local_file.name}'. " | |
| f"If you're trying to load a model from Hugging Face, " | |
| f"please provide the model ID instead of a file path." | |
| ) | |
| config = json.loads(config_path.read_text()) | |
| hf_model_path = local_file.parent | |
| else: | |
| hf_model_path = Path(snapshot_download(model_id)) | |
| config = json.loads((hf_model_path / "config.json").read_text()) | |
| model = cls(config) | |
| if use_safetensors: | |
| try: | |
| from safetensors.torch import load_file | |
| checkpoint = load_file(hf_model_path / "model.safetensors" if not Path(model_id).is_file() else model_id) | |
| except ImportError: | |
| raise ImportError( | |
| "The safetensors library is not installed. " | |
| "Please install it with `pip install safetensors` " | |
| "or use `use_safetensors=False` to load the model with PyTorch." | |
| ) | |
| else: | |
| checkpoint = torch.load(hf_model_path / "model.ckpt" if not Path(model_id).is_file() else model_id) | |
| model.upsampler.load_state_dict(checkpoint, strict=True) | |
| return model | |
| aura_sr = ZeroGPUAuraSR.from_pretrained("fal/AuraSR-v2") | |
| aura_sr_v1 = ZeroGPUAuraSR.from_pretrained("fal-ai/AuraSR") | |
| def predict(img, model_selection): | |
| return {'v1': aura_sr_v1, 'v2': aura_sr}.get(model_selection).upscale_4x(img) | |
| demo = gr.Interface( | |
| predict, | |
| inputs=[gr.Image(), gr.Dropdown(value='v2', choices=['v1', 'v2'])], | |
| outputs=gr.Image() | |
| ) | |
| demo.launch() |