Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	T2V, Video Pix2Pix and Pose-Guided Gen
Browse files- README.md +5 -7
- app.py +73 -0
- app_pix2pix_video.py +70 -0
- app_pose.py +62 -0
- app_text_to_video.py +44 -0
- config.py +1 -0
- gradio_utils.py +77 -0
- model.py +296 -0
- requirements.txt +34 -0
- share.py +8 -0
- style.css +3 -0
- utils.py +187 -0
    	
        README.md
    CHANGED
    
    | @@ -1,12 +1,10 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title: Text2Video | 
| 3 | 
            -
            emoji:  | 
| 4 | 
            -
            colorFrom:  | 
| 5 | 
            -
            colorTo:  | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
             
            sdk_version: 3.23.0
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
            -
            ---
         | 
| 11 | 
            -
             | 
| 12 | 
            -
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: Text2Video-Zero
         | 
| 3 | 
            +
            emoji: π
         | 
| 4 | 
            +
            colorFrom: green
         | 
| 5 | 
            +
            colorTo: blue
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
             
            sdk_version: 3.23.0
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
            +
            ---
         | 
|  | |
|  | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,73 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from model import Model, ModelType
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # from app_canny import create_demo as create_demo_canny
         | 
| 7 | 
            +
            from app_pose import create_demo as create_demo_pose
         | 
| 8 | 
            +
            from app_text_to_video import create_demo as create_demo_text_to_video
         | 
| 9 | 
            +
            from app_pix2pix_video import create_demo as create_demo_pix2pix_video
         | 
| 10 | 
            +
            # from app_canny_db import create_demo as create_demo_canny_db
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            model = Model(device='cuda', dtype=torch.float16)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            with gr.Blocks(css='style.css') as demo:
         | 
| 16 | 
            +
                gr.HTML(
         | 
| 17 | 
            +
                    """
         | 
| 18 | 
            +
                    <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
         | 
| 19 | 
            +
                    <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
         | 
| 20 | 
            +
                        Text2Video-Zero
         | 
| 21 | 
            +
                    </h1>
         | 
| 22 | 
            +
                    <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
         | 
| 23 | 
            +
                    We propose <b>Text2Video-Zero, the first zero-shot text-to-video syntenes framework</b>, that also natively supports, Video Instruct Pix2Pix, Pose Conditional, Edge Conditional 
         | 
| 24 | 
            +
                    and, Edge Conditional and DreamBooth Specialized applications.
         | 
| 25 | 
            +
                    </h2>
         | 
| 26 | 
            +
                    <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
         | 
| 27 | 
            +
                    Levon Khachatryan, Andranik Movsisyan, Vahram Tadevosyan, Roberto Henschel, Atlas Wang, Shant Navasardyan
         | 
| 28 | 
            +
                    and <a href="https://www.humphreyshi.com/home">Humphrey Shi</a> 
         | 
| 29 | 
            +
                    [<a href="" style="color:blue;">arXiv</a>] 
         | 
| 30 | 
            +
                    [<a href="" style="color:blue;">GitHub</a>]
         | 
| 31 | 
            +
                    </h3>
         | 
| 32 | 
            +
                    </div>
         | 
| 33 | 
            +
                    """)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                with gr.Tab('Zero-Shot Text2Video'):
         | 
| 36 | 
            +
                    # pass
         | 
| 37 | 
            +
                    create_demo_text_to_video(model)
         | 
| 38 | 
            +
                with gr.Tab('Video Instruct Pix2Pix'):
         | 
| 39 | 
            +
                    # pass
         | 
| 40 | 
            +
                    create_demo_pix2pix_video(model)
         | 
| 41 | 
            +
                with gr.Tab('Pose Conditional'):
         | 
| 42 | 
            +
                    # pass
         | 
| 43 | 
            +
                    create_demo_pose(model)
         | 
| 44 | 
            +
                with gr.Tab('Edge Conditional'):
         | 
| 45 | 
            +
                    pass
         | 
| 46 | 
            +
                    # create_demo_canny(model)
         | 
| 47 | 
            +
                with gr.Tab('Edge Conditional and Dreambooth Specialized'):
         | 
| 48 | 
            +
                    pass
         | 
| 49 | 
            +
                    # create_demo_canny_db(model)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                gr.HTML(
         | 
| 52 | 
            +
                    """
         | 
| 53 | 
            +
                    <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
         | 
| 54 | 
            +
                    <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
         | 
| 55 | 
            +
                    <b>Version: v1.0</b>
         | 
| 56 | 
            +
                    </h3>
         | 
| 57 | 
            +
                    <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
         | 
| 58 | 
            +
                    <b>Caution</b>: 
         | 
| 59 | 
            +
                    We would like the raise the awareness of users of this demo of its potential issues and concerns.
         | 
| 60 | 
            +
                    Like previous large foundation models, Text2Video-Zero could be problematic in some cases, partially we use pretrained Stable Diffusion, therefore Text2Video-Zero can Inherit Its Imperfections.
         | 
| 61 | 
            +
                    So far, we keep all features available for research testing both to show the great potential of the Text2Video-Zero framework and to collect important feedback to improve the model in the future.
         | 
| 62 | 
            +
                    We welcome researchers and users to report issues with the HuggingFace community discussion feature or email the authors.
         | 
| 63 | 
            +
                    </h3>
         | 
| 64 | 
            +
                    <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
         | 
| 65 | 
            +
                    <b>Biases and content acknowledgement</b>:
         | 
| 66 | 
            +
                    Beware that Text2Video-Zero may output content that reinforces or exacerbates societal biases, as well as realistic faces, pornography, and violence. 
         | 
| 67 | 
            +
                    Text2Video-Zero in this demo is meant only for research purposes.
         | 
| 68 | 
            +
                    </h3>
         | 
| 69 | 
            +
                    </div>
         | 
| 70 | 
            +
                    """)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            demo.launch(debug=True)
         | 
| 73 | 
            +
            # demo.queue(api_open=False).launch(file_directories=['temporal'], share=True)
         | 
    	
        app_pix2pix_video.py
    ADDED
    
    | @@ -0,0 +1,70 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from model import Model
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def create_demo(model: Model):
         | 
| 6 | 
            +
                examples = [
         | 
| 7 | 
            +
                    ['__assets__/pix2pix video/camel.mp4', 'make it Van Gogh Starry Night style'],
         | 
| 8 | 
            +
                    ['__assets__/pix2pix video/mini-cooper.mp4', 'make it Picasso style'],
         | 
| 9 | 
            +
                    ['__assets__/pix2pix video/snowboard.mp4', 'replace man with robot'],
         | 
| 10 | 
            +
                    ['__assets__/pix2pix video/white-swan.mp4', 'replace swan with mallard'],
         | 
| 11 | 
            +
                ]
         | 
| 12 | 
            +
                with gr.Blocks() as demo:
         | 
| 13 | 
            +
                    with gr.Row():
         | 
| 14 | 
            +
                        gr.Markdown('## Video Instruct Pix2Pix')
         | 
| 15 | 
            +
                    with gr.Row():
         | 
| 16 | 
            +
                        with gr.Column():
         | 
