Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| from io import BytesIO | |
| import base64 | |
| from functools import partial | |
| from PIL import Image, ImageOps | |
| import gradio as gr | |
| from makeavid_sd.inference import ( | |
| InferenceUNetPseudo3D, | |
| jnp, | |
| SCHEDULERS | |
| ) | |
| print(os.environ.get('XLA_PYTHON_CLIENT_PREALLOCATE', 'NotSet')) | |
| print(os.environ.get('XLA_PYTHON_CLIENT_ALLOCATOR', 'NotSet')) | |
| _seen_compilations = set() | |
| _model = InferenceUNetPseudo3D( | |
| model_path = 'TempoFunk/makeavid-sd-jax', | |
| dtype = jnp.float16, | |
| hf_auth_token = os.environ.get('HUGGING_FACE_HUB_TOKEN', None) | |
| ) | |
| import datetime | |
| print(datetime.datetime.now(datetime.timezone.utc).isoformat()) | |
| if _model.failed != False: | |
| trace = f'```{_model.failed}```' | |
| with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled = False) as demo: | |
| exception = gr.Markdown(trace) | |
| demo.launch() | |
| _examples = [] | |
| _expath = 'examples' | |
| for x in sorted(os.listdir(_expath)): | |
| with open(os.path.join(_expath, x, 'params.json'), 'r') as f: | |
| ex = json.load(f) | |
| ex['image_input'] = None | |
| if os.path.isfile(os.path.join(_expath, x, 'input.png')): | |
| ex['image_input'] = os.path.join(_expath, x, 'input.png') | |
| ex['image_output'] = os.path.join(_expath, x, 'output.gif') | |
| _examples.append(ex) | |
| _output_formats = ( | |
| 'webp', 'gif' | |
| ) | |
| # gradio is illiterate. type hints make it go poopoo in pantsu. | |
| def generate( | |
| prompt = 'An elderly man having a great time in the park.', | |
| neg_prompt = '', | |
| hint_image = None, | |
| inference_steps = 20, | |
| cfg = 15.0, | |
| cfg_image = 9.0, | |
| seed = 0, | |
| fps = 12, | |
| num_frames = 24, | |
| height = 512, | |
| width = 512, | |
| scheduler_type = 'dpm', | |
| output_format = 'gif' | |
| ) -> str: | |
| num_frames = min(24, max(2, int(num_frames))) | |
| inference_steps = min(60, max(2, int(inference_steps))) | |
| height = min(576, max(256, int(height))) | |
| width = min(576, max(256, int(width))) | |
| height = (height // 64) * 64 | |
| width = (width // 64) * 64 | |
| cfg = max(cfg, 1.0) | |
| cfg_image = max(cfg_image, 1.0) | |
| fps = min(1000, max(1, int(fps))) | |
| seed = min(2**32-2, int(seed)) | |
| if seed < 0: | |
| seed = -seed | |
| if hint_image is not None: | |
| if hint_image.mode != 'RGB': | |
| hint_image = hint_image.convert('RGB') | |
| if hint_image.size != (width, height): | |
| hint_image = ImageOps.fit(hint_image, (width, height), method = Image.Resampling.LANCZOS) | |
| scheduler_type = scheduler_type.lower() | |
| if scheduler_type not in SCHEDULERS: | |
| scheduler_type = 'dpm' | |
| output_format = output_format.lower() | |
| if output_format not in _output_formats: | |
| output_format = 'gif' | |
| mask_image = None | |
| images = _model.generate( | |
| prompt = [prompt] * _model.device_count, | |
| neg_prompt = neg_prompt, | |
| hint_image = hint_image, | |
| mask_image = mask_image, | |
| inference_steps = inference_steps, | |
| cfg = cfg, | |
| cfg_image = cfg_image, | |
| height = height, | |
| width = width, | |
| num_frames = num_frames, | |
| seed = seed, | |
| scheduler_type = scheduler_type | |
| ) | |
| _seen_compilations.add((hint_image is None, inference_steps, height, width, num_frames)) | |
| with BytesIO() as buffer: | |
| images[1].save( | |
| buffer, | |
| format = output_format, | |
| save_all = True, | |
| append_images = images[2:], | |
| loop = 0, | |
| duration = round(1000 / fps), | |
| allow_mixed = True, | |
| optimize = True | |
| ) | |
| data = f'data:image/{output_format};base64,' + base64.b64encode(buffer.getvalue()).decode() | |
| with BytesIO() as buffer: | |
| images[-1].save(buffer, format = 'png', optimize = True) | |
| last_data = f'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode() | |
| with BytesIO() as buffer: | |
| images[0].save(buffer, format ='png', optimize = True) | |
| first_data = f'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode() | |
| return data, last_data, first_data | |
| def check_if_compiled(hint_image, inference_steps, height, width, num_frames, scheduler_type, message): | |
| height = int(height) | |
| width = int(width) | |
| inference_steps = int(inference_steps) | |
| height = (height // 64) * 64 | |
| width = (width // 64) * 64 | |
| if (hint_image is None, inference_steps, height, width, num_frames, scheduler_type) in _seen_compilations: | |
| return '' | |
| else: | |
| return message | |
| with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled = False) as demo: | |
| variant = 'panel' | |
| with gr.Row(): | |
| with gr.Column(): | |
| intro1 = gr.Markdown(""" | |
| # Make-A-Video Stable Diffusion JAX | |
| We have extended a pretrained latent-diffusion inpainting image generation model with **temporal convolutions and attention**. | |
| We guide the video generation with a hint image by taking advantage of the extra 5 input channels of the inpainting model. | |
| In this demo the hint image can be given by the user, otherwise it is generated by an generative image model. | |
| The temporal layers are a port of [Make-A-Video PyTorch](https://github.com/lucidrains/make-a-video-pytorch) to [JAX](https://github.com/google/jax) utilizing [FLAX](https://github.com/google/flax). | |
| The convolution is pseudo 3D and seperately convolves accross the spatial dimension in 2D and over the temporal dimension in 1D. | |
| Temporal attention is purely self attention and also separately attends to time. | |
| Only the new temporal layers have been fine tuned on a dataset of videos themed around dance. | |
| The model has been trained for 80 epochs on a dataset of 18,000 Videos with 120 frames each, randomly selecting a 24 frame range from each sample. | |
| Model: [TempoFunk/makeavid-sd-jax](https://huggingface.co/TempoFunk/makeavid-sd-jax) | |
| Datasets: [TempoFunk/tempofunk-sdance](https://huggingface.co/datasets/TempoFunk/tempofunk-sdance), [TempoFunk/small](https://huggingface.co/datasets/TempoFunk/small) | |
| Model implementation and training code can be found at <https://github.com/lopho/makeavid-sd-tpu> (WIP) | |
| """) | |
| with gr.Column(): | |
| intro3 = gr.Markdown(""" | |
| **Please be patient. The model might have to compile with current parameters.** | |
| This can take up to 5 minutes on the first run, and 2-3 minutes on later runs. | |
| The compilation will be cached and later runs with the same parameters | |
| will be much faster. | |
| Changes to the following parameters require the model to compile | |
| - Number of frames | |
| - Width & Height | |
| - Inference steps | |
| - Input image vs. no input image | |
| - Noise scheduler type | |
| If you encounter any issues, please report them here: [Space discussions](https://huggingface.co/spaces/TempoFunk/makeavid-sd-jax/discussions) (or DM [@lopho](https://twitter.com/lopho)) | |
| <small>Leave a ❤️ like if you like. Consider it a dopamine donation at no cost.</small> | |
| """) | |
| with gr.Row(variant = variant): | |
| with gr.Column(): | |
| with gr.Row(): | |
| #cancel_button = gr.Button(value = 'Cancel') | |
| submit_button = gr.Button(value = 'Make A Video', variant = 'primary') | |
| prompt_input = gr.Textbox( | |
| label = 'Prompt', | |
| value = 'They are dancing in the club but everybody is a 3d cg hairy monster wearing a hairy costume.', | |
| interactive = True | |
| ) | |
| neg_prompt_input = gr.Textbox( | |
| label = 'Negative prompt (optional)', | |
| value = 'monochrome, saturated', | |
| interactive = True | |
| ) | |
| cfg_input = gr.Slider( | |
| label = 'Guidance scale video', | |
| minimum = 1.0, | |
| maximum = 20.0, | |
| step = 0.1, | |
| value = 15.0, | |
| interactive = True | |
| ) | |
| cfg_image_input = gr.Slider( | |
| label = 'Guidance scale hint (no effect with input image)', | |
| minimum = 1.0, | |
| maximum = 20.0, | |
| step = 0.1, | |
| value = 15.0, | |
| interactive = True | |
| ) | |
| seed_input = gr.Number( | |
| label = 'Random seed', | |
| value = 0, | |
| interactive = True, | |
| precision = 0 | |
| ) | |
| image_input = gr.Image( | |
| label = 'Hint image (optional)', | |
| interactive = True, | |
| image_mode = 'RGB', | |
| type = 'pil', | |
| optional = True, | |
| source = 'upload' | |
| ) | |
| inference_steps_input = gr.Slider( | |
| label = 'Steps', | |
| minimum = 2, | |
| maximum = 60, | |
| value = 20, | |
| step = 1, | |
| interactive = True | |
| ) | |
| num_frames_input = gr.Slider( | |
| label = 'Number of frames to generate', | |
| minimum = 2, | |
| maximum = 24, | |
| step = 1, | |
| value = 24, | |
| interactive = True | |
| ) | |
| width_input = gr.Slider( | |
| label = 'Width', | |
| minimum = 256, | |
| maximum = 576, | |
| step = 64, | |
| value = 512, | |
| interactive = True | |
| ) | |
| height_input = gr.Slider( | |
| label = 'Height', | |
| minimum = 256, | |
| maximum = 576, | |
| step = 64, | |
| value = 512, | |
| interactive = True | |
| ) | |
| scheduler_input = gr.Dropdown( | |
| label = 'Noise scheduler', | |
| choices = list(SCHEDULERS.keys()), | |
| value = 'dpm', | |
| interactive = True | |
| ) | |
| with gr.Row(): | |
| fps_input = gr.Slider( | |
| label = 'Output FPS', | |
| minimum = 1, | |
| maximum = 1000, | |
| step = 1, | |
| value = 12, | |
| interactive = True | |
| ) | |
| output_format = gr.Dropdown( | |
| label = 'Output format', | |
| choices = _output_formats, | |
| value = 'gif', | |
| interactive = True | |
| ) | |
| with gr.Column(): | |
| #will_trigger = gr.Markdown('') | |
| patience = gr.Markdown('**Please be patient. The model might have to compile with current parameters.**') | |
| image_output = gr.Image( | |
| label = 'Output', | |
| value = 'example.gif', | |
| interactive = False | |
| ) | |
| tips = gr.Markdown('🤫 *Secret tip*: try using the last frame as input for the next generation.') | |
| with gr.Row(): | |
| last_frame_output = gr.Image( | |
| label = 'Last frame', | |
| interactive = False | |
| ) | |
| first_frame_output = gr.Image( | |
| label = 'Initial frame', | |
| interactive = False | |
| ) | |
| examples_lst = [] | |
| for x in _examples: | |
| examples_lst.append([ | |
| x['image_output'], | |
| x['prompt'], | |
| x['neg_prompt'], | |
| x['image_input'], | |
| x['cfg'], | |
| x['cfg_image'], | |
| x['seed'], | |
| x['fps'], | |
| x['steps'], | |
| x['scheduler'], | |
| x['num_frames'], | |
| x['height'], | |
| x['width'], | |
| x['format'] | |
| ]) | |
| examples = gr.Examples( | |
| examples = examples_lst, | |
| inputs = [ | |
| image_output, | |
| prompt_input, | |
| neg_prompt_input, | |
| image_input, | |
| cfg_input, | |
| cfg_image_input, | |
| seed_input, | |
| fps_input, | |
| inference_steps_input, | |
| scheduler_input, | |
| num_frames_input, | |
| height_input, | |
| width_input, | |
| output_format | |
| ], | |
| postprocess = False | |
| ) | |
| #trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input, scheduler_input ] | |
| #trigger_check_fun = partial(check_if_compiled, message = 'Current parameters need compilation.') | |
| #height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) | |
| #width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) | |
| #num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) | |
| #image_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) | |
| #inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) | |
| #scheduler_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger) | |
| submit_button.click( | |
| fn = generate, | |
| inputs = [ | |
| prompt_input, | |
| neg_prompt_input, | |
| image_input, | |
| inference_steps_input, | |
| cfg_input, | |
| cfg_image_input, | |
| seed_input, | |
| fps_input, | |
| num_frames_input, | |
| height_input, | |
| width_input, | |
| scheduler_input, | |
| output_format | |
| ], | |
| outputs = [ image_output, last_frame_output, first_frame_output ], | |
| postprocess = False | |
| ) | |
| #cancel_button.click(fn = lambda: None, cancels = ev) | |
| demo.queue(concurrency_count = 1, max_size = 8, api_open = True) | |
| demo.launch(show_api = True) | |