| 17 | 
            +
                            input_image = gr.Video(label="Input Video",source='upload', type='numpy', format="mp4", visible=True).style(height="auto")
         | 
| 18 | 
            +
                        with gr.Column():
         | 
| 19 | 
            +
                            prompt = gr.Textbox(label='Prompt')
         | 
| 20 | 
            +
                            run_button = gr.Button(label='Run')
         | 
| 21 | 
            +
                            with gr.Accordion('Advanced options', open=False):
         | 
| 22 | 
            +
                                image_resolution = gr.Slider(label='Image Resolution',
         | 
| 23 | 
            +
                                                             minimum=256,
         | 
| 24 | 
            +
                                                             maximum=1024,
         | 
| 25 | 
            +
                                                             value=512,
         | 
| 26 | 
            +
                                                             step=64)
         | 
| 27 | 
            +
                                seed = gr.Slider(label='Seed',
         | 
| 28 | 
            +
                                                 minimum=0,
         | 
| 29 | 
            +
                                                 maximum=65536,
         | 
| 30 | 
            +
                                                 value=0,
         | 
| 31 | 
            +
                                                 step=1)
         | 
| 32 | 
            +
                                start_t = gr.Slider(label='Starting time in seconds',
         | 
| 33 | 
            +
                                                    minimum=0,
         | 
| 34 | 
            +
                                                    maximum=10,
         | 
| 35 | 
            +
                                                    value=0,
         | 
| 36 | 
            +
                                                    step=1)
         | 
| 37 | 
            +
                                end_t = gr.Slider(label='End time in seconds (-1 corresponds to uploaded video duration)',
         | 
| 38 | 
            +
                                                  minimum=0,
         | 
| 39 | 
            +
                                                  maximum=10,
         | 
| 40 | 
            +
                                                  value=-1,
         | 
| 41 | 
            +
                                                  step=1)
         | 
| 42 | 
            +
                                out_fps = gr.Slider(label='Output video fps (-1 corresponds to uploaded video fps)',
         | 
| 43 | 
            +
                                                    minimum=1,
         | 
| 44 | 
            +
                                                    maximum=30,
         | 
| 45 | 
            +
                                                    value=-1,
         | 
| 46 | 
            +
                                                    step=1)
         | 
| 47 | 
            +
                        with gr.Column():
         | 
| 48 | 
            +
                            result = gr.Video(label='Output',
         | 
| 49 | 
            +
                                                show_label=True)
         | 
| 50 | 
            +
                    inputs = [
         | 
| 51 | 
            +
                        input_image,
         | 
| 52 | 
            +
                        prompt,
         | 
| 53 | 
            +
                        image_resolution,
         | 
| 54 | 
            +
                        seed,
         | 
| 55 | 
            +
                        start_t,
         | 
| 56 | 
            +
                        end_t,
         | 
| 57 | 
            +
                        out_fps
         | 
| 58 | 
            +
                    ]
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    gr.Examples(examples=examples,
         | 
| 61 | 
            +
                                inputs=inputs,
         | 
| 62 | 
            +
                                outputs=result,
         | 
| 63 | 
            +
                                # cache_examples=os.getenv('SYSTEM') == 'spaces',
         | 
| 64 | 
            +
                                run_on_click=False,
         | 
| 65 | 
            +
                                )
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    run_button.click(fn=model.process_pix2pix,
         | 
| 68 | 
            +
                                     inputs=inputs,
         | 
| 69 | 
            +
                                     outputs=result)
         | 
| 70 | 
            +
                return demo
         | 
    	
        app_pose.py
    ADDED
    
    | @@ -0,0 +1,62 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from model import Model
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            examples = [
         | 
| 7 | 
            +
                ['Motion 1', "A Robot is dancing in Sahara desert"],
         | 
| 8 | 
            +
                ['Motion 2', "A Robot is dancing in Sahara desert"],
         | 
| 9 | 
            +
                ['Motion 3', "A Robot is dancing in Sahara desert"],
         | 
| 10 | 
            +
                ['Motion 4', "A Robot is dancing in Sahara desert"],
         | 
| 11 | 
            +
                ['Motion 5', "A Robot is dancing in Sahara desert"],
         | 
| 12 | 
            +
            ]
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            def create_demo(model: Model):
         | 
| 15 | 
            +
                with gr.Blocks() as demo:
         | 
| 16 | 
            +
                    with gr.Row():
         | 
| 17 | 
            +
                        gr.Markdown('## Text and Pose Conditional Video Generation')
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    with gr.Row():
         | 
| 20 | 
            +
                        gr.Markdown('### You must select one pose sequence shown below, or use the examples')
         | 
| 21 | 
            +
                        with gr.Column():
         | 
| 22 | 
            +
                            gallery_pose_sequence = gr.Gallery(label="Pose Sequence", value=[('__assets__/poses_skeleton_gifs/dance1.gif', "Motion 1"), ('__assets__/poses_skeleton_gifs/dance2.gif', "Motion 2"), ('__assets__/poses_skeleton_gifs/dance3.gif', "Motion 3"), ('__assets__/poses_skeleton_gifs/dance4.gif', "Motion 4"), ('__assets__/poses_skeleton_gifs/dance5.gif', "Motion 5")]).style(grid=[2], height="auto")
         | 
| 23 | 
            +
                            input_video_path = gr.Textbox(label="Pose Sequence",visible=False,value="Motion 1")
         | 
| 24 | 
            +
                            gr.Markdown("## Selection")
         | 
| 25 | 
            +
                            pose_sequence_selector = gr.Markdown('Pose Sequence: **Motion 1**')
         | 
| 26 | 
            +
                        with gr.Column():
         | 
| 27 | 
            +
                            prompt = gr.Textbox(label='Prompt')
         | 
| 28 | 
            +
                            run_button = gr.Button(label='Run')
         | 
| 29 | 
            +
                        with gr.Column():
         | 
| 30 | 
            +
                            result = gr.Image(label="Generated Video")
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    input_video_path.change(on_video_path_update, None, pose_sequence_selector)
         | 
| 33 | 
            +
                    gallery_pose_sequence.select(pose_gallery_callback, None, input_video_path)
         | 
| 34 | 
            +
                    inputs = [
         | 
| 35 | 
            +
                        input_video_path,
         | 
| 36 | 
            +
                        #pose_sequence, 
         | 
| 37 | 
            +
                        prompt,
         | 
| 38 | 
            +
                    ]
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    gr.Examples(examples=examples,
         | 
| 41 | 
            +
                                inputs=inputs,
         | 
| 42 | 
            +
                                outputs=result,
         | 
| 43 | 
            +
                                # cache_examples=os.getenv('SYSTEM') == 'spaces',
         | 
| 44 | 
            +
                                fn=model.process_controlnet_pose,
         | 
| 45 | 
            +
                                run_on_click=False,
         | 
| 46 | 
            +
                                )
         | 
| 47 | 
            +
                    #fn=process,
         | 
| 48 | 
            +
                    #)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
                    run_button.click(fn=model.process_controlnet_pose,
         | 
| 52 | 
            +
                                     inputs=inputs,
         | 
| 53 | 
            +
                                     outputs=result,)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                return demo
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def on_video_path_update(evt: gr.EventData):
         | 
| 59 | 
            +
                return f'Pose Sequence: **{evt._data}**'
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            def pose_gallery_callback(evt: gr.SelectData):
         | 
| 62 | 
            +
                return f"Motion {evt.index+1}"
         | 
    	
        app_text_to_video.py
    ADDED
    
    | @@ -0,0 +1,44 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            from model import Model
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            examples = [
         | 
| 5 | 
            +
                "an astronaut waving the arm on the moon",
         | 
| 6 | 
            +
                "a sloth surfing on a wakeboard",
         | 
| 7 | 
            +
                        "an astronaut walking on a street",
         | 
| 8 | 
            +
                        "a cute cat walking on grass",
         | 
| 9 | 
            +
                "a horse is galloping on a street",
         | 
| 10 | 
            +
               "an astronaut is skiing down the hill",
         | 
| 11 | 
            +
                "a gorilla walking alone down the street"
         | 
| 12 | 
            +
                "a gorilla dancing on times square",
         | 
| 13 | 
            +
                "A panda dancing dancing like crazy on Times Square",
         | 
| 14 | 
            +
                ]
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def create_demo(model: Model):
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                with gr.Blocks() as demo:
         | 
| 20 | 
            +
                    with gr.Row():
         | 
| 21 | 
            +
                        gr.Markdown('## Text2Video-Zero: Video Generation')
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    with gr.Row():
         | 
| 24 | 
            +
                        with gr.Column():
         | 
| 25 | 
            +
                            prompt = gr.Textbox(label='Prompt')
         | 
| 26 | 
            +
                            run_button = gr.Button(label='Run')
         | 
| 27 | 
            +
                        with gr.Column():
         | 
| 28 | 
            +
                            result = gr.Video(label="Generated Video")
         | 
| 29 | 
            +
                    inputs = [
         | 
| 30 | 
            +
                        prompt, 
         | 
| 31 | 
            +
                    ]
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    gr.Examples(examples=examples,
         | 
| 34 | 
            +
                            inputs=inputs,
         | 
| 35 | 
            +
                            outputs=result,
         | 
| 36 | 
            +
                            cache_examples=False,
         | 
| 37 | 
            +
                            #cache_examples=os.getenv('SYSTEM') == 'spaces')
         | 
| 38 | 
            +
                            run_on_click=False,
         | 
| 39 | 
            +
                    )
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    run_button.click(fn=model.process_text2video,
         | 
| 42 | 
            +
                                     inputs=inputs,
         | 
| 43 | 
            +
                                     outputs=result,)
         | 
| 44 | 
            +
                return demo
         | 
    	
        config.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            save_memory = False
         | 
    	
        gradio_utils.py
    ADDED
    
    | @@ -0,0 +1,77 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # App Canny utils
         | 
| 2 | 
            +
            def edge_path_to_video_path(edge_path):
         | 
| 3 | 
            +
                video_path = edge_path
         | 
| 4 | 
            +
             | 
| 5 | 
            +
                vid_name = edge_path.split("/")[-1]
         | 
| 6 | 
            +
                if vid_name == "butterfly.mp4":
         | 
| 7 | 
            +
                    video_path = "__assets__/canny_videos_mp4/butterfly.mp4"
         | 
| 8 | 
            +
                elif vid_name == "deer.mp4":
         | 
| 9 | 
            +
                    video_path = "__assets__/canny_videos_mp4/deer.mp4"
         | 
| 10 | 
            +
                elif vid_name == "fox.mp4":
         | 
| 11 | 
            +
                    video_path = "__assets__/canny_videos_mp4/fox.mp4"
         | 
| 12 | 
            +
                elif vid_name == "girl_dancing.mp4":
         | 
| 13 | 
            +
                    video_path = "__assets__/canny_videos_mp4/girl_dancing.mp4"
         | 
| 14 | 
            +
                elif vid_name == "girl_turning.mp4":
         | 
| 15 | 
            +
                    video_path = "__assets__/canny_videos_mp4/girl_turning.mp4"
         | 
| 16 | 
            +
                elif vid_name == "halloween.mp4":
         | 
| 17 | 
            +
                    video_path = "__assets__/canny_videos_mp4/halloween.mp4"
         | 
| 18 | 
            +
                elif vid_name == "santa.mp4":
         | 
| 19 | 
            +
                    video_path = "__assets__/canny_videos_mp4/santa.mp4"
         | 
| 20 | 
            +
                return video_path
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            # App Pose utils
         | 
| 24 | 
            +
            def motion_to_video_path(motion):
         | 
| 25 | 
            +
                videos = [
         | 
| 26 | 
            +
                    "__assets__/poses_skeleton_gifs/dance1_corr.mp4",
         | 
| 27 | 
            +
                    "__assets__/poses_skeleton_gifs/dance2_corr.mp4",
         | 
| 28 | 
            +
                    "__assets__/poses_skeleton_gifs/dance3_corr.mp4",
         | 
| 29 | 
            +
                    "__assets__/poses_skeleton_gifs/dance4_corr.mp4",
         | 
| 30 | 
            +
                    "__assets__/poses_skeleton_gifs/dance5_corr.mp4"
         | 
| 31 | 
            +
                ]
         | 
| 32 | 
            +
                id = int(motion.split(" ")[1]) - 1
         | 
| 33 | 
            +
                return videos[id]
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            # App Canny Dreambooth utils
         | 
| 37 | 
            +
            def get_video_from_canny_selection(canny_selection):
         | 
| 38 | 
            +
                if canny_selection == "woman1":
         | 
| 39 | 
            +
                    input_video_path = "__assets__/db_files/woman1.mp4"
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                elif canny_selection == "woman2":
         | 
| 42 | 
            +
                    input_video_path = "__assets__/db_files/woman2.mp4"
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                elif canny_selection == "man1":
         | 
| 45 | 
            +
                    input_video_path = "__assets__/db_files/man1.mp4"
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                elif canny_selection == "woman3":
         | 
| 48 | 
            +
                    input_video_path = "__assets__/db_files/woman3.mp4"
         | 
| 49 | 
            +
                else:
         | 
| 50 | 
            +
                    raise Exception
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                return input_video_path
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            def get_model_from_db_selection(db_selection):
         | 
| 56 | 
            +
                if db_selection == "Anime DB":
         | 
| 57 | 
            +
                    input_video_path = 'PAIR/controlnet-canny-anime'
         | 
| 58 | 
            +
                elif db_selection == "Avatar DB":
         | 
| 59 | 
            +
                    input_video_path = 'PAIR/controlnet-canny-avatar'
         | 
| 60 | 
            +
                elif db_selection == "GTA-5 DB":
         | 
| 61 | 
            +
                    input_video_path = 'PAIR/controlnet-canny-gta5'
         | 
| 62 | 
            +
                elif db_selection == "Arcane DB":
         | 
| 63 | 
            +
                    input_video_path = 'PAIR/controlnet-canny-arcane'
         | 
| 64 | 
            +
                else:
         | 
| 65 | 
            +
                    raise Exception
         | 
| 66 | 
            +
                return input_video_path
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            def get_db_name_from_id(id):
         | 
| 70 | 
            +
                db_names = ["Anime DB", "Arcane DB", "GTA-5 DB", "Avatar DB"]
         | 
| 71 | 
            +
                return db_names[id]
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def get_canny_name_from_id(id):
         | 
| 75 | 
            +
                canny_names = ["woman1", "woman2", "man1", "woman3"]
         | 
| 76 | 
            +
                return canny_names[id]
         | 
| 77 | 
            +
             | 
    	
        model.py
    ADDED
    
    | @@ -0,0 +1,296 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from enum import Enum
         | 
| 2 | 
            +
            import gc
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import decord
         | 
| 7 | 
            +
            from diffusers import StableDiffusionInstructPix2PixPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel
         | 
| 8 | 
            +
            from diffusers.schedulers import EulerAncestralDiscreteScheduler, DDIMScheduler
         | 
| 9 | 
            +
            from text_to_video.text_to_video_pipeline import TextToVideoPipeline
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import utils
         | 
| 12 | 
            +
            import gradio_utils
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            decord.bridge.set_bridge('torch')
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class ModelType(Enum):
         | 
| 18 | 
            +
                Pix2Pix_Video = 1,
         | 
| 19 | 
            +
                Text2Video = 2,
         | 
| 20 | 
            +
                ControlNetCanny = 3,
         | 
| 21 | 
            +
                ControlNetCannyDB = 4,
         | 
| 22 | 
            +
                ControlNetPose = 5,
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class Model:
         | 
| 26 | 
            +
                def __init__(self, device, dtype, **kwargs):
         | 
| 27 | 
            +
                    self.device = device
         | 
| 28 | 
            +
                    self.dtype = dtype
         | 
| 29 | 
            +
                    self.generator = torch.Generator(device=device)
         | 
| 30 | 
            +
                    self.pipe_dict = {
         | 
| 31 | 
            +
                        ModelType.Pix2Pix_Video: StableDiffusionInstructPix2PixPipeline,
         | 
| 32 | 
            +
                        ModelType.Text2Video: TextToVideoPipeline,
         | 
| 33 | 
            +
                        ModelType.ControlNetCanny: StableDiffusionControlNetPipeline,
         | 
| 34 | 
            +
                        ModelType.ControlNetCannyDB: StableDiffusionControlNetPipeline,
         | 
| 35 | 
            +
                        ModelType.ControlNetPose: StableDiffusionControlNetPipeline,
         | 
| 36 | 
            +
                    }
         | 
| 37 | 
            +
                    self.controlnet_attn_proc = utils.CrossFrameAttnProcessor(unet_chunk_size=2)
         | 
| 38 | 
            +
                    self.pix2pix_attn_proc = utils.CrossFrameAttnProcessor(unet_chunk_size=3)
         | 
| 39 | 
            +
                    self.text2video_attn_proc = utils.CrossFrameAttnProcessor(unet_chunk_size=2)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    self.pipe = None
         | 
| 42 | 
            +
                    self.model_type = None
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    self.states = {}
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def set_model(self, model_type: ModelType, model_id: str, **kwargs):
         | 
| 47 | 
            +
                    if self.pipe is not None:
         | 
| 48 | 
            +
                        del self.pipe
         | 
| 49 | 
            +
                    torch.cuda.empty_cache()
         | 
| 50 | 
            +
                    gc.collect()
         | 
| 51 | 
            +
                    safety_checker = kwargs.pop('safety_checker', None)
         | 
| 52 | 
            +
                    self.pipe = self.pipe_dict[model_type].from_pretrained(model_id, safety_checker=safety_checker, **kwargs).to(self.device).to(self.dtype)
         | 
| 53 | 
            +
                    self.model_type = model_type
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def inference_chunk(self, frame_ids, **kwargs):
         | 
| 56 | 
            +
                    if self.pipe is None:
         | 
| 57 | 
            +
                        return
         | 
| 58 | 
            +
                    image = kwargs.pop('image')
         | 
| 59 | 
            +
                    prompt = np.array(kwargs.pop('prompt'))
         | 
| 60 | 
            +
                    negative_prompt = np.array(kwargs.pop('negative_prompt', ''))
         | 
| 61 | 
            +
                    latents = None
         | 
| 62 | 
            +
                    if 'latents' in kwargs:
         | 
| 63 | 
            +
                        latents = kwargs.pop('latents')[frame_ids]
         | 
| 64 | 
            +
                    return self.pipe(image=image[frame_ids],
         | 
| 65 | 
            +
                                     prompt=prompt[frame_ids].tolist(),
         | 
| 66 | 
            +
                                     negative_prompt=negative_prompt[frame_ids].tolist(),
         | 
| 67 | 
            +
                                     latents=latents,
         | 
| 68 | 
            +
                                     generator=self.generator,
         | 
| 69 | 
            +
                                     **kwargs)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def inference(self, split_to_chunks=False, chunk_size=8, **kwargs):
         | 
| 72 | 
            +
                    if self.pipe is None:
         | 
| 73 | 
            +
                        return
         | 
| 74 | 
            +
                    seed = kwargs.pop('seed', 0)
         | 
| 75 | 
            +
                    kwargs.pop('generator', '')
         | 
| 76 | 
            +
                    # self.generator.manual_seed(seed)
         | 
| 77 | 
            +
                    if split_to_chunks:
         | 
| 78 | 
            +
                        assert 'image' in kwargs
         | 
| 79 | 
            +
                        assert 'prompt' in kwargs
         | 
| 80 | 
            +
                        image = kwargs.pop('image')
         | 
| 81 | 
            +
                        prompt = kwargs.pop('prompt')
         | 
| 82 | 
            +
                        negative_prompt = kwargs.pop('negative_prompt', '')
         | 
| 83 | 
            +
                        f = image.shape[0]
         | 
| 84 | 
            +
                        chunk_ids = np.arange(0, f, chunk_size - 1)
         | 
| 85 | 
            +
                        result = []
         | 
| 86 | 
            +
                        for i in range(len(chunk_ids)):
         | 
| 87 | 
            +
                            ch_start = chunk_ids[i]
         | 
| 88 | 
            +
                            ch_end = f if i == len(chunk_ids) - 1 else chunk_ids[i + 1]
         | 
| 89 | 
            +
                            frame_ids = [0] + list(range(ch_start, ch_end))
         | 
| 90 | 
            +
                            self.generator.manual_seed(seed)
         | 
| 91 | 
            +
                            print(f'Processing chunk {i + 1} / {len(chunk_ids)}')
         | 
| 92 | 
            +
                            result.append(self.inference_chunk(frame_ids=frame_ids,
         | 
| 93 | 
            +
                                                               image=image,
         | 
| 94 | 
            +
                                                               prompt=[prompt] * f,
         | 
| 95 | 
            +
                                                               negative_prompt=[negative_prompt] * f,
         | 
| 96 | 
            +
                                                               **kwargs).images[1:])
         | 
| 97 | 
            +
                        result = np.concatenate(result)
         | 
| 98 | 
            +
                        return result
         | 
| 99 | 
            +
                    else:
         | 
| 100 | 
            +
                        return self.pipe(generator=self.generator, **kwargs).videos[0]
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def process_controlnet_canny(self,
         | 
| 103 | 
            +
                                             video_path,
         | 
| 104 | 
            +
                                             prompt,
         | 
| 105 | 
            +
                                             num_inference_steps=20,
         | 
| 106 | 
            +
                                             controlnet_conditioning_scale=1.0,
         | 
| 107 | 
            +
                                             guidance_scale=9.0,
         | 
| 108 | 
            +
                                             seed=42,
         | 
| 109 | 
            +
                                             eta=0.0,
         | 
| 110 | 
            +
                                             low_threshold=100,
         | 
| 111 | 
            +
                                             high_threshold=200,
         | 
| 112 | 
            +
                                             resolution=512):
         | 
| 113 | 
            +
                    video_path = gradio_utils.edge_path_to_video_path(video_path)
         | 
| 114 | 
            +
                    if self.model_type != ModelType.ControlNetCanny:
         | 
| 115 | 
            +
                        controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
         | 
| 116 | 
            +
                        self.set_model(ModelType.ControlNetCanny, model_id="runwayml/stable-diffusion-v1-5", controlnet=controlnet)
         | 
| 117 | 
            +
                        self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
         | 
| 118 | 
            +
                        self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc)
         | 
| 119 | 
            +
                        self.pipe.controlnet.set_attn_processor(processor=self.controlnet_attn_proc)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    # TODO: Check scheduler
         | 
| 122 | 
            +
                    added_prompt = 'best quality, extremely detailed'
         | 
| 123 | 
            +
                    negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    video, fps = utils.prepare_video(video_path, resolution, self.device, self.dtype, False)
         | 
| 126 | 
            +
                    control = utils.pre_process_canny(video, low_threshold, high_threshold).to(self.device).to(self.dtype)
         | 
| 127 | 
            +
                    f, _, h, w = video.shape
         | 
| 128 | 
            +
                    self.generator.manual_seed(seed)
         | 
| 129 | 
            +
                    latents = torch.randn((1, 4, h//8, w//8), dtype=self.dtype, device=self.device, generator=self.generator)
         | 
| 130 | 
            +
                    latents = latents.repeat(f, 1, 1, 1)
         | 
| 131 | 
            +
                    result = self.inference(image=control,
         | 
| 132 | 
            +
                                            prompt=prompt + ', ' + added_prompt,
         | 
| 133 | 
            +
                                            height=h,
         | 
| 134 | 
            +
                                            width=w,
         | 
| 135 | 
            +
                                            negative_prompt=negative_prompts,
         | 
| 136 | 
            +
                                            num_inference_steps=num_inference_steps,
         | 
| 137 | 
            +
                                            guidance_scale=guidance_scale,
         | 
| 138 | 
            +
                                            controlnet_conditioning_scale=controlnet_conditioning_scale,
         | 
| 139 | 
            +
                                            eta=eta,
         | 
| 140 | 
            +
                                            latents=latents,
         | 
| 141 | 
            +
                                            seed=seed,
         | 
| 142 | 
            +
                                            output_type='numpy',
         | 
| 143 | 
            +
                                            split_to_chunks=True,
         | 
| 144 | 
            +
                                            chunk_size=8,
         | 
| 145 | 
            +
                                            )
         | 
| 146 | 
            +
                    return utils.create_video(result, fps)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                def process_controlnet_pose(self,
         | 
| 149 | 
            +
                                            video_path,
         | 
| 150 | 
            +
                                            prompt,
         | 
| 151 | 
            +
                                            num_inference_steps=20,
         | 
| 152 | 
            +
                                            controlnet_conditioning_scale=1.0,
         | 
| 153 | 
            +
                                            guidance_scale=9.0,
         | 
| 154 | 
            +
                                            seed=42,
         | 
| 155 | 
            +
                                            eta=0.0,
         | 
| 156 | 
            +
                                            resolution=512):
         | 
| 157 | 
            +
                    video_path = gradio_utils.motion_to_video_path(video_path)
         | 
| 158 | 
            +
                    if self.model_type != ModelType.ControlNetPose:
         | 
| 159 | 
            +
                        controlnet = ControlNetModel.from_pretrained("fusing/stable-diffusion-v1-5-controlnet-openpose")
         | 
| 160 | 
            +
                        self.set_model(ModelType.ControlNetPose, model_id="runwayml/stable-diffusion-v1-5", controlnet=controlnet)
         | 
| 161 | 
            +
                        self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
         | 
| 162 | 
            +
                        self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc)
         | 
| 163 | 
            +
                        self.pipe.controlnet.set_attn_processor(processor=self.controlnet_attn_proc)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    added_prompt = 'best quality, extremely detailed, HD, ultra-realistic, 8K, HQ, masterpiece, trending on artstation, art, smooth'
         | 
| 166 | 
            +
                    negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic'
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    video, fps = utils.prepare_video(video_path, resolution, self.device, self.dtype, False, output_fps=4)
         | 
| 169 | 
            +
                    control = utils.pre_process_pose(video, apply_pose_detect=False).to(self.device).to(self.dtype)
         | 
| 170 | 
            +
                    f, _, h, w = video.shape
         | 
| 171 | 
            +
                    self.generator.manual_seed(seed)
         | 
| 172 | 
            +
                    latents = torch.randn((1, 4, h//8, w//8), dtype=self.dtype, device=self.device, generator=self.generator)
         | 
| 173 | 
            +
                    latents = latents.repeat(f, 1, 1, 1)
         | 
| 174 | 
            +
                    result = self.inference(image=control,
         | 
| 175 | 
            +
                                            prompt=prompt + ', ' + added_prompt,
         | 
| 176 | 
            +
                                            height=h,
         | 
| 177 | 
            +
                                            width=w,
         | 
| 178 | 
            +
                                            negative_prompt=negative_prompts,
         | 
| 179 | 
            +
                                            num_inference_steps=num_inference_steps,
         | 
| 180 | 
            +
                                            guidance_scale=guidance_scale,
         | 
| 181 | 
            +
                                            controlnet_conditioning_scale=controlnet_conditioning_scale,
         | 
| 182 | 
            +
                                            eta=eta,
         | 
| 183 | 
            +
                                            latents=latents,
         | 
| 184 | 
            +
                                            seed=seed,
         | 
| 185 | 
            +
                                            output_type='numpy',
         | 
| 186 | 
            +
                                            split_to_chunks=True,
         | 
| 187 | 
            +
                                            chunk_size=8,
         | 
| 188 | 
            +
                                            )
         | 
| 189 | 
            +
                    return utils.create_gif(result, fps)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                def process_controlnet_canny_db(self,
         | 
| 192 | 
            +
                                                db_path,
         | 
| 193 | 
            +
                                                video_path,
         | 
| 194 | 
            +
                                                prompt,
         | 
| 195 | 
            +
                                                num_inference_steps=20,
         | 
| 196 | 
            +
                                                controlnet_conditioning_scale=1.0,
         | 
| 197 | 
            +
                                                guidance_scale=9.0,
         | 
| 198 | 
            +
                                                seed=42,
         | 
| 199 | 
            +
                                                eta=0.0,
         | 
| 200 | 
            +
                                                low_threshold=100,
         | 
| 201 | 
            +
                                                high_threshold=200,
         | 
| 202 | 
            +
                                                resolution=512):
         | 
| 203 | 
            +
                    db_path = gradio_utils.get_model_from_db_selection(db_path)
         | 
| 204 | 
            +
                    video_path = gradio_utils.get_video_from_canny_selection(video_path)
         | 
| 205 | 
            +
                    # Load db and controlnet weights
         | 
| 206 | 
            +
                    if 'db_path' not in self.states or db_path != self.states['db_path']:
         | 
| 207 | 
            +
                        controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
         | 
| 208 | 
            +
                        self.set_model(ModelType.ControlNetCannyDB, model_id=db_path, controlnet=controlnet)
         | 
| 209 | 
            +
                        self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
         | 
| 210 | 
            +
                        self.pipe.unet.set_attn_processor(processor=self.controlnet_attn_proc)
         | 
| 211 | 
            +
                        self.pipe.controlnet.set_attn_processor(processor=self.controlnet_attn_proc)
         | 
| 212 | 
            +
                        self.states['db_path'] = db_path
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    added_prompt = 'best quality, extremely detailed'
         | 
| 215 | 
            +
                    negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    video, fps = utils.prepare_video(video_path, resolution, self.device, self.dtype, False)
         | 
| 218 | 
            +
                    control = utils.pre_process_canny(video, low_threshold, high_threshold).to(self.device).to(self.dtype)
         | 
| 219 | 
            +
                    f, _, h, w = video.shape
         | 
| 220 | 
            +
                    self.generator.manual_seed(seed)
         | 
| 221 | 
            +
                    latents = torch.randn((1, 4, h//8, w//8), dtype=self.dtype, device=self.device, generator=self.generator)
         | 
| 222 | 
            +
                    latents = latents.repeat(f, 1, 1, 1)
         | 
| 223 | 
            +
                    result = self.inference(image=control,
         | 
| 224 | 
            +
                                            prompt=prompt + ', ' + added_prompt,
         | 
| 225 | 
            +
                                            height=h,
         | 
| 226 | 
            +
                                            width=w,
         | 
| 227 | 
            +
                                            negative_prompt=negative_prompts,
         | 
| 228 | 
            +
                                            num_inference_steps=num_inference_steps,
         | 
| 229 | 
            +
                                            guidance_scale=guidance_scale,
         | 
| 230 | 
            +
                                            controlnet_conditioning_scale=controlnet_conditioning_scale,
         | 
| 231 | 
            +
                                            eta=eta,
         | 
| 232 | 
            +
                                            latents=latents,
         | 
| 233 | 
            +
                                            seed=seed,
         | 
| 234 | 
            +
                                            output_type='numpy',
         | 
| 235 | 
            +
                                            split_to_chunks=True,
         | 
| 236 | 
            +
                                            chunk_size=8,
         | 
| 237 | 
            +
                                            )
         | 
| 238 | 
            +
                    return utils.create_gif(result, fps)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                def process_pix2pix(self, video, prompt, resolution=512, seed=0, start_t=0, end_t=-1, out_fps=-1):
         | 
| 241 | 
            +
                    if self.model_type != ModelType.Pix2Pix_Video:
         | 
| 242 | 
            +
                        self.set_model(ModelType.Pix2Pix_Video, model_id="timbrooks/instruct-pix2pix")
         | 
| 243 | 
            +
                        self.pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(self.pipe.scheduler.config)
         | 
| 244 | 
            +
                        self.pipe.unet.set_attn_processor(processor=self.pix2pix_attn_proc)
         | 
| 245 | 
            +
                    video, fps = utils.prepare_video(video, resolution, self.device, self.dtype, True, start_t, end_t, out_fps)
         | 
| 246 | 
            +
                    self.generator.manual_seed(seed)
         | 
| 247 | 
            +
                    result = self.inference(image=video,
         | 
| 248 | 
            +
                                            prompt=prompt,
         | 
| 249 | 
            +
                                            seed=seed,
         | 
| 250 | 
            +
                                            output_type='numpy',
         | 
| 251 | 
            +
                                            num_inference_steps=50,
         | 
| 252 | 
            +
                                            image_guidance_scale=1.5,
         | 
| 253 | 
            +
                                            split_to_chunks=True,
         | 
| 254 | 
            +
                                            chunk_size=8,
         | 
| 255 | 
            +
                                            )
         | 
| 256 | 
            +
                    return utils.create_video(result, fps)
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                def process_text2video(self, prompt, resolution=512, seed=24, num_frames=8, fps=4, t0=881, t1=941,
         | 
| 259 | 
            +
                                       use_cf_attn=True, use_motion_field=True, use_foreground_motion_field=False,
         | 
| 260 | 
            +
                                       smooth_bg=False, smooth_bg_strength=0.4, motion_field_strength=12):
         | 
| 261 | 
            +
                    
         | 
| 262 | 
            +
                    if self.model_type != ModelType.Text2Video:
         | 
| 263 | 
            +
                        unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
         | 
| 264 | 
            +
                        self.set_model(ModelType.Text2Video, model_id="runwayml/stable-diffusion-v1-5", unet=unet)
         | 
| 265 | 
            +
                        self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
         | 
| 266 | 
            +
                        self.pipe.unet.set_attn_processor(processor=self.text2video_attn_proc)
         | 
| 267 | 
            +
                        self.generator.manual_seed(seed)
         | 
| 268 | 
            +
                    
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting"
         | 
| 271 | 
            +
                    self.generator.manual_seed(seed)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    prompt = prompt.rstrip()
         | 
| 274 | 
            +
                    if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1]  == "."):
         | 
| 275 | 
            +
                        prompt = prompt.rstrip()[:-1]
         | 
| 276 | 
            +
                    prompt = prompt.rstrip()
         | 
| 277 | 
            +
                    prompt = prompt + ", "+added_prompt
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                    result = self.inference(prompt=[prompt],
         | 
| 280 | 
            +
                                            video_length=num_frames,
         | 
| 281 | 
            +
                                            height=resolution,
         | 
| 282 | 
            +
                                            width=resolution,
         | 
| 283 | 
            +
                                            num_inference_steps=50,
         | 
| 284 | 
            +
                                            guidance_scale=7.5,
         | 
| 285 | 
            +
                                            guidance_stop_step=1.0,
         | 
| 286 | 
            +
                                            t0=t0,
         | 
| 287 | 
            +
                                            t1=t1,
         | 
| 288 | 
            +
                                            use_foreground_motion_field=use_foreground_motion_field,
         | 
| 289 | 
            +
                                            motion_field_strength=motion_field_strength,
         | 
| 290 | 
            +
                                            use_motion_field=use_motion_field,
         | 
| 291 | 
            +
                                            smooth_bg=smooth_bg,
         | 
| 292 | 
            +
                                            smooth_bg_strength=smooth_bg_strength,
         | 
| 293 | 
            +
                                            seed=seed,
         | 
| 294 | 
            +
                                            output_type='numpy',
         | 
| 295 | 
            +
                                            )
         | 
| 296 | 
            +
                    return utils.create_video(result, fps)
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            accelerate==0.16.0
         | 
| 2 | 
            +
            addict==2.4.0
         | 
| 3 | 
            +
            albumentations==1.3.0
         | 
| 4 | 
            +
            basicsr==1.4.2
         | 
| 5 | 
            +
            decord==0.6.0
         | 
| 6 | 
            +
            diffusers==0.14.0
         | 
| 7 | 
            +
            einops==0.6.0
         | 
| 8 | 
            +
            gradio==3.23.0
         | 
| 9 | 
            +
            kornia==0.6
         | 
| 10 | 
            +
            imageio==2.9.0
         | 
| 11 | 
            +
            imageio-ffmpeg==0.4.2
         | 
| 12 | 
            +
            invisible-watermark>=0.1.5
         | 
| 13 | 
            +
            moviepy==1.0.3
         | 
| 14 | 
            +
            numpy==1.24.1
         | 
| 15 | 
            +
            omegaconf==2.3.0
         | 
| 16 | 
            +
            open_clip_torch==2.16.0
         | 
| 17 | 
            +
            opencv_python==4.7.0.68
         | 
| 18 | 
            +
            opencv-contrib-python==4.3.0.36
         | 
| 19 | 
            +
            Pillow==9.4.0
         | 
| 20 | 
            +
            pytorch_lightning==1.5.0
         | 
| 21 | 
            +
            prettytable==3.6.0
         | 
| 22 | 
            +
            scikit_image==0.19.3
         | 
| 23 | 
            +
            scipy==1.10.1
         | 
| 24 | 
            +
            tensorboardX==2.6
         | 
| 25 | 
            +
            tqdm==4.64.1
         | 
| 26 | 
            +
            timm==0.6.12
         | 
| 27 | 
            +
            transformers==4.26.0
         | 
| 28 | 
            +
            test-tube>=0.7.5
         | 
| 29 | 
            +
            webdataset==0.2.5
         | 
| 30 | 
            +
            yapf==0.32.0
         | 
| 31 | 
            +
            safetensors==0.2.7
         | 
| 32 | 
            +
            huggingface-hub==0.13.0
         | 
| 33 | 
            +
            torch==1.13.1
         | 
| 34 | 
            +
            torchvision==0.14.1
         | 
    	
        share.py
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import config
         | 
| 2 | 
            +
            from cldm.hack import disable_verbosity, enable_sliced_attention
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            disable_verbosity()
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            if config.save_memory:
         | 
| 8 | 
            +
                enable_sliced_attention()
         | 
    	
        style.css
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            h1 {
         | 
| 2 | 
            +
              text-align: center;
         | 
| 3 | 
            +
            }
         | 
    	
        utils.py
    ADDED
    
    | @@ -0,0 +1,187 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torchvision
         | 
| 5 | 
            +
            from torchvision.transforms import Resize
         | 
| 6 | 
            +
            import imageio
         | 
| 7 | 
            +
            from einops import rearrange
         | 
| 8 | 
            +
            import cv2
         | 
| 9 | 
            +
            from annotator.util import resize_image, HWC3
         | 
| 10 | 
            +
            from annotator.canny import CannyDetector
         | 
| 11 | 
            +
            from annotator.openpose import OpenposeDetector
         | 
| 12 | 
            +
            import decord
         | 
| 13 | 
            +
            decord.bridge.set_bridge('torch')
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            apply_canny = CannyDetector()
         | 
| 16 | 
            +
            apply_openpose = OpenposeDetector()
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def add_watermark(image, im_size, watermark_path="__assets__/pair_watermark.png",
         | 
| 20 | 
            +
                              wmsize=16, bbuf=5, opacity=0.9):
         | 
| 21 | 
            +
                '''
         | 
| 22 | 
            +
                Creates a watermark on the saved inference image.
         | 
| 23 | 
            +
                We request that you do not remove this to properly assign credit to
         | 
| 24 | 
            +
                Shi-Lab's work.
         | 
| 25 | 
            +
                '''
         | 
| 26 | 
            +
                watermark = Image.open(watermark_path).resize((wmsize, wmsize))
         | 
| 27 | 
            +
                loc = im_size - wmsize - bbuf
         | 
| 28 | 
            +
                image[:,:,loc:-bbuf, loc:-bbuf] = watermark
         | 
| 29 | 
            +
                return image
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def pre_process_canny(input_video, low_threshold=100, high_threshold=200):
         | 
| 33 | 
            +
                detected_maps = []
         | 
| 34 | 
            +
                for frame in input_video:
         | 
| 35 | 
            +
                    img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8)
         | 
| 36 | 
            +
                    detected_map = apply_canny(img, low_threshold, high_threshold)
         | 
| 37 | 
            +
                    detected_map = HWC3(detected_map)
         | 
| 38 | 
            +
                    detected_maps.append(detected_map[None])
         | 
| 39 | 
            +
                detected_maps = np.concatenate(detected_maps)
         | 
| 40 | 
            +
                control = torch.from_numpy(detected_maps.copy()).float() / 255.0
         | 
| 41 | 
            +
                return rearrange(control, 'f h w c -> f c h w')
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def pre_process_pose(input_video, apply_pose_detect: bool = True):
         | 
| 45 | 
            +
                detected_maps = []
         | 
| 46 | 
            +
                for frame in input_video:
         | 
| 47 | 
            +
                    img = rearrange(frame, 'c h w -> h w c').cpu().numpy().astype(np.uint8)
         | 
| 48 | 
            +
                    img = HWC3(img)
         | 
| 49 | 
            +
                    if apply_pose_detect:
         | 
| 50 | 
            +
                        detected_map, _ = apply_openpose(img)
         | 
| 51 | 
            +
                    else:
         | 
| 52 | 
            +
                        detected_map = img
         | 
| 53 | 
            +
                    detected_map = HWC3(detected_map)
         | 
| 54 | 
            +
                    H, W, C = img.shape
         | 
| 55 | 
            +
                    detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
         | 
| 56 | 
            +
                    detected_maps.append(detected_map[None])
         | 
| 57 | 
            +
                detected_maps = np.concatenate(detected_maps)
         | 
| 58 | 
            +
                control = torch.from_numpy(detected_maps.copy()).float() / 255.0
         | 
| 59 | 
            +
                return rearrange(control, 'f h w c -> f c h w')
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def create_video(frames, fps, rescale=False, path=None):
         | 
| 63 | 
            +
                if path is None:
         | 
| 64 | 
            +
                    dir = "temporal"
         | 
| 65 | 
            +
                    os.makedirs(dir, exist_ok=True)
         | 
| 66 | 
            +
                    path = os.path.join(dir, 'movie.mp4')
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                outputs = []
         | 
| 69 | 
            +
                for i, x in enumerate(frames):
         | 
| 70 | 
            +
                    x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4)
         | 
| 71 | 
            +
                    if rescale:
         | 
| 72 | 
            +
                        x = (x + 1.0) / 2.0  # -1,1 -> 0,1
         | 
| 73 | 
            +
                    x = (x * 255).numpy().astype(np.uint8)
         | 
| 74 | 
            +
                    x = add_watermark(x, im_size=512)
         | 
| 75 | 
            +
                    outputs.append(x)
         | 
| 76 | 
            +
                    # imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                imageio.mimsave(path, outputs, fps=fps)
         | 
| 79 | 
            +
                return path
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            def create_gif(frames, fps, rescale=False):
         | 
| 82 | 
            +
                dir = "temporal"
         | 
| 83 | 
            +
                os.makedirs(dir, exist_ok=True)
         | 
| 84 | 
            +
                path = os.path.join(dir, 'canny_db.gif')
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                outputs = []
         | 
| 87 | 
            +
                for i, x in enumerate(frames):
         | 
| 88 | 
            +
                    x = torchvision.utils.make_grid(torch.Tensor(x), nrow=4)
         | 
| 89 | 
            +
                    if rescale:
         | 
| 90 | 
            +
                        x = (x + 1.0) / 2.0  # -1,1 -> 0,1
         | 
| 91 | 
            +
                    x = (x * 255).numpy().astype(np.uint8)
         | 
| 92 | 
            +
                    x = add_watermark(x, im_size=512)
         | 
| 93 | 
            +
                    outputs.append(x)
         | 
| 94 | 
            +
                    # imageio.imsave(os.path.join(dir, os.path.splitext(name)[0] + f'_{i}.jpg'), x)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                imageio.mimsave(path, outputs, fps=fps)
         | 
| 97 | 
            +
                return path
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            def prepare_video(video_path:str, resolution:int, device, dtype, normalize=True, start_t:float=0, end_t:float=-1, output_fps:int=-1):
         | 
| 100 | 
            +
                vr = decord.VideoReader(video_path)
         | 
| 101 | 
            +
                video = vr.get_batch(range(0, len(vr))).asnumpy()
         | 
| 102 | 
            +
                initial_fps = vr.get_avg_fps()
         | 
| 103 | 
            +
                if output_fps == -1:
         | 
| 104 | 
            +
                    output_fps = int(initial_fps)
         | 
| 105 | 
            +
                if end_t == -1:
         | 
| 106 | 
            +
                    end_t = len(vr) / initial_fps
         | 
| 107 | 
            +
                else:
         | 
| 108 | 
            +
                    end_t = min(len(vr) / initial_fps, end_t)
         | 
| 109 | 
            +
                assert 0 <= start_t < end_t
         | 
| 110 | 
            +
                assert output_fps > 0
         | 
| 111 | 
            +
                f, h, w, c = video.shape
         | 
| 112 | 
            +
                start_f_ind = int(start_t * initial_fps)
         | 
| 113 | 
            +
                end_f_ind = int(end_t * initial_fps)
         | 
| 114 | 
            +
                num_f = int((end_t - start_t) * output_fps)
         | 
| 115 | 
            +
                sample_idx = np.linspace(start_f_ind, end_f_ind, num_f, endpoint=False).astype(int)
         | 
| 116 | 
            +
                video = video[sample_idx]
         | 
| 117 | 
            +
                video = rearrange(video, "f h w c -> f c h w")
         | 
| 118 | 
            +
                video = torch.Tensor(video).to(device).to(dtype)
         | 
| 119 | 
            +
                if h > w:
         | 
| 120 | 
            +
                    w = int(w * resolution / h)
         | 
| 121 | 
            +
                    w = w - w % 8
         | 
| 122 | 
            +
                    h = resolution - resolution % 8
         | 
| 123 | 
            +
                    video = Resize((h, w))(video)
         | 
| 124 | 
            +
                else:
         | 
| 125 | 
            +
                    h = int(h * resolution / w)
         | 
| 126 | 
            +
                    h = h - h % 8
         | 
| 127 | 
            +
                    w = resolution - resolution % 8
         | 
| 128 | 
            +
                    video = Resize((h, w))(video)
         | 
| 129 | 
            +
                if normalize:
         | 
| 130 | 
            +
                    video = video / 127.5 - 1.0
         | 
| 131 | 
            +
                return video, output_fps
         | 
| 132 | 
            +
             | 
| 133 | 
            +
             | 
| 134 | 
            +
            def post_process_gif(list_of_results, image_resolution):
         | 
| 135 | 
            +
                output_file = "/tmp/ddxk.gif"
         | 
| 136 | 
            +
                imageio.mimsave(output_file, list_of_results, fps=4)
         | 
| 137 | 
            +
                return output_file
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            class CrossFrameAttnProcessor:
         | 
| 141 | 
            +
                def __init__(self, unet_chunk_size=2):
         | 
| 142 | 
            +
                    self.unet_chunk_size = unet_chunk_size
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def __call__(
         | 
| 145 | 
            +
                        self,
         | 
| 146 | 
            +
                        attn,
         | 
| 147 | 
            +
                        hidden_states,
         | 
| 148 | 
            +
                        encoder_hidden_states=None,
         | 
| 149 | 
            +
                        attention_mask=None):
         | 
| 150 | 
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         | 
| 151 | 
            +
                    attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
         | 
| 152 | 
            +
                    query = attn.to_q(hidden_states)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    is_cross_attention = encoder_hidden_states is not None
         | 
| 155 | 
            +
                    if encoder_hidden_states is None:
         | 
| 156 | 
            +
                        encoder_hidden_states = hidden_states
         | 
| 157 | 
            +
                    elif attn.cross_attention_norm:
         | 
| 158 | 
            +
                        encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
         | 
| 159 | 
            +
                    key = attn.to_k(encoder_hidden_states)
         | 
| 160 | 
            +
                    value = attn.to_v(encoder_hidden_states)
         | 
| 161 | 
            +
                    # Sparse Attention
         | 
| 162 | 
            +
                    if not is_cross_attention:
         | 
| 163 | 
            +
                        video_length = key.size()[0] // self.unet_chunk_size
         | 
| 164 | 
            +
                        # former_frame_index = torch.arange(video_length) - 1
         | 
| 165 | 
            +
                        # former_frame_index[0] = 0
         | 
| 166 | 
            +
                        former_frame_index = [0] * video_length
         | 
| 167 | 
            +
                        key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
         | 
| 168 | 
            +
                        key = key[:, former_frame_index]
         | 
| 169 | 
            +
                        key = rearrange(key, "b f d c -> (b f) d c")
         | 
| 170 | 
            +
                        value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
         | 
| 171 | 
            +
                        value = value[:, former_frame_index]
         | 
| 172 | 
            +
                        value = rearrange(value, "b f d c -> (b f) d c")
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    query = attn.head_to_batch_dim(query)
         | 
| 175 | 
            +
                    key = attn.head_to_batch_dim(key)
         | 
| 176 | 
            +
                    value = attn.head_to_batch_dim(value)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    attention_probs = attn.get_attention_scores(query, key, attention_mask)
         | 
| 179 | 
            +
                    hidden_states = torch.bmm(attention_probs, value)
         | 
| 180 | 
            +
                    hidden_states = attn.batch_to_head_dim(hidden_states)
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    # linear proj
         | 
| 183 | 
            +
                    hidden_states = attn.to_out[0](hidden_states)
         | 
| 184 | 
            +
                    # dropout
         | 
| 185 | 
            +
                    hidden_states = attn.to_out[1](hidden_states)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    return hidden_states
         | 

