Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		James Zhou
		
	commited on
		
		
					Commit 
							
							·
						
						9867d34
	
1
								Parent(s):
							
							860b27a
								
[init]
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +2 -0
 - LICENSE +77 -0
 - app.py +814 -0
 - assets/data_pipeline.png +3 -0
 - assets/model_arch.png +3 -0
 - assets/pan_chart.png +3 -0
 - configs/hunyuanvideo-foley-xxl.yaml +49 -0
 - examples/1_result.mp4 +3 -0
 - examples/1_video.mp4 +3 -0
 - examples/2_result.mp4 +3 -0
 - examples/2_video.mp4 +3 -0
 - examples/3_result.mp4 +3 -0
 - examples/3_video.mp4 +3 -0
 - examples/4_result.mp4 +3 -0
 - examples/4_video.mp4 +3 -0
 - examples/5_result.mp4 +3 -0
 - examples/5_video.mp4 +3 -0
 - examples/6_result.mp4 +3 -0
 - examples/6_video.mp4 +3 -0
 - examples/7_result.mp4 +3 -0
 - examples/7_video.mp4 +3 -0
 - examples/8_result.mp4 +3 -0
 - examples/8_video.mp4 +3 -0
 - hunyuanvideo_foley/__init__.py +0 -0
 - hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc +0 -0
 - hunyuanvideo_foley/constants.py +57 -0
 - hunyuanvideo_foley/models/__init__.py +0 -0
 - hunyuanvideo_foley/models/__pycache__/mmaudio_layer.cpython-312.pyc +0 -0
 - hunyuanvideo_foley/models/dac_vae/__init__.py +16 -0
 - hunyuanvideo_foley/models/dac_vae/__main__.py +36 -0
 - hunyuanvideo_foley/models/dac_vae/model/__init__.py +4 -0
 - hunyuanvideo_foley/models/dac_vae/model/base.py +301 -0
 - hunyuanvideo_foley/models/dac_vae/model/dac.py +410 -0
 - hunyuanvideo_foley/models/dac_vae/model/discriminator.py +228 -0
 - hunyuanvideo_foley/models/dac_vae/nn/__init__.py +3 -0
 - hunyuanvideo_foley/models/dac_vae/nn/layers.py +33 -0
 - hunyuanvideo_foley/models/dac_vae/nn/loss.py +368 -0
 - hunyuanvideo_foley/models/dac_vae/nn/quantize.py +262 -0
 - hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py +91 -0
 - hunyuanvideo_foley/models/dac_vae/utils/__init__.py +121 -0
 - hunyuanvideo_foley/models/dac_vae/utils/decode.py +95 -0
 - hunyuanvideo_foley/models/dac_vae/utils/encode.py +94 -0
 - hunyuanvideo_foley/models/hifi_foley.py +794 -0
 - hunyuanvideo_foley/models/nn/__init__.py +0 -0
 - hunyuanvideo_foley/models/nn/activation_layers.py +44 -0
 - hunyuanvideo_foley/models/nn/attn_layers.py +546 -0
 - hunyuanvideo_foley/models/nn/embed_layers.py +136 -0
 - hunyuanvideo_foley/models/nn/mlp_layers.py +149 -0
 - hunyuanvideo_foley/models/nn/modulate_layers.py +49 -0
 - hunyuanvideo_foley/models/nn/norm_layers.py +70 -0
 
    	
        .gitattributes
    CHANGED
    
    | 
         @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text 
     | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
|
| 
         | 
| 
         | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
            +
            *.png filter=lfs diff=lfs merge=lfs -text
         
     | 
| 37 | 
         
            +
            *.mp4 filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,77 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
         
     | 
| 2 | 
         
            +
            Tencent HunyuanVideo-Foley Release Date: August 28, 2025
         
     | 
| 3 | 
         
            +
            THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
         
     | 
| 4 | 
         
            +
            By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
         
     | 
| 5 | 
         
            +
            1.	DEFINITIONS.
         
     | 
| 6 | 
         
            +
            a.	“Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
         
     | 
| 7 | 
         
            +
            b.	“Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
         
     | 
| 8 | 
         
            +
            c.	“Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
         
     | 
| 9 | 
         
            +
            d.	“Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
         
     | 
| 10 | 
         
            +
            e.	“Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
         
     | 
| 11 | 
         
            +
            f.	“Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
         
     | 
| 12 | 
         
            +
            g.	“Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
         
     | 
| 13 | 
         
            +
            h.	“Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
         
     | 
| 14 | 
         
            +
            i.	“Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
         
     | 
| 15 | 
         
            +
            j.	“Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanVideo-Foley released at [https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley].
         
     | 
| 16 | 
         
            +
            k.	“Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
         
     | 
| 17 | 
         
            +
            l.	“Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea. 
         
     | 
| 18 | 
         
            +
            m.	“Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
         
     | 
| 19 | 
         
            +
            n.	“including” shall mean including but not limited to.
         
     | 
| 20 | 
         
            +
            2.	GRANT OF RIGHTS.
         
     | 
| 21 | 
         
            +
            We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
         
     | 
| 22 | 
         
            +
            3.	DISTRIBUTION.
         
     | 
| 23 | 
         
            +
            You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
         
     | 
| 24 | 
         
            +
            a.	You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
         
     | 
| 25 | 
         
            +
            b.	You must cause any modified files to carry prominent notices stating that You changed the files;
         
     | 
| 26 | 
         
            +
            c.	You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
         
     | 
| 27 | 
         
            +
            d.	All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
         
     | 
| 28 | 
         
            +
            You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
         
     | 
| 29 | 
         
            +
            4.	ADDITIONAL COMMERCIAL TERMS.
         
     | 
| 30 | 
         
            +
            If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
         
     | 
| 31 | 
         
            +
            5.	RULES OF USE.
         
     | 
| 32 | 
         
            +
            a.	Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
         
     | 
| 33 | 
         
            +
            b.	You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
         
     | 
| 34 | 
         
            +
            c.	You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
         
     | 
| 35 | 
         
            +
            6.	INTELLECTUAL PROPERTY.
         
     | 
| 36 | 
         
            +
            a.	Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
         
     | 
| 37 | 
         
            +
            b.	No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
         
     | 
| 38 | 
         
            +
            c.	If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
         
     | 
| 39 | 
         
            +
            d.	Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
         
     | 
| 40 | 
         
            +
            7.	DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
         
     | 
| 41 | 
         
            +
            a.	We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
         
     | 
| 42 | 
         
            +
            b.	UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
         
     | 
| 43 | 
         
            +
            c.	TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
         
     | 
| 44 | 
         
            +
            8.	SURVIVAL AND TERMINATION.
         
     | 
| 45 | 
         
            +
            a.	The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
         
     | 
| 46 | 
         
            +
            b.	We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
         
     | 
| 47 | 
         
            +
            9.	GOVERNING LAW AND JURISDICTION.
         
     | 
| 48 | 
         
            +
            a.	This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
         
     | 
| 49 | 
         
            +
            b.	Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
         
     | 
| 50 | 
         
            +
             
         
     | 
| 51 | 
         
            +
            EXHIBIT A
         
     | 
| 52 | 
         
            +
            ACCEPTABLE USE POLICY
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            Tencent reserves the right to update this Acceptable Use Policy from time to time.
         
     | 
| 55 | 
         
            +
            Last modified: November 5, 2024
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
         
     | 
| 58 | 
         
            +
            1.	Outside the Territory;
         
     | 
| 59 | 
         
            +
            2.	In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
         
     | 
| 60 | 
         
            +
            3.	To harm Yourself or others;
         
     | 
| 61 | 
         
            +
            4.	To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others; 
         
     | 
| 62 | 
         
            +
            5.	To override or circumvent the safety guardrails and safeguards We have put in place;
         
     | 
| 63 | 
         
            +
            6.	For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
         
     | 
| 64 | 
         
            +
            7.	To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
         
     | 
| 65 | 
         
            +
            8.	To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
         
     | 
| 66 | 
         
            +
            9.	To intentionally defame, disparage or otherwise harass others;
         
     | 
| 67 | 
         
            +
            10.	To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
         
     | 
| 68 | 
         
            +
            11.	To generate or disseminate personal identifiable information with the purpose of harming others;
         
     | 
| 69 | 
         
            +
            12.	To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
         
     | 
| 70 | 
         
            +
            13.	To impersonate another individual without consent, authorization, or legal right;
         
     | 
| 71 | 
         
            +
            14.	To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
         
     | 
| 72 | 
         
            +
            15.	In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
         
     | 
| 73 | 
         
            +
            16.	To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
         
     | 
| 74 | 
         
            +
            17.	For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
         
     | 
| 75 | 
         
            +
            18.	To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
         
     | 
| 76 | 
         
            +
            19.	For military purposes;
         
     | 
| 77 | 
         
            +
            20.	To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,814 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import tempfile
         
     | 
| 3 | 
         
            +
            import gradio as gr
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torchaudio
         
     | 
| 6 | 
         
            +
            from loguru import logger
         
     | 
| 7 | 
         
            +
            from typing import Optional, Tuple
         
     | 
| 8 | 
         
            +
            import random
         
     | 
| 9 | 
         
            +
            import numpy as np
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from hunyuanvideo_foley.utils.model_utils import load_model
         
     | 
| 12 | 
         
            +
            from hunyuanvideo_foley.utils.feature_utils import feature_process
         
     | 
| 13 | 
         
            +
            from hunyuanvideo_foley.utils.model_utils import denoise_process
         
     | 
| 14 | 
         
            +
            from hunyuanvideo_foley.utils.media_utils import merge_audio_video
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            # Global variables for model storage
         
     | 
| 17 | 
         
            +
            model_dict = None
         
     | 
| 18 | 
         
            +
            cfg = None
         
     | 
| 19 | 
         
            +
            device = None
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            # need to modify the model path
         
     | 
| 22 | 
         
            +
            MODEL_PATH = os.environ.get("HIFI_FOLEY_MODEL_PATH", "./pretrained_models/")
         
     | 
| 23 | 
         
            +
            CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def setup_device(device_str: str = "auto", gpu_id: int = 0) -> torch.device:
         
     | 
| 26 | 
         
            +
                """Setup computing device"""
         
     | 
| 27 | 
         
            +
                if device_str == "auto":
         
     | 
| 28 | 
         
            +
                    if torch.cuda.is_available():
         
     | 
| 29 | 
         
            +
                        device = torch.device(f"cuda:{gpu_id}")
         
     | 
| 30 | 
         
            +
                        logger.info(f"Using CUDA device: {device}")
         
     | 
| 31 | 
         
            +
                    elif torch.backends.mps.is_available():
         
     | 
| 32 | 
         
            +
                        device = torch.device("mps")
         
     | 
| 33 | 
         
            +
                        logger.info("Using MPS device")
         
     | 
| 34 | 
         
            +
                    else:
         
     | 
| 35 | 
         
            +
                        device = torch.device("cpu")
         
     | 
| 36 | 
         
            +
                        logger.info("Using CPU device")
         
     | 
| 37 | 
         
            +
                else:
         
     | 
| 38 | 
         
            +
                    if device_str == "cuda":
         
     | 
| 39 | 
         
            +
                        device = torch.device(f"cuda:{gpu_id}")
         
     | 
| 40 | 
         
            +
                    else:
         
     | 
| 41 | 
         
            +
                        device = torch.device(device_str)
         
     | 
| 42 | 
         
            +
                    logger.info(f"Using specified device: {device}")
         
     | 
| 43 | 
         
            +
                
         
     | 
| 44 | 
         
            +
                return device
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            def auto_load_models() -> str:
         
     | 
| 47 | 
         
            +
                """Automatically load preset models"""
         
     | 
| 48 | 
         
            +
                global model_dict, cfg, device
         
     | 
| 49 | 
         
            +
                
         
     | 
| 50 | 
         
            +
                try:
         
     | 
| 51 | 
         
            +
                    if not os.path.exists(MODEL_PATH):
         
     | 
| 52 | 
         
            +
                        return f"❌ Model file not found: {MODEL_PATH}"
         
     | 
| 53 | 
         
            +
                    if not os.path.exists(CONFIG_PATH):
         
     | 
| 54 | 
         
            +
                        return f"❌ Config file not found: {CONFIG_PATH}"
         
     | 
| 55 | 
         
            +
                    
         
     | 
| 56 | 
         
            +
                    # Use GPU by default
         
     | 
| 57 | 
         
            +
                    device = setup_device("auto", 0)
         
     | 
| 58 | 
         
            +
                    
         
     | 
| 59 | 
         
            +
                    # Load model
         
     | 
| 60 | 
         
            +
                    logger.info("Auto-loading model...")
         
     | 
| 61 | 
         
            +
                    logger.info(f"Model path: {MODEL_PATH}")
         
     | 
| 62 | 
         
            +
                    logger.info(f"Config path: {CONFIG_PATH}")
         
     | 
| 63 | 
         
            +
                    
         
     | 
| 64 | 
         
            +
                    model_dict, cfg = load_model(MODEL_PATH, CONFIG_PATH, device)
         
     | 
| 65 | 
         
            +
                    
         
     | 
| 66 | 
         
            +
                    logger.info("✅ Model loaded successfully!")
         
     | 
| 67 | 
         
            +
                    return "✅ Model loaded successfully!"
         
     | 
| 68 | 
         
            +
                    
         
     | 
| 69 | 
         
            +
                except Exception as e:
         
     | 
| 70 | 
         
            +
                    logger.error(f"Model loading failed: {str(e)}")
         
     | 
| 71 | 
         
            +
                    return f"❌ Model loading failed: {str(e)}"
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            def infer_single_video(
         
     | 
| 74 | 
         
            +
                video_file, 
         
     | 
| 75 | 
         
            +
                text_prompt: str, 
         
     | 
| 76 | 
         
            +
                guidance_scale: float = 4.5, 
         
     | 
| 77 | 
         
            +
                num_inference_steps: int = 50,
         
     | 
| 78 | 
         
            +
                sample_nums: int = 1
         
     | 
| 79 | 
         
            +
            ) -> Tuple[list, str]:
         
     | 
| 80 | 
         
            +
                """Single video inference"""
         
     | 
| 81 | 
         
            +
                global model_dict, cfg, device
         
     | 
| 82 | 
         
            +
                
         
     | 
| 83 | 
         
            +
                if model_dict is None or cfg is None:
         
     | 
| 84 | 
         
            +
                    return [], "❌ Please load the model first!"
         
     | 
| 85 | 
         
            +
                
         
     | 
| 86 | 
         
            +
                if video_file is None:
         
     | 
| 87 | 
         
            +
                    return [], "❌ Please upload a video file!"
         
     | 
| 88 | 
         
            +
                
         
     | 
| 89 | 
         
            +
                # Allow empty text prompt, use empty string if no prompt provided
         
     | 
| 90 | 
         
            +
                if text_prompt is None:
         
     | 
| 91 | 
         
            +
                    text_prompt = ""
         
     | 
| 92 | 
         
            +
                text_prompt = text_prompt.strip()
         
     | 
| 93 | 
         
            +
                
         
     | 
| 94 | 
         
            +
                try:
         
     | 
| 95 | 
         
            +
                    logger.info(f"Processing video: {video_file}")
         
     | 
| 96 | 
         
            +
                    logger.info(f"Text prompt: {text_prompt}")
         
     | 
| 97 | 
         
            +
                    
         
     | 
| 98 | 
         
            +
                    # Feature processing
         
     | 
| 99 | 
         
            +
                    visual_feats, text_feats, audio_len_in_s = feature_process(
         
     | 
| 100 | 
         
            +
                        video_file,
         
     | 
| 101 | 
         
            +
                        text_prompt,
         
     | 
| 102 | 
         
            +
                        model_dict,
         
     | 
| 103 | 
         
            +
                        cfg
         
     | 
| 104 | 
         
            +
                    )
         
     | 
| 105 | 
         
            +
                    
         
     | 
| 106 | 
         
            +
                    # Denoising process to generate multiple audio samples
         
     | 
| 107 | 
         
            +
                    # Note: The model now generates sample_nums audio samples per inference
         
     | 
| 108 | 
         
            +
                    # The denoise_process function returns audio with shape [batch_size, channels, samples]
         
     | 
| 109 | 
         
            +
                    logger.info(f"Generating {sample_nums} audio samples...")
         
     | 
| 110 | 
         
            +
                    audio, sample_rate = denoise_process(
         
     | 
| 111 | 
         
            +
                        visual_feats,
         
     | 
| 112 | 
         
            +
                        text_feats,
         
     | 
| 113 | 
         
            +
                        audio_len_in_s,
         
     | 
| 114 | 
         
            +
                        model_dict,
         
     | 
| 115 | 
         
            +
                        cfg,
         
     | 
| 116 | 
         
            +
                        guidance_scale=guidance_scale,
         
     | 
| 117 | 
         
            +
                        num_inference_steps=num_inference_steps,
         
     | 
| 118 | 
         
            +
                        batch_size=sample_nums
         
     | 
| 119 | 
         
            +
                    )
         
     | 
| 120 | 
         
            +
                    
         
     | 
| 121 | 
         
            +
                    # Create temporary files to save results
         
     | 
| 122 | 
         
            +
                    temp_dir = tempfile.mkdtemp()
         
     | 
| 123 | 
         
            +
                    video_outputs = []
         
     | 
| 124 | 
         
            +
                    
         
     | 
| 125 | 
         
            +
                    # Process each generated audio sample
         
     | 
| 126 | 
         
            +
                    for i in range(sample_nums):
         
     | 
| 127 | 
         
            +
                        # Save audio file
         
     | 
| 128 | 
         
            +
                        audio_output = os.path.join(temp_dir, f"generated_audio_{i+1}.wav")
         
     | 
| 129 | 
         
            +
                        torchaudio.save(audio_output, audio[i], sample_rate)
         
     | 
| 130 | 
         
            +
                        
         
     | 
| 131 | 
         
            +
                        # Merge video and audio
         
     | 
| 132 | 
         
            +
                        video_output = os.path.join(temp_dir, f"video_with_audio_{i+1}.mp4")
         
     | 
| 133 | 
         
            +
                        merge_audio_video(audio_output, video_file, video_output)
         
     | 
| 134 | 
         
            +
                        video_outputs.append(video_output)
         
     | 
| 135 | 
         
            +
                    
         
     | 
| 136 | 
         
            +
                    logger.info(f"Inference completed! Generated {sample_nums} samples.")
         
     | 
| 137 | 
         
            +
                    return video_outputs, f"✅ Generated {sample_nums} audio sample(s) successfully!"
         
     | 
| 138 | 
         
            +
                    
         
     | 
| 139 | 
         
            +
                except Exception as e:
         
     | 
| 140 | 
         
            +
                    logger.error(f"Inference failed: {str(e)}")
         
     | 
| 141 | 
         
            +
                    return [], f"❌ Inference failed: {str(e)}"
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            def update_video_outputs(video_list, status_msg):
         
     | 
| 144 | 
         
            +
                """Update video outputs based on the number of generated samples"""
         
     | 
| 145 | 
         
            +
                # Initialize all outputs as None
         
     | 
| 146 | 
         
            +
                outputs = [None] * 6
         
     | 
| 147 | 
         
            +
                
         
     | 
| 148 | 
         
            +
                # Set values based on generated videos
         
     | 
| 149 | 
         
            +
                for i, video_path in enumerate(video_list[:6]):  # Max 6 samples
         
     | 
| 150 | 
         
            +
                    outputs[i] = video_path
         
     | 
| 151 | 
         
            +
                
         
     | 
| 152 | 
         
            +
                # Return all outputs plus status message
         
     | 
| 153 | 
         
            +
                return tuple(outputs + [status_msg])
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
            def create_gradio_interface():
         
     | 
| 156 | 
         
            +
                """Create Gradio interface"""
         
     | 
| 157 | 
         
            +
                
         
     | 
| 158 | 
         
            +
                # Custom CSS for beautiful interface with better contrast
         
     | 
| 159 | 
         
            +
                css = """
         
     | 
| 160 | 
         
            +
                .gradio-container {
         
     | 
| 161 | 
         
            +
                    font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
         
     | 
| 162 | 
         
            +
                    background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
         
     | 
| 163 | 
         
            +
                    min-height: 100vh;
         
     | 
| 164 | 
         
            +
                }
         
     | 
| 165 | 
         
            +
                
         
     | 
| 166 | 
         
            +
                .main-header {
         
     | 
| 167 | 
         
            +
                    text-align: center;
         
     | 
| 168 | 
         
            +
                    padding: 2rem 0;
         
     | 
| 169 | 
         
            +
                    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
         
     | 
| 170 | 
         
            +
                    border-radius: 20px;
         
     | 
| 171 | 
         
            +
                    margin-bottom: 2rem;
         
     | 
| 172 | 
         
            +
                    box-shadow: 0 8px 32px rgba(0,0,0,0.15);
         
     | 
| 173 | 
         
            +
                }
         
     | 
| 174 | 
         
            +
                
         
     | 
| 175 | 
         
            +
                .main-header h1 {
         
     | 
| 176 | 
         
            +
                    color: white;
         
     | 
| 177 | 
         
            +
                    font-size: 3rem;
         
     | 
| 178 | 
         
            +
                    font-weight: 700;
         
     | 
| 179 | 
         
            +
                    margin-bottom: 0.5rem;
         
     | 
| 180 | 
         
            +
                    text-shadow: 0 2px 10px rgba(0,0,0,0.3);
         
     | 
| 181 | 
         
            +
                }
         
     | 
| 182 | 
         
            +
                
         
     | 
| 183 | 
         
            +
                .main-header p {
         
     | 
| 184 | 
         
            +
                    color: rgba(255, 255, 255, 0.95);
         
     | 
| 185 | 
         
            +
                    font-size: 1.2rem;
         
     | 
| 186 | 
         
            +
                    font-weight: 300;
         
     | 
| 187 | 
         
            +
                }
         
     | 
| 188 | 
         
            +
                
         
     | 
| 189 | 
         
            +
                .status-card {
         
     | 
| 190 | 
         
            +
                    background: white;
         
     | 
| 191 | 
         
            +
                    border-radius: 15px;
         
     | 
| 192 | 
         
            +
                    padding: 1rem;
         
     | 
| 193 | 
         
            +
                    margin-bottom: 1.5rem;
         
     | 
| 194 | 
         
            +
                    border: 1px solid #e1e5e9;
         
     | 
| 195 | 
         
            +
                    box-shadow: 0 4px 20px rgba(0,0,0,0.08);
         
     | 
| 196 | 
         
            +
                }
         
     | 
| 197 | 
         
            +
                
         
     | 
| 198 | 
         
            +
                .status-card label {
         
     | 
| 199 | 
         
            +
                    color: #2d3748 !important;
         
     | 
| 200 | 
         
            +
                    font-weight: 600 !important;
         
     | 
| 201 | 
         
            +
                }
         
     | 
| 202 | 
         
            +
                
         
     | 
| 203 | 
         
            +
                .usage-guide h3 {
         
     | 
| 204 | 
         
            +
                    color: #2d3748 !important;
         
     | 
| 205 | 
         
            +
                    font-weight: 600 !important;
         
     | 
| 206 | 
         
            +
                    margin-bottom: 0.5rem !important;
         
     | 
| 207 | 
         
            +
                }
         
     | 
| 208 | 
         
            +
                
         
     | 
| 209 | 
         
            +
                .usage-guide p {
         
     | 
| 210 | 
         
            +
                    color: #4a5568 !important;
         
     | 
| 211 | 
         
            +
                    font-size: 1rem !important;
         
     | 
| 212 | 
         
            +
                    line-height: 1.6 !important;
         
     | 
| 213 | 
         
            +
                    margin: 0.5rem 0 !important;
         
     | 
| 214 | 
         
            +
                }
         
     | 
| 215 | 
         
            +
                
         
     | 
| 216 | 
         
            +
                .usage-guide strong {
         
     | 
| 217 | 
         
            +
                    color: #1a202c !important;
         
     | 
| 218 | 
         
            +
                    font-weight: 700 !important;
         
     | 
| 219 | 
         
            +
                }
         
     | 
| 220 | 
         
            +
                
         
     | 
| 221 | 
         
            +
                .usage-guide em {
         
     | 
| 222 | 
         
            +
                    color: #1a202c !important;
         
     | 
| 223 | 
         
            +
                    font-weight: 700 !important;
         
     | 
| 224 | 
         
            +
                    font-style: normal !important;
         
     | 
| 225 | 
         
            +
                }
         
     | 
| 226 | 
         
            +
                
         
     | 
| 227 | 
         
            +
                .main-interface {
         
     | 
| 228 | 
         
            +
                    margin-bottom: 2rem;
         
     | 
| 229 | 
         
            +
                }
         
     | 
| 230 | 
         
            +
                
         
     | 
| 231 | 
         
            +
                .input-section {
         
     | 
| 232 | 
         
            +
                    background: white;
         
     | 
| 233 | 
         
            +
                    border-radius: 20px;
         
     | 
| 234 | 
         
            +
                    padding: 2rem;
         
     | 
| 235 | 
         
            +
                    margin-right: 1rem;
         
     | 
| 236 | 
         
            +
                    box-shadow: 0 8px 32px rgba(0,0,0,0.1);
         
     | 
| 237 | 
         
            +
                    border: 1px solid #e1e5e9;
         
     | 
| 238 | 
         
            +
                }
         
     | 
| 239 | 
         
            +
                
         
     | 
| 240 | 
         
            +
                .input-section h3 {
         
     | 
| 241 | 
         
            +
                    color: #2d3748 !important;
         
     | 
| 242 | 
         
            +
                    font-weight: 600 !important;
         
     | 
| 243 | 
         
            +
                    margin-bottom: 1rem !important;
         
     | 
| 244 | 
         
            +
                }
         
     | 
| 245 | 
         
            +
                
         
     | 
| 246 | 
         
            +
                .input-section label {
         
     | 
| 247 | 
         
            +
                    color: #4a5568 !important;
         
     | 
| 248 | 
         
            +
                    font-weight: 500 !important;
         
     | 
| 249 | 
         
            +
                }
         
     | 
| 250 | 
         
            +
                
         
     | 
| 251 | 
         
            +
                .output-section {
         
     | 
| 252 | 
         
            +
                    background: white;
         
     | 
| 253 | 
         
            +
                    border-radius: 20px;
         
     | 
| 254 | 
         
            +
                    padding: 2rem;
         
     | 
| 255 | 
         
            +
                    margin-left: 1rem;
         
     | 
| 256 | 
         
            +
                    box-shadow: 0 8px 32px rgba(0,0,0,0.1);
         
     | 
| 257 | 
         
            +
                    border: 1px solid #e1e5e9;
         
     | 
| 258 | 
         
            +
                }
         
     | 
| 259 | 
         
            +
                
         
     | 
| 260 | 
         
            +
                .output-section h3 {
         
     | 
| 261 | 
         
            +
                    color: #2d3748 !important;
         
     | 
| 262 | 
         
            +
                    font-weight: 600 !important;
         
     | 
| 263 | 
         
            +
                    margin-bottom: 1rem !important;
         
     | 
| 264 | 
         
            +
                }
         
     | 
| 265 | 
         
            +
                
         
     | 
| 266 | 
         
            +
                .output-section label {
         
     | 
| 267 | 
         
            +
                    color: #4a5568 !important;
         
     | 
| 268 | 
         
            +
                    font-weight: 500 !important;
         
     | 
| 269 | 
         
            +
                }
         
     | 
| 270 | 
         
            +
                
         
     | 
| 271 | 
         
            +
                .examples-section h3 {
         
     | 
| 272 | 
         
            +
                    color: #2d3748 !important;
         
     | 
| 273 | 
         
            +
                    font-weight: 600 !important;
         
     | 
| 274 | 
         
            +
                    margin-bottom: 1.5rem !important;
         
     | 
| 275 | 
         
            +
                }
         
     | 
| 276 | 
         
            +
                
         
     | 
| 277 | 
         
            +
                .generate-btn {
         
     | 
| 278 | 
         
            +
                    background: linear-gradient(45deg, #667eea, #764ba2) !important;
         
     | 
| 279 | 
         
            +
                    border: none !important;
         
     | 
| 280 | 
         
            +
                    color: white !important;
         
     | 
| 281 | 
         
            +
                    font-weight: 600 !important;
         
     | 
| 282 | 
         
            +
                    font-size: 1.1rem !important;
         
     | 
| 283 | 
         
            +
                    padding: 12px 30px !important;
         
     | 
| 284 | 
         
            +
                    border-radius: 25px !important;
         
     | 
| 285 | 
         
            +
                    box-shadow: 0 4px 15px rgba(102, 126, 234, 0.4) !important;
         
     | 
| 286 | 
         
            +
                    transition: all 0.3s ease !important;
         
     | 
| 287 | 
         
            +
                }
         
     | 
| 288 | 
         
            +
                
         
     | 
| 289 | 
         
            +
                .generate-btn:hover {
         
     | 
| 290 | 
         
            +
                    transform: translateY(-2px) !important;
         
     | 
| 291 | 
         
            +
                    box-shadow: 0 8px 25px rgba(102, 126, 234, 0.6) !important;
         
     | 
| 292 | 
         
            +
                }
         
     | 
| 293 | 
         
            +
                
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                
         
     | 
| 296 | 
         
            +
                .examples-section {
         
     | 
| 297 | 
         
            +
                    background: white;
         
     | 
| 298 | 
         
            +
                    border-radius: 20px;
         
     | 
| 299 | 
         
            +
                    padding: 2rem;
         
     | 
| 300 | 
         
            +
                    margin-top: 2rem;
         
     | 
| 301 | 
         
            +
                    box-shadow: 0 8px 32px rgba(0,0,0,0.1);
         
     | 
| 302 | 
         
            +
                    border: 1px solid #e1e5e9;
         
     | 
| 303 | 
         
            +
                }
         
     | 
| 304 | 
         
            +
                
         
     | 
| 305 | 
         
            +
                .examples-section p {
         
     | 
| 306 | 
         
            +
                    color: #4a5568 !important;
         
     | 
| 307 | 
         
            +
                    margin-bottom: 1rem !important;
         
     | 
| 308 | 
         
            +
                }
         
     | 
| 309 | 
         
            +
                
         
     | 
| 310 | 
         
            +
                .example-row {
         
     | 
| 311 | 
         
            +
                    background: #f8fafc;
         
     | 
| 312 | 
         
            +
                    border: 1px solid #e2e8f0;
         
     | 
| 313 | 
         
            +
                    border-radius: 15px;
         
     | 
| 314 | 
         
            +
                    padding: 1.5rem;
         
     | 
| 315 | 
         
            +
                    margin: 1rem 0;
         
     | 
| 316 | 
         
            +
                    transition: all 0.3s ease;
         
     | 
| 317 | 
         
            +
                    align-items: center;
         
     | 
| 318 | 
         
            +
                }
         
     | 
| 319 | 
         
            +
                
         
     | 
| 320 | 
         
            +
                .example-row:hover {
         
     | 
| 321 | 
         
            +
                    border-color: #667eea;
         
     | 
| 322 | 
         
            +
                    transform: translateY(-2px);
         
     | 
| 323 | 
         
            +
                    box-shadow: 0 4px 20px rgba(102, 126, 234, 0.15);
         
     | 
| 324 | 
         
            +
                }
         
     | 
| 325 | 
         
            +
                
         
     | 
| 326 | 
         
            +
                .example-row .markdown {
         
     | 
| 327 | 
         
            +
                    color: #2d3748 !important;
         
     | 
| 328 | 
         
            +
                }
         
     | 
| 329 | 
         
            +
                
         
     | 
| 330 | 
         
            +
                .example-row .markdown p {
         
     | 
| 331 | 
         
            +
                    color: #2d3748 !important;
         
     | 
| 332 | 
         
            +
                    margin: 0.5rem 0 !important;
         
     | 
| 333 | 
         
            +
                    line-height: 1.5 !important;
         
     | 
| 334 | 
         
            +
                }
         
     | 
| 335 | 
         
            +
                
         
     | 
| 336 | 
         
            +
                .example-row .markdown strong {
         
     | 
| 337 | 
         
            +
                    color: #1a202c !important;
         
     | 
| 338 | 
         
            +
                    font-weight: 600 !important;
         
     | 
| 339 | 
         
            +
                }
         
     | 
| 340 | 
         
            +
                
         
     | 
| 341 | 
         
            +
                /* Example grid layout styles */
         
     | 
| 342 | 
         
            +
                .example-grid-row {
         
     | 
| 343 | 
         
            +
                    margin: 1rem 0;
         
     | 
| 344 | 
         
            +
                    gap: 1rem;
         
     | 
| 345 | 
         
            +
                }
         
     | 
| 346 | 
         
            +
                
         
     | 
| 347 | 
         
            +
                .example-item {
         
     | 
| 348 | 
         
            +
                    background: #f8fafc;
         
     | 
| 349 | 
         
            +
                    border: 1px solid #e2e8f0;
         
     | 
| 350 | 
         
            +
                    border-radius: 15px;
         
     | 
| 351 | 
         
            +
                    padding: 1rem;
         
     | 
| 352 | 
         
            +
                    transition: all 0.3s ease;
         
     | 
| 353 | 
         
            +
                    margin: 0.25rem;
         
     | 
| 354 | 
         
            +
                    max-width: 250px;
         
     | 
| 355 | 
         
            +
                    margin-left: auto;
         
     | 
| 356 | 
         
            +
                    margin-right: auto;
         
     | 
| 357 | 
         
            +
                }
         
     | 
| 358 | 
         
            +
                
         
     | 
| 359 | 
         
            +
                .example-item:hover {
         
     | 
| 360 | 
         
            +
                    border-color: #667eea;
         
     | 
| 361 | 
         
            +
                    transform: translateY(-2px);
         
     | 
| 362 | 
         
            +
                    box-shadow: 0 4px 20px rgba(102, 126, 234, 0.15);
         
     | 
| 363 | 
         
            +
                }
         
     | 
| 364 | 
         
            +
                
         
     | 
| 365 | 
         
            +
                .example-caption {
         
     | 
| 366 | 
         
            +
                    margin: 0.5rem 0 !important;
         
     | 
| 367 | 
         
            +
                    min-height: 2.8rem !important;
         
     | 
| 368 | 
         
            +
                    display: flex !important;
         
     | 
| 369 | 
         
            +
                    align-items: flex-start !important;
         
     | 
| 370 | 
         
            +
                }
         
     | 
| 371 | 
         
            +
                
         
     | 
| 372 | 
         
            +
                .example-caption p {
         
     | 
| 373 | 
         
            +
                    color: #2d3748 !important;
         
     | 
| 374 | 
         
            +
                    font-size: 0.9rem !important;
         
     | 
| 375 | 
         
            +
                    line-height: 1.4 !important;
         
     | 
| 376 | 
         
            +
                    margin: 0.5rem 0 !important;
         
     | 
| 377 | 
         
            +
                }
         
     | 
| 378 | 
         
            +
                
         
     | 
| 379 | 
         
            +
                /* Multi-video gallery styles */
         
     | 
| 380 | 
         
            +
                .additional-samples {
         
     | 
| 381 | 
         
            +
                    margin-top: 1rem;
         
     | 
| 382 | 
         
            +
                    gap: 0.5rem;
         
     | 
| 383 | 
         
            +
                }
         
     | 
| 384 | 
         
            +
                
         
     | 
| 385 | 
         
            +
                .additional-samples .gradio-video {
         
     | 
| 386 | 
         
            +
                    border-radius: 10px;
         
     | 
| 387 | 
         
            +
                    overflow: hidden;
         
     | 
| 388 | 
         
            +
                }
         
     | 
| 389 | 
         
            +
                
         
     | 
| 390 | 
         
            +
                /* Video gallery responsive layout */
         
     | 
| 391 | 
         
            +
                .video-gallery {
         
     | 
| 392 | 
         
            +
                    display: grid;
         
     | 
| 393 | 
         
            +
                    gap: 1rem;
         
     | 
| 394 | 
         
            +
                    margin-top: 1rem;
         
     | 
| 395 | 
         
            +
                }
         
     | 
| 396 | 
         
            +
                
         
     | 
| 397 | 
         
            +
                .video-gallery.single {
         
     | 
| 398 | 
         
            +
                    grid-template-columns: 1fr;
         
     | 
| 399 | 
         
            +
                }
         
     | 
| 400 | 
         
            +
                
         
     | 
| 401 | 
         
            +
                .video-gallery.dual {
         
     | 
| 402 | 
         
            +
                    grid-template-columns: 1fr 1fr;
         
     | 
| 403 | 
         
            +
                }
         
     | 
| 404 | 
         
            +
                
         
     | 
| 405 | 
         
            +
                .video-gallery.multi {
         
     | 
| 406 | 
         
            +
                    grid-template-columns: repeat(2, 1fr);
         
     | 
| 407 | 
         
            +
                    grid-template-rows: auto auto auto;
         
     | 
| 408 | 
         
            +
                }
         
     | 
| 409 | 
         
            +
                
         
     | 
| 410 | 
         
            +
                .footer-text {
         
     | 
| 411 | 
         
            +
                    color: #718096 !important;
         
     | 
| 412 | 
         
            +
                    text-align: center;
         
     | 
| 413 | 
         
            +
                    padding: 2rem;
         
     | 
| 414 | 
         
            +
                    font-size: 0.9rem;
         
     | 
| 415 | 
         
            +
                }
         
     | 
| 416 | 
         
            +
                
         
     | 
| 417 | 
         
            +
                /* Video component styling for consistent size */
         
     | 
| 418 | 
         
            +
                .input-section video,
         
     | 
| 419 | 
         
            +
                .output-section video,
         
     | 
| 420 | 
         
            +
                .example-row video {
         
     | 
| 421 | 
         
            +
                    width: 100% !important;
         
     | 
| 422 | 
         
            +
                    height: 300px !important;
         
     | 
| 423 | 
         
            +
                    object-fit: contain !important;
         
     | 
| 424 | 
         
            +
                    border-radius: 10px !important;
         
     | 
| 425 | 
         
            +
                    background-color: #000 !important;
         
     | 
| 426 | 
         
            +
                }
         
     | 
| 427 | 
         
            +
                
         
     | 
| 428 | 
         
            +
                .example-row video {
         
     | 
| 429 | 
         
            +
                    height: 150px !important;
         
     | 
| 430 | 
         
            +
                }
         
     | 
| 431 | 
         
            +
                
         
     | 
| 432 | 
         
            +
                /* Fix for additional samples video display */
         
     | 
| 433 | 
         
            +
                .additional-samples video {
         
     | 
| 434 | 
         
            +
                    height: 150px !important;
         
     | 
| 435 | 
         
            +
                    object-fit: contain !important;
         
     | 
| 436 | 
         
            +
                    border-radius: 10px !important;
         
     | 
| 437 | 
         
            +
                    background-color: #000 !important;
         
     | 
| 438 | 
         
            +
                }
         
     | 
| 439 | 
         
            +
                
         
     | 
| 440 | 
         
            +
                .additional-samples .gradio-video {
         
     | 
| 441 | 
         
            +
                    border-radius: 10px !important;
         
     | 
| 442 | 
         
            +
                    overflow: hidden !important;
         
     | 
| 443 | 
         
            +
                    background-color: #000 !important;
         
     | 
| 444 | 
         
            +
                }
         
     | 
| 445 | 
         
            +
                
         
     | 
| 446 | 
         
            +
                .additional-samples .gradio-video > div {
         
     | 
| 447 | 
         
            +
                    background-color: #000 !important;
         
     | 
| 448 | 
         
            +
                    border-radius: 10px !important;
         
     | 
| 449 | 
         
            +
                }
         
     | 
| 450 | 
         
            +
                
         
     | 
| 451 | 
         
            +
                /* Video container styling */
         
     | 
| 452 | 
         
            +
                .input-section .video-container,
         
     | 
| 453 | 
         
            +
                .output-section .video-container,
         
     | 
| 454 | 
         
            +
                .example-row .video-container {
         
     | 
| 455 | 
         
            +
                    background-color: #000 !important;
         
     | 
| 456 | 
         
            +
                    border-radius: 10px !important;
         
     | 
| 457 | 
         
            +
                    display: flex !important;
         
     | 
| 458 | 
         
            +
                    align-items: center !important;
         
     | 
| 459 | 
         
            +
                    justify-content: center !important;
         
     | 
| 460 | 
         
            +
                    overflow: hidden !important;
         
     | 
| 461 | 
         
            +
                }
         
     | 
| 462 | 
         
            +
                
         
     | 
| 463 | 
         
            +
                /* Ensure proper alignment */
         
     | 
| 464 | 
         
            +
                .example-row {
         
     | 
| 465 | 
         
            +
                    display: flex !important;
         
     | 
| 466 | 
         
            +
                    align-items: stretch !important;
         
     | 
| 467 | 
         
            +
                }
         
     | 
| 468 | 
         
            +
                
         
     | 
| 469 | 
         
            +
                .example-row > div {
         
     | 
| 470 | 
         
            +
                    display: flex !important;
         
     | 
| 471 | 
         
            +
                    flex-direction: column !important;
         
     | 
| 472 | 
         
            +
                    justify-content: center !important;
         
     | 
| 473 | 
         
            +
                }
         
     | 
| 474 | 
         
            +
                
         
     | 
| 475 | 
         
            +
                /* Video wrapper for better control */
         
     | 
| 476 | 
         
            +
                .video-wrapper {
         
     | 
| 477 | 
         
            +
                    position: relative !important;
         
     | 
| 478 | 
         
            +
                    width: 100% !important;
         
     | 
| 479 | 
         
            +
                    background: #000 !important;
         
     | 
| 480 | 
         
            +
                    border-radius: 10px !important;
         
     | 
| 481 | 
         
            +
                    overflow: hidden !important;
         
     | 
| 482 | 
         
            +
                    display: flex !important;
         
     | 
| 483 | 
         
            +
                    align-items: center !important;
         
     | 
| 484 | 
         
            +
                    justify-content: center !important;
         
     | 
| 485 | 
         
            +
                }
         
     | 
| 486 | 
         
            +
                """
         
     | 
| 487 | 
         
            +
                
         
     | 
| 488 | 
         
            +
                with gr.Blocks(css=css, title="HunyuanVideo-Foley") as app:
         
     | 
| 489 | 
         
            +
                    
         
     | 
| 490 | 
         
            +
                    # Main header
         
     | 
| 491 | 
         
            +
                    with gr.Column(elem_classes=["main-header"]):
         
     | 
| 492 | 
         
            +
                        gr.HTML("""
         
     | 
| 493 | 
         
            +
                        <h1>🎵 HunyuanVideo-Foley</h1>
         
     | 
| 494 | 
         
            +
                        <p>Text-Video-to-Audio Synthesis: Generate realistic audio from video and text descriptions</p>
         
     | 
| 495 | 
         
            +
                        """)
         
     | 
| 496 | 
         
            +
                    
         
     | 
| 497 | 
         
            +
                    # Usage Guide
         
     | 
| 498 | 
         
            +
                    with gr.Column(elem_classes=["status-card"]):
         
     | 
| 499 | 
         
            +
                        gr.Markdown("""
         
     | 
| 500 | 
         
            +
                        ### 📋 Quick Start Guide
         
     | 
| 501 | 
         
            +
                        **1.** Upload your video file\t**2.** Add optional text description\t**3.** Adjust sample numbers (1-6)\t**4.** Click Generate Audio
         
     | 
| 502 | 
         
            +
                        
         
     | 
| 503 | 
         
            +
                        💡 For quick start, you can load the prepared examples by clicking the button.
         
     | 
| 504 | 
         
            +
                        """, elem_classes=["usage-guide"])
         
     | 
| 505 | 
         
            +
                    
         
     | 
| 506 | 
         
            +
                    # Main inference interface - Input and Results side by side
         
     | 
| 507 | 
         
            +
                    with gr.Row(elem_classes=["main-interface"]):
         
     | 
| 508 | 
         
            +
                        # Input section
         
     | 
| 509 | 
         
            +
                        with gr.Column(scale=1, elem_classes=["input-section"]):
         
     | 
| 510 | 
         
            +
                            gr.Markdown("### 📹 Video Input")
         
     | 
| 511 | 
         
            +
                            
         
     | 
| 512 | 
         
            +
                            video_input = gr.Video(
         
     | 
| 513 | 
         
            +
                                label="Upload Video",
         
     | 
| 514 | 
         
            +
                                info="Supported formats: MP4, AVI, MOV, etc.",
         
     | 
| 515 | 
         
            +
                                height=300
         
     | 
| 516 | 
         
            +
                            )
         
     | 
| 517 | 
         
            +
                            
         
     | 
| 518 | 
         
            +
                            text_input = gr.Textbox(
         
     | 
| 519 | 
         
            +
                                label="🎯 Audio Description (English)",
         
     | 
| 520 | 
         
            +
                                placeholder="A person walks on frozen ice",
         
     | 
| 521 | 
         
            +
                                lines=3,
         
     | 
| 522 | 
         
            +
                                info="Describe the audio you want to generate (optional)"
         
     | 
| 523 | 
         
            +
                            )
         
     | 
| 524 | 
         
            +
                            
         
     | 
| 525 | 
         
            +
                            with gr.Row():
         
     | 
| 526 | 
         
            +
                                guidance_scale = gr.Slider(
         
     | 
| 527 | 
         
            +
                                    minimum=1.0,
         
     | 
| 528 | 
         
            +
                                    maximum=10.0,
         
     | 
| 529 | 
         
            +
                                    value=4.5,
         
     | 
| 530 | 
         
            +
                                    step=0.1,
         
     | 
| 531 | 
         
            +
                                    label="🎚️ CFG Scale",
         
     | 
| 532 | 
         
            +
                                )
         
     | 
| 533 | 
         
            +
                                
         
     | 
| 534 | 
         
            +
                                inference_steps = gr.Slider(
         
     | 
| 535 | 
         
            +
                                    minimum=10,
         
     | 
| 536 | 
         
            +
                                    maximum=100,
         
     | 
| 537 | 
         
            +
                                    value=50,
         
     | 
| 538 | 
         
            +
                                    step=5,
         
     | 
| 539 | 
         
            +
                                    label="⚡ Steps",
         
     | 
| 540 | 
         
            +
                                )
         
     | 
| 541 | 
         
            +
                                
         
     | 
| 542 | 
         
            +
                                sample_nums = gr.Slider(
         
     | 
| 543 | 
         
            +
                                    minimum=1,
         
     | 
| 544 | 
         
            +
                                    maximum=6,
         
     | 
| 545 | 
         
            +
                                    value=1,
         
     | 
| 546 | 
         
            +
                                    step=1,
         
     | 
| 547 | 
         
            +
                                    label="🎲 Sample Nums",
         
     | 
| 548 | 
         
            +
                                )
         
     | 
| 549 | 
         
            +
                            
         
     | 
| 550 | 
         
            +
                            generate_btn = gr.Button(
         
     | 
| 551 | 
         
            +
                                "🎵 Generate Audio", 
         
     | 
| 552 | 
         
            +
                                variant="primary",
         
     | 
| 553 | 
         
            +
                                elem_classes=["generate-btn"]
         
     | 
| 554 | 
         
            +
                            )
         
     | 
| 555 | 
         
            +
                        
         
     | 
| 556 | 
         
            +
                        # Results section
         
     | 
| 557 | 
         
            +
                        with gr.Column(scale=1, elem_classes=["output-section"]):
         
     | 
| 558 | 
         
            +
                            gr.Markdown("### 🎥 Generated Results")
         
     | 
| 559 | 
         
            +
                            
         
     | 
| 560 | 
         
            +
                            # Multi-video gallery for displaying multiple generated samples
         
     | 
| 561 | 
         
            +
                            with gr.Column():
         
     | 
| 562 | 
         
            +
                                # Primary video (Sample 1)
         
     | 
| 563 | 
         
            +
                                video_output_1 = gr.Video(
         
     | 
| 564 | 
         
            +
                                    label="Sample 1",
         
     | 
| 565 | 
         
            +
                                    height=250,
         
     | 
| 566 | 
         
            +
                                    visible=True
         
     | 
| 567 | 
         
            +
                                )
         
     | 
| 568 | 
         
            +
                                
         
     | 
| 569 | 
         
            +
                                # Additional videos (Samples 2-6) - initially hidden
         
     | 
| 570 | 
         
            +
                                with gr.Row(elem_classes=["additional-samples"]):
         
     | 
| 571 | 
         
            +
                                    with gr.Column(scale=1):
         
     | 
| 572 | 
         
            +
                                        video_output_2 = gr.Video(
         
     | 
| 573 | 
         
            +
                                            label="Sample 2",
         
     | 
| 574 | 
         
            +
                                            height=150,
         
     | 
| 575 | 
         
            +
                                            visible=False
         
     | 
| 576 | 
         
            +
                                        )
         
     | 
| 577 | 
         
            +
                                        video_output_3 = gr.Video(
         
     | 
| 578 | 
         
            +
                                            label="Sample 3", 
         
     | 
| 579 | 
         
            +
                                            height=150,
         
     | 
| 580 | 
         
            +
                                            visible=False
         
     | 
| 581 | 
         
            +
                                        )
         
     | 
| 582 | 
         
            +
                                    with gr.Column(scale=1):
         
     | 
| 583 | 
         
            +
                                        video_output_4 = gr.Video(
         
     | 
| 584 | 
         
            +
                                            label="Sample 4",
         
     | 
| 585 | 
         
            +
                                            height=150,
         
     | 
| 586 | 
         
            +
                                            visible=False
         
     | 
| 587 | 
         
            +
                                        )
         
     | 
| 588 | 
         
            +
                                        video_output_5 = gr.Video(
         
     | 
| 589 | 
         
            +
                                            label="Sample 5",
         
     | 
| 590 | 
         
            +
                                            height=150,
         
     | 
| 591 | 
         
            +
                                            visible=False
         
     | 
| 592 | 
         
            +
                                        )
         
     | 
| 593 | 
         
            +
                                
         
     | 
| 594 | 
         
            +
                                # Sample 6 - full width
         
     | 
| 595 | 
         
            +
                                video_output_6 = gr.Video(
         
     | 
| 596 | 
         
            +
                                    label="Sample 6",
         
     | 
| 597 | 
         
            +
                                    height=150,
         
     | 
| 598 | 
         
            +
                                    visible=False
         
     | 
| 599 | 
         
            +
                                )
         
     | 
| 600 | 
         
            +
                            
         
     | 
| 601 | 
         
            +
                            result_text = gr.Textbox(
         
     | 
| 602 | 
         
            +
                                label="Status",
         
     | 
| 603 | 
         
            +
                                interactive=False,
         
     | 
| 604 | 
         
            +
                                lines=2
         
     | 
| 605 | 
         
            +
                            )
         
     | 
| 606 | 
         
            +
                    
         
     | 
| 607 | 
         
            +
                    # Examples section at the bottom
         
     | 
| 608 | 
         
            +
                    with gr.Column(elem_classes=["examples-section"]):
         
     | 
| 609 | 
         
            +
                        gr.Markdown("### 🌟 Examples")
         
     | 
| 610 | 
         
            +
                        gr.Markdown("Click on any example to load it into the interface above")
         
     | 
| 611 | 
         
            +
                        
         
     | 
| 612 | 
         
            +
                        # Define your custom examples here - 8 examples total
         
     | 
| 613 | 
         
            +
                        examples_data = [
         
     | 
| 614 | 
         
            +
                            # Example 1
         
     | 
| 615 | 
         
            +
                            {
         
     | 
| 616 | 
         
            +
                                "caption": "A person walks on frozen ice",
         
     | 
| 617 | 
         
            +
                                "video_path": "examples/1_video.mp4",
         
     | 
| 618 | 
         
            +
                                "result_path": "examples/1_result.mp4"
         
     | 
| 619 | 
         
            +
                            },
         
     | 
| 620 | 
         
            +
                            # Example 2
         
     | 
| 621 | 
         
            +
                            {
         
     | 
| 622 | 
         
            +
                                "caption": "With a faint sound as their hands parted, the two embraced, a soft 'mm' escaping between them.",
         
     | 
| 623 | 
         
            +
                                "video_path": "examples/2_video.mp4",
         
     | 
| 624 | 
         
            +
                                "result_path": "examples/2_result.mp4"
         
     | 
| 625 | 
         
            +
                            },
         
     | 
| 626 | 
         
            +
                            # Example 3
         
     | 
| 627 | 
         
            +
                            {
         
     | 
| 628 | 
         
            +
                                "caption": "The sound of the number 3's bouncing footsteps is as light and clear as glass marbles hitting the ground. Each step carries a magical sound.", 
         
     | 
| 629 | 
         
            +
                                "video_path": "examples/3_video.mp4",
         
     | 
| 630 | 
         
            +
                                "result_path": "examples/3_result.mp4"
         
     | 
| 631 | 
         
            +
                            },
         
     | 
| 632 | 
         
            +
                            # Example 4
         
     | 
| 633 | 
         
            +
                            {
         
     | 
| 634 | 
         
            +
                                "caption": "gentle gurgling of the stream's current, and music plays in the background which is a beautiful and serene piano solo with a hint of classical charm, evoking a sense of peace and serenity in people's hearts.",
         
     | 
| 635 | 
         
            +
                                "video_path": "examples/4_video.mp4",
         
     | 
| 636 | 
         
            +
                                "result_path": "examples/4_result.mp4"
         
     | 
| 637 | 
         
            +
                            },
         
     | 
| 638 | 
         
            +
                            # Example 5 - Add your new examples here
         
     | 
| 639 | 
         
            +
                            {
         
     | 
| 640 | 
         
            +
                                "caption": "snow crunching under the snowboard's edge.",
         
     | 
| 641 | 
         
            +
                                "video_path": "examples/5_video.mp4",
         
     | 
| 642 | 
         
            +
                                "result_path": "examples/5_result.mp4"
         
     | 
| 643 | 
         
            +
                            },
         
     | 
| 644 | 
         
            +
                            # Example 6
         
     | 
| 645 | 
         
            +
                            {
         
     | 
| 646 | 
         
            +
                                "caption": "The crackling of the fire, the whooshing of the flames, and the occasional crisp popping of charred leaves filled the forest.",
         
     | 
| 647 | 
         
            +
                                "video_path": "examples/6_video.mp4",
         
     | 
| 648 | 
         
            +
                                "result_path": "examples/6_result.mp4"
         
     | 
| 649 | 
         
            +
                            },
         
     | 
| 650 | 
         
            +
                            # Example 7
         
     | 
| 651 | 
         
            +
                            {
         
     | 
| 652 | 
         
            +
                                "caption": "humming of the scooter engine accelerates slowly.",
         
     | 
| 653 | 
         
            +
                                "video_path": "examples/7_video.mp4",
         
     | 
| 654 | 
         
            +
                                "result_path": "examples/7_result.mp4"
         
     | 
| 655 | 
         
            +
                            },
         
     | 
| 656 | 
         
            +
                            # Example 8
         
     | 
| 657 | 
         
            +
                            {
         
     | 
| 658 | 
         
            +
                                "caption": "splash of water and loud thud as person hits the surface.",
         
     | 
| 659 | 
         
            +
                                "video_path": "examples/8_video.mp4",
         
     | 
| 660 | 
         
            +
                                "result_path": "examples/8_result.mp4"
         
     | 
| 661 | 
         
            +
                            }
         
     | 
| 662 | 
         
            +
                        ]
         
     | 
| 663 | 
         
            +
                        
         
     | 
| 664 | 
         
            +
                        # Create example grid - 4 examples per row, 2 rows total
         
     | 
| 665 | 
         
            +
                        example_buttons = []
         
     | 
| 666 | 
         
            +
                        for row in range(2):  # 2 rows
         
     | 
| 667 | 
         
            +
                            with gr.Row(elem_classes=["example-grid-row"]):
         
     | 
| 668 | 
         
            +
                                for col in range(4):  # 4 columns
         
     | 
| 669 | 
         
            +
                                    idx = row * 4 + col
         
     | 
| 670 | 
         
            +
                                    if idx < len(examples_data):
         
     | 
| 671 | 
         
            +
                                        example = examples_data[idx]
         
     | 
| 672 | 
         
            +
                                        
         
     | 
| 673 | 
         
            +
                                        with gr.Column(scale=1, elem_classes=["example-item"]):
         
     | 
| 674 | 
         
            +
                                            # Video thumbnail
         
     | 
| 675 | 
         
            +
                                            if os.path.exists(example['video_path']):
         
     | 
| 676 | 
         
            +
                                                example_video = gr.Video(
         
     | 
| 677 | 
         
            +
                                                    value=example['video_path'],
         
     | 
| 678 | 
         
            +
                                                    label=f"Example {idx+1}",
         
     | 
| 679 | 
         
            +
                                                    interactive=False,
         
     | 
| 680 | 
         
            +
                                                    show_label=True,
         
     | 
| 681 | 
         
            +
                                                    height=180
         
     | 
| 682 | 
         
            +
                                                )
         
     | 
| 683 | 
         
            +
                                            else:
         
     | 
| 684 | 
         
            +
                                                example_video = gr.HTML(f"""
         
     | 
| 685 | 
         
            +
                                                <div style="background: #f0f0f0; padding: 15px; text-align: center; border-radius: 8px; height: 180px; display: flex; align-items: center; justify-content: center;">
         
     | 
| 686 | 
         
            +
                                                    <div>
         
     | 
| 687 | 
         
            +
                                                        <p style="color: #666; margin: 0; font-size: 12px;">📹 Video not found</p>
         
     | 
| 688 | 
         
            +
                                                        <small style="color: #999; font-size: 10px;">{example['video_path']}</small>
         
     | 
| 689 | 
         
            +
                                                    </div>
         
     | 
| 690 | 
         
            +
                                                </div>
         
     | 
| 691 | 
         
            +
                                                """)
         
     | 
| 692 | 
         
            +
                                            
         
     | 
| 693 | 
         
            +
                                            # Caption (truncated for grid layout)
         
     | 
| 694 | 
         
            +
                                            caption_preview = example['caption'][:60] + "..." if len(example['caption']) > 60 else example['caption']
         
     | 
| 695 | 
         
            +
                                            gr.Markdown(f"{caption_preview}", elem_classes=["example-caption"])
         
     | 
| 696 | 
         
            +
                                            
         
     | 
| 697 | 
         
            +
                                            # Load button
         
     | 
| 698 | 
         
            +
                                            example_btn = gr.Button(
         
     | 
| 699 | 
         
            +
                                                f"Load Example {idx+1}",
         
     | 
| 700 | 
         
            +
                                                variant="secondary",
         
     | 
| 701 | 
         
            +
                                                size="sm"
         
     | 
| 702 | 
         
            +
                                            )
         
     | 
| 703 | 
         
            +
                                            example_buttons.append((example_btn, example))
         
     | 
| 704 | 
         
            +
                    
         
     | 
| 705 | 
         
            +
                    # Event handlers
         
     | 
| 706 | 
         
            +
                    def process_inference(video_file, text_prompt, guidance_scale, inference_steps, sample_nums):
         
     | 
| 707 | 
         
            +
                        # Generate videos
         
     | 
| 708 | 
         
            +
                        video_list, status_msg = infer_single_video(
         
     | 
| 709 | 
         
            +
                            video_file, text_prompt, guidance_scale, inference_steps, int(sample_nums)
         
     | 
| 710 | 
         
            +
                        )
         
     | 
| 711 | 
         
            +
                        # Update outputs with proper visibility
         
     | 
| 712 | 
         
            +
                        return update_video_outputs(video_list, status_msg)
         
     | 
| 713 | 
         
            +
                    
         
     | 
| 714 | 
         
            +
                    # Add dynamic visibility control based on sample_nums
         
     | 
| 715 | 
         
            +
                    def update_visibility(sample_nums):
         
     | 
| 716 | 
         
            +
                        sample_nums = int(sample_nums)
         
     | 
| 717 | 
         
            +
                        return [
         
     | 
| 718 | 
         
            +
                            gr.update(visible=True),  # Sample 1 always visible
         
     | 
| 719 | 
         
            +
                            gr.update(visible=sample_nums >= 2),  # Sample 2
         
     | 
| 720 | 
         
            +
                            gr.update(visible=sample_nums >= 3),  # Sample 3
         
     | 
| 721 | 
         
            +
                            gr.update(visible=sample_nums >= 4),  # Sample 4
         
     | 
| 722 | 
         
            +
                            gr.update(visible=sample_nums >= 5),  # Sample 5
         
     | 
| 723 | 
         
            +
                            gr.update(visible=sample_nums >= 6),  # Sample 6
         
     | 
| 724 | 
         
            +
                        ]
         
     | 
| 725 | 
         
            +
                    
         
     | 
| 726 | 
         
            +
                    # Update visibility when sample_nums changes
         
     | 
| 727 | 
         
            +
                    sample_nums.change(
         
     | 
| 728 | 
         
            +
                        fn=update_visibility,
         
     | 
| 729 | 
         
            +
                        inputs=[sample_nums],
         
     | 
| 730 | 
         
            +
                        outputs=[video_output_1, video_output_2, video_output_3, video_output_4, video_output_5, video_output_6]
         
     | 
| 731 | 
         
            +
                    )
         
     | 
| 732 | 
         
            +
                    
         
     | 
| 733 | 
         
            +
                    generate_btn.click(
         
     | 
| 734 | 
         
            +
                        fn=process_inference,
         
     | 
| 735 | 
         
            +
                        inputs=[video_input, text_input, guidance_scale, inference_steps, sample_nums],
         
     | 
| 736 | 
         
            +
                        outputs=[
         
     | 
| 737 | 
         
            +
                            video_output_1,  # Sample 1 value
         
     | 
| 738 | 
         
            +
                            video_output_2,  # Sample 2 value  
         
     | 
| 739 | 
         
            +
                            video_output_3,  # Sample 3 value
         
     | 
| 740 | 
         
            +
                            video_output_4,  # Sample 4 value
         
     | 
| 741 | 
         
            +
                            video_output_5,  # Sample 5 value
         
     | 
| 742 | 
         
            +
                            video_output_6,  # Sample 6 value
         
     | 
| 743 | 
         
            +
                            result_text
         
     | 
| 744 | 
         
            +
                        ]
         
     | 
| 745 | 
         
            +
                    )
         
     | 
| 746 | 
         
            +
                    
         
     | 
| 747 | 
         
            +
                    # Add click handlers for example buttons
         
     | 
| 748 | 
         
            +
                    for btn, example in example_buttons:
         
     | 
| 749 | 
         
            +
                        def create_example_handler(ex):
         
     | 
| 750 | 
         
            +
                            def handler():
         
     | 
| 751 | 
         
            +
                                # Check if files exist, if not, return placeholder message
         
     | 
| 752 | 
         
            +
                                if os.path.exists(ex['video_path']):
         
     | 
| 753 | 
         
            +
                                    video_file = ex['video_path']
         
     | 
| 754 | 
         
            +
                                else:
         
     | 
| 755 | 
         
            +
                                    video_file = None
         
     | 
| 756 | 
         
            +
                                    
         
     | 
| 757 | 
         
            +
                                if os.path.exists(ex['result_path']):
         
     | 
| 758 | 
         
            +
                                    result_video = ex['result_path']
         
     | 
| 759 | 
         
            +
                                else:
         
     | 
| 760 | 
         
            +
                                    result_video = None
         
     | 
| 761 | 
         
            +
                                
         
     | 
| 762 | 
         
            +
                                status_msg = f"✅ Loaded example with caption: {ex['caption'][:50]}..."
         
     | 
| 763 | 
         
            +
                                if not video_file:
         
     | 
| 764 | 
         
            +
                                    status_msg += f"\n⚠️ Video file not found: {ex['video_path']}"
         
     | 
| 765 | 
         
            +
                                if not result_video:
         
     | 
| 766 | 
         
            +
                                    status_msg += f"\n⚠️ Result video not found: {ex['result_path']}"
         
     | 
| 767 | 
         
            +
                                    
         
     | 
| 768 | 
         
            +
                                return video_file, ex['caption'], result_video, status_msg
         
     | 
| 769 | 
         
            +
                            return handler
         
     | 
| 770 | 
         
            +
                        
         
     | 
| 771 | 
         
            +
                        btn.click(
         
     | 
| 772 | 
         
            +
                            fn=create_example_handler(example),
         
     | 
| 773 | 
         
            +
                            outputs=[video_input, text_input, video_output_1, result_text]
         
     | 
| 774 | 
         
            +
                        )
         
     | 
| 775 | 
         
            +
                    
         
     | 
| 776 | 
         
            +
                    # Footer
         
     | 
| 777 | 
         
            +
                    gr.HTML("""
         
     | 
| 778 | 
         
            +
                    <div class="footer-text">
         
     | 
| 779 | 
         
            +
                        <p>🚀 Powered by HunyuanVideo-Foley | Generate high-quality audio from video and text descriptions</p>
         
     | 
| 780 | 
         
            +
                    </div>
         
     | 
| 781 | 
         
            +
                    """)
         
     | 
| 782 | 
         
            +
                
         
     | 
| 783 | 
         
            +
                return app
         
     | 
| 784 | 
         
            +
             
     | 
| 785 | 
         
            +
            def set_manual_seed(global_seed):
         
     | 
| 786 | 
         
            +
                random.seed(global_seed)
         
     | 
| 787 | 
         
            +
                np.random.seed(global_seed)
         
     | 
| 788 | 
         
            +
                torch.manual_seed(global_seed)
         
     | 
| 789 | 
         
            +
             
     | 
| 790 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 791 | 
         
            +
                set_manual_seed(1)
         
     | 
| 792 | 
         
            +
                # Setup logging
         
     | 
| 793 | 
         
            +
                logger.remove()
         
     | 
| 794 | 
         
            +
                logger.add(lambda msg: print(msg, end=''), level="INFO")
         
     | 
| 795 | 
         
            +
                
         
     | 
| 796 | 
         
            +
                # Auto-load model
         
     | 
| 797 | 
         
            +
                logger.info("Starting application and loading model...")
         
     | 
| 798 | 
         
            +
                model_load_result = auto_load_models()
         
     | 
| 799 | 
         
            +
                logger.info(model_load_result)
         
     | 
| 800 | 
         
            +
                
         
     | 
| 801 | 
         
            +
                # Create and launch Gradio app
         
     | 
| 802 | 
         
            +
                app = create_gradio_interface()
         
     | 
| 803 | 
         
            +
                
         
     | 
| 804 | 
         
            +
                # Log completion status
         
     | 
| 805 | 
         
            +
                if "successfully" in model_load_result:
         
     | 
| 806 | 
         
            +
                    logger.info("Application ready, model loaded")
         
     | 
| 807 | 
         
            +
                
         
     | 
| 808 | 
         
            +
                app.launch(
         
     | 
| 809 | 
         
            +
                    server_name="0.0.0.0",
         
     | 
| 810 | 
         
            +
                    server_port=8080,
         
     | 
| 811 | 
         
            +
                    share=False,
         
     | 
| 812 | 
         
            +
                    debug=False,
         
     | 
| 813 | 
         
            +
                    show_error=True
         
     | 
| 814 | 
         
            +
                )
         
     | 
    	
        assets/data_pipeline.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        assets/model_arch.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        assets/pan_chart.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        configs/hunyuanvideo-foley-xxl.yaml
    ADDED
    
    | 
         @@ -0,0 +1,49 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model_config:
         
     | 
| 2 | 
         
            +
              model_name: HunyuanVideo-Foley-XXL
         
     | 
| 3 | 
         
            +
              model_type: 1d
         
     | 
| 4 | 
         
            +
              model_precision: bf16
         
     | 
| 5 | 
         
            +
              model_kwargs:
         
     | 
| 6 | 
         
            +
                depth_triple_blocks: 18
         
     | 
| 7 | 
         
            +
                depth_single_blocks: 36
         
     | 
| 8 | 
         
            +
                hidden_size: 1536
         
     | 
| 9 | 
         
            +
                num_heads: 12
         
     | 
| 10 | 
         
            +
                mlp_ratio: 4
         
     | 
| 11 | 
         
            +
                mlp_act_type: "gelu_tanh"
         
     | 
| 12 | 
         
            +
                qkv_bias: True
         
     | 
| 13 | 
         
            +
                qk_norm: True
         
     | 
| 14 | 
         
            +
                qk_norm_type: "rms"
         
     | 
| 15 | 
         
            +
                attn_mode: "torch"
         
     | 
| 16 | 
         
            +
                embedder_type: "default"
         
     | 
| 17 | 
         
            +
                interleaved_audio_visual_rope: True
         
     | 
| 18 | 
         
            +
                enable_learnable_empty_visual_feat: True
         
     | 
| 19 | 
         
            +
                sync_modulation: False
         
     | 
| 20 | 
         
            +
                add_sync_feat_to_audio: True
         
     | 
| 21 | 
         
            +
                cross_attention: True
         
     | 
| 22 | 
         
            +
                use_attention_mask: False
         
     | 
| 23 | 
         
            +
                condition_projection: "linear"
         
     | 
| 24 | 
         
            +
                sync_feat_dim: 768 # syncformer 768 dim
         
     | 
| 25 | 
         
            +
                condition_dim: 768  # clap 768 text condition dim (clip-text)
         
     | 
| 26 | 
         
            +
                clip_dim: 768  # siglip2 visual dim
         
     | 
| 27 | 
         
            +
                audio_vae_latent_dim: 128 
         
     | 
| 28 | 
         
            +
                audio_frame_rate: 50
         
     | 
| 29 | 
         
            +
                patch_size: 1
         
     | 
| 30 | 
         
            +
                rope_dim_list: null
         
     | 
| 31 | 
         
            +
                rope_theta: 10000
         
     | 
| 32 | 
         
            +
                text_length: 77
         
     | 
| 33 | 
         
            +
                clip_length: 64 
         
     | 
| 34 | 
         
            +
                sync_length: 192
         
     | 
| 35 | 
         
            +
                use_mmaudio_singleblock: True
         
     | 
| 36 | 
         
            +
                depth_triple_ssl_encoder: null
         
     | 
| 37 | 
         
            +
                depth_single_ssl_encoder: 8
         
     | 
| 38 | 
         
            +
                use_repa_with_audiossl: True
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            diffusion_config:
         
     | 
| 41 | 
         
            +
              denoise_type: "flow"
         
     | 
| 42 | 
         
            +
              flow_path_type: "linear"
         
     | 
| 43 | 
         
            +
              flow_predict_type: "velocity"
         
     | 
| 44 | 
         
            +
              flow_reverse: True
         
     | 
| 45 | 
         
            +
              flow_solver: "euler"
         
     | 
| 46 | 
         
            +
              sample_flow_shift: 1.0
         
     | 
| 47 | 
         
            +
              sample_use_flux_shift: False
         
     | 
| 48 | 
         
            +
              flux_base_shift: 0.5
         
     | 
| 49 | 
         
            +
              flux_max_shift: 1.15
         
     | 
    	
        examples/1_result.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:7f3f49d6130592f479b0aca5f02ba25960140ed8d9d17340ff7f6306b39096a8
         
     | 
| 3 | 
         
            +
            size 11357340
         
     | 
    	
        examples/1_video.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:54fc2de0b52f6969157b9caff212ffffddd4d34a75efb47ef7e7f8352d0a38db
         
     | 
| 3 | 
         
            +
            size 11181543
         
     | 
    	
        examples/2_result.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:cf4f324b158f6a6926e77bbd0791610d79f0c3a600a571ed8ec61b0b7e645e46
         
     | 
| 3 | 
         
            +
            size 1720732
         
     | 
    	
        examples/2_video.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:512b185682ec8e60407dff65718443d7a28c75c79aec1e733b5abf7433af41a7
         
     | 
| 3 | 
         
            +
            size 1636945
         
     | 
    	
        examples/3_result.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:df90300335b7ab1fb2fc4c020976837c6b3781f796e211bcbbaa30d34353d3e5
         
     | 
| 3 | 
         
            +
            size 1738462
         
     | 
    	
        examples/3_video.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:e79f80cf939fcb507e3fe218f61146d4fd3949a84e802b0f5b67bb2e981931a7
         
     | 
| 3 | 
         
            +
            size 1652180
         
     | 
    	
        examples/4_result.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:c7d2b5b63f6756f719d53e8087f772aa6bb25f31fbcd9f1cbae9e075fc841a2c
         
     | 
| 3 | 
         
            +
            size 45242387
         
     | 
    	
        examples/4_video.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:f94cfd97634f3df085672ce2a91805697320507a716360bf85ba6eabb5a4b6f0
         
     | 
| 3 | 
         
            +
            size 45066257
         
     | 
    	
        examples/5_result.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:bf1db36336822b54b4e0e0aa8c98334fc2b97b9288271ddc9d42a45417b1f1d9
         
     | 
| 3 | 
         
            +
            size 40423834
         
     | 
    	
        examples/5_video.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:746954bde2d5e693beecd8e3661bcd66ff0e55a8143f4f3b37f0b6d3873a8fff
         
     | 
| 3 | 
         
            +
            size 40248335
         
     | 
    	
        examples/6_result.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:7778c26a677c04e93dc722cee44b896c5281bea8328a10800a08b98865419cd0
         
     | 
| 3 | 
         
            +
            size 4005580
         
     | 
    	
        examples/6_video.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:866b2cff3441ddd686e551181c48ad7ca718626a489cf64198626b42bd732366
         
     | 
| 3 | 
         
            +
            size 3872852
         
     | 
    	
        examples/7_result.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:491114177a4ceeb50a6edfa5fca14fc6ce4fdb61ee7dfb0e13983236c42ee10d
         
     | 
| 3 | 
         
            +
            size 32307884
         
     | 
    	
        examples/7_video.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:7bc8a31e867245f6a6a6fbfa9778ac4c12e816184dc70324dc92d4496a36f62b
         
     | 
| 3 | 
         
            +
            size 32131367
         
     | 
    	
        examples/8_result.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:17053c5f8a373d656be2d2030619a71a5fd55db6365f8d54234402121e6030ce
         
     | 
| 3 | 
         
            +
            size 29544164
         
     | 
    	
        examples/8_video.mp4
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:45a1f35974ad6e2d86304828fdeb230d9b008aae2f10cff8c87d71a8dcc6491e
         
     | 
| 3 | 
         
            +
            size 29367637
         
     | 
    	
        hunyuanvideo_foley/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc
    ADDED
    
    | 
         Binary file (173 Bytes). View file 
     | 
| 
         | 
    	
        hunyuanvideo_foley/constants.py
    ADDED
    
    | 
         @@ -0,0 +1,57 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Constants used throughout the HunyuanVideo-Foley project."""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from typing import Dict, List
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # Model configuration
         
     | 
| 6 | 
         
            +
            DEFAULT_AUDIO_SAMPLE_RATE = 48000
         
     | 
| 7 | 
         
            +
            DEFAULT_VIDEO_FPS = 25
         
     | 
| 8 | 
         
            +
            DEFAULT_AUDIO_CHANNELS = 2
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            # Video processing
         
     | 
| 11 | 
         
            +
            MAX_VIDEO_DURATION_SECONDS = 15.0
         
     | 
| 12 | 
         
            +
            MIN_VIDEO_DURATION_SECONDS = 1.0
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            # Audio processing
         
     | 
| 15 | 
         
            +
            AUDIO_VAE_LATENT_DIM = 128
         
     | 
| 16 | 
         
            +
            AUDIO_FRAME_RATE = 75  # frames per second in latent space
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            # Visual features
         
     | 
| 19 | 
         
            +
            FPS_VISUAL: Dict[str, int] = {
         
     | 
| 20 | 
         
            +
                "siglip2": 8, 
         
     | 
| 21 | 
         
            +
                "synchformer": 25
         
     | 
| 22 | 
         
            +
            }
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            # Model paths (can be overridden by environment variables)
         
     | 
| 25 | 
         
            +
            DEFAULT_MODEL_PATH = "./pretrained_models/"
         
     | 
| 26 | 
         
            +
            DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            # Inference parameters
         
     | 
| 29 | 
         
            +
            DEFAULT_GUIDANCE_SCALE = 4.5
         
     | 
| 30 | 
         
            +
            DEFAULT_NUM_INFERENCE_STEPS = 50
         
     | 
| 31 | 
         
            +
            MIN_GUIDANCE_SCALE = 1.0
         
     | 
| 32 | 
         
            +
            MAX_GUIDANCE_SCALE = 10.0
         
     | 
| 33 | 
         
            +
            MIN_INFERENCE_STEPS = 10
         
     | 
| 34 | 
         
            +
            MAX_INFERENCE_STEPS = 100
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            # Text processing
         
     | 
| 37 | 
         
            +
            MAX_TEXT_LENGTH = 100
         
     | 
| 38 | 
         
            +
            DEFAULT_NEGATIVE_PROMPT = "noisy, harsh"
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            # File extensions
         
     | 
| 41 | 
         
            +
            SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"]
         
     | 
| 42 | 
         
            +
            SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"]
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            # Quality settings
         
     | 
| 45 | 
         
            +
            AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = {
         
     | 
| 46 | 
         
            +
                "high": ["-b:a", "192k"],
         
     | 
| 47 | 
         
            +
                "medium": ["-b:a", "128k"], 
         
     | 
| 48 | 
         
            +
                "low": ["-b:a", "96k"]
         
     | 
| 49 | 
         
            +
            }
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            # Error messages
         
     | 
| 52 | 
         
            +
            ERROR_MESSAGES: Dict[str, str] = {
         
     | 
| 53 | 
         
            +
                "model_not_loaded": "Model is not loaded. Please load the model first.",
         
     | 
| 54 | 
         
            +
                "invalid_video_format": "Unsupported video format. Supported formats: {formats}",
         
     | 
| 55 | 
         
            +
                "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds",
         
     | 
| 56 | 
         
            +
                "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html"
         
     | 
| 57 | 
         
            +
            }
         
     | 
    	
        hunyuanvideo_foley/models/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        hunyuanvideo_foley/models/__pycache__/mmaudio_layer.cpython-312.pyc
    ADDED
    
    | 
         Binary file (12.1 kB). View file 
     | 
| 
         | 
    	
        hunyuanvideo_foley/models/dac_vae/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,16 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            __version__ = "1.0.0"
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # preserved here for legacy reasons
         
     | 
| 4 | 
         
            +
            __model_version__ = "latest"
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import audiotools
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            audiotools.ml.BaseModel.INTERN += ["dac.**"]
         
     | 
| 9 | 
         
            +
            audiotools.ml.BaseModel.EXTERN += ["einops"]
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from . import nn
         
     | 
| 13 | 
         
            +
            from . import model
         
     | 
| 14 | 
         
            +
            from . import utils
         
     | 
| 15 | 
         
            +
            from .model import DAC
         
     | 
| 16 | 
         
            +
            from .model import DACFile
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/__main__.py
    ADDED
    
    | 
         @@ -0,0 +1,36 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import sys
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import argbind
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from .utils import download
         
     | 
| 6 | 
         
            +
            from .utils.decode import decode
         
     | 
| 7 | 
         
            +
            from .utils.encode import encode
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            STAGES = ["encode", "decode", "download"]
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            def run(stage: str):
         
     | 
| 13 | 
         
            +
                """Run stages.
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                Parameters
         
     | 
| 16 | 
         
            +
                ----------
         
     | 
| 17 | 
         
            +
                stage : str
         
     | 
| 18 | 
         
            +
                    Stage to run
         
     | 
| 19 | 
         
            +
                """
         
     | 
| 20 | 
         
            +
                if stage not in STAGES:
         
     | 
| 21 | 
         
            +
                    raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
         
     | 
| 22 | 
         
            +
                stage_fn = globals()[stage]
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                if stage == "download":
         
     | 
| 25 | 
         
            +
                    stage_fn()
         
     | 
| 26 | 
         
            +
                    return
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                stage_fn()
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 32 | 
         
            +
                group = sys.argv.pop(1)
         
     | 
| 33 | 
         
            +
                args = argbind.parse_args(group=group)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                with argbind.scope(args):
         
     | 
| 36 | 
         
            +
                    run(group)
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/model/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .base import CodecMixin
         
     | 
| 2 | 
         
            +
            from .base import DACFile
         
     | 
| 3 | 
         
            +
            from .dac import DAC
         
     | 
| 4 | 
         
            +
            from .discriminator import Discriminator
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/model/base.py
    ADDED
    
    | 
         @@ -0,0 +1,301 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 3 | 
         
            +
            from pathlib import Path
         
     | 
| 4 | 
         
            +
            from typing import Union
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import tqdm
         
     | 
| 9 | 
         
            +
            from audiotools import AudioSignal
         
     | 
| 10 | 
         
            +
            from torch import nn
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            SUPPORTED_VERSIONS = ["1.0.0"]
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            @dataclass
         
     | 
| 16 | 
         
            +
            class DACFile:
         
     | 
| 17 | 
         
            +
                codes: torch.Tensor
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                # Metadata
         
     | 
| 20 | 
         
            +
                chunk_length: int
         
     | 
| 21 | 
         
            +
                original_length: int
         
     | 
| 22 | 
         
            +
                input_db: float
         
     | 
| 23 | 
         
            +
                channels: int
         
     | 
| 24 | 
         
            +
                sample_rate: int
         
     | 
| 25 | 
         
            +
                padding: bool
         
     | 
| 26 | 
         
            +
                dac_version: str
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                def save(self, path):
         
     | 
| 29 | 
         
            +
                    artifacts = {
         
     | 
| 30 | 
         
            +
                        "codes": self.codes.numpy().astype(np.uint16),
         
     | 
| 31 | 
         
            +
                        "metadata": {
         
     | 
| 32 | 
         
            +
                            "input_db": self.input_db.numpy().astype(np.float32),
         
     | 
| 33 | 
         
            +
                            "original_length": self.original_length,
         
     | 
| 34 | 
         
            +
                            "sample_rate": self.sample_rate,
         
     | 
| 35 | 
         
            +
                            "chunk_length": self.chunk_length,
         
     | 
| 36 | 
         
            +
                            "channels": self.channels,
         
     | 
| 37 | 
         
            +
                            "padding": self.padding,
         
     | 
| 38 | 
         
            +
                            "dac_version": SUPPORTED_VERSIONS[-1],
         
     | 
| 39 | 
         
            +
                        },
         
     | 
| 40 | 
         
            +
                    }
         
     | 
| 41 | 
         
            +
                    path = Path(path).with_suffix(".dac")
         
     | 
| 42 | 
         
            +
                    with open(path, "wb") as f:
         
     | 
| 43 | 
         
            +
                        np.save(f, artifacts)
         
     | 
| 44 | 
         
            +
                    return path
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                @classmethod
         
     | 
| 47 | 
         
            +
                def load(cls, path):
         
     | 
| 48 | 
         
            +
                    artifacts = np.load(path, allow_pickle=True)[()]
         
     | 
| 49 | 
         
            +
                    codes = torch.from_numpy(artifacts["codes"].astype(int))
         
     | 
| 50 | 
         
            +
                    if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
         
     | 
| 51 | 
         
            +
                        raise RuntimeError(
         
     | 
| 52 | 
         
            +
                            f"Given file {path} can't be loaded with this version of descript-audio-codec."
         
     | 
| 53 | 
         
            +
                        )
         
     | 
| 54 | 
         
            +
                    return cls(codes=codes, **artifacts["metadata"])
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            class CodecMixin:
         
     | 
| 58 | 
         
            +
                @property
         
     | 
| 59 | 
         
            +
                def padding(self):
         
     | 
| 60 | 
         
            +
                    if not hasattr(self, "_padding"):
         
     | 
| 61 | 
         
            +
                        self._padding = True
         
     | 
| 62 | 
         
            +
                    return self._padding
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                @padding.setter
         
     | 
| 65 | 
         
            +
                def padding(self, value):
         
     | 
| 66 | 
         
            +
                    assert isinstance(value, bool)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    layers = [
         
     | 
| 69 | 
         
            +
                        l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
         
     | 
| 70 | 
         
            +
                    ]
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    for layer in layers:
         
     | 
| 73 | 
         
            +
                        if value:
         
     | 
| 74 | 
         
            +
                            if hasattr(layer, "original_padding"):
         
     | 
| 75 | 
         
            +
                                layer.padding = layer.original_padding
         
     | 
| 76 | 
         
            +
                        else:
         
     | 
| 77 | 
         
            +
                            layer.original_padding = layer.padding
         
     | 
| 78 | 
         
            +
                            layer.padding = tuple(0 for _ in range(len(layer.padding)))
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    self._padding = value
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def get_delay(self):
         
     | 
| 83 | 
         
            +
                    # Any number works here, delay is invariant to input length
         
     | 
| 84 | 
         
            +
                    l_out = self.get_output_length(0)
         
     | 
| 85 | 
         
            +
                    L = l_out
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    layers = []
         
     | 
| 88 | 
         
            +
                    for layer in self.modules():
         
     | 
| 89 | 
         
            +
                        if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
         
     | 
| 90 | 
         
            +
                            layers.append(layer)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    for layer in reversed(layers):
         
     | 
| 93 | 
         
            +
                        d = layer.dilation[0]
         
     | 
| 94 | 
         
            +
                        k = layer.kernel_size[0]
         
     | 
| 95 | 
         
            +
                        s = layer.stride[0]
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                        if isinstance(layer, nn.ConvTranspose1d):
         
     | 
| 98 | 
         
            +
                            L = ((L - d * (k - 1) - 1) / s) + 1
         
     | 
| 99 | 
         
            +
                        elif isinstance(layer, nn.Conv1d):
         
     | 
| 100 | 
         
            +
                            L = (L - 1) * s + d * (k - 1) + 1
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                        L = math.ceil(L)
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    l_in = L
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    return (l_in - l_out) // 2
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                def get_output_length(self, input_length):
         
     | 
| 109 | 
         
            +
                    L = input_length
         
     | 
| 110 | 
         
            +
                    # Calculate output length
         
     | 
| 111 | 
         
            +
                    for layer in self.modules():
         
     | 
| 112 | 
         
            +
                        if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
         
     | 
| 113 | 
         
            +
                            d = layer.dilation[0]
         
     | 
| 114 | 
         
            +
                            k = layer.kernel_size[0]
         
     | 
| 115 | 
         
            +
                            s = layer.stride[0]
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                            if isinstance(layer, nn.Conv1d):
         
     | 
| 118 | 
         
            +
                                L = ((L - d * (k - 1) - 1) / s) + 1
         
     | 
| 119 | 
         
            +
                            elif isinstance(layer, nn.ConvTranspose1d):
         
     | 
| 120 | 
         
            +
                                L = (L - 1) * s + d * (k - 1) + 1
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                            L = math.floor(L)
         
     | 
| 123 | 
         
            +
                    return L
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                @torch.no_grad()
         
     | 
| 126 | 
         
            +
                def compress(
         
     | 
| 127 | 
         
            +
                    self,
         
     | 
| 128 | 
         
            +
                    audio_path_or_signal: Union[str, Path, AudioSignal],
         
     | 
| 129 | 
         
            +
                    win_duration: float = 1.0,
         
     | 
| 130 | 
         
            +
                    verbose: bool = False,
         
     | 
| 131 | 
         
            +
                    normalize_db: float = -16,
         
     | 
| 132 | 
         
            +
                    n_quantizers: int = None,
         
     | 
| 133 | 
         
            +
                ) -> DACFile:
         
     | 
| 134 | 
         
            +
                    """Processes an audio signal from a file or AudioSignal object into
         
     | 
| 135 | 
         
            +
                    discrete codes. This function processes the signal in short windows,
         
     | 
| 136 | 
         
            +
                    using constant GPU memory.
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    Parameters
         
     | 
| 139 | 
         
            +
                    ----------
         
     | 
| 140 | 
         
            +
                    audio_path_or_signal : Union[str, Path, AudioSignal]
         
     | 
| 141 | 
         
            +
                        audio signal to reconstruct
         
     | 
| 142 | 
         
            +
                    win_duration : float, optional
         
     | 
| 143 | 
         
            +
                        window duration in seconds, by default 5.0
         
     | 
| 144 | 
         
            +
                    verbose : bool, optional
         
     | 
| 145 | 
         
            +
                        by default False
         
     | 
| 146 | 
         
            +
                    normalize_db : float, optional
         
     | 
| 147 | 
         
            +
                        normalize db, by default -16
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                    Returns
         
     | 
| 150 | 
         
            +
                    -------
         
     | 
| 151 | 
         
            +
                    DACFile
         
     | 
| 152 | 
         
            +
                        Object containing compressed codes and metadata
         
     | 
| 153 | 
         
            +
                        required for decompression
         
     | 
| 154 | 
         
            +
                    """
         
     | 
| 155 | 
         
            +
                    audio_signal = audio_path_or_signal
         
     | 
| 156 | 
         
            +
                    if isinstance(audio_signal, (str, Path)):
         
     | 
| 157 | 
         
            +
                        audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    self.eval()
         
     | 
| 160 | 
         
            +
                    original_padding = self.padding
         
     | 
| 161 | 
         
            +
                    original_device = audio_signal.device
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    audio_signal = audio_signal.clone()
         
     | 
| 164 | 
         
            +
                    audio_signal = audio_signal.to_mono()
         
     | 
| 165 | 
         
            +
                    original_sr = audio_signal.sample_rate
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    resample_fn = audio_signal.resample
         
     | 
| 168 | 
         
            +
                    loudness_fn = audio_signal.loudness
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    # If audio is > 10 minutes long, use the ffmpeg versions
         
     | 
| 171 | 
         
            +
                    if audio_signal.signal_duration >= 10 * 60 * 60:
         
     | 
| 172 | 
         
            +
                        resample_fn = audio_signal.ffmpeg_resample
         
     | 
| 173 | 
         
            +
                        loudness_fn = audio_signal.ffmpeg_loudness
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    original_length = audio_signal.signal_length
         
     | 
| 176 | 
         
            +
                    resample_fn(self.sample_rate)
         
     | 
| 177 | 
         
            +
                    input_db = loudness_fn()
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    if normalize_db is not None:
         
     | 
| 180 | 
         
            +
                        audio_signal.normalize(normalize_db)
         
     | 
| 181 | 
         
            +
                    audio_signal.ensure_max_of_audio()
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                    nb, nac, nt = audio_signal.audio_data.shape
         
     | 
| 184 | 
         
            +
                    audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
         
     | 
| 185 | 
         
            +
                    win_duration = (
         
     | 
| 186 | 
         
            +
                        audio_signal.signal_duration if win_duration is None else win_duration
         
     | 
| 187 | 
         
            +
                    )
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    if audio_signal.signal_duration <= win_duration:
         
     | 
| 190 | 
         
            +
                        # Unchunked compression (used if signal length < win duration)
         
     | 
| 191 | 
         
            +
                        self.padding = True
         
     | 
| 192 | 
         
            +
                        n_samples = nt
         
     | 
| 193 | 
         
            +
                        hop = nt
         
     | 
| 194 | 
         
            +
                    else:
         
     | 
| 195 | 
         
            +
                        # Chunked inference
         
     | 
| 196 | 
         
            +
                        self.padding = False
         
     | 
| 197 | 
         
            +
                        # Zero-pad signal on either side by the delay
         
     | 
| 198 | 
         
            +
                        audio_signal.zero_pad(self.delay, self.delay)
         
     | 
| 199 | 
         
            +
                        n_samples = int(win_duration * self.sample_rate)
         
     | 
| 200 | 
         
            +
                        # Round n_samples to nearest hop length multiple
         
     | 
| 201 | 
         
            +
                        n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
         
     | 
| 202 | 
         
            +
                        hop = self.get_output_length(n_samples)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    codes = []
         
     | 
| 205 | 
         
            +
                    range_fn = range if not verbose else tqdm.trange
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    for i in range_fn(0, nt, hop):
         
     | 
| 208 | 
         
            +
                        x = audio_signal[..., i : i + n_samples]
         
     | 
| 209 | 
         
            +
                        x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                        audio_data = x.audio_data.to(self.device)
         
     | 
| 212 | 
         
            +
                        audio_data = self.preprocess(audio_data, self.sample_rate)
         
     | 
| 213 | 
         
            +
                        _, c, _, _, _ = self.encode(audio_data, n_quantizers)
         
     | 
| 214 | 
         
            +
                        codes.append(c.to(original_device))
         
     | 
| 215 | 
         
            +
                        chunk_length = c.shape[-1]
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    codes = torch.cat(codes, dim=-1)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                    dac_file = DACFile(
         
     | 
| 220 | 
         
            +
                        codes=codes,
         
     | 
| 221 | 
         
            +
                        chunk_length=chunk_length,
         
     | 
| 222 | 
         
            +
                        original_length=original_length,
         
     | 
| 223 | 
         
            +
                        input_db=input_db,
         
     | 
| 224 | 
         
            +
                        channels=nac,
         
     | 
| 225 | 
         
            +
                        sample_rate=original_sr,
         
     | 
| 226 | 
         
            +
                        padding=self.padding,
         
     | 
| 227 | 
         
            +
                        dac_version=SUPPORTED_VERSIONS[-1],
         
     | 
| 228 | 
         
            +
                    )
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    if n_quantizers is not None:
         
     | 
| 231 | 
         
            +
                        codes = codes[:, :n_quantizers, :]
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                    self.padding = original_padding
         
     | 
| 234 | 
         
            +
                    return dac_file
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                @torch.no_grad()
         
     | 
| 237 | 
         
            +
                def decompress(
         
     | 
| 238 | 
         
            +
                    self,
         
     | 
| 239 | 
         
            +
                    obj: Union[str, Path, DACFile],
         
     | 
| 240 | 
         
            +
                    verbose: bool = False,
         
     | 
| 241 | 
         
            +
                ) -> AudioSignal:
         
     | 
| 242 | 
         
            +
                    """Reconstruct audio from a given .dac file
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    Parameters
         
     | 
| 245 | 
         
            +
                    ----------
         
     | 
| 246 | 
         
            +
                    obj : Union[str, Path, DACFile]
         
     | 
| 247 | 
         
            +
                        .dac file location or corresponding DACFile object.
         
     | 
| 248 | 
         
            +
                    verbose : bool, optional
         
     | 
| 249 | 
         
            +
                        Prints progress if True, by default False
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                    Returns
         
     | 
| 252 | 
         
            +
                    -------
         
     | 
| 253 | 
         
            +
                    AudioSignal
         
     | 
| 254 | 
         
            +
                        Object with the reconstructed audio
         
     | 
| 255 | 
         
            +
                    """
         
     | 
| 256 | 
         
            +
                    self.eval()
         
     | 
| 257 | 
         
            +
                    if isinstance(obj, (str, Path)):
         
     | 
| 258 | 
         
            +
                        obj = DACFile.load(obj)
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                    original_padding = self.padding
         
     | 
| 261 | 
         
            +
                    self.padding = obj.padding
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    range_fn = range if not verbose else tqdm.trange
         
     | 
| 264 | 
         
            +
                    codes = obj.codes
         
     | 
| 265 | 
         
            +
                    original_device = codes.device
         
     | 
| 266 | 
         
            +
                    chunk_length = obj.chunk_length
         
     | 
| 267 | 
         
            +
                    recons = []
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    for i in range_fn(0, codes.shape[-1], chunk_length):
         
     | 
| 270 | 
         
            +
                        c = codes[..., i : i + chunk_length].to(self.device)
         
     | 
| 271 | 
         
            +
                        z = self.quantizer.from_codes(c)[0]
         
     | 
| 272 | 
         
            +
                        r = self.decode(z)
         
     | 
| 273 | 
         
            +
                        recons.append(r.to(original_device))
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    recons = torch.cat(recons, dim=-1)
         
     | 
| 276 | 
         
            +
                    recons = AudioSignal(recons, self.sample_rate)
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                    resample_fn = recons.resample
         
     | 
| 279 | 
         
            +
                    loudness_fn = recons.loudness
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                    # If audio is > 10 minutes long, use the ffmpeg versions
         
     | 
| 282 | 
         
            +
                    if recons.signal_duration >= 10 * 60 * 60:
         
     | 
| 283 | 
         
            +
                        resample_fn = recons.ffmpeg_resample
         
     | 
| 284 | 
         
            +
                        loudness_fn = recons.ffmpeg_loudness
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
                    if obj.input_db is not None:
         
     | 
| 287 | 
         
            +
                        recons.normalize(obj.input_db)
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    resample_fn(obj.sample_rate)
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    if obj.original_length is not None:
         
     | 
| 292 | 
         
            +
                        recons = recons[..., : obj.original_length]
         
     | 
| 293 | 
         
            +
                        loudness_fn()
         
     | 
| 294 | 
         
            +
                        recons.audio_data = recons.audio_data.reshape(
         
     | 
| 295 | 
         
            +
                            -1, obj.channels, obj.original_length
         
     | 
| 296 | 
         
            +
                        )
         
     | 
| 297 | 
         
            +
                    else:
         
     | 
| 298 | 
         
            +
                        loudness_fn()
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                    self.padding = original_padding
         
     | 
| 301 | 
         
            +
                    return recons
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/model/dac.py
    ADDED
    
    | 
         @@ -0,0 +1,410 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            from typing import List
         
     | 
| 3 | 
         
            +
            from typing import Union
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            from audiotools import AudioSignal
         
     | 
| 8 | 
         
            +
            from audiotools.ml import BaseModel
         
     | 
| 9 | 
         
            +
            from torch import nn
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from .base import CodecMixin
         
     | 
| 12 | 
         
            +
            from ..nn.layers import Snake1d
         
     | 
| 13 | 
         
            +
            from ..nn.layers import WNConv1d
         
     | 
| 14 | 
         
            +
            from ..nn.layers import WNConvTranspose1d
         
     | 
| 15 | 
         
            +
            from ..nn.quantize import ResidualVectorQuantize
         
     | 
| 16 | 
         
            +
            from ..nn.vae_utils import DiagonalGaussianDistribution
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def init_weights(m):
         
     | 
| 20 | 
         
            +
                if isinstance(m, nn.Conv1d):
         
     | 
| 21 | 
         
            +
                    nn.init.trunc_normal_(m.weight, std=0.02)
         
     | 
| 22 | 
         
            +
                    nn.init.constant_(m.bias, 0)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            class ResidualUnit(nn.Module):
         
     | 
| 26 | 
         
            +
                def __init__(self, dim: int = 16, dilation: int = 1):
         
     | 
| 27 | 
         
            +
                    super().__init__()
         
     | 
| 28 | 
         
            +
                    pad = ((7 - 1) * dilation) // 2
         
     | 
| 29 | 
         
            +
                    self.block = nn.Sequential(
         
     | 
| 30 | 
         
            +
                        Snake1d(dim),
         
     | 
| 31 | 
         
            +
                        WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
         
     | 
| 32 | 
         
            +
                        Snake1d(dim),
         
     | 
| 33 | 
         
            +
                        WNConv1d(dim, dim, kernel_size=1),
         
     | 
| 34 | 
         
            +
                    )
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                def forward(self, x):
         
     | 
| 37 | 
         
            +
                    y = self.block(x)
         
     | 
| 38 | 
         
            +
                    pad = (x.shape[-1] - y.shape[-1]) // 2
         
     | 
| 39 | 
         
            +
                    if pad > 0:
         
     | 
| 40 | 
         
            +
                        x = x[..., pad:-pad]
         
     | 
| 41 | 
         
            +
                    return x + y
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            class EncoderBlock(nn.Module):
         
     | 
| 45 | 
         
            +
                def __init__(self, dim: int = 16, stride: int = 1):
         
     | 
| 46 | 
         
            +
                    super().__init__()
         
     | 
| 47 | 
         
            +
                    self.block = nn.Sequential(
         
     | 
| 48 | 
         
            +
                        ResidualUnit(dim // 2, dilation=1),
         
     | 
| 49 | 
         
            +
                        ResidualUnit(dim // 2, dilation=3),
         
     | 
| 50 | 
         
            +
                        ResidualUnit(dim // 2, dilation=9),
         
     | 
| 51 | 
         
            +
                        Snake1d(dim // 2),
         
     | 
| 52 | 
         
            +
                        WNConv1d(
         
     | 
| 53 | 
         
            +
                            dim // 2,
         
     | 
| 54 | 
         
            +
                            dim,
         
     | 
| 55 | 
         
            +
                            kernel_size=2 * stride,
         
     | 
| 56 | 
         
            +
                            stride=stride,
         
     | 
| 57 | 
         
            +
                            padding=math.ceil(stride / 2),
         
     | 
| 58 | 
         
            +
                        ),
         
     | 
| 59 | 
         
            +
                    )
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                def forward(self, x):
         
     | 
| 62 | 
         
            +
                    return self.block(x)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            class Encoder(nn.Module):
         
     | 
| 66 | 
         
            +
                def __init__(
         
     | 
| 67 | 
         
            +
                    self,
         
     | 
| 68 | 
         
            +
                    d_model: int = 64,
         
     | 
| 69 | 
         
            +
                    strides: list = [2, 4, 8, 8],
         
     | 
| 70 | 
         
            +
                    d_latent: int = 64,
         
     | 
| 71 | 
         
            +
                ):
         
     | 
| 72 | 
         
            +
                    super().__init__()
         
     | 
| 73 | 
         
            +
                    # Create first convolution
         
     | 
| 74 | 
         
            +
                    self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    # Create EncoderBlocks that double channels as they downsample by `stride`
         
     | 
| 77 | 
         
            +
                    for stride in strides:
         
     | 
| 78 | 
         
            +
                        d_model *= 2
         
     | 
| 79 | 
         
            +
                        self.block += [EncoderBlock(d_model, stride=stride)]
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    # Create last convolution
         
     | 
| 82 | 
         
            +
                    self.block += [
         
     | 
| 83 | 
         
            +
                        Snake1d(d_model),
         
     | 
| 84 | 
         
            +
                        WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
         
     | 
| 85 | 
         
            +
                    ]
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    # Wrap black into nn.Sequential
         
     | 
| 88 | 
         
            +
                    self.block = nn.Sequential(*self.block)
         
     | 
| 89 | 
         
            +
                    self.enc_dim = d_model
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                def forward(self, x):
         
     | 
| 92 | 
         
            +
                    return self.block(x)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            class DecoderBlock(nn.Module):
         
     | 
| 96 | 
         
            +
                def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
         
     | 
| 97 | 
         
            +
                    super().__init__()
         
     | 
| 98 | 
         
            +
                    self.block = nn.Sequential(
         
     | 
| 99 | 
         
            +
                        Snake1d(input_dim),
         
     | 
| 100 | 
         
            +
                        WNConvTranspose1d(
         
     | 
| 101 | 
         
            +
                            input_dim,
         
     | 
| 102 | 
         
            +
                            output_dim,
         
     | 
| 103 | 
         
            +
                            kernel_size=2 * stride,
         
     | 
| 104 | 
         
            +
                            stride=stride,
         
     | 
| 105 | 
         
            +
                            padding=math.ceil(stride / 2),
         
     | 
| 106 | 
         
            +
                            output_padding=stride % 2,
         
     | 
| 107 | 
         
            +
                        ),
         
     | 
| 108 | 
         
            +
                        ResidualUnit(output_dim, dilation=1),
         
     | 
| 109 | 
         
            +
                        ResidualUnit(output_dim, dilation=3),
         
     | 
| 110 | 
         
            +
                        ResidualUnit(output_dim, dilation=9),
         
     | 
| 111 | 
         
            +
                    )
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                def forward(self, x):
         
     | 
| 114 | 
         
            +
                    return self.block(x)
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            class Decoder(nn.Module):
         
     | 
| 118 | 
         
            +
                def __init__(
         
     | 
| 119 | 
         
            +
                    self,
         
     | 
| 120 | 
         
            +
                    input_channel,
         
     | 
| 121 | 
         
            +
                    channels,
         
     | 
| 122 | 
         
            +
                    rates,
         
     | 
| 123 | 
         
            +
                    d_out: int = 1,
         
     | 
| 124 | 
         
            +
                ):
         
     | 
| 125 | 
         
            +
                    super().__init__()
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    # Add first conv layer
         
     | 
| 128 | 
         
            +
                    layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    # Add upsampling + MRF blocks
         
     | 
| 131 | 
         
            +
                    for i, stride in enumerate(rates):
         
     | 
| 132 | 
         
            +
                        input_dim = channels // 2**i
         
     | 
| 133 | 
         
            +
                        output_dim = channels // 2 ** (i + 1)
         
     | 
| 134 | 
         
            +
                        layers += [DecoderBlock(input_dim, output_dim, stride)]
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    # Add final conv layer
         
     | 
| 137 | 
         
            +
                    layers += [
         
     | 
| 138 | 
         
            +
                        Snake1d(output_dim),
         
     | 
| 139 | 
         
            +
                        WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
         
     | 
| 140 | 
         
            +
                        nn.Tanh(),
         
     | 
| 141 | 
         
            +
                    ]
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    self.model = nn.Sequential(*layers)
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                def forward(self, x):
         
     | 
| 146 | 
         
            +
                    return self.model(x)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
            class DAC(BaseModel, CodecMixin):
         
     | 
| 150 | 
         
            +
                def __init__(
         
     | 
| 151 | 
         
            +
                    self,
         
     | 
| 152 | 
         
            +
                    encoder_dim: int = 64,
         
     | 
| 153 | 
         
            +
                    encoder_rates: List[int] = [2, 4, 8, 8],
         
     | 
| 154 | 
         
            +
                    latent_dim: int = None,
         
     | 
| 155 | 
         
            +
                    decoder_dim: int = 1536,
         
     | 
| 156 | 
         
            +
                    decoder_rates: List[int] = [8, 8, 4, 2],
         
     | 
| 157 | 
         
            +
                    n_codebooks: int = 9,
         
     | 
| 158 | 
         
            +
                    codebook_size: int = 1024,
         
     | 
| 159 | 
         
            +
                    codebook_dim: Union[int, list] = 8,
         
     | 
| 160 | 
         
            +
                    quantizer_dropout: bool = False,
         
     | 
| 161 | 
         
            +
                    sample_rate: int = 44100,
         
     | 
| 162 | 
         
            +
                    continuous: bool = False,
         
     | 
| 163 | 
         
            +
                ):
         
     | 
| 164 | 
         
            +
                    super().__init__()
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    self.encoder_dim = encoder_dim
         
     | 
| 167 | 
         
            +
                    self.encoder_rates = encoder_rates
         
     | 
| 168 | 
         
            +
                    self.decoder_dim = decoder_dim
         
     | 
| 169 | 
         
            +
                    self.decoder_rates = decoder_rates
         
     | 
| 170 | 
         
            +
                    self.sample_rate = sample_rate
         
     | 
| 171 | 
         
            +
                    self.continuous = continuous
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                    if latent_dim is None:
         
     | 
| 174 | 
         
            +
                        latent_dim = encoder_dim * (2 ** len(encoder_rates))
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
                    self.latent_dim = latent_dim
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    self.hop_length = np.prod(encoder_rates)
         
     | 
| 179 | 
         
            +
                    self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                    if not continuous:
         
     | 
| 182 | 
         
            +
                        self.n_codebooks = n_codebooks
         
     | 
| 183 | 
         
            +
                        self.codebook_size = codebook_size
         
     | 
| 184 | 
         
            +
                        self.codebook_dim = codebook_dim
         
     | 
| 185 | 
         
            +
                        self.quantizer = ResidualVectorQuantize(
         
     | 
| 186 | 
         
            +
                            input_dim=latent_dim,
         
     | 
| 187 | 
         
            +
                            n_codebooks=n_codebooks,
         
     | 
| 188 | 
         
            +
                            codebook_size=codebook_size,
         
     | 
| 189 | 
         
            +
                            codebook_dim=codebook_dim,
         
     | 
| 190 | 
         
            +
                            quantizer_dropout=quantizer_dropout,
         
     | 
| 191 | 
         
            +
                        )
         
     | 
| 192 | 
         
            +
                    else:
         
     | 
| 193 | 
         
            +
                        self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
         
     | 
| 194 | 
         
            +
                        self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    self.decoder = Decoder(
         
     | 
| 197 | 
         
            +
                        latent_dim,
         
     | 
| 198 | 
         
            +
                        decoder_dim,
         
     | 
| 199 | 
         
            +
                        decoder_rates,
         
     | 
| 200 | 
         
            +
                    )
         
     | 
| 201 | 
         
            +
                    self.sample_rate = sample_rate
         
     | 
| 202 | 
         
            +
                    self.apply(init_weights)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    self.delay = self.get_delay()
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                @property
         
     | 
| 207 | 
         
            +
                def dtype(self):
         
     | 
| 208 | 
         
            +
                    """Get the dtype of the model parameters."""
         
     | 
| 209 | 
         
            +
                    # Return the dtype of the first parameter found
         
     | 
| 210 | 
         
            +
                    for param in self.parameters():
         
     | 
| 211 | 
         
            +
                        return param.dtype
         
     | 
| 212 | 
         
            +
                    return torch.float32  # fallback
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                @property
         
     | 
| 215 | 
         
            +
                def device(self):
         
     | 
| 216 | 
         
            +
                    """Get the device of the model parameters."""
         
     | 
| 217 | 
         
            +
                    # Return the device of the first parameter found
         
     | 
| 218 | 
         
            +
                    for param in self.parameters():
         
     | 
| 219 | 
         
            +
                        return param.device
         
     | 
| 220 | 
         
            +
                    return torch.device('cpu')  # fallback
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                def preprocess(self, audio_data, sample_rate):
         
     | 
| 223 | 
         
            +
                    if sample_rate is None:
         
     | 
| 224 | 
         
            +
                        sample_rate = self.sample_rate
         
     | 
| 225 | 
         
            +
                    assert sample_rate == self.sample_rate
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                    length = audio_data.shape[-1]
         
     | 
| 228 | 
         
            +
                    right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
         
     | 
| 229 | 
         
            +
                    audio_data = nn.functional.pad(audio_data, (0, right_pad))
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    return audio_data
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                def encode(
         
     | 
| 234 | 
         
            +
                    self,
         
     | 
| 235 | 
         
            +
                    audio_data: torch.Tensor,
         
     | 
| 236 | 
         
            +
                    n_quantizers: int = None,
         
     | 
| 237 | 
         
            +
                ):
         
     | 
| 238 | 
         
            +
                    """Encode given audio data and return quantized latent codes
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    Parameters
         
     | 
| 241 | 
         
            +
                    ----------
         
     | 
| 242 | 
         
            +
                    audio_data : Tensor[B x 1 x T]
         
     | 
| 243 | 
         
            +
                        Audio data to encode
         
     | 
| 244 | 
         
            +
                    n_quantizers : int, optional
         
     | 
| 245 | 
         
            +
                        Number of quantizers to use, by default None
         
     | 
| 246 | 
         
            +
                        If None, all quantizers are used.
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    Returns
         
     | 
| 249 | 
         
            +
                    -------
         
     | 
| 250 | 
         
            +
                    dict
         
     | 
| 251 | 
         
            +
                        A dictionary with the following keys:
         
     | 
| 252 | 
         
            +
                        "z" : Tensor[B x D x T]
         
     | 
| 253 | 
         
            +
                            Quantized continuous representation of input
         
     | 
| 254 | 
         
            +
                        "codes" : Tensor[B x N x T]
         
     | 
| 255 | 
         
            +
                            Codebook indices for each codebook
         
     | 
| 256 | 
         
            +
                            (quantized discrete representation of input)
         
     | 
| 257 | 
         
            +
                        "latents" : Tensor[B x N*D x T]
         
     | 
| 258 | 
         
            +
                            Projected latents (continuous representation of input before quantization)
         
     | 
| 259 | 
         
            +
                        "vq/commitment_loss" : Tensor[1]
         
     | 
| 260 | 
         
            +
                            Commitment loss to train encoder to predict vectors closer to codebook
         
     | 
| 261 | 
         
            +
                            entries
         
     | 
| 262 | 
         
            +
                        "vq/codebook_loss" : Tensor[1]
         
     | 
| 263 | 
         
            +
                            Codebook loss to update the codebook
         
     | 
| 264 | 
         
            +
                        "length" : int
         
     | 
| 265 | 
         
            +
                            Number of samples in input audio
         
     | 
| 266 | 
         
            +
                    """
         
     | 
| 267 | 
         
            +
                    z = self.encoder(audio_data)  # [B x D x T]
         
     | 
| 268 | 
         
            +
                    if not self.continuous:
         
     | 
| 269 | 
         
            +
                        z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
         
     | 
| 270 | 
         
            +
                    else:
         
     | 
| 271 | 
         
            +
                        z = self.quant_conv(z)  # [B x 2D x T]
         
     | 
| 272 | 
         
            +
                        z = DiagonalGaussianDistribution(z)
         
     | 
| 273 | 
         
            +
                        codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    return z, codes, latents, commitment_loss, codebook_loss
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                def decode(self, z: torch.Tensor):
         
     | 
| 278 | 
         
            +
                    """Decode given latent codes and return audio data
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    Parameters
         
     | 
| 281 | 
         
            +
                    ----------
         
     | 
| 282 | 
         
            +
                    z : Tensor[B x D x T]
         
     | 
| 283 | 
         
            +
                        Quantized continuous representation of input
         
     | 
| 284 | 
         
            +
                    length : int, optional
         
     | 
| 285 | 
         
            +
                        Number of samples in output audio, by default None
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                    Returns
         
     | 
| 288 | 
         
            +
                    -------
         
     | 
| 289 | 
         
            +
                    dict
         
     | 
| 290 | 
         
            +
                        A dictionary with the following keys:
         
     | 
| 291 | 
         
            +
                        "audio" : Tensor[B x 1 x length]
         
     | 
| 292 | 
         
            +
                            Decoded audio data.
         
     | 
| 293 | 
         
            +
                    """
         
     | 
| 294 | 
         
            +
                    if not self.continuous:
         
     | 
| 295 | 
         
            +
                        audio = self.decoder(z)
         
     | 
| 296 | 
         
            +
                    else:
         
     | 
| 297 | 
         
            +
                        z = self.post_quant_conv(z)
         
     | 
| 298 | 
         
            +
                        audio = self.decoder(z)
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                    return audio
         
     | 
| 301 | 
         
            +
             
     | 
| 302 | 
         
            +
                def forward(
         
     | 
| 303 | 
         
            +
                    self,
         
     | 
| 304 | 
         
            +
                    audio_data: torch.Tensor,
         
     | 
| 305 | 
         
            +
                    sample_rate: int = None,
         
     | 
| 306 | 
         
            +
                    n_quantizers: int = None,
         
     | 
| 307 | 
         
            +
                ):
         
     | 
| 308 | 
         
            +
                    """Model forward pass
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                    Parameters
         
     | 
| 311 | 
         
            +
                    ----------
         
     | 
| 312 | 
         
            +
                    audio_data : Tensor[B x 1 x T]
         
     | 
| 313 | 
         
            +
                        Audio data to encode
         
     | 
| 314 | 
         
            +
                    sample_rate : int, optional
         
     | 
| 315 | 
         
            +
                        Sample rate of audio data in Hz, by default None
         
     | 
| 316 | 
         
            +
                        If None, defaults to `self.sample_rate`
         
     | 
| 317 | 
         
            +
                    n_quantizers : int, optional
         
     | 
| 318 | 
         
            +
                        Number of quantizers to use, by default None.
         
     | 
| 319 | 
         
            +
                        If None, all quantizers are used.
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                    Returns
         
     | 
| 322 | 
         
            +
                    -------
         
     | 
| 323 | 
         
            +
                    dict
         
     | 
| 324 | 
         
            +
                        A dictionary with the following keys:
         
     | 
| 325 | 
         
            +
                        "z" : Tensor[B x D x T]
         
     | 
| 326 | 
         
            +
                            Quantized continuous representation of input
         
     | 
| 327 | 
         
            +
                        "codes" : Tensor[B x N x T]
         
     | 
| 328 | 
         
            +
                            Codebook indices for each codebook
         
     | 
| 329 | 
         
            +
                            (quantized discrete representation of input)
         
     | 
| 330 | 
         
            +
                        "latents" : Tensor[B x N*D x T]
         
     | 
| 331 | 
         
            +
                            Projected latents (continuous representation of input before quantization)
         
     | 
| 332 | 
         
            +
                        "vq/commitment_loss" : Tensor[1]
         
     | 
| 333 | 
         
            +
                            Commitment loss to train encoder to predict vectors closer to codebook
         
     | 
| 334 | 
         
            +
                            entries
         
     | 
| 335 | 
         
            +
                        "vq/codebook_loss" : Tensor[1]
         
     | 
| 336 | 
         
            +
                            Codebook loss to update the codebook
         
     | 
| 337 | 
         
            +
                        "length" : int
         
     | 
| 338 | 
         
            +
                            Number of samples in input audio
         
     | 
| 339 | 
         
            +
                        "audio" : Tensor[B x 1 x length]
         
     | 
| 340 | 
         
            +
                            Decoded audio data.
         
     | 
| 341 | 
         
            +
                    """
         
     | 
| 342 | 
         
            +
                    length = audio_data.shape[-1]
         
     | 
| 343 | 
         
            +
                    audio_data = self.preprocess(audio_data, sample_rate)
         
     | 
| 344 | 
         
            +
                    if not self.continuous:
         
     | 
| 345 | 
         
            +
                        z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                        x = self.decode(z)
         
     | 
| 348 | 
         
            +
                        return {
         
     | 
| 349 | 
         
            +
                            "audio": x[..., :length],
         
     | 
| 350 | 
         
            +
                            "z": z,
         
     | 
| 351 | 
         
            +
                            "codes": codes,
         
     | 
| 352 | 
         
            +
                            "latents": latents,
         
     | 
| 353 | 
         
            +
                            "vq/commitment_loss": commitment_loss,
         
     | 
| 354 | 
         
            +
                            "vq/codebook_loss": codebook_loss,
         
     | 
| 355 | 
         
            +
                        }
         
     | 
| 356 | 
         
            +
                    else:
         
     | 
| 357 | 
         
            +
                        posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
         
     | 
| 358 | 
         
            +
                        z = posterior.sample()
         
     | 
| 359 | 
         
            +
                        x = self.decode(z)
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                        kl_loss = posterior.kl()
         
     | 
| 362 | 
         
            +
                        kl_loss = kl_loss.mean()
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                        return {
         
     | 
| 365 | 
         
            +
                            "audio": x[..., :length],
         
     | 
| 366 | 
         
            +
                            "z": z,
         
     | 
| 367 | 
         
            +
                            "kl_loss": kl_loss,
         
     | 
| 368 | 
         
            +
                        }
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
             
     | 
| 371 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 372 | 
         
            +
                import numpy as np
         
     | 
| 373 | 
         
            +
                from functools import partial
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                model = DAC().to("cpu")
         
     | 
| 376 | 
         
            +
             
     | 
| 377 | 
         
            +
                for n, m in model.named_modules():
         
     | 
| 378 | 
         
            +
                    o = m.extra_repr()
         
     | 
| 379 | 
         
            +
                    p = sum([np.prod(p.size()) for p in m.parameters()])
         
     | 
| 380 | 
         
            +
                    fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
         
     | 
| 381 | 
         
            +
                    setattr(m, "extra_repr", partial(fn, o=o, p=p))
         
     | 
| 382 | 
         
            +
                print(model)
         
     | 
| 383 | 
         
            +
                print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                length = 88200 * 2
         
     | 
| 386 | 
         
            +
                x = torch.randn(1, 1, length).to(model.device)
         
     | 
| 387 | 
         
            +
                x.requires_grad_(True)
         
     | 
| 388 | 
         
            +
                x.retain_grad()
         
     | 
| 389 | 
         
            +
             
     | 
| 390 | 
         
            +
                # Make a forward pass
         
     | 
| 391 | 
         
            +
                out = model(x)["audio"]
         
     | 
| 392 | 
         
            +
                print("Input shape:", x.shape)
         
     | 
| 393 | 
         
            +
                print("Output shape:", out.shape)
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                # Create gradient variable
         
     | 
| 396 | 
         
            +
                grad = torch.zeros_like(out)
         
     | 
| 397 | 
         
            +
                grad[:, :, grad.shape[-1] // 2] = 1
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                # Make a backward pass
         
     | 
| 400 | 
         
            +
                out.backward(grad)
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                # Check non-zero values
         
     | 
| 403 | 
         
            +
                gradmap = x.grad.squeeze(0)
         
     | 
| 404 | 
         
            +
                gradmap = (gradmap != 0).sum(0)  # sum across features
         
     | 
| 405 | 
         
            +
                rf = (gradmap != 0).sum()
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                print(f"Receptive field: {rf.item()}")
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
                x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
         
     | 
| 410 | 
         
            +
                model.decompress(model.compress(x, verbose=True), verbose=True)
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/model/discriminator.py
    ADDED
    
    | 
         @@ -0,0 +1,228 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 4 | 
         
            +
            from audiotools import AudioSignal
         
     | 
| 5 | 
         
            +
            from audiotools import ml
         
     | 
| 6 | 
         
            +
            from audiotools import STFTParams
         
     | 
| 7 | 
         
            +
            from einops import rearrange
         
     | 
| 8 | 
         
            +
            from torch.nn.utils import weight_norm
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            def WNConv1d(*args, **kwargs):
         
     | 
| 12 | 
         
            +
                act = kwargs.pop("act", True)
         
     | 
| 13 | 
         
            +
                conv = weight_norm(nn.Conv1d(*args, **kwargs))
         
     | 
| 14 | 
         
            +
                if not act:
         
     | 
| 15 | 
         
            +
                    return conv
         
     | 
| 16 | 
         
            +
                return nn.Sequential(conv, nn.LeakyReLU(0.1))
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def WNConv2d(*args, **kwargs):
         
     | 
| 20 | 
         
            +
                act = kwargs.pop("act", True)
         
     | 
| 21 | 
         
            +
                conv = weight_norm(nn.Conv2d(*args, **kwargs))
         
     | 
| 22 | 
         
            +
                if not act:
         
     | 
| 23 | 
         
            +
                    return conv
         
     | 
| 24 | 
         
            +
                return nn.Sequential(conv, nn.LeakyReLU(0.1))
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            class MPD(nn.Module):
         
     | 
| 28 | 
         
            +
                def __init__(self, period):
         
     | 
| 29 | 
         
            +
                    super().__init__()
         
     | 
| 30 | 
         
            +
                    self.period = period
         
     | 
| 31 | 
         
            +
                    self.convs = nn.ModuleList(
         
     | 
| 32 | 
         
            +
                        [
         
     | 
| 33 | 
         
            +
                            WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
         
     | 
| 34 | 
         
            +
                            WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
         
     | 
| 35 | 
         
            +
                            WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
         
     | 
| 36 | 
         
            +
                            WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
         
     | 
| 37 | 
         
            +
                            WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
         
     | 
| 38 | 
         
            +
                        ]
         
     | 
| 39 | 
         
            +
                    )
         
     | 
| 40 | 
         
            +
                    self.conv_post = WNConv2d(
         
     | 
| 41 | 
         
            +
                        1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
         
     | 
| 42 | 
         
            +
                    )
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                def pad_to_period(self, x):
         
     | 
| 45 | 
         
            +
                    t = x.shape[-1]
         
     | 
| 46 | 
         
            +
                    x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
         
     | 
| 47 | 
         
            +
                    return x
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                def forward(self, x):
         
     | 
| 50 | 
         
            +
                    fmap = []
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    x = self.pad_to_period(x)
         
     | 
| 53 | 
         
            +
                    x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    for layer in self.convs:
         
     | 
| 56 | 
         
            +
                        x = layer(x)
         
     | 
| 57 | 
         
            +
                        fmap.append(x)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    x = self.conv_post(x)
         
     | 
| 60 | 
         
            +
                    fmap.append(x)
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    return fmap
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            class MSD(nn.Module):
         
     | 
| 66 | 
         
            +
                def __init__(self, rate: int = 1, sample_rate: int = 44100):
         
     | 
| 67 | 
         
            +
                    super().__init__()
         
     | 
| 68 | 
         
            +
                    self.convs = nn.ModuleList(
         
     | 
| 69 | 
         
            +
                        [
         
     | 
| 70 | 
         
            +
                            WNConv1d(1, 16, 15, 1, padding=7),
         
     | 
| 71 | 
         
            +
                            WNConv1d(16, 64, 41, 4, groups=4, padding=20),
         
     | 
| 72 | 
         
            +
                            WNConv1d(64, 256, 41, 4, groups=16, padding=20),
         
     | 
| 73 | 
         
            +
                            WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
         
     | 
| 74 | 
         
            +
                            WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
         
     | 
| 75 | 
         
            +
                            WNConv1d(1024, 1024, 5, 1, padding=2),
         
     | 
| 76 | 
         
            +
                        ]
         
     | 
| 77 | 
         
            +
                    )
         
     | 
| 78 | 
         
            +
                    self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
         
     | 
| 79 | 
         
            +
                    self.sample_rate = sample_rate
         
     | 
| 80 | 
         
            +
                    self.rate = rate
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                def forward(self, x):
         
     | 
| 83 | 
         
            +
                    x = AudioSignal(x, self.sample_rate)
         
     | 
| 84 | 
         
            +
                    x.resample(self.sample_rate // self.rate)
         
     | 
| 85 | 
         
            +
                    x = x.audio_data
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    fmap = []
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    for l in self.convs:
         
     | 
| 90 | 
         
            +
                        x = l(x)
         
     | 
| 91 | 
         
            +
                        fmap.append(x)
         
     | 
| 92 | 
         
            +
                    x = self.conv_post(x)
         
     | 
| 93 | 
         
            +
                    fmap.append(x)
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                    return fmap
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            class MRD(nn.Module):
         
     | 
| 102 | 
         
            +
                def __init__(
         
     | 
| 103 | 
         
            +
                    self,
         
     | 
| 104 | 
         
            +
                    window_length: int,
         
     | 
| 105 | 
         
            +
                    hop_factor: float = 0.25,
         
     | 
| 106 | 
         
            +
                    sample_rate: int = 44100,
         
     | 
| 107 | 
         
            +
                    bands: list = BANDS,
         
     | 
| 108 | 
         
            +
                ):
         
     | 
| 109 | 
         
            +
                    """Complex multi-band spectrogram discriminator.
         
     | 
| 110 | 
         
            +
                    Parameters
         
     | 
| 111 | 
         
            +
                    ----------
         
     | 
| 112 | 
         
            +
                    window_length : int
         
     | 
| 113 | 
         
            +
                        Window length of STFT.
         
     | 
| 114 | 
         
            +
                    hop_factor : float, optional
         
     | 
| 115 | 
         
            +
                        Hop factor of the STFT, defaults to ``0.25 * window_length``.
         
     | 
| 116 | 
         
            +
                    sample_rate : int, optional
         
     | 
| 117 | 
         
            +
                        Sampling rate of audio in Hz, by default 44100
         
     | 
| 118 | 
         
            +
                    bands : list, optional
         
     | 
| 119 | 
         
            +
                        Bands to run discriminator over.
         
     | 
| 120 | 
         
            +
                    """
         
     | 
| 121 | 
         
            +
                    super().__init__()
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    self.window_length = window_length
         
     | 
| 124 | 
         
            +
                    self.hop_factor = hop_factor
         
     | 
| 125 | 
         
            +
                    self.sample_rate = sample_rate
         
     | 
| 126 | 
         
            +
                    self.stft_params = STFTParams(
         
     | 
| 127 | 
         
            +
                        window_length=window_length,
         
     | 
| 128 | 
         
            +
                        hop_length=int(window_length * hop_factor),
         
     | 
| 129 | 
         
            +
                        match_stride=True,
         
     | 
| 130 | 
         
            +
                    )
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    n_fft = window_length // 2 + 1
         
     | 
| 133 | 
         
            +
                    bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
         
     | 
| 134 | 
         
            +
                    self.bands = bands
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    ch = 32
         
     | 
| 137 | 
         
            +
                    convs = lambda: nn.ModuleList(
         
     | 
| 138 | 
         
            +
                        [
         
     | 
| 139 | 
         
            +
                            WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
         
     | 
| 140 | 
         
            +
                            WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
         
     | 
| 141 | 
         
            +
                            WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
         
     | 
| 142 | 
         
            +
                            WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
         
     | 
| 143 | 
         
            +
                            WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
         
     | 
| 144 | 
         
            +
                        ]
         
     | 
| 145 | 
         
            +
                    )
         
     | 
| 146 | 
         
            +
                    self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
         
     | 
| 147 | 
         
            +
                    self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                def spectrogram(self, x):
         
     | 
| 150 | 
         
            +
                    x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
         
     | 
| 151 | 
         
            +
                    x = torch.view_as_real(x.stft())
         
     | 
| 152 | 
         
            +
                    x = rearrange(x, "b 1 f t c -> (b 1) c t f")
         
     | 
| 153 | 
         
            +
                    # Split into bands
         
     | 
| 154 | 
         
            +
                    x_bands = [x[..., b[0] : b[1]] for b in self.bands]
         
     | 
| 155 | 
         
            +
                    return x_bands
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                def forward(self, x):
         
     | 
| 158 | 
         
            +
                    x_bands = self.spectrogram(x)
         
     | 
| 159 | 
         
            +
                    fmap = []
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    x = []
         
     | 
| 162 | 
         
            +
                    for band, stack in zip(x_bands, self.band_convs):
         
     | 
| 163 | 
         
            +
                        for layer in stack:
         
     | 
| 164 | 
         
            +
                            band = layer(band)
         
     | 
| 165 | 
         
            +
                            fmap.append(band)
         
     | 
| 166 | 
         
            +
                        x.append(band)
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    x = torch.cat(x, dim=-1)
         
     | 
| 169 | 
         
            +
                    x = self.conv_post(x)
         
     | 
| 170 | 
         
            +
                    fmap.append(x)
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                    return fmap
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
            class Discriminator(ml.BaseModel):
         
     | 
| 176 | 
         
            +
                def __init__(
         
     | 
| 177 | 
         
            +
                    self,
         
     | 
| 178 | 
         
            +
                    rates: list = [],
         
     | 
| 179 | 
         
            +
                    periods: list = [2, 3, 5, 7, 11],
         
     | 
| 180 | 
         
            +
                    fft_sizes: list = [2048, 1024, 512],
         
     | 
| 181 | 
         
            +
                    sample_rate: int = 44100,
         
     | 
| 182 | 
         
            +
                    bands: list = BANDS,
         
     | 
| 183 | 
         
            +
                ):
         
     | 
| 184 | 
         
            +
                    """Discriminator that combines multiple discriminators.
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    Parameters
         
     | 
| 187 | 
         
            +
                    ----------
         
     | 
| 188 | 
         
            +
                    rates : list, optional
         
     | 
| 189 | 
         
            +
                        sampling rates (in Hz) to run MSD at, by default []
         
     | 
| 190 | 
         
            +
                        If empty, MSD is not used.
         
     | 
| 191 | 
         
            +
                    periods : list, optional
         
     | 
| 192 | 
         
            +
                        periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
         
     | 
| 193 | 
         
            +
                    fft_sizes : list, optional
         
     | 
| 194 | 
         
            +
                        Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
         
     | 
| 195 | 
         
            +
                    sample_rate : int, optional
         
     | 
| 196 | 
         
            +
                        Sampling rate of audio in Hz, by default 44100
         
     | 
| 197 | 
         
            +
                    bands : list, optional
         
     | 
| 198 | 
         
            +
                        Bands to run MRD at, by default `BANDS`
         
     | 
| 199 | 
         
            +
                    """
         
     | 
| 200 | 
         
            +
                    super().__init__()
         
     | 
| 201 | 
         
            +
                    discs = []
         
     | 
| 202 | 
         
            +
                    discs += [MPD(p) for p in periods]
         
     | 
| 203 | 
         
            +
                    discs += [MSD(r, sample_rate=sample_rate) for r in rates]
         
     | 
| 204 | 
         
            +
                    discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
         
     | 
| 205 | 
         
            +
                    self.discriminators = nn.ModuleList(discs)
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                def preprocess(self, y):
         
     | 
| 208 | 
         
            +
                    # Remove DC offset
         
     | 
| 209 | 
         
            +
                    y = y - y.mean(dim=-1, keepdims=True)
         
     | 
| 210 | 
         
            +
                    # Peak normalize the volume of input audio
         
     | 
| 211 | 
         
            +
                    y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
         
     | 
| 212 | 
         
            +
                    return y
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                def forward(self, x):
         
     | 
| 215 | 
         
            +
                    x = self.preprocess(x)
         
     | 
| 216 | 
         
            +
                    fmaps = [d(x) for d in self.discriminators]
         
     | 
| 217 | 
         
            +
                    return fmaps
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 221 | 
         
            +
                disc = Discriminator()
         
     | 
| 222 | 
         
            +
                x = torch.zeros(1, 1, 44100)
         
     | 
| 223 | 
         
            +
                results = disc(x)
         
     | 
| 224 | 
         
            +
                for i, result in enumerate(results):
         
     | 
| 225 | 
         
            +
                    print(f"disc{i}")
         
     | 
| 226 | 
         
            +
                    for i, r in enumerate(result):
         
     | 
| 227 | 
         
            +
                        print(r.shape, r.mean(), r.min(), r.max())
         
     | 
| 228 | 
         
            +
                    print()
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/nn/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from . import layers
         
     | 
| 2 | 
         
            +
            from . import loss
         
     | 
| 3 | 
         
            +
            from . import quantize
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/nn/layers.py
    ADDED
    
    | 
         @@ -0,0 +1,33 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            from einops import rearrange
         
     | 
| 6 | 
         
            +
            from torch.nn.utils import weight_norm
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def WNConv1d(*args, **kwargs):
         
     | 
| 10 | 
         
            +
                return weight_norm(nn.Conv1d(*args, **kwargs))
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            def WNConvTranspose1d(*args, **kwargs):
         
     | 
| 14 | 
         
            +
                return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            # Scripting this brings model speed up 1.4x
         
     | 
| 18 | 
         
            +
            @torch.jit.script
         
     | 
| 19 | 
         
            +
            def snake(x, alpha):
         
     | 
| 20 | 
         
            +
                shape = x.shape
         
     | 
| 21 | 
         
            +
                x = x.reshape(shape[0], shape[1], -1)
         
     | 
| 22 | 
         
            +
                x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
         
     | 
| 23 | 
         
            +
                x = x.reshape(shape)
         
     | 
| 24 | 
         
            +
                return x
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            class Snake1d(nn.Module):
         
     | 
| 28 | 
         
            +
                def __init__(self, channels):
         
     | 
| 29 | 
         
            +
                    super().__init__()
         
     | 
| 30 | 
         
            +
                    self.alpha = nn.Parameter(torch.ones(1, channels, 1))
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def forward(self, x):
         
     | 
| 33 | 
         
            +
                    return snake(x, self.alpha)
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/nn/loss.py
    ADDED
    
    | 
         @@ -0,0 +1,368 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import typing
         
     | 
| 2 | 
         
            +
            from typing import List
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
            from audiotools import AudioSignal
         
     | 
| 7 | 
         
            +
            from audiotools import STFTParams
         
     | 
| 8 | 
         
            +
            from torch import nn
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            class L1Loss(nn.L1Loss):
         
     | 
| 12 | 
         
            +
                """L1 Loss between AudioSignals. Defaults
         
     | 
| 13 | 
         
            +
                to comparing ``audio_data``, but any
         
     | 
| 14 | 
         
            +
                attribute of an AudioSignal can be used.
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                Parameters
         
     | 
| 17 | 
         
            +
                ----------
         
     | 
| 18 | 
         
            +
                attribute : str, optional
         
     | 
| 19 | 
         
            +
                    Attribute of signal to compare, defaults to ``audio_data``.
         
     | 
| 20 | 
         
            +
                weight : float, optional
         
     | 
| 21 | 
         
            +
                    Weight of this loss, defaults to 1.0.
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
         
     | 
| 24 | 
         
            +
                """
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
         
     | 
| 27 | 
         
            +
                    self.attribute = attribute
         
     | 
| 28 | 
         
            +
                    self.weight = weight
         
     | 
| 29 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                def forward(self, x: AudioSignal, y: AudioSignal):
         
     | 
| 32 | 
         
            +
                    """
         
     | 
| 33 | 
         
            +
                    Parameters
         
     | 
| 34 | 
         
            +
                    ----------
         
     | 
| 35 | 
         
            +
                    x : AudioSignal
         
     | 
| 36 | 
         
            +
                        Estimate AudioSignal
         
     | 
| 37 | 
         
            +
                    y : AudioSignal
         
     | 
| 38 | 
         
            +
                        Reference AudioSignal
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    Returns
         
     | 
| 41 | 
         
            +
                    -------
         
     | 
| 42 | 
         
            +
                    torch.Tensor
         
     | 
| 43 | 
         
            +
                        L1 loss between AudioSignal attributes.
         
     | 
| 44 | 
         
            +
                    """
         
     | 
| 45 | 
         
            +
                    if isinstance(x, AudioSignal):
         
     | 
| 46 | 
         
            +
                        x = getattr(x, self.attribute)
         
     | 
| 47 | 
         
            +
                        y = getattr(y, self.attribute)
         
     | 
| 48 | 
         
            +
                    return super().forward(x, y)
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            class SISDRLoss(nn.Module):
         
     | 
| 52 | 
         
            +
                """
         
     | 
| 53 | 
         
            +
                Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
         
     | 
| 54 | 
         
            +
                of estimated and reference audio signals or aligned features.
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                Parameters
         
     | 
| 57 | 
         
            +
                ----------
         
     | 
| 58 | 
         
            +
                scaling : int, optional
         
     | 
| 59 | 
         
            +
                    Whether to use scale-invariant (True) or
         
     | 
| 60 | 
         
            +
                    signal-to-noise ratio (False), by default True
         
     | 
| 61 | 
         
            +
                reduction : str, optional
         
     | 
| 62 | 
         
            +
                    How to reduce across the batch (either 'mean',
         
     | 
| 63 | 
         
            +
                    'sum', or none).], by default ' mean'
         
     | 
| 64 | 
         
            +
                zero_mean : int, optional
         
     | 
| 65 | 
         
            +
                    Zero mean the references and estimates before
         
     | 
| 66 | 
         
            +
                    computing the loss, by default True
         
     | 
| 67 | 
         
            +
                clip_min : int, optional
         
     | 
| 68 | 
         
            +
                    The minimum possible loss value. Helps network
         
     | 
| 69 | 
         
            +
                    to not focus on making already good examples better, by default None
         
     | 
| 70 | 
         
            +
                weight : float, optional
         
     | 
| 71 | 
         
            +
                    Weight of this loss, defaults to 1.0.
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
         
     | 
| 74 | 
         
            +
                """
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                def __init__(
         
     | 
| 77 | 
         
            +
                    self,
         
     | 
| 78 | 
         
            +
                    scaling: int = True,
         
     | 
| 79 | 
         
            +
                    reduction: str = "mean",
         
     | 
| 80 | 
         
            +
                    zero_mean: int = True,
         
     | 
| 81 | 
         
            +
                    clip_min: int = None,
         
     | 
| 82 | 
         
            +
                    weight: float = 1.0,
         
     | 
| 83 | 
         
            +
                ):
         
     | 
| 84 | 
         
            +
                    self.scaling = scaling
         
     | 
| 85 | 
         
            +
                    self.reduction = reduction
         
     | 
| 86 | 
         
            +
                    self.zero_mean = zero_mean
         
     | 
| 87 | 
         
            +
                    self.clip_min = clip_min
         
     | 
| 88 | 
         
            +
                    self.weight = weight
         
     | 
| 89 | 
         
            +
                    super().__init__()
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                def forward(self, x: AudioSignal, y: AudioSignal):
         
     | 
| 92 | 
         
            +
                    eps = 1e-8
         
     | 
| 93 | 
         
            +
                    # nb, nc, nt
         
     | 
| 94 | 
         
            +
                    if isinstance(x, AudioSignal):
         
     | 
| 95 | 
         
            +
                        references = x.audio_data
         
     | 
| 96 | 
         
            +
                        estimates = y.audio_data
         
     | 
| 97 | 
         
            +
                    else:
         
     | 
| 98 | 
         
            +
                        references = x
         
     | 
| 99 | 
         
            +
                        estimates = y
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    nb = references.shape[0]
         
     | 
| 102 | 
         
            +
                    references = references.reshape(nb, 1, -1).permute(0, 2, 1)
         
     | 
| 103 | 
         
            +
                    estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    # samples now on axis 1
         
     | 
| 106 | 
         
            +
                    if self.zero_mean:
         
     | 
| 107 | 
         
            +
                        mean_reference = references.mean(dim=1, keepdim=True)
         
     | 
| 108 | 
         
            +
                        mean_estimate = estimates.mean(dim=1, keepdim=True)
         
     | 
| 109 | 
         
            +
                    else:
         
     | 
| 110 | 
         
            +
                        mean_reference = 0
         
     | 
| 111 | 
         
            +
                        mean_estimate = 0
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    _references = references - mean_reference
         
     | 
| 114 | 
         
            +
                    _estimates = estimates - mean_estimate
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    references_projection = (_references**2).sum(dim=-2) + eps
         
     | 
| 117 | 
         
            +
                    references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    scale = (
         
     | 
| 120 | 
         
            +
                        (references_on_estimates / references_projection).unsqueeze(1)
         
     | 
| 121 | 
         
            +
                        if self.scaling
         
     | 
| 122 | 
         
            +
                        else 1
         
     | 
| 123 | 
         
            +
                    )
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    e_true = scale * _references
         
     | 
| 126 | 
         
            +
                    e_res = _estimates - e_true
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    signal = (e_true**2).sum(dim=1)
         
     | 
| 129 | 
         
            +
                    noise = (e_res**2).sum(dim=1)
         
     | 
| 130 | 
         
            +
                    sdr = -10 * torch.log10(signal / noise + eps)
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    if self.clip_min is not None:
         
     | 
| 133 | 
         
            +
                        sdr = torch.clamp(sdr, min=self.clip_min)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    if self.reduction == "mean":
         
     | 
| 136 | 
         
            +
                        sdr = sdr.mean()
         
     | 
| 137 | 
         
            +
                    elif self.reduction == "sum":
         
     | 
| 138 | 
         
            +
                        sdr = sdr.sum()
         
     | 
| 139 | 
         
            +
                    return sdr
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
            class MultiScaleSTFTLoss(nn.Module):
         
     | 
| 143 | 
         
            +
                """Computes the multi-scale STFT loss from [1].
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                Parameters
         
     | 
| 146 | 
         
            +
                ----------
         
     | 
| 147 | 
         
            +
                window_lengths : List[int], optional
         
     | 
| 148 | 
         
            +
                    Length of each window of each STFT, by default [2048, 512]
         
     | 
| 149 | 
         
            +
                loss_fn : typing.Callable, optional
         
     | 
| 150 | 
         
            +
                    How to compare each loss, by default nn.L1Loss()
         
     | 
| 151 | 
         
            +
                clamp_eps : float, optional
         
     | 
| 152 | 
         
            +
                    Clamp on the log magnitude, below, by default 1e-5
         
     | 
| 153 | 
         
            +
                mag_weight : float, optional
         
     | 
| 154 | 
         
            +
                    Weight of raw magnitude portion of loss, by default 1.0
         
     | 
| 155 | 
         
            +
                log_weight : float, optional
         
     | 
| 156 | 
         
            +
                    Weight of log magnitude portion of loss, by default 1.0
         
     | 
| 157 | 
         
            +
                pow : float, optional
         
     | 
| 158 | 
         
            +
                    Power to raise magnitude to before taking log, by default 2.0
         
     | 
| 159 | 
         
            +
                weight : float, optional
         
     | 
| 160 | 
         
            +
                    Weight of this loss, by default 1.0
         
     | 
| 161 | 
         
            +
                match_stride : bool, optional
         
     | 
| 162 | 
         
            +
                    Whether to match the stride of convolutional layers, by default False
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                References
         
     | 
| 165 | 
         
            +
                ----------
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                1.  Engel, Jesse, Chenjie Gu, and Adam Roberts.
         
     | 
| 168 | 
         
            +
                    "DDSP: Differentiable Digital Signal Processing."
         
     | 
| 169 | 
         
            +
                    International Conference on Learning Representations. 2019.
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
         
     | 
| 172 | 
         
            +
                """
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                def __init__(
         
     | 
| 175 | 
         
            +
                    self,
         
     | 
| 176 | 
         
            +
                    window_lengths: List[int] = [2048, 512],
         
     | 
| 177 | 
         
            +
                    loss_fn: typing.Callable = nn.L1Loss(),
         
     | 
| 178 | 
         
            +
                    clamp_eps: float = 1e-5,
         
     | 
| 179 | 
         
            +
                    mag_weight: float = 1.0,
         
     | 
| 180 | 
         
            +
                    log_weight: float = 1.0,
         
     | 
| 181 | 
         
            +
                    pow: float = 2.0,
         
     | 
| 182 | 
         
            +
                    weight: float = 1.0,
         
     | 
| 183 | 
         
            +
                    match_stride: bool = False,
         
     | 
| 184 | 
         
            +
                    window_type: str = None,
         
     | 
| 185 | 
         
            +
                ):
         
     | 
| 186 | 
         
            +
                    super().__init__()
         
     | 
| 187 | 
         
            +
                    self.stft_params = [
         
     | 
| 188 | 
         
            +
                        STFTParams(
         
     | 
| 189 | 
         
            +
                            window_length=w,
         
     | 
| 190 | 
         
            +
                            hop_length=w // 4,
         
     | 
| 191 | 
         
            +
                            match_stride=match_stride,
         
     | 
| 192 | 
         
            +
                            window_type=window_type,
         
     | 
| 193 | 
         
            +
                        )
         
     | 
| 194 | 
         
            +
                        for w in window_lengths
         
     | 
| 195 | 
         
            +
                    ]
         
     | 
| 196 | 
         
            +
                    self.loss_fn = loss_fn
         
     | 
| 197 | 
         
            +
                    self.log_weight = log_weight
         
     | 
| 198 | 
         
            +
                    self.mag_weight = mag_weight
         
     | 
| 199 | 
         
            +
                    self.clamp_eps = clamp_eps
         
     | 
| 200 | 
         
            +
                    self.weight = weight
         
     | 
| 201 | 
         
            +
                    self.pow = pow
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                def forward(self, x: AudioSignal, y: AudioSignal):
         
     | 
| 204 | 
         
            +
                    """Computes multi-scale STFT between an estimate and a reference
         
     | 
| 205 | 
         
            +
                    signal.
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    Parameters
         
     | 
| 208 | 
         
            +
                    ----------
         
     | 
| 209 | 
         
            +
                    x : AudioSignal
         
     | 
| 210 | 
         
            +
                        Estimate signal
         
     | 
| 211 | 
         
            +
                    y : AudioSignal
         
     | 
| 212 | 
         
            +
                        Reference signal
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                    Returns
         
     | 
| 215 | 
         
            +
                    -------
         
     | 
| 216 | 
         
            +
                    torch.Tensor
         
     | 
| 217 | 
         
            +
                        Multi-scale STFT loss.
         
     | 
| 218 | 
         
            +
                    """
         
     | 
| 219 | 
         
            +
                    loss = 0.0
         
     | 
| 220 | 
         
            +
                    for s in self.stft_params:
         
     | 
| 221 | 
         
            +
                        x.stft(s.window_length, s.hop_length, s.window_type)
         
     | 
| 222 | 
         
            +
                        y.stft(s.window_length, s.hop_length, s.window_type)
         
     | 
| 223 | 
         
            +
                        loss += self.log_weight * self.loss_fn(
         
     | 
| 224 | 
         
            +
                            x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
         
     | 
| 225 | 
         
            +
                            y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
         
     | 
| 226 | 
         
            +
                        )
         
     | 
| 227 | 
         
            +
                        loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
         
     | 
| 228 | 
         
            +
                    return loss
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
            class MelSpectrogramLoss(nn.Module):
         
     | 
| 232 | 
         
            +
                """Compute distance between mel spectrograms. Can be used
         
     | 
| 233 | 
         
            +
                in a multi-scale way.
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
                Parameters
         
     | 
| 236 | 
         
            +
                ----------
         
     | 
| 237 | 
         
            +
                n_mels : List[int]
         
     | 
| 238 | 
         
            +
                    Number of mels per STFT, by default [150, 80],
         
     | 
| 239 | 
         
            +
                window_lengths : List[int], optional
         
     | 
| 240 | 
         
            +
                    Length of each window of each STFT, by default [2048, 512]
         
     | 
| 241 | 
         
            +
                loss_fn : typing.Callable, optional
         
     | 
| 242 | 
         
            +
                    How to compare each loss, by default nn.L1Loss()
         
     | 
| 243 | 
         
            +
                clamp_eps : float, optional
         
     | 
| 244 | 
         
            +
                    Clamp on the log magnitude, below, by default 1e-5
         
     | 
| 245 | 
         
            +
                mag_weight : float, optional
         
     | 
| 246 | 
         
            +
                    Weight of raw magnitude portion of loss, by default 1.0
         
     | 
| 247 | 
         
            +
                log_weight : float, optional
         
     | 
| 248 | 
         
            +
                    Weight of log magnitude portion of loss, by default 1.0
         
     | 
| 249 | 
         
            +
                pow : float, optional
         
     | 
| 250 | 
         
            +
                    Power to raise magnitude to before taking log, by default 2.0
         
     | 
| 251 | 
         
            +
                weight : float, optional
         
     | 
| 252 | 
         
            +
                    Weight of this loss, by default 1.0
         
     | 
| 253 | 
         
            +
                match_stride : bool, optional
         
     | 
| 254 | 
         
            +
                    Whether to match the stride of convolutional layers, by default False
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
         
     | 
| 257 | 
         
            +
                """
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                def __init__(
         
     | 
| 260 | 
         
            +
                    self,
         
     | 
| 261 | 
         
            +
                    n_mels: List[int] = [150, 80],
         
     | 
| 262 | 
         
            +
                    window_lengths: List[int] = [2048, 512],
         
     | 
| 263 | 
         
            +
                    loss_fn: typing.Callable = nn.L1Loss(),
         
     | 
| 264 | 
         
            +
                    clamp_eps: float = 1e-5,
         
     | 
| 265 | 
         
            +
                    mag_weight: float = 1.0,
         
     | 
| 266 | 
         
            +
                    log_weight: float = 1.0,
         
     | 
| 267 | 
         
            +
                    pow: float = 2.0,
         
     | 
| 268 | 
         
            +
                    weight: float = 1.0,
         
     | 
| 269 | 
         
            +
                    match_stride: bool = False,
         
     | 
| 270 | 
         
            +
                    mel_fmin: List[float] = [0.0, 0.0],
         
     | 
| 271 | 
         
            +
                    mel_fmax: List[float] = [None, None],
         
     | 
| 272 | 
         
            +
                    window_type: str = None,
         
     | 
| 273 | 
         
            +
                ):
         
     | 
| 274 | 
         
            +
                    super().__init__()
         
     | 
| 275 | 
         
            +
                    self.stft_params = [
         
     | 
| 276 | 
         
            +
                        STFTParams(
         
     | 
| 277 | 
         
            +
                            window_length=w,
         
     | 
| 278 | 
         
            +
                            hop_length=w // 4,
         
     | 
| 279 | 
         
            +
                            match_stride=match_stride,
         
     | 
| 280 | 
         
            +
                            window_type=window_type,
         
     | 
| 281 | 
         
            +
                        )
         
     | 
| 282 | 
         
            +
                        for w in window_lengths
         
     | 
| 283 | 
         
            +
                    ]
         
     | 
| 284 | 
         
            +
                    self.n_mels = n_mels
         
     | 
| 285 | 
         
            +
                    self.loss_fn = loss_fn
         
     | 
| 286 | 
         
            +
                    self.clamp_eps = clamp_eps
         
     | 
| 287 | 
         
            +
                    self.log_weight = log_weight
         
     | 
| 288 | 
         
            +
                    self.mag_weight = mag_weight
         
     | 
| 289 | 
         
            +
                    self.weight = weight
         
     | 
| 290 | 
         
            +
                    self.mel_fmin = mel_fmin
         
     | 
| 291 | 
         
            +
                    self.mel_fmax = mel_fmax
         
     | 
| 292 | 
         
            +
                    self.pow = pow
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                def forward(self, x: AudioSignal, y: AudioSignal):
         
     | 
| 295 | 
         
            +
                    """Computes mel loss between an estimate and a reference
         
     | 
| 296 | 
         
            +
                    signal.
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    Parameters
         
     | 
| 299 | 
         
            +
                    ----------
         
     | 
| 300 | 
         
            +
                    x : AudioSignal
         
     | 
| 301 | 
         
            +
                        Estimate signal
         
     | 
| 302 | 
         
            +
                    y : AudioSignal
         
     | 
| 303 | 
         
            +
                        Reference signal
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    Returns
         
     | 
| 306 | 
         
            +
                    -------
         
     | 
| 307 | 
         
            +
                    torch.Tensor
         
     | 
| 308 | 
         
            +
                        Mel loss.
         
     | 
| 309 | 
         
            +
                    """
         
     | 
| 310 | 
         
            +
                    loss = 0.0
         
     | 
| 311 | 
         
            +
                    for n_mels, fmin, fmax, s in zip(
         
     | 
| 312 | 
         
            +
                        self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
         
     | 
| 313 | 
         
            +
                    ):
         
     | 
| 314 | 
         
            +
                        kwargs = {
         
     | 
| 315 | 
         
            +
                            "window_length": s.window_length,
         
     | 
| 316 | 
         
            +
                            "hop_length": s.hop_length,
         
     | 
| 317 | 
         
            +
                            "window_type": s.window_type,
         
     | 
| 318 | 
         
            +
                        }
         
     | 
| 319 | 
         
            +
                        x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
         
     | 
| 320 | 
         
            +
                        y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                        loss += self.log_weight * self.loss_fn(
         
     | 
| 323 | 
         
            +
                            x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
         
     | 
| 324 | 
         
            +
                            y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
         
     | 
| 325 | 
         
            +
                        )
         
     | 
| 326 | 
         
            +
                        loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
         
     | 
| 327 | 
         
            +
                    return loss
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
            class GANLoss(nn.Module):
         
     | 
| 331 | 
         
            +
                """
         
     | 
| 332 | 
         
            +
                Computes a discriminator loss, given a discriminator on
         
     | 
| 333 | 
         
            +
                generated waveforms/spectrograms compared to ground truth
         
     | 
| 334 | 
         
            +
                waveforms/spectrograms. Computes the loss for both the
         
     | 
| 335 | 
         
            +
                discriminator and the generator in separate functions.
         
     | 
| 336 | 
         
            +
                """
         
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
                def __init__(self, discriminator):
         
     | 
| 339 | 
         
            +
                    super().__init__()
         
     | 
| 340 | 
         
            +
                    self.discriminator = discriminator
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                def forward(self, fake, real):
         
     | 
| 343 | 
         
            +
                    d_fake = self.discriminator(fake.audio_data)
         
     | 
| 344 | 
         
            +
                    d_real = self.discriminator(real.audio_data)
         
     | 
| 345 | 
         
            +
                    return d_fake, d_real
         
     | 
| 346 | 
         
            +
             
     | 
| 347 | 
         
            +
                def discriminator_loss(self, fake, real):
         
     | 
| 348 | 
         
            +
                    d_fake, d_real = self.forward(fake.clone().detach(), real)
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                    loss_d = 0
         
     | 
| 351 | 
         
            +
                    for x_fake, x_real in zip(d_fake, d_real):
         
     | 
| 352 | 
         
            +
                        loss_d += torch.mean(x_fake[-1] ** 2)
         
     | 
| 353 | 
         
            +
                        loss_d += torch.mean((1 - x_real[-1]) ** 2)
         
     | 
| 354 | 
         
            +
                    return loss_d
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
                def generator_loss(self, fake, real):
         
     | 
| 357 | 
         
            +
                    d_fake, d_real = self.forward(fake, real)
         
     | 
| 358 | 
         
            +
             
     | 
| 359 | 
         
            +
                    loss_g = 0
         
     | 
| 360 | 
         
            +
                    for x_fake in d_fake:
         
     | 
| 361 | 
         
            +
                        loss_g += torch.mean((1 - x_fake[-1]) ** 2)
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                    loss_feature = 0
         
     | 
| 364 | 
         
            +
             
     | 
| 365 | 
         
            +
                    for i in range(len(d_fake)):
         
     | 
| 366 | 
         
            +
                        for j in range(len(d_fake[i]) - 1):
         
     | 
| 367 | 
         
            +
                            loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
         
     | 
| 368 | 
         
            +
                    return loss_g, loss_feature
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/nn/quantize.py
    ADDED
    
    | 
         @@ -0,0 +1,262 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Union
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 7 | 
         
            +
            from einops import rearrange
         
     | 
| 8 | 
         
            +
            from torch.nn.utils import weight_norm
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from .layers import WNConv1d
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class VectorQuantize(nn.Module):
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
                Implementation of VQ similar to Karpathy's repo:
         
     | 
| 16 | 
         
            +
                https://github.com/karpathy/deep-vector-quantization
         
     | 
| 17 | 
         
            +
                Additionally uses following tricks from Improved VQGAN
         
     | 
| 18 | 
         
            +
                (https://arxiv.org/pdf/2110.04627.pdf):
         
     | 
| 19 | 
         
            +
                    1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
         
     | 
| 20 | 
         
            +
                        for improved codebook usage
         
     | 
| 21 | 
         
            +
                    2. l2-normalized codes: Converts euclidean distance to cosine similarity which
         
     | 
| 22 | 
         
            +
                        improves training stability
         
     | 
| 23 | 
         
            +
                """
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
         
     | 
| 26 | 
         
            +
                    super().__init__()
         
     | 
| 27 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 28 | 
         
            +
                    self.codebook_dim = codebook_dim
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
         
     | 
| 31 | 
         
            +
                    self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
         
     | 
| 32 | 
         
            +
                    self.codebook = nn.Embedding(codebook_size, codebook_dim)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def forward(self, z):
         
     | 
| 35 | 
         
            +
                    """Quantized the input tensor using a fixed codebook and returns
         
     | 
| 36 | 
         
            +
                    the corresponding codebook vectors
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    Parameters
         
     | 
| 39 | 
         
            +
                    ----------
         
     | 
| 40 | 
         
            +
                    z : Tensor[B x D x T]
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    Returns
         
     | 
| 43 | 
         
            +
                    -------
         
     | 
| 44 | 
         
            +
                    Tensor[B x D x T]
         
     | 
| 45 | 
         
            +
                        Quantized continuous representation of input
         
     | 
| 46 | 
         
            +
                    Tensor[1]
         
     | 
| 47 | 
         
            +
                        Commitment loss to train encoder to predict vectors closer to codebook
         
     | 
| 48 | 
         
            +
                        entries
         
     | 
| 49 | 
         
            +
                    Tensor[1]
         
     | 
| 50 | 
         
            +
                        Codebook loss to update the codebook
         
     | 
| 51 | 
         
            +
                    Tensor[B x T]
         
     | 
| 52 | 
         
            +
                        Codebook indices (quantized discrete representation of input)
         
     | 
| 53 | 
         
            +
                    Tensor[B x D x T]
         
     | 
| 54 | 
         
            +
                        Projected latents (continuous representation of input before quantization)
         
     | 
| 55 | 
         
            +
                    """
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
         
     | 
| 58 | 
         
            +
                    z_e = self.in_proj(z)  # z_e : (B x D x T)
         
     | 
| 59 | 
         
            +
                    z_q, indices = self.decode_latents(z_e)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
         
     | 
| 62 | 
         
            +
                    codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    z_q = (
         
     | 
| 65 | 
         
            +
                        z_e + (z_q - z_e).detach()
         
     | 
| 66 | 
         
            +
                    )  # noop in forward pass, straight-through gradient estimator in backward pass
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    z_q = self.out_proj(z_q)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    return z_q, commitment_loss, codebook_loss, indices, z_e
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                def embed_code(self, embed_id):
         
     | 
| 73 | 
         
            +
                    return F.embedding(embed_id, self.codebook.weight)
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                def decode_code(self, embed_id):
         
     | 
| 76 | 
         
            +
                    return self.embed_code(embed_id).transpose(1, 2)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                def decode_latents(self, latents):
         
     | 
| 79 | 
         
            +
                    encodings = rearrange(latents, "b d t -> (b t) d")
         
     | 
| 80 | 
         
            +
                    codebook = self.codebook.weight  # codebook: (N x D)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    # L2 normalize encodings and codebook (ViT-VQGAN)
         
     | 
| 83 | 
         
            +
                    encodings = F.normalize(encodings)
         
     | 
| 84 | 
         
            +
                    codebook = F.normalize(codebook)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    # Compute euclidean distance with codebook
         
     | 
| 87 | 
         
            +
                    dist = (
         
     | 
| 88 | 
         
            +
                        encodings.pow(2).sum(1, keepdim=True)
         
     | 
| 89 | 
         
            +
                        - 2 * encodings @ codebook.t()
         
     | 
| 90 | 
         
            +
                        + codebook.pow(2).sum(1, keepdim=True).t()
         
     | 
| 91 | 
         
            +
                    )
         
     | 
| 92 | 
         
            +
                    indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
         
     | 
| 93 | 
         
            +
                    z_q = self.decode_code(indices)
         
     | 
| 94 | 
         
            +
                    return z_q, indices
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            class ResidualVectorQuantize(nn.Module):
         
     | 
| 98 | 
         
            +
                """
         
     | 
| 99 | 
         
            +
                Introduced in SoundStream: An end2end neural audio codec
         
     | 
| 100 | 
         
            +
                https://arxiv.org/abs/2107.03312
         
     | 
| 101 | 
         
            +
                """
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                def __init__(
         
     | 
| 104 | 
         
            +
                    self,
         
     | 
| 105 | 
         
            +
                    input_dim: int = 512,
         
     | 
| 106 | 
         
            +
                    n_codebooks: int = 9,
         
     | 
| 107 | 
         
            +
                    codebook_size: int = 1024,
         
     | 
| 108 | 
         
            +
                    codebook_dim: Union[int, list] = 8,
         
     | 
| 109 | 
         
            +
                    quantizer_dropout: float = 0.0,
         
     | 
| 110 | 
         
            +
                ):
         
     | 
| 111 | 
         
            +
                    super().__init__()
         
     | 
| 112 | 
         
            +
                    if isinstance(codebook_dim, int):
         
     | 
| 113 | 
         
            +
                        codebook_dim = [codebook_dim for _ in range(n_codebooks)]
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    self.n_codebooks = n_codebooks
         
     | 
| 116 | 
         
            +
                    self.codebook_dim = codebook_dim
         
     | 
| 117 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    self.quantizers = nn.ModuleList(
         
     | 
| 120 | 
         
            +
                        [
         
     | 
| 121 | 
         
            +
                            VectorQuantize(input_dim, codebook_size, codebook_dim[i])
         
     | 
| 122 | 
         
            +
                            for i in range(n_codebooks)
         
     | 
| 123 | 
         
            +
                        ]
         
     | 
| 124 | 
         
            +
                    )
         
     | 
| 125 | 
         
            +
                    self.quantizer_dropout = quantizer_dropout
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                def forward(self, z, n_quantizers: int = None):
         
     | 
| 128 | 
         
            +
                    """Quantized the input tensor using a fixed set of `n` codebooks and returns
         
     | 
| 129 | 
         
            +
                    the corresponding codebook vectors
         
     | 
| 130 | 
         
            +
                    Parameters
         
     | 
| 131 | 
         
            +
                    ----------
         
     | 
| 132 | 
         
            +
                    z : Tensor[B x D x T]
         
     | 
| 133 | 
         
            +
                    n_quantizers : int, optional
         
     | 
| 134 | 
         
            +
                        No. of quantizers to use
         
     | 
| 135 | 
         
            +
                        (n_quantizers < self.n_codebooks ex: for quantizer dropout)
         
     | 
| 136 | 
         
            +
                        Note: if `self.quantizer_dropout` is True, this argument is ignored
         
     | 
| 137 | 
         
            +
                            when in training mode, and a random number of quantizers is used.
         
     | 
| 138 | 
         
            +
                    Returns
         
     | 
| 139 | 
         
            +
                    -------
         
     | 
| 140 | 
         
            +
                    dict
         
     | 
| 141 | 
         
            +
                        A dictionary with the following keys:
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                        "z" : Tensor[B x D x T]
         
     | 
| 144 | 
         
            +
                            Quantized continuous representation of input
         
     | 
| 145 | 
         
            +
                        "codes" : Tensor[B x N x T]
         
     | 
| 146 | 
         
            +
                            Codebook indices for each codebook
         
     | 
| 147 | 
         
            +
                            (quantized discrete representation of input)
         
     | 
| 148 | 
         
            +
                        "latents" : Tensor[B x N*D x T]
         
     | 
| 149 | 
         
            +
                            Projected latents (continuous representation of input before quantization)
         
     | 
| 150 | 
         
            +
                        "vq/commitment_loss" : Tensor[1]
         
     | 
| 151 | 
         
            +
                            Commitment loss to train encoder to predict vectors closer to codebook
         
     | 
| 152 | 
         
            +
                            entries
         
     | 
| 153 | 
         
            +
                        "vq/codebook_loss" : Tensor[1]
         
     | 
| 154 | 
         
            +
                            Codebook loss to update the codebook
         
     | 
| 155 | 
         
            +
                    """
         
     | 
| 156 | 
         
            +
                    z_q = 0
         
     | 
| 157 | 
         
            +
                    residual = z
         
     | 
| 158 | 
         
            +
                    commitment_loss = 0
         
     | 
| 159 | 
         
            +
                    codebook_loss = 0
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    codebook_indices = []
         
     | 
| 162 | 
         
            +
                    latents = []
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    if n_quantizers is None:
         
     | 
| 165 | 
         
            +
                        n_quantizers = self.n_codebooks
         
     | 
| 166 | 
         
            +
                    if self.training:
         
     | 
| 167 | 
         
            +
                        n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
         
     | 
| 168 | 
         
            +
                        dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
         
     | 
| 169 | 
         
            +
                        n_dropout = int(z.shape[0] * self.quantizer_dropout)
         
     | 
| 170 | 
         
            +
                        n_quantizers[:n_dropout] = dropout[:n_dropout]
         
     | 
| 171 | 
         
            +
                        n_quantizers = n_quantizers.to(z.device)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                    for i, quantizer in enumerate(self.quantizers):
         
     | 
| 174 | 
         
            +
                        if self.training is False and i >= n_quantizers:
         
     | 
| 175 | 
         
            +
                            break
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                        z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
         
     | 
| 178 | 
         
            +
                            residual
         
     | 
| 179 | 
         
            +
                        )
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                        # Create mask to apply quantizer dropout
         
     | 
| 182 | 
         
            +
                        mask = (
         
     | 
| 183 | 
         
            +
                            torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
         
     | 
| 184 | 
         
            +
                        )
         
     | 
| 185 | 
         
            +
                        z_q = z_q + z_q_i * mask[:, None, None]
         
     | 
| 186 | 
         
            +
                        residual = residual - z_q_i
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                        # Sum losses
         
     | 
| 189 | 
         
            +
                        commitment_loss += (commitment_loss_i * mask).mean()
         
     | 
| 190 | 
         
            +
                        codebook_loss += (codebook_loss_i * mask).mean()
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                        codebook_indices.append(indices_i)
         
     | 
| 193 | 
         
            +
                        latents.append(z_e_i)
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                    codes = torch.stack(codebook_indices, dim=1)
         
     | 
| 196 | 
         
            +
                    latents = torch.cat(latents, dim=1)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    return z_q, codes, latents, commitment_loss, codebook_loss
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                def from_codes(self, codes: torch.Tensor):
         
     | 
| 201 | 
         
            +
                    """Given the quantized codes, reconstruct the continuous representation
         
     | 
| 202 | 
         
            +
                    Parameters
         
     | 
| 203 | 
         
            +
                    ----------
         
     | 
| 204 | 
         
            +
                    codes : Tensor[B x N x T]
         
     | 
| 205 | 
         
            +
                        Quantized discrete representation of input
         
     | 
| 206 | 
         
            +
                    Returns
         
     | 
| 207 | 
         
            +
                    -------
         
     | 
| 208 | 
         
            +
                    Tensor[B x D x T]
         
     | 
| 209 | 
         
            +
                        Quantized continuous representation of input
         
     | 
| 210 | 
         
            +
                    """
         
     | 
| 211 | 
         
            +
                    z_q = 0.0
         
     | 
| 212 | 
         
            +
                    z_p = []
         
     | 
| 213 | 
         
            +
                    n_codebooks = codes.shape[1]
         
     | 
| 214 | 
         
            +
                    for i in range(n_codebooks):
         
     | 
| 215 | 
         
            +
                        z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
         
     | 
| 216 | 
         
            +
                        z_p.append(z_p_i)
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                        z_q_i = self.quantizers[i].out_proj(z_p_i)
         
     | 
| 219 | 
         
            +
                        z_q = z_q + z_q_i
         
     | 
| 220 | 
         
            +
                    return z_q, torch.cat(z_p, dim=1), codes
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                def from_latents(self, latents: torch.Tensor):
         
     | 
| 223 | 
         
            +
                    """Given the unquantized latents, reconstruct the
         
     | 
| 224 | 
         
            +
                    continuous representation after quantization.
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    Parameters
         
     | 
| 227 | 
         
            +
                    ----------
         
     | 
| 228 | 
         
            +
                    latents : Tensor[B x N x T]
         
     | 
| 229 | 
         
            +
                        Continuous representation of input after projection
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                    Returns
         
     | 
| 232 | 
         
            +
                    -------
         
     | 
| 233 | 
         
            +
                    Tensor[B x D x T]
         
     | 
| 234 | 
         
            +
                        Quantized representation of full-projected space
         
     | 
| 235 | 
         
            +
                    Tensor[B x D x T]
         
     | 
| 236 | 
         
            +
                        Quantized representation of latent space
         
     | 
| 237 | 
         
            +
                    """
         
     | 
| 238 | 
         
            +
                    z_q = 0
         
     | 
| 239 | 
         
            +
                    z_p = []
         
     | 
| 240 | 
         
            +
                    codes = []
         
     | 
| 241 | 
         
            +
                    dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                    n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
         
     | 
| 244 | 
         
            +
                        0
         
     | 
| 245 | 
         
            +
                    ]
         
     | 
| 246 | 
         
            +
                    for i in range(n_codebooks):
         
     | 
| 247 | 
         
            +
                        j, k = dims[i], dims[i + 1]
         
     | 
| 248 | 
         
            +
                        z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
         
     | 
| 249 | 
         
            +
                        z_p.append(z_p_i)
         
     | 
| 250 | 
         
            +
                        codes.append(codes_i)
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                        z_q_i = self.quantizers[i].out_proj(z_p_i)
         
     | 
| 253 | 
         
            +
                        z_q = z_q + z_q_i
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 259 | 
         
            +
                rvq = ResidualVectorQuantize(quantizer_dropout=True)
         
     | 
| 260 | 
         
            +
                x = torch.randn(16, 512, 80)
         
     | 
| 261 | 
         
            +
                y = rvq(x)
         
     | 
| 262 | 
         
            +
                print(y["latents"].shape)
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,91 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class AbstractDistribution:
         
     | 
| 6 | 
         
            +
                def sample(self):
         
     | 
| 7 | 
         
            +
                    raise NotImplementedError()
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                def mode(self):
         
     | 
| 10 | 
         
            +
                    raise NotImplementedError()
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class DiracDistribution(AbstractDistribution):
         
     | 
| 14 | 
         
            +
                def __init__(self, value):
         
     | 
| 15 | 
         
            +
                    self.value = value
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def sample(self):
         
     | 
| 18 | 
         
            +
                    return self.value
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                def mode(self):
         
     | 
| 21 | 
         
            +
                    return self.value
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class DiagonalGaussianDistribution(object):
         
     | 
| 25 | 
         
            +
                def __init__(self, parameters, deterministic=False):
         
     | 
| 26 | 
         
            +
                    self.parameters = parameters
         
     | 
| 27 | 
         
            +
                    self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
         
     | 
| 28 | 
         
            +
                    self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
         
     | 
| 29 | 
         
            +
                    self.deterministic = deterministic
         
     | 
| 30 | 
         
            +
                    self.std = torch.exp(0.5 * self.logvar)
         
     | 
| 31 | 
         
            +
                    self.var = torch.exp(self.logvar)
         
     | 
| 32 | 
         
            +
                    if self.deterministic:
         
     | 
| 33 | 
         
            +
                        self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def sample(self):
         
     | 
| 36 | 
         
            +
                    x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
         
     | 
| 37 | 
         
            +
                    return x
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                def kl(self, other=None):
         
     | 
| 40 | 
         
            +
                    if self.deterministic:
         
     | 
| 41 | 
         
            +
                        return torch.Tensor([0.0])
         
     | 
| 42 | 
         
            +
                    else:
         
     | 
| 43 | 
         
            +
                        if other is None:
         
     | 
| 44 | 
         
            +
                            return 0.5 * torch.mean(
         
     | 
| 45 | 
         
            +
                                torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
         
     | 
| 46 | 
         
            +
                                dim=[1, 2],
         
     | 
| 47 | 
         
            +
                            )
         
     | 
| 48 | 
         
            +
                        else:
         
     | 
| 49 | 
         
            +
                            return 0.5 * torch.mean(
         
     | 
| 50 | 
         
            +
                                torch.pow(self.mean - other.mean, 2) / other.var
         
     | 
| 51 | 
         
            +
                                + self.var / other.var
         
     | 
| 52 | 
         
            +
                                - 1.0
         
     | 
| 53 | 
         
            +
                                - self.logvar
         
     | 
| 54 | 
         
            +
                                + other.logvar,
         
     | 
| 55 | 
         
            +
                                dim=[1, 2],
         
     | 
| 56 | 
         
            +
                            )
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                def nll(self, sample, dims=[1, 2]):
         
     | 
| 59 | 
         
            +
                    if self.deterministic:
         
     | 
| 60 | 
         
            +
                        return torch.Tensor([0.0])
         
     | 
| 61 | 
         
            +
                    logtwopi = np.log(2.0 * np.pi)
         
     | 
| 62 | 
         
            +
                    return 0.5 * torch.sum(
         
     | 
| 63 | 
         
            +
                        logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
         
     | 
| 64 | 
         
            +
                        dim=dims,
         
     | 
| 65 | 
         
            +
                    )
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                def mode(self):
         
     | 
| 68 | 
         
            +
                    return self.mean
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            def normal_kl(mean1, logvar1, mean2, logvar2):
         
     | 
| 72 | 
         
            +
                """
         
     | 
| 73 | 
         
            +
                source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
         
     | 
| 74 | 
         
            +
                Compute the KL divergence between two gaussians.
         
     | 
| 75 | 
         
            +
                Shapes are automatically broadcasted, so batches can be compared to
         
     | 
| 76 | 
         
            +
                scalars, among other use cases.
         
     | 
| 77 | 
         
            +
                """
         
     | 
| 78 | 
         
            +
                tensor = None
         
     | 
| 79 | 
         
            +
                for obj in (mean1, logvar1, mean2, logvar2):
         
     | 
| 80 | 
         
            +
                    if isinstance(obj, torch.Tensor):
         
     | 
| 81 | 
         
            +
                        tensor = obj
         
     | 
| 82 | 
         
            +
                        break
         
     | 
| 83 | 
         
            +
                assert tensor is not None, "at least one argument must be a Tensor"
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                # Force variances to be Tensors. Broadcasting helps convert scalars to
         
     | 
| 86 | 
         
            +
                # Tensors, but it does not work for torch.exp().
         
     | 
| 87 | 
         
            +
                logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                return 0.5 * (
         
     | 
| 90 | 
         
            +
                    -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
         
     | 
| 91 | 
         
            +
                )
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/utils/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,121 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from pathlib import Path
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import argbind
         
     | 
| 4 | 
         
            +
            from audiotools import ml
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from ..model import DAC
         
     | 
| 7 | 
         
            +
            Accelerator = ml.Accelerator
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            __MODEL_LATEST_TAGS__ = {
         
     | 
| 10 | 
         
            +
                ("44khz", "8kbps"): "0.0.1",
         
     | 
| 11 | 
         
            +
                ("24khz", "8kbps"): "0.0.4",
         
     | 
| 12 | 
         
            +
                ("16khz", "8kbps"): "0.0.5",
         
     | 
| 13 | 
         
            +
                ("44khz", "16kbps"): "1.0.0",
         
     | 
| 14 | 
         
            +
            }
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            __MODEL_URLS__ = {
         
     | 
| 17 | 
         
            +
                (
         
     | 
| 18 | 
         
            +
                    "44khz",
         
     | 
| 19 | 
         
            +
                    "0.0.1",
         
     | 
| 20 | 
         
            +
                    "8kbps",
         
     | 
| 21 | 
         
            +
                ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
         
     | 
| 22 | 
         
            +
                (
         
     | 
| 23 | 
         
            +
                    "24khz",
         
     | 
| 24 | 
         
            +
                    "0.0.4",
         
     | 
| 25 | 
         
            +
                    "8kbps",
         
     | 
| 26 | 
         
            +
                ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
         
     | 
| 27 | 
         
            +
                (
         
     | 
| 28 | 
         
            +
                    "16khz",
         
     | 
| 29 | 
         
            +
                    "0.0.5",
         
     | 
| 30 | 
         
            +
                    "8kbps",
         
     | 
| 31 | 
         
            +
                ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
         
     | 
| 32 | 
         
            +
                (
         
     | 
| 33 | 
         
            +
                    "44khz",
         
     | 
| 34 | 
         
            +
                    "1.0.0",
         
     | 
| 35 | 
         
            +
                    "16kbps",
         
     | 
| 36 | 
         
            +
                ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
         
     | 
| 37 | 
         
            +
            }
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            @argbind.bind(group="download", positional=True, without_prefix=True)
         
     | 
| 41 | 
         
            +
            def download(
         
     | 
| 42 | 
         
            +
                model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
         
     | 
| 43 | 
         
            +
            ):
         
     | 
| 44 | 
         
            +
                """
         
     | 
| 45 | 
         
            +
                Function that downloads the weights file from URL if a local cache is not found.
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                Parameters
         
     | 
| 48 | 
         
            +
                ----------
         
     | 
| 49 | 
         
            +
                model_type : str
         
     | 
| 50 | 
         
            +
                    The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
         
     | 
| 51 | 
         
            +
                model_bitrate: str
         
     | 
| 52 | 
         
            +
                    Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
         
     | 
| 53 | 
         
            +
                    Only 44khz model supports 16kbps.
         
     | 
| 54 | 
         
            +
                tag : str
         
     | 
| 55 | 
         
            +
                    The tag of the model to download. Defaults to "latest".
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                Returns
         
     | 
| 58 | 
         
            +
                -------
         
     | 
| 59 | 
         
            +
                Path
         
     | 
| 60 | 
         
            +
                    Directory path required to load model via audiotools.
         
     | 
| 61 | 
         
            +
                """
         
     | 
| 62 | 
         
            +
                model_type = model_type.lower()
         
     | 
| 63 | 
         
            +
                tag = tag.lower()
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                assert model_type in [
         
     | 
| 66 | 
         
            +
                    "44khz",
         
     | 
| 67 | 
         
            +
                    "24khz",
         
     | 
| 68 | 
         
            +
                    "16khz",
         
     | 
| 69 | 
         
            +
                ], "model_type must be one of '44khz', '24khz', or '16khz'"
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                assert model_bitrate in [
         
     | 
| 72 | 
         
            +
                    "8kbps",
         
     | 
| 73 | 
         
            +
                    "16kbps",
         
     | 
| 74 | 
         
            +
                ], "model_bitrate must be one of '8kbps', or '16kbps'"
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                if tag == "latest":
         
     | 
| 77 | 
         
            +
                    tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                if download_link is None:
         
     | 
| 82 | 
         
            +
                    raise ValueError(
         
     | 
| 83 | 
         
            +
                        f"Could not find model with tag {tag} and model type {model_type}"
         
     | 
| 84 | 
         
            +
                    )
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                local_path = (
         
     | 
| 87 | 
         
            +
                    Path.home()
         
     | 
| 88 | 
         
            +
                    / ".cache"
         
     | 
| 89 | 
         
            +
                    / "descript"
         
     | 
| 90 | 
         
            +
                    / "dac"
         
     | 
| 91 | 
         
            +
                    / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
         
     | 
| 92 | 
         
            +
                )
         
     | 
| 93 | 
         
            +
                if not local_path.exists():
         
     | 
| 94 | 
         
            +
                    local_path.parent.mkdir(parents=True, exist_ok=True)
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    # Download the model
         
     | 
| 97 | 
         
            +
                    import requests
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    response = requests.get(download_link)
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    if response.status_code != 200:
         
     | 
| 102 | 
         
            +
                        raise ValueError(
         
     | 
| 103 | 
         
            +
                            f"Could not download model. Received response code {response.status_code}"
         
     | 
| 104 | 
         
            +
                        )
         
     | 
| 105 | 
         
            +
                    local_path.write_bytes(response.content)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                return local_path
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
            def load_model(
         
     | 
| 111 | 
         
            +
                model_type: str = "44khz",
         
     | 
| 112 | 
         
            +
                model_bitrate: str = "8kbps",
         
     | 
| 113 | 
         
            +
                tag: str = "latest",
         
     | 
| 114 | 
         
            +
                load_path: str = None,
         
     | 
| 115 | 
         
            +
            ):
         
     | 
| 116 | 
         
            +
                if not load_path:
         
     | 
| 117 | 
         
            +
                    load_path = download(
         
     | 
| 118 | 
         
            +
                        model_type=model_type, model_bitrate=model_bitrate, tag=tag
         
     | 
| 119 | 
         
            +
                    )
         
     | 
| 120 | 
         
            +
                generator = DAC.load(load_path)
         
     | 
| 121 | 
         
            +
                return generator
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/utils/decode.py
    ADDED
    
    | 
         @@ -0,0 +1,95 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import warnings
         
     | 
| 2 | 
         
            +
            from pathlib import Path
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import argbind
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            from audiotools import AudioSignal
         
     | 
| 8 | 
         
            +
            from tqdm import tqdm
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from ..model import DACFile
         
     | 
| 11 | 
         
            +
            from . import load_model
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            warnings.filterwarnings("ignore", category=UserWarning)
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            @argbind.bind(group="decode", positional=True, without_prefix=True)
         
     | 
| 17 | 
         
            +
            @torch.inference_mode()
         
     | 
| 18 | 
         
            +
            @torch.no_grad()
         
     | 
| 19 | 
         
            +
            def decode(
         
     | 
| 20 | 
         
            +
                input: str,
         
     | 
| 21 | 
         
            +
                output: str = "",
         
     | 
| 22 | 
         
            +
                weights_path: str = "",
         
     | 
| 23 | 
         
            +
                model_tag: str = "latest",
         
     | 
| 24 | 
         
            +
                model_bitrate: str = "8kbps",
         
     | 
| 25 | 
         
            +
                device: str = "cuda",
         
     | 
| 26 | 
         
            +
                model_type: str = "44khz",
         
     | 
| 27 | 
         
            +
                verbose: bool = False,
         
     | 
| 28 | 
         
            +
            ):
         
     | 
| 29 | 
         
            +
                """Decode audio from codes.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                Parameters
         
     | 
| 32 | 
         
            +
                ----------
         
     | 
| 33 | 
         
            +
                input : str
         
     | 
| 34 | 
         
            +
                    Path to input directory or file
         
     | 
| 35 | 
         
            +
                output : str, optional
         
     | 
| 36 | 
         
            +
                    Path to output directory, by default "".
         
     | 
| 37 | 
         
            +
                    If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
         
     | 
| 38 | 
         
            +
                weights_path : str, optional
         
     | 
| 39 | 
         
            +
                    Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
         
     | 
| 40 | 
         
            +
                    model_tag and model_type.
         
     | 
| 41 | 
         
            +
                model_tag : str, optional
         
     | 
| 42 | 
         
            +
                    Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
         
     | 
| 43 | 
         
            +
                model_bitrate: str
         
     | 
| 44 | 
         
            +
                    Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
         
     | 
| 45 | 
         
            +
                device : str, optional
         
     | 
| 46 | 
         
            +
                    Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
         
     | 
| 47 | 
         
            +
                model_type : str, optional
         
     | 
| 48 | 
         
            +
                    The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
         
     | 
| 49 | 
         
            +
                """
         
     | 
| 50 | 
         
            +
                generator = load_model(
         
     | 
| 51 | 
         
            +
                    model_type=model_type,
         
     | 
| 52 | 
         
            +
                    model_bitrate=model_bitrate,
         
     | 
| 53 | 
         
            +
                    tag=model_tag,
         
     | 
| 54 | 
         
            +
                    load_path=weights_path,
         
     | 
| 55 | 
         
            +
                )
         
     | 
| 56 | 
         
            +
                generator.to(device)
         
     | 
| 57 | 
         
            +
                generator.eval()
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                # Find all .dac files in input directory
         
     | 
| 60 | 
         
            +
                _input = Path(input)
         
     | 
| 61 | 
         
            +
                input_files = list(_input.glob("**/*.dac"))
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                # If input is a .dac file, add it to the list
         
     | 
| 64 | 
         
            +
                if _input.suffix == ".dac":
         
     | 
| 65 | 
         
            +
                    input_files.append(_input)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                # Create output directory
         
     | 
| 68 | 
         
            +
                output = Path(output)
         
     | 
| 69 | 
         
            +
                output.mkdir(parents=True, exist_ok=True)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
         
     | 
| 72 | 
         
            +
                    # Load file
         
     | 
| 73 | 
         
            +
                    artifact = DACFile.load(input_files[i])
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    # Reconstruct audio from codes
         
     | 
| 76 | 
         
            +
                    recons = generator.decompress(artifact, verbose=verbose)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    # Compute output path
         
     | 
| 79 | 
         
            +
                    relative_path = input_files[i].relative_to(input)
         
     | 
| 80 | 
         
            +
                    output_dir = output / relative_path.parent
         
     | 
| 81 | 
         
            +
                    if not relative_path.name:
         
     | 
| 82 | 
         
            +
                        output_dir = output
         
     | 
| 83 | 
         
            +
                        relative_path = input_files[i]
         
     | 
| 84 | 
         
            +
                    output_name = relative_path.with_suffix(".wav").name
         
     | 
| 85 | 
         
            +
                    output_path = output_dir / output_name
         
     | 
| 86 | 
         
            +
                    output_path.parent.mkdir(parents=True, exist_ok=True)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    # Write to file
         
     | 
| 89 | 
         
            +
                    recons.write(output_path)
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 93 | 
         
            +
                args = argbind.parse_args()
         
     | 
| 94 | 
         
            +
                with argbind.scope(args):
         
     | 
| 95 | 
         
            +
                    decode()
         
     | 
    	
        hunyuanvideo_foley/models/dac_vae/utils/encode.py
    ADDED
    
    | 
         @@ -0,0 +1,94 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            import warnings
         
     | 
| 3 | 
         
            +
            from pathlib import Path
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import argbind
         
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            from audiotools import AudioSignal
         
     | 
| 9 | 
         
            +
            from audiotools.core import util
         
     | 
| 10 | 
         
            +
            from tqdm import tqdm
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from . import load_model
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            warnings.filterwarnings("ignore", category=UserWarning)
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            @argbind.bind(group="encode", positional=True, without_prefix=True)
         
     | 
| 18 | 
         
            +
            @torch.inference_mode()
         
     | 
| 19 | 
         
            +
            @torch.no_grad()
         
     | 
| 20 | 
         
            +
            def encode(
         
     | 
| 21 | 
         
            +
                input: str,
         
     | 
| 22 | 
         
            +
                output: str = "",
         
     | 
| 23 | 
         
            +
                weights_path: str = "",
         
     | 
| 24 | 
         
            +
                model_tag: str = "latest",
         
     | 
| 25 | 
         
            +
                model_bitrate: str = "8kbps",
         
     | 
| 26 | 
         
            +
                n_quantizers: int = None,
         
     | 
| 27 | 
         
            +
                device: str = "cuda",
         
     | 
| 28 | 
         
            +
                model_type: str = "44khz",
         
     | 
| 29 | 
         
            +
                win_duration: float = 5.0,
         
     | 
| 30 | 
         
            +
                verbose: bool = False,
         
     | 
| 31 | 
         
            +
            ):
         
     | 
| 32 | 
         
            +
                """Encode audio files in input path to .dac format.
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                Parameters
         
     | 
| 35 | 
         
            +
                ----------
         
     | 
| 36 | 
         
            +
                input : str
         
     | 
| 37 | 
         
            +
                    Path to input audio file or directory
         
     | 
| 38 | 
         
            +
                output : str, optional
         
     | 
| 39 | 
         
            +
                    Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
         
     | 
| 40 | 
         
            +
                weights_path : str, optional
         
     | 
| 41 | 
         
            +
                    Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
         
     | 
| 42 | 
         
            +
                    model_tag and model_type.
         
     | 
| 43 | 
         
            +
                model_tag : str, optional
         
     | 
| 44 | 
         
            +
                    Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
         
     | 
| 45 | 
         
            +
                model_bitrate: str
         
     | 
| 46 | 
         
            +
                    Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
         
     | 
| 47 | 
         
            +
                n_quantizers : int, optional
         
     | 
| 48 | 
         
            +
                    Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
         
     | 
| 49 | 
         
            +
                device : str, optional
         
     | 
| 50 | 
         
            +
                    Device to use, by default "cuda"
         
     | 
| 51 | 
         
            +
                model_type : str, optional
         
     | 
| 52 | 
         
            +
                    The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
         
     | 
| 53 | 
         
            +
                """
         
     | 
| 54 | 
         
            +
                generator = load_model(
         
     | 
| 55 | 
         
            +
                    model_type=model_type,
         
     | 
| 56 | 
         
            +
                    model_bitrate=model_bitrate,
         
     | 
| 57 | 
         
            +
                    tag=model_tag,
         
     | 
| 58 | 
         
            +
                    load_path=weights_path,
         
     | 
| 59 | 
         
            +
                )
         
     | 
| 60 | 
         
            +
                generator.to(device)
         
     | 
| 61 | 
         
            +
                generator.eval()
         
     | 
| 62 | 
         
            +
                kwargs = {"n_quantizers": n_quantizers}
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                # Find all audio files in input path
         
     | 
| 65 | 
         
            +
                input = Path(input)
         
     | 
| 66 | 
         
            +
                audio_files = util.find_audio(input)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                output = Path(output)
         
     | 
| 69 | 
         
            +
                output.mkdir(parents=True, exist_ok=True)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                for i in tqdm(range(len(audio_files)), desc="Encoding files"):
         
     | 
| 72 | 
         
            +
                    # Load file
         
     | 
| 73 | 
         
            +
                    signal = AudioSignal(audio_files[i])
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    # Encode audio to .dac format
         
     | 
| 76 | 
         
            +
                    artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    # Compute output path
         
     | 
| 79 | 
         
            +
                    relative_path = audio_files[i].relative_to(input)
         
     | 
| 80 | 
         
            +
                    output_dir = output / relative_path.parent
         
     | 
| 81 | 
         
            +
                    if not relative_path.name:
         
     | 
| 82 | 
         
            +
                        output_dir = output
         
     | 
| 83 | 
         
            +
                        relative_path = audio_files[i]
         
     | 
| 84 | 
         
            +
                    output_name = relative_path.with_suffix(".dac").name
         
     | 
| 85 | 
         
            +
                    output_path = output_dir / output_name
         
     | 
| 86 | 
         
            +
                    output_path.parent.mkdir(parents=True, exist_ok=True)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    artifact.save(output_path)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 92 | 
         
            +
                args = argbind.parse_args()
         
     | 
| 93 | 
         
            +
                with argbind.scope(args):
         
     | 
| 94 | 
         
            +
                    encode()
         
     | 
    	
        hunyuanvideo_foley/models/hifi_foley.py
    ADDED
    
    | 
         @@ -0,0 +1,794 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import List, Tuple, Optional, Union, Dict
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
            from einops import rearrange
         
     | 
| 7 | 
         
            +
            from einops.layers.torch import Rearrange
         
     | 
| 8 | 
         
            +
            from diffusers.models import ModelMixin
         
     | 
| 9 | 
         
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from .nn.activation_layers import SwiGLU, get_activation_layer
         
     | 
| 12 | 
         
            +
            from .nn.attn_layers import apply_rotary_emb, attention
         
     | 
| 13 | 
         
            +
            from .nn.embed_layers import TimestepEmbedder, ConditionProjection, PatchEmbed1D
         
     | 
| 14 | 
         
            +
            from .nn.mlp_layers import MLP, ConvMLP, FinalLayer1D, ChannelLastConv1d
         
     | 
| 15 | 
         
            +
            from .nn.modulate_layers import ModulateDiT, ckpt_wrapper, apply_gate, modulate
         
     | 
| 16 | 
         
            +
            from .nn.norm_layers import get_norm_layer
         
     | 
| 17 | 
         
            +
            from .nn.posemb_layers import get_nd_rotary_pos_embed
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def interleave_two_sequences(x1: torch.Tensor, x2: torch.Tensor):
         
     | 
| 20 | 
         
            +
                # [B, N1, H, C] & [B, N2, H, C]
         
     | 
| 21 | 
         
            +
                B, N1, H, C = x1.shape
         
     | 
| 22 | 
         
            +
                B, N2, H, C = x2.shape
         
     | 
| 23 | 
         
            +
                assert x1.ndim == x2.ndim == 4
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                if N1 != N2:
         
     | 
| 26 | 
         
            +
                    x2 = x2.view(B, N2, -1).transpose(1, 2)
         
     | 
| 27 | 
         
            +
                    x2 = F.interpolate(x2, size=(N1), mode="nearest-exact")
         
     | 
| 28 | 
         
            +
                    x2 = x2.transpose(1, 2).view(B, N1, H, C)
         
     | 
| 29 | 
         
            +
                x = torch.stack((x1, x2), dim=2)
         
     | 
| 30 | 
         
            +
                x = x.reshape(B, N1 * 2, H, C)
         
     | 
| 31 | 
         
            +
                return x
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            def decouple_interleaved_two_sequences(x: torch.Tensor, len1: int, len2: int):
         
     | 
| 34 | 
         
            +
                B, N, H, C = x.shape
         
     | 
| 35 | 
         
            +
                assert N % 2 == 0 and N // 2 == len1
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                x = x.reshape(B, -1, 2, H, C)
         
     | 
| 38 | 
         
            +
                x1 = x[:, :, 0]
         
     | 
| 39 | 
         
            +
                x2 = x[:, :, 1]
         
     | 
| 40 | 
         
            +
                if x2.shape[1] != len2:
         
     | 
| 41 | 
         
            +
                    x2 = x2.view(B, len1, H * C).transpose(1, 2)
         
     | 
| 42 | 
         
            +
                    x2 = F.interpolate(x2, size=(len2), mode="nearest-exact")
         
     | 
| 43 | 
         
            +
                    x2 = x2.transpose(1, 2).view(B, len2, H, C)
         
     | 
| 44 | 
         
            +
                return x1, x2
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
            class TwoStreamCABlock(nn.Module):
         
     | 
| 47 | 
         
            +
                def __init__(
         
     | 
| 48 | 
         
            +
                    self,
         
     | 
| 49 | 
         
            +
                    hidden_size: int,
         
     | 
| 50 | 
         
            +
                    num_heads: int,
         
     | 
| 51 | 
         
            +
                    mlp_ratio: float,
         
     | 
| 52 | 
         
            +
                    mlp_act_type: str = "gelu_tanh",
         
     | 
| 53 | 
         
            +
                    qk_norm: bool = True,
         
     | 
| 54 | 
         
            +
                    qk_norm_type: str = "rms",
         
     | 
| 55 | 
         
            +
                    qkv_bias: bool = False,
         
     | 
| 56 | 
         
            +
                    attn_mode: str = "torch",
         
     | 
| 57 | 
         
            +
                    reverse: bool = False,
         
     | 
| 58 | 
         
            +
                    interleaved_audio_visual_rope: bool = False,
         
     | 
| 59 | 
         
            +
                    dtype: Optional[torch.dtype] = None,
         
     | 
| 60 | 
         
            +
                    device: Optional[torch.device] = None,
         
     | 
| 61 | 
         
            +
                ):
         
     | 
| 62 | 
         
            +
                    factory_kwargs = {"device": device, "dtype": dtype}
         
     | 
| 63 | 
         
            +
                    super().__init__()
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    self.deterministic = False
         
     | 
| 66 | 
         
            +
                    self.reverse = reverse
         
     | 
| 67 | 
         
            +
                    self.attn_mode = attn_mode
         
     | 
| 68 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 69 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 70 | 
         
            +
                    head_dim = hidden_size // num_heads
         
     | 
| 71 | 
         
            +
                    mlp_hidden_dim = int(hidden_size * mlp_ratio)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    self.interleaved_audio_visual_rope = interleaved_audio_visual_rope
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    # Self attention for audio + visual
         
     | 
| 76 | 
         
            +
                    self.audio_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
         
     | 
| 77 | 
         
            +
                    self.audio_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
         
     | 
| 78 | 
         
            +
                    self.audio_self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
         
     | 
| 79 | 
         
            +
                    qk_norm_layer = get_norm_layer(qk_norm_type)
         
     | 
| 80 | 
         
            +
                    self.audio_self_q_norm = (
         
     | 
| 81 | 
         
            +
                        qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
         
     | 
| 82 | 
         
            +
                    )
         
     | 
| 83 | 
         
            +
                    self.audio_self_k_norm = (
         
     | 
| 84 | 
         
            +
                        qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
         
     | 
| 85 | 
         
            +
                    )
         
     | 
| 86 | 
         
            +
                    self.audio_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    # visual cond
         
     | 
| 89 | 
         
            +
                    self.v_cond_mod = ModulateDiT(hidden_size, factor=9, act_layer=get_activation_layer("silu"), **factory_kwargs)
         
     | 
| 90 | 
         
            +
                    self.v_cond_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
         
     | 
| 91 | 
         
            +
                    self.v_cond_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
         
     | 
| 92 | 
         
            +
                    self.v_cond_attn_q_norm = (
         
     | 
| 93 | 
         
            +
                        qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
         
     | 
| 94 | 
         
            +
                    )
         
     | 
| 95 | 
         
            +
                    self.v_cond_attn_k_norm = (
         
     | 
| 96 | 
         
            +
                        qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
         
     | 
| 97 | 
         
            +
                    )
         
     | 
| 98 | 
         
            +
                    self.v_cond_self_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    self.max_text_len = 100
         
     | 
| 101 | 
         
            +
                    self.rope_dim_list = None
         
     | 
| 102 | 
         
            +
                    
         
     | 
| 103 | 
         
            +
                    # audio and video norm for cross attention with text
         
     | 
| 104 | 
         
            +
                    self.audio_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
         
     | 
| 105 | 
         
            +
                    self.v_cond_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    # Cross attention: (video_audio) as query, text as key/value
         
     | 
| 108 | 
         
            +
                    self.audio_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
         
     | 
| 109 | 
         
            +
                    self.v_cond_cross_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
         
     | 
| 110 | 
         
            +
                    self.text_cross_kv = nn.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs)
         
     | 
| 111 | 
         
            +
                    
         
     | 
| 112 | 
         
            +
                    self.audio_cross_q_norm = (
         
     | 
| 113 | 
         
            +
                        qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
         
     | 
| 114 | 
         
            +
                    )
         
     | 
| 115 | 
         
            +
                    self.v_cond_cross_q_norm = (
         
     | 
| 116 | 
         
            +
                        qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
         
     | 
| 117 | 
         
            +
                    )
         
     | 
| 118 | 
         
            +
                    self.text_cross_k_norm = (
         
     | 
| 119 | 
         
            +
                        qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
         
     | 
| 120 | 
         
            +
                    )
         
     | 
| 121 | 
         
            +
                    self.audio_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
         
     | 
| 122 | 
         
            +
                    self.v_cond_cross_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    # MLPs
         
     | 
| 125 | 
         
            +
                    self.audio_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
         
     | 
| 126 | 
         
            +
                    self.audio_mlp = MLP(
         
     | 
| 127 | 
         
            +
                        hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
         
     | 
| 128 | 
         
            +
                    )
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    self.v_cond_norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
         
     | 
| 131 | 
         
            +
                    self.v_cond_mlp = MLP(
         
     | 
| 132 | 
         
            +
                        hidden_size, mlp_hidden_dim, act_layer=get_activation_layer(mlp_act_type), bias=True, **factory_kwargs
         
     | 
| 133 | 
         
            +
                    )
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None):
         
     | 
| 136 | 
         
            +
                    target_ndim = 1  # n-d RoPE
         
     | 
| 137 | 
         
            +
                    rope_sizes = [text_len]
         
     | 
| 138 | 
         
            +
                    
         
     | 
| 139 | 
         
            +
                    if rope_dim_list is None:
         
     | 
| 140 | 
         
            +
                        rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
         
     | 
| 141 | 
         
            +
                    assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
         
     | 
| 142 | 
         
            +
                    
         
     | 
| 143 | 
         
            +
                    text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed(
         
     | 
| 144 | 
         
            +
                        rope_dim_list=rope_dim_list,
         
     | 
| 145 | 
         
            +
                        start=rope_sizes,
         
     | 
| 146 | 
         
            +
                        theta=10000,
         
     | 
| 147 | 
         
            +
                        use_real=True,
         
     | 
| 148 | 
         
            +
                        theta_rescale_factor=1.0,
         
     | 
| 149 | 
         
            +
                    )
         
     | 
| 150 | 
         
            +
                    return text_freqs_cos, text_freqs_sin
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                def set_attn_mode(self, new_mode):
         
     | 
| 153 | 
         
            +
                    if new_mode != "torch":
         
     | 
| 154 | 
         
            +
                        raise NotImplementedError(f"Only support 'torch' mode, got {new_mode}.")
         
     | 
| 155 | 
         
            +
                    self.attn_mode = new_mode
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                def enable_deterministic(self):
         
     | 
| 158 | 
         
            +
                    self.deterministic = True
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                def disable_deterministic(self):
         
     | 
| 161 | 
         
            +
                    self.deterministic = False
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                def forward(
         
     | 
| 164 | 
         
            +
                    self,
         
     | 
| 165 | 
         
            +
                    audio: torch.Tensor,
         
     | 
| 166 | 
         
            +
                    cond: torch.Tensor,
         
     | 
| 167 | 
         
            +
                    v_cond: torch.Tensor,
         
     | 
| 168 | 
         
            +
                    attn_mask: torch.Tensor,
         
     | 
| 169 | 
         
            +
                    vec: torch.Tensor,
         
     | 
| 170 | 
         
            +
                    freqs_cis: tuple = None,
         
     | 
| 171 | 
         
            +
                    v_freqs_cis: tuple = None,
         
     | 
| 172 | 
         
            +
                    sync_vec: torch.Tensor = None,
         
     | 
| 173 | 
         
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
         
     | 
| 174 | 
         
            +
                    # Get modulation parameters
         
     | 
| 175 | 
         
            +
                    if sync_vec is not None:
         
     | 
| 176 | 
         
            +
                        assert sync_vec.ndim == 3
         
     | 
| 177 | 
         
            +
                        (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, 
         
     | 
| 178 | 
         
            +
                         audio_mod2_shift, audio_mod2_scale, audio_mod2_gate, 
         
     | 
| 179 | 
         
            +
                         audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
         
     | 
| 180 | 
         
            +
                        ) = self.audio_mod(sync_vec).chunk(9, dim=-1)
         
     | 
| 181 | 
         
            +
                    else:
         
     | 
| 182 | 
         
            +
                        (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, 
         
     | 
| 183 | 
         
            +
                         audio_mod2_shift, audio_mod2_scale, audio_mod2_gate,
         
     | 
| 184 | 
         
            +
                         audio_mod3_shift, audio_mod3_scale, audio_mod3_gate,
         
     | 
| 185 | 
         
            +
                        ) = self.audio_mod(vec).chunk(9, dim=-1)
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                    (
         
     | 
| 188 | 
         
            +
                        v_cond_mod1_shift,
         
     | 
| 189 | 
         
            +
                        v_cond_mod1_scale,
         
     | 
| 190 | 
         
            +
                        v_cond_mod1_gate,
         
     | 
| 191 | 
         
            +
                        v_cond_mod2_shift,
         
     | 
| 192 | 
         
            +
                        v_cond_mod2_scale,
         
     | 
| 193 | 
         
            +
                        v_cond_mod2_gate,
         
     | 
| 194 | 
         
            +
                        v_cond_mod3_shift,
         
     | 
| 195 | 
         
            +
                        v_cond_mod3_scale,
         
     | 
| 196 | 
         
            +
                        v_cond_mod3_gate,
         
     | 
| 197 | 
         
            +
                    ) = self.v_cond_mod(vec).chunk(9, dim=-1)
         
     | 
| 198 | 
         
            +
                    
         
     | 
| 199 | 
         
            +
                    # 1. Self Attention for audio + visual
         
     | 
| 200 | 
         
            +
                    audio_modulated = self.audio_norm1(audio)
         
     | 
| 201 | 
         
            +
                    audio_modulated = modulate(audio_modulated, shift=audio_mod1_shift, scale=audio_mod1_scale)
         
     | 
| 202 | 
         
            +
                    audio_qkv = self.audio_self_attn_qkv(audio_modulated)
         
     | 
| 203 | 
         
            +
                    audio_q, audio_k, audio_v = rearrange(audio_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
         
     | 
| 204 | 
         
            +
                    audio_q = self.audio_self_q_norm(audio_q).to(audio_v)
         
     | 
| 205 | 
         
            +
                    audio_k = self.audio_self_k_norm(audio_k).to(audio_v)
         
     | 
| 206 | 
         
            +
                    
         
     | 
| 207 | 
         
            +
                    # Prepare visual cond for attention
         
     | 
| 208 | 
         
            +
                    v_cond_modulated = self.v_cond_norm1(v_cond)
         
     | 
| 209 | 
         
            +
                    v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod1_shift, scale=v_cond_mod1_scale)
         
     | 
| 210 | 
         
            +
                    v_cond_qkv = self.v_cond_attn_qkv(v_cond_modulated)
         
     | 
| 211 | 
         
            +
                    v_cond_q, v_cond_k, v_cond_v = rearrange(v_cond_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
         
     | 
| 212 | 
         
            +
                    v_cond_q = self.v_cond_attn_q_norm(v_cond_q).to(v_cond_v)
         
     | 
| 213 | 
         
            +
                    v_cond_k = self.v_cond_attn_k_norm(v_cond_k).to(v_cond_v)
         
     | 
| 214 | 
         
            +
                    
         
     | 
| 215 | 
         
            +
                    # Apply RoPE if needed for audio and visual
         
     | 
| 216 | 
         
            +
                    if freqs_cis is not None:
         
     | 
| 217 | 
         
            +
                        if not self.interleaved_audio_visual_rope:
         
     | 
| 218 | 
         
            +
                            audio_qq, audio_kk = apply_rotary_emb(audio_q, audio_k, freqs_cis, head_first=False)
         
     | 
| 219 | 
         
            +
                            audio_q, audio_k = audio_qq, audio_kk
         
     | 
| 220 | 
         
            +
                        else:
         
     | 
| 221 | 
         
            +
                            ori_audio_len = audio_q.shape[1]
         
     | 
| 222 | 
         
            +
                            ori_v_con_len = v_cond_q.shape[1]
         
     | 
| 223 | 
         
            +
                            interleaved_audio_visual_q = interleave_two_sequences(audio_q, v_cond_q)
         
     | 
| 224 | 
         
            +
                            interleaved_audio_visual_k = interleave_two_sequences(audio_k, v_cond_k)
         
     | 
| 225 | 
         
            +
                            interleaved_audio_visual_qq, interleaved_audio_visual_kk = apply_rotary_emb(
         
     | 
| 226 | 
         
            +
                                interleaved_audio_visual_q, interleaved_audio_visual_k, freqs_cis, head_first=False
         
     | 
| 227 | 
         
            +
                            )
         
     | 
| 228 | 
         
            +
                            audio_qq, v_cond_qq = decouple_interleaved_two_sequences(
         
     | 
| 229 | 
         
            +
                                interleaved_audio_visual_qq, ori_audio_len, ori_v_con_len
         
     | 
| 230 | 
         
            +
                            )
         
     | 
| 231 | 
         
            +
                            audio_kk, v_cond_kk = decouple_interleaved_two_sequences(
         
     | 
| 232 | 
         
            +
                                interleaved_audio_visual_kk, ori_audio_len, ori_v_con_len
         
     | 
| 233 | 
         
            +
                            )
         
     | 
| 234 | 
         
            +
                            audio_q, audio_k = audio_qq, audio_kk
         
     | 
| 235 | 
         
            +
                            v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    # Apply RoPE to visual if needed and not interleaved
         
     | 
| 238 | 
         
            +
                    if v_freqs_cis is not None and not self.interleaved_audio_visual_rope:
         
     | 
| 239 | 
         
            +
                        v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False)
         
     | 
| 240 | 
         
            +
                        v_cond_q, v_cond_k = v_cond_qq, v_cond_kk
         
     | 
| 241 | 
         
            +
                    
         
     | 
| 242 | 
         
            +
                    # Concatenate for self-attention
         
     | 
| 243 | 
         
            +
                    q = torch.cat((v_cond_q, audio_q), dim=1)
         
     | 
| 244 | 
         
            +
                    k = torch.cat((v_cond_k, audio_k), dim=1)
         
     | 
| 245 | 
         
            +
                    v = torch.cat((v_cond_v, audio_v), dim=1)
         
     | 
| 246 | 
         
            +
                    
         
     | 
| 247 | 
         
            +
                    # Run self-attention
         
     | 
| 248 | 
         
            +
                    attn = attention(q, k, v, mode=self.attn_mode, attn_mask=attn_mask, deterministic=self.deterministic)
         
     | 
| 249 | 
         
            +
                    v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1)
         
     | 
| 250 | 
         
            +
                    
         
     | 
| 251 | 
         
            +
                    # Apply self-attention output to audio and v_cond
         
     | 
| 252 | 
         
            +
                    audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate)
         
     | 
| 253 | 
         
            +
                    v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate)
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    # 2. Cross Attention: (v_cond, audio) as query, text as key/value
         
     | 
| 256 | 
         
            +
                    # audio, v_cond modulation
         
     | 
| 257 | 
         
            +
                    audio_modulated = self.audio_norm2(audio)
         
     | 
| 258 | 
         
            +
                    audio_modulated = modulate(audio_modulated, shift=audio_mod2_shift, scale=audio_mod2_scale)
         
     | 
| 259 | 
         
            +
                    v_cond_modulated = self.v_cond_norm2(v_cond)
         
     | 
| 260 | 
         
            +
                    v_cond_modulated = modulate(v_cond_modulated, shift=v_cond_mod2_shift, scale=v_cond_mod2_scale)
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    # Prepare audio query
         
     | 
| 263 | 
         
            +
                    audio_q = self.audio_cross_q(audio_modulated)
         
     | 
| 264 | 
         
            +
                    audio_q = rearrange(audio_q, "B L (H D) -> B L H D", H=self.num_heads)
         
     | 
| 265 | 
         
            +
                    audio_q = self.audio_cross_q_norm(audio_q)
         
     | 
| 266 | 
         
            +
                    
         
     | 
| 267 | 
         
            +
                    # Prepare v_cond query
         
     | 
| 268 | 
         
            +
                    v_cond_q = self.v_cond_cross_q(v_cond_modulated)
         
     | 
| 269 | 
         
            +
                    v_cond_q = rearrange(v_cond_q, "B L (H D) -> B L H D", H=self.num_heads)
         
     | 
| 270 | 
         
            +
                    v_cond_q = self.v_cond_cross_q_norm(v_cond_q)
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                    # Prepare text key/value
         
     | 
| 273 | 
         
            +
                    text_kv = self.text_cross_kv(cond)
         
     | 
| 274 | 
         
            +
                    text_k, text_v = rearrange(text_kv, "B L (K H D) -> K B L H D", K=2, H=self.num_heads)
         
     | 
| 275 | 
         
            +
                    text_k = self.text_cross_k_norm(text_k).to(text_v)
         
     | 
| 276 | 
         
            +
                    
         
     | 
| 277 | 
         
            +
                    # Apply RoPE to (v_cond, audio) query and text key if needed
         
     | 
| 278 | 
         
            +
                    head_dim = self.hidden_size // self.num_heads
         
     | 
| 279 | 
         
            +
                    audio_cross_freqs_cos, audio_cross_freqs_sin = self.build_rope_for_text(audio_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
         
     | 
| 280 | 
         
            +
                    audio_cross_freqs_cis = (audio_cross_freqs_cos.to(audio_q.device), audio_cross_freqs_sin.to(audio_q.device))
         
     | 
| 281 | 
         
            +
                    audio_q = apply_rotary_emb(audio_q, audio_q, audio_cross_freqs_cis, head_first=False)[0]
         
     | 
| 282 | 
         
            +
                    
         
     | 
| 283 | 
         
            +
                    v_cond_cross_freqs_cos, v_cond_cross_freqs_sin = self.build_rope_for_text(v_cond_q.shape[1], head_dim, rope_dim_list=self.rope_dim_list)
         
     | 
| 284 | 
         
            +
                    v_cond_cross_freqs_cis = (v_cond_cross_freqs_cos.to(v_cond_q.device), v_cond_cross_freqs_sin.to(v_cond_q.device))
         
     | 
| 285 | 
         
            +
                    v_cond_q = apply_rotary_emb(v_cond_q, v_cond_q, v_cond_cross_freqs_cis, head_first=False)[0]
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                    text_len = text_k.shape[1]
         
     | 
| 288 | 
         
            +
                    
         
     | 
| 289 | 
         
            +
                    text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim, 
         
     | 
| 290 | 
         
            +
                                                                             rope_dim_list=self.rope_dim_list)
         
     | 
| 291 | 
         
            +
                    text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device))
         
     | 
| 292 | 
         
            +
                    text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1]
         
     | 
| 293 | 
         
            +
                    
         
     | 
| 294 | 
         
            +
                    # Concat v_cond and audio for cross-attention  
         
     | 
| 295 | 
         
            +
                    v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1)
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                    # Run cross-attention
         
     | 
| 298 | 
         
            +
                    cross_attn = attention(v_cond_audio_q, text_k, text_v, mode=self.attn_mode, deterministic=self.deterministic)
         
     | 
| 299 | 
         
            +
                    v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1)
         
     | 
| 300 | 
         
            +
                    
         
     | 
| 301 | 
         
            +
                    # Apply cross-attention output
         
     | 
| 302 | 
         
            +
                    audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate)
         
     | 
| 303 | 
         
            +
                    v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate)
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                    # 3. Apply MLPs
         
     | 
| 306 | 
         
            +
                    audio = audio + apply_gate(
         
     | 
| 307 | 
         
            +
                        self.audio_mlp(modulate(self.audio_norm3(audio), shift=audio_mod3_shift, scale=audio_mod3_scale)),
         
     | 
| 308 | 
         
            +
                        gate=audio_mod3_gate,
         
     | 
| 309 | 
         
            +
                    )
         
     | 
| 310 | 
         
            +
                    
         
     | 
| 311 | 
         
            +
                    # Apply visual MLP
         
     | 
| 312 | 
         
            +
                    v_cond = v_cond + apply_gate(
         
     | 
| 313 | 
         
            +
                        self.v_cond_mlp(modulate(self.v_cond_norm3(v_cond), shift=v_cond_mod3_shift, scale=v_cond_mod3_scale)),
         
     | 
| 314 | 
         
            +
                        gate=v_cond_mod3_gate,
         
     | 
| 315 | 
         
            +
                    )
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                    return audio, cond, v_cond
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
            class SingleStreamBlock(nn.Module):
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                def __init__(self, hidden_size: int,
         
     | 
| 322 | 
         
            +
                                num_heads: int,
         
     | 
| 323 | 
         
            +
                                mlp_ratio: float,
         
     | 
| 324 | 
         
            +
                                qk_norm_type: str = "rms",
         
     | 
| 325 | 
         
            +
                                dtype: Optional[torch.dtype] = None,
         
     | 
| 326 | 
         
            +
                                device: Optional[torch.device] = None,):
         
     | 
| 327 | 
         
            +
                    factory_kwargs = {"device": device, "dtype": dtype}
         
     | 
| 328 | 
         
            +
                    super().__init__()
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 331 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                    self.modulation = ModulateDiT(
         
     | 
| 334 | 
         
            +
                        hidden_size=hidden_size,
         
     | 
| 335 | 
         
            +
                        factor=6,
         
     | 
| 336 | 
         
            +
                        act_layer=get_activation_layer("silu"),
         
     | 
| 337 | 
         
            +
                        **factory_kwargs,
         
     | 
| 338 | 
         
            +
                    )
         
     | 
| 339 | 
         
            +
                    self.linear_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=True)
         
     | 
| 340 | 
         
            +
                    self.linear1 = ChannelLastConv1d(hidden_size, hidden_size, kernel_size=3, padding=1, **factory_kwargs)
         
     | 
| 341 | 
         
            +
                    self.linear2 = ConvMLP(hidden_size, hidden_size * mlp_ratio, kernel_size=3, padding=1, **factory_kwargs)
         
     | 
| 342 | 
         
            +
                    self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
         
     | 
| 343 | 
         
            +
                    self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
         
     | 
| 344 | 
         
            +
                    self.q_norm = nn.RMSNorm(hidden_size // num_heads)
         
     | 
| 345 | 
         
            +
                    self.k_norm = nn.RMSNorm(hidden_size // num_heads)
         
     | 
| 346 | 
         
            +
                    self.rearrange = Rearrange("B L (H D K) -> B H L D K", K=3, H=num_heads)
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
                def forward(self, x: torch.Tensor, cond: torch.Tensor,freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None):
         
     | 
| 349 | 
         
            +
                    assert cond.ndim == 3, "Condition should be in shape of [B, T, D]"
         
     | 
| 350 | 
         
            +
                    modulation = self.modulation(cond)
         
     | 
| 351 | 
         
            +
                    shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = modulation.chunk(6, dim=-1)
         
     | 
| 352 | 
         
            +
                    x_norm1 = self.norm1(x) * (1 + scale_msa) + shift_msa
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                    qkv = self.linear_qkv(x_norm1)
         
     | 
| 355 | 
         
            +
                    q, k, v = self.rearrange(qkv).chunk(3, dim=-1)
         
     | 
| 356 | 
         
            +
                    q = q.squeeze(-1)
         
     | 
| 357 | 
         
            +
                    k = k.squeeze(-1)
         
     | 
| 358 | 
         
            +
                    v = v.squeeze(-1)
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
                    q = self.q_norm(q)
         
     | 
| 361 | 
         
            +
                    k = self.k_norm(k)
         
     | 
| 362 | 
         
            +
                    q, k = apply_rotary_emb(q, k, freqs_cis, head_first=True)
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                    q = q.contiguous()
         
     | 
| 365 | 
         
            +
                    k = k.contiguous()
         
     | 
| 366 | 
         
            +
                    v = v.contiguous()
         
     | 
| 367 | 
         
            +
                    out = F.scaled_dot_product_attention(q, k, v)
         
     | 
| 368 | 
         
            +
                    out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                    x = x + apply_gate(self.linear1(out),gate=gate_msa)
         
     | 
| 371 | 
         
            +
                    x_norm = self.norm2(x) * (1 + scale_mlp) + shift_mlp
         
     | 
| 372 | 
         
            +
                    x = x + apply_gate(self.linear2(x_norm), gate=gate_mlp)
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                    return x
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
            class HunyuanVideoFoley(ModelMixin, ConfigMixin):
         
     | 
| 377 | 
         
            +
                @register_to_config
         
     | 
| 378 | 
         
            +
                def __init__(
         
     | 
| 379 | 
         
            +
                    self,
         
     | 
| 380 | 
         
            +
                    model_config,
         
     | 
| 381 | 
         
            +
                    dtype: Optional[torch.dtype] = None,
         
     | 
| 382 | 
         
            +
                    device: Optional[torch.device] = None,
         
     | 
| 383 | 
         
            +
                ):
         
     | 
| 384 | 
         
            +
                    factory_kwargs = {"device": device, "dtype": dtype}
         
     | 
| 385 | 
         
            +
                    super().__init__()
         
     | 
| 386 | 
         
            +
             
     | 
| 387 | 
         
            +
                    model_args = model_config.model_config.model_kwargs
         
     | 
| 388 | 
         
            +
                    self.depth_triple_blocks = model_args.get("depth_triple_blocks", 19)
         
     | 
| 389 | 
         
            +
                    self.depth_single_blocks = model_args.get("depth_single_blocks", 38)
         
     | 
| 390 | 
         
            +
                    # Gradient checkpoint.
         
     | 
| 391 | 
         
            +
                    self.gradient_checkpoint = False
         
     | 
| 392 | 
         
            +
                    self.gradient_checkpoint_layers = None
         
     | 
| 393 | 
         
            +
                    if self.gradient_checkpoint:
         
     | 
| 394 | 
         
            +
                        assert self.gradient_checkpoint_layers <= self.depth_triple_blocks + self.depth_single_blocks, (
         
     | 
| 395 | 
         
            +
                            f"Gradient checkpoint layers must be less or equal than the depth of the model. "
         
     | 
| 396 | 
         
            +
                            f"Got gradient_checkpoint_layers={self.gradient_checkpoint_layers} and depth={self.depth_triple_blocks + self.depth_single_blocks}."
         
     | 
| 397 | 
         
            +
                        )
         
     | 
| 398 | 
         
            +
             
     | 
| 399 | 
         
            +
                    self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", False)
         
     | 
| 400 | 
         
            +
             
     | 
| 401 | 
         
            +
                    # Condition projection. Default to linear projection.
         
     | 
| 402 | 
         
            +
                    self.condition_projection = model_args.get("condition_projection", "linear")
         
     | 
| 403 | 
         
            +
                    self.condition_dim = model_args.get("condition_dim", None)
         
     | 
| 404 | 
         
            +
                    self.use_attention_mask = model_args.get("use_attention_mask", False)
         
     | 
| 405 | 
         
            +
             
     | 
| 406 | 
         
            +
                    self.patch_size = model_args.get("patch_size", 1)
         
     | 
| 407 | 
         
            +
                    self.visual_in_channels = model_args.get("clip_dim", 768)
         
     | 
| 408 | 
         
            +
                    self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128)
         
     | 
| 409 | 
         
            +
                    self.out_channels = self.audio_vae_latent_dim 
         
     | 
| 410 | 
         
            +
                    self.unpatchify_channels = self.out_channels
         
     | 
| 411 | 
         
            +
                    self.reverse = model_args.get("reverse", False)
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                    self.num_heads = model_args.get("num_heads", 24)
         
     | 
| 414 | 
         
            +
                    self.hidden_size = model_args.get("hidden_size", 3072)
         
     | 
| 415 | 
         
            +
                    self.rope_dim_list = model_args.get("rope_dim_list", None)
         
     | 
| 416 | 
         
            +
                    self.mlp_ratio = model_args.get("mlp_ratio", 4.0)
         
     | 
| 417 | 
         
            +
                    self.mlp_act_type = model_args.get("mlp_act_type", "gelu_tanh")
         
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
                    self.qkv_bias = model_args.get("qkv_bias", True)
         
     | 
| 420 | 
         
            +
                    self.qk_norm = model_args.get("qk_norm", True)
         
     | 
| 421 | 
         
            +
                    self.qk_norm_type = model_args.get("qk_norm_type", "rms")
         
     | 
| 422 | 
         
            +
                    self.attn_mode = model_args.get("attn_mode", "torch")
         
     | 
| 423 | 
         
            +
             
     | 
| 424 | 
         
            +
                    self.embedder_type = model_args.get("embedder_type", "default")
         
     | 
| 425 | 
         
            +
             
     | 
| 426 | 
         
            +
                    # sync condition things
         
     | 
| 427 | 
         
            +
                    self.sync_modulation = model_args.get("sync_modulation", False)
         
     | 
| 428 | 
         
            +
                    self.add_sync_feat_to_audio = model_args.get("add_sync_feat_to_audio", False)
         
     | 
| 429 | 
         
            +
                    self.sync_feat_dim = model_args.get("sync_feat_dim", 768)
         
     | 
| 430 | 
         
            +
                    self.sync_in_ksz = model_args.get("sync_in_ksz", 1)
         
     | 
| 431 | 
         
            +
             
     | 
| 432 | 
         
            +
                    # condition tokens length
         
     | 
| 433 | 
         
            +
                    self.clip_len = model_args.get("clip_length", 64)
         
     | 
| 434 | 
         
            +
                    self.sync_len = model_args.get("sync_length", 192)
         
     | 
| 435 | 
         
            +
             
     | 
| 436 | 
         
            +
                    if self.hidden_size % self.num_heads != 0:
         
     | 
| 437 | 
         
            +
                        raise ValueError(f"Hidden size {self.hidden_size} must be divisible by num_heads {self.num_heads}")
         
     | 
| 438 | 
         
            +
             
     | 
| 439 | 
         
            +
                    # Build audio patchify layer and visual gated linear projection
         
     | 
| 440 | 
         
            +
                    self.patch_size = 1
         
     | 
| 441 | 
         
            +
                    self.audio_embedder = PatchEmbed1D(self.patch_size, self.audio_vae_latent_dim, self.hidden_size, **factory_kwargs)
         
     | 
| 442 | 
         
            +
                    self.visual_proj = SwiGLU(self.visual_in_channels, hidden_dim=self.hidden_size, out_dim=self.hidden_size)
         
     | 
| 443 | 
         
            +
             
     | 
| 444 | 
         
            +
                    # condition
         
     | 
| 445 | 
         
            +
                    if self.condition_projection == "linear":
         
     | 
| 446 | 
         
            +
                        self.cond_in = ConditionProjection(
         
     | 
| 447 | 
         
            +
                            self.condition_dim, self.hidden_size, get_activation_layer("silu"), **factory_kwargs
         
     | 
| 448 | 
         
            +
                        )
         
     | 
| 449 | 
         
            +
                    else:
         
     | 
| 450 | 
         
            +
                        raise NotImplementedError(f"Unsupported condition_projection: {self.condition_projection}")
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
                    # time modulation
         
     | 
| 453 | 
         
            +
                    self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
         
     | 
| 454 | 
         
            +
             
     | 
| 455 | 
         
            +
                    # visual sync embedder if needed
         
     | 
| 456 | 
         
            +
                    if self.sync_in_ksz == 1:
         
     | 
| 457 | 
         
            +
                        sync_in_padding = 0
         
     | 
| 458 | 
         
            +
                    elif self.sync_in_ksz == 3:
         
     | 
| 459 | 
         
            +
                        sync_in_padding = 1
         
     | 
| 460 | 
         
            +
                    else:
         
     | 
| 461 | 
         
            +
                        raise ValueError
         
     | 
| 462 | 
         
            +
                    if self.sync_modulation or self.add_sync_feat_to_audio:
         
     | 
| 463 | 
         
            +
                        self.sync_in = nn.Sequential(
         
     | 
| 464 | 
         
            +
                            nn.Linear(self.sync_feat_dim, self.hidden_size),
         
     | 
| 465 | 
         
            +
                            nn.SiLU(),
         
     | 
| 466 | 
         
            +
                            ConvMLP(self.hidden_size, self.hidden_size * 4, kernel_size=self.sync_in_ksz, padding=sync_in_padding),
         
     | 
| 467 | 
         
            +
                        )
         
     | 
| 468 | 
         
            +
                        self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, self.sync_feat_dim)))
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
                    self.triple_blocks = nn.ModuleList(
         
     | 
| 471 | 
         
            +
                        [
         
     | 
| 472 | 
         
            +
                            TwoStreamCABlock(
         
     | 
| 473 | 
         
            +
                                hidden_size=self.hidden_size,
         
     | 
| 474 | 
         
            +
                                num_heads=self.num_heads,
         
     | 
| 475 | 
         
            +
                                mlp_ratio=self.mlp_ratio,
         
     | 
| 476 | 
         
            +
                                mlp_act_type=self.mlp_act_type,
         
     | 
| 477 | 
         
            +
                                qk_norm=self.qk_norm,
         
     | 
| 478 | 
         
            +
                                qk_norm_type=self.qk_norm_type,
         
     | 
| 479 | 
         
            +
                                qkv_bias=self.qkv_bias,
         
     | 
| 480 | 
         
            +
                                attn_mode=self.attn_mode,
         
     | 
| 481 | 
         
            +
                                reverse=self.reverse,
         
     | 
| 482 | 
         
            +
                                interleaved_audio_visual_rope=self.interleaved_audio_visual_rope,
         
     | 
| 483 | 
         
            +
                                **factory_kwargs,
         
     | 
| 484 | 
         
            +
                            )
         
     | 
| 485 | 
         
            +
                            for _ in range(self.depth_triple_blocks)
         
     | 
| 486 | 
         
            +
                        ]
         
     | 
| 487 | 
         
            +
                    )
         
     | 
| 488 | 
         
            +
             
     | 
| 489 | 
         
            +
             
     | 
| 490 | 
         
            +
                    self.single_blocks = nn.ModuleList(
         
     | 
| 491 | 
         
            +
                        [
         
     | 
| 492 | 
         
            +
                            SingleStreamBlock(
         
     | 
| 493 | 
         
            +
                                hidden_size=self.hidden_size,
         
     | 
| 494 | 
         
            +
                                num_heads=self.num_heads,
         
     | 
| 495 | 
         
            +
                                mlp_ratio=self.mlp_ratio,
         
     | 
| 496 | 
         
            +
                                qk_norm_type=self.qk_norm_type,
         
     | 
| 497 | 
         
            +
                                **factory_kwargs,
         
     | 
| 498 | 
         
            +
                            )
         
     | 
| 499 | 
         
            +
                            for _ in range(self.depth_single_blocks)
         
     | 
| 500 | 
         
            +
                        ]
         
     | 
| 501 | 
         
            +
                    )
         
     | 
| 502 | 
         
            +
             
     | 
| 503 | 
         
            +
                    self.final_layer = FinalLayer1D(
         
     | 
| 504 | 
         
            +
                        self.hidden_size, self.patch_size, self.out_channels, get_activation_layer("silu"), **factory_kwargs
         
     | 
| 505 | 
         
            +
                    )
         
     | 
| 506 | 
         
            +
                    self.unpatchify_channels = self.out_channels
         
     | 
| 507 | 
         
            +
             
     | 
| 508 | 
         
            +
                    self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels), requires_grad=True)
         
     | 
| 509 | 
         
            +
                    self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim), requires_grad=True)
         
     | 
| 510 | 
         
            +
                    nn.init.constant_(self.empty_clip_feat, 0)
         
     | 
| 511 | 
         
            +
                    nn.init.constant_(self.empty_sync_feat, 0)
         
     | 
| 512 | 
         
            +
             
     | 
| 513 | 
         
            +
                def get_empty_string_sequence(self, bs=None) -> torch.Tensor:
         
     | 
| 514 | 
         
            +
                    if bs is None:
         
     | 
| 515 | 
         
            +
                        return self.empty_string_feat
         
     | 
| 516 | 
         
            +
                    else:
         
     | 
| 517 | 
         
            +
                        return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1)
         
     | 
| 518 | 
         
            +
             
     | 
| 519 | 
         
            +
                def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
         
     | 
| 520 | 
         
            +
                    len = len if len is not None else self.clip_len
         
     | 
| 521 | 
         
            +
                    if bs is None:
         
     | 
| 522 | 
         
            +
                        return self.empty_clip_feat.expand(len, -1)  # 15s
         
     | 
| 523 | 
         
            +
                    else:
         
     | 
| 524 | 
         
            +
                        return self.empty_clip_feat.unsqueeze(0).expand(bs, len, -1)  # 15s
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
            +
                def get_empty_sync_sequence(self, bs=None, len=None) -> torch.Tensor:
         
     | 
| 527 | 
         
            +
                    len = len if len is not None else self.sync_len
         
     | 
| 528 | 
         
            +
                    if bs is None:
         
     | 
| 529 | 
         
            +
                        return self.empty_sync_feat.expand(len, -1)
         
     | 
| 530 | 
         
            +
                    else:
         
     | 
| 531 | 
         
            +
                        return self.empty_sync_feat.unsqueeze(0).expand(bs, len, -1)
         
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
                def build_rope_for_audio_visual(self, audio_emb_len, visual_cond_len):
         
     | 
| 534 | 
         
            +
                    assert self.patch_size == 1
         
     | 
| 535 | 
         
            +
                    # ======================================== Build RoPE for audio tokens ======================================
         
     | 
| 536 | 
         
            +
                    target_ndim = 1  # n-d RoPE
         
     | 
| 537 | 
         
            +
                    rope_sizes = [audio_emb_len]
         
     | 
| 538 | 
         
            +
                    head_dim = self.hidden_size // self.num_heads
         
     | 
| 539 | 
         
            +
                    rope_dim_list = self.rope_dim_list
         
     | 
| 540 | 
         
            +
                    if rope_dim_list is None:
         
     | 
| 541 | 
         
            +
                        rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
         
     | 
| 542 | 
         
            +
                    assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
         
     | 
| 543 | 
         
            +
                    freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
         
     | 
| 544 | 
         
            +
                        rope_dim_list=rope_dim_list,
         
     | 
| 545 | 
         
            +
                        start=rope_sizes,
         
     | 
| 546 | 
         
            +
                        theta=10000,
         
     | 
| 547 | 
         
            +
                        use_real=True,
         
     | 
| 548 | 
         
            +
                        theta_rescale_factor=1.0,
         
     | 
| 549 | 
         
            +
                    )
         
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
                    # ========================== Build RoPE for clip tokens =========================
         
     | 
| 552 | 
         
            +
                    target_ndim = 1  # n-d RoPE
         
     | 
| 553 | 
         
            +
                    rope_sizes = [visual_cond_len]
         
     | 
| 554 | 
         
            +
                    head_dim = self.hidden_size // self.num_heads
         
     | 
| 555 | 
         
            +
                    rope_dim_list = self.rope_dim_list
         
     | 
| 556 | 
         
            +
                    if rope_dim_list is None:
         
     | 
| 557 | 
         
            +
                        rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
         
     | 
| 558 | 
         
            +
                    assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
         
     | 
| 559 | 
         
            +
                    v_freqs_cos, v_freqs_sin = get_nd_rotary_pos_embed(
         
     | 
| 560 | 
         
            +
                        rope_dim_list=rope_dim_list,
         
     | 
| 561 | 
         
            +
                        start=rope_sizes,
         
     | 
| 562 | 
         
            +
                        theta=10000,
         
     | 
| 563 | 
         
            +
                        use_real=True,
         
     | 
| 564 | 
         
            +
                        theta_rescale_factor=1.0,
         
     | 
| 565 | 
         
            +
                        freq_scaling=1.0 * audio_emb_len / visual_cond_len,
         
     | 
| 566 | 
         
            +
                    )
         
     | 
| 567 | 
         
            +
                    return freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin
         
     | 
| 568 | 
         
            +
             
     | 
| 569 | 
         
            +
                def build_rope_for_interleaved_audio_visual(self, total_len):
         
     | 
| 570 | 
         
            +
                    assert self.patch_size == 1
         
     | 
| 571 | 
         
            +
                    # ========================== Build RoPE for audio tokens ========================
         
     | 
| 572 | 
         
            +
                    target_ndim = 1  # n-d RoPE
         
     | 
| 573 | 
         
            +
                    rope_sizes = [total_len]
         
     | 
| 574 | 
         
            +
                    head_dim = self.hidden_size // self.num_heads
         
     | 
| 575 | 
         
            +
                    rope_dim_list = self.rope_dim_list
         
     | 
| 576 | 
         
            +
                    if rope_dim_list is None:
         
     | 
| 577 | 
         
            +
                        rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
         
     | 
| 578 | 
         
            +
                    assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
         
     | 
| 579 | 
         
            +
                    freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
         
     | 
| 580 | 
         
            +
                        rope_dim_list=rope_dim_list,
         
     | 
| 581 | 
         
            +
                        start=rope_sizes,
         
     | 
| 582 | 
         
            +
                        theta=10000,
         
     | 
| 583 | 
         
            +
                        use_real=True,
         
     | 
| 584 | 
         
            +
                        theta_rescale_factor=1.0,
         
     | 
| 585 | 
         
            +
                    )
         
     | 
| 586 | 
         
            +
                    return freqs_cos, freqs_sin
         
     | 
| 587 | 
         
            +
             
     | 
| 588 | 
         
            +
                def set_attn_mode(self, new_mode):
         
     | 
| 589 | 
         
            +
                    for block in self.triple_blocks:
         
     | 
| 590 | 
         
            +
                        block.set_attn_mode(new_mode)
         
     | 
| 591 | 
         
            +
                    for block in self.single_blocks:
         
     | 
| 592 | 
         
            +
                        block.set_attn_mode(new_mode)
         
     | 
| 593 | 
         
            +
             
     | 
| 594 | 
         
            +
                def enable_deterministic(self):
         
     | 
| 595 | 
         
            +
                    for block in self.triple_blocks:
         
     | 
| 596 | 
         
            +
                        block.enable_deterministic()
         
     | 
| 597 | 
         
            +
                    for block in self.single_blocks:
         
     | 
| 598 | 
         
            +
                        block.enable_deterministic()
         
     | 
| 599 | 
         
            +
             
     | 
| 600 | 
         
            +
                def disable_deterministic(self):
         
     | 
| 601 | 
         
            +
                    for block in self.triple_blocks:
         
     | 
| 602 | 
         
            +
                        block.disable_deterministic()
         
     | 
| 603 | 
         
            +
                    for block in self.single_blocks:
         
     | 
| 604 | 
         
            +
                        block.disable_deterministic()
         
     | 
| 605 | 
         
            +
             
     | 
| 606 | 
         
            +
                def forward(
         
     | 
| 607 | 
         
            +
                    self,
         
     | 
| 608 | 
         
            +
                    x: torch.Tensor,
         
     | 
| 609 | 
         
            +
                    t: torch.Tensor,  # Should be in range(0, 1000).
         
     | 
| 610 | 
         
            +
                    clip_feat: Optional[torch.Tensor] = None,
         
     | 
| 611 | 
         
            +
                    cond: torch.Tensor = None,
         
     | 
| 612 | 
         
            +
                    audio_mask: Optional[torch.Tensor] = None,
         
     | 
| 613 | 
         
            +
                    cond_mask: torch.Tensor = None,
         
     | 
| 614 | 
         
            +
                    sync_feat: Optional[torch.Tensor] = None,
         
     | 
| 615 | 
         
            +
                    drop_visual: Optional[List[bool]] = None,
         
     | 
| 616 | 
         
            +
                    return_dict: bool = True,
         
     | 
| 617 | 
         
            +
                ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
         
     | 
| 618 | 
         
            +
                    out = {}
         
     | 
| 619 | 
         
            +
                    audio = x
         
     | 
| 620 | 
         
            +
                    bs, _, ol = x.shape
         
     | 
| 621 | 
         
            +
                    tl = ol // self.patch_size
         
     | 
| 622 | 
         
            +
             
     | 
| 623 | 
         
            +
                    # Prepare learnable empty conditions for visual condition
         
     | 
| 624 | 
         
            +
                    if drop_visual is not None:
         
     | 
| 625 | 
         
            +
                        clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
         
     | 
| 626 | 
         
            +
                        sync_feat[drop_visual] = self.get_empty_sync_sequence().to(dtype=sync_feat.dtype)
         
     | 
| 627 | 
         
            +
             
     | 
| 628 | 
         
            +
                    # ========================= Prepare time & visual modulation =========================
         
     | 
| 629 | 
         
            +
                    vec = self.time_in(t)
         
     | 
| 630 | 
         
            +
                    sync_vec = None
         
     | 
| 631 | 
         
            +
                    if self.sync_modulation:
         
     | 
| 632 | 
         
            +
                        assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
         
     | 
| 633 | 
         
            +
                        sync_feat = sync_feat.view(bs, int(sync_feat.shape[1] / 8), 8, self.sync_feat_dim) + self.sync_pos_emb
         
     | 
| 634 | 
         
            +
                        sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim)  # bs, num_segments * 8, channels
         
     | 
| 635 | 
         
            +
                        sync_vec = self.sync_in(sync_feat)  # bs, num_segments * 8, c
         
     | 
| 636 | 
         
            +
                        sync_vec = (
         
     | 
| 637 | 
         
            +
                            F.interpolate(sync_vec.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
         
     | 
| 638 | 
         
            +
                        )  # bs, tl, c
         
     | 
| 639 | 
         
            +
                        sync_vec = sync_vec + vec.unsqueeze(1)
         
     | 
| 640 | 
         
            +
                    elif self.add_sync_feat_to_audio:
         
     | 
| 641 | 
         
            +
                        assert sync_feat is not None and sync_feat.shape[1] % 8 == 0
         
     | 
| 642 | 
         
            +
                        sync_feat = sync_feat.view(bs, sync_feat.shape[1] // 8, 8, self.sync_feat_dim) + self.sync_pos_emb
         
     | 
| 643 | 
         
            +
                        sync_feat = sync_feat.view(bs, -1, self.sync_feat_dim)  # bs, num_segments * 8, channels
         
     | 
| 644 | 
         
            +
                        sync_feat = self.sync_in(sync_feat)  # bs, num_segments * 8, c
         
     | 
| 645 | 
         
            +
                        add_sync_feat_to_audio = (
         
     | 
| 646 | 
         
            +
                            F.interpolate(sync_feat.transpose(1, 2), size=(tl), mode="nearest-exact").contiguous().transpose(1, 2)
         
     | 
| 647 | 
         
            +
                        )  # bs, tl, c
         
     | 
| 648 | 
         
            +
             
     | 
| 649 | 
         
            +
                    # ========================= Get text, audio and video clip embedding =========================
         
     | 
| 650 | 
         
            +
                    cond = self.cond_in(cond)
         
     | 
| 651 | 
         
            +
                    cond_seq_len = cond.shape[1]
         
     | 
| 652 | 
         
            +
             
     | 
| 653 | 
         
            +
                    audio = self.audio_embedder(x)
         
     | 
| 654 | 
         
            +
                    audio_seq_len = audio.shape[1]
         
     | 
| 655 | 
         
            +
                    v_cond = self.visual_proj(clip_feat)
         
     | 
| 656 | 
         
            +
                    v_cond_seq_len = v_cond.shape[1]
         
     | 
| 657 | 
         
            +
             
     | 
| 658 | 
         
            +
                    # ========================= Compute attention mask =========================
         
     | 
| 659 | 
         
            +
                    attn_mask = None
         
     | 
| 660 | 
         
            +
                    if self.use_attention_mask:
         
     | 
| 661 | 
         
            +
                        assert cond_mask is not None
         
     | 
| 662 | 
         
            +
                        batch_size = audio.shape[0]
         
     | 
| 663 | 
         
            +
                        seq_len = cond_seq_len + v_cond_seq_len + audio_seq_len
         
     | 
| 664 | 
         
            +
             
     | 
| 665 | 
         
            +
                        # get default audio_mask and v_cond_mask
         
     | 
| 666 | 
         
            +
                        audio_mask = torch.ones((batch_size, audio_seq_len), dtype=torch.bool, device=audio.device)
         
     | 
| 667 | 
         
            +
                        v_cond_mask = torch.ones((batch_size, v_cond_seq_len), dtype=torch.bool, device=audio.device)
         
     | 
| 668 | 
         
            +
             
     | 
| 669 | 
         
            +
                        # batch_size x seq_len
         
     | 
| 670 | 
         
            +
                        concat_mask = torch.cat([cond_mask, v_cond_mask, audio_mask], dim=1)
         
     | 
| 671 | 
         
            +
                        # batch_size x 1 x seq_len x seq_len
         
     | 
| 672 | 
         
            +
                        attn_mask_1 = concat_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
         
     | 
| 673 | 
         
            +
                        # batch_size x 1 x seq_len x seq_len
         
     | 
| 674 | 
         
            +
                        attn_mask_2 = attn_mask_1.transpose(2, 3)
         
     | 
| 675 | 
         
            +
                        # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
         
     | 
| 676 | 
         
            +
                        attn_mask = (attn_mask_1 & attn_mask_2).bool()
         
     | 
| 677 | 
         
            +
                        # avoids self-attention weight being NaN for text padding tokens
         
     | 
| 678 | 
         
            +
                        attn_mask[:, :, :, 0] = True
         
     | 
| 679 | 
         
            +
             
     | 
| 680 | 
         
            +
             
     | 
| 681 | 
         
            +
                    # ========================= Build rope for audio and clip tokens =========================
         
     | 
| 682 | 
         
            +
                    if self.interleaved_audio_visual_rope:
         
     | 
| 683 | 
         
            +
                        freqs_cos, freqs_sin = self.build_rope_for_interleaved_audio_visual(audio_seq_len * 2)
         
     | 
| 684 | 
         
            +
                        v_freqs_cos = v_freqs_sin = None
         
     | 
| 685 | 
         
            +
                    else:
         
     | 
| 686 | 
         
            +
                        freqs_cos, freqs_sin, v_freqs_cos, v_freqs_sin = self.build_rope_for_audio_visual(
         
     | 
| 687 | 
         
            +
                            audio_seq_len, v_cond_seq_len
         
     | 
| 688 | 
         
            +
                        )
         
     | 
| 689 | 
         
            +
             
     | 
| 690 | 
         
            +
                    # ========================= Pass through DiT blocks =========================
         
     | 
| 691 | 
         
            +
                    freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
         
     | 
| 692 | 
         
            +
                    v_freqs_cis = (v_freqs_cos, v_freqs_sin) if v_freqs_cos is not None else None
         
     | 
| 693 | 
         
            +
             
     | 
| 694 | 
         
            +
                    if self.add_sync_feat_to_audio:
         
     | 
| 695 | 
         
            +
                        add_sync_layer = 0
         
     | 
| 696 | 
         
            +
                    assert (
         
     | 
| 697 | 
         
            +
                        add_sync_layer < self.depth_triple_blocks
         
     | 
| 698 | 
         
            +
                    ), f"The layer to add mel_spectrogram feature and sync feature should in the triple_stream_blocks (n: {self.depth_triple_blocks})."
         
     | 
| 699 | 
         
            +
                    # Triple-stream blocks
         
     | 
| 700 | 
         
            +
                    for layer_num, block in enumerate(self.triple_blocks):
         
     | 
| 701 | 
         
            +
                        if self.add_sync_feat_to_audio and layer_num == add_sync_layer:
         
     | 
| 702 | 
         
            +
                            audio = audio + add_sync_feat_to_audio
         
     | 
| 703 | 
         
            +
                        triple_block_args = [audio, cond, v_cond, attn_mask, vec, freqs_cis, v_freqs_cis, sync_vec]
         
     | 
| 704 | 
         
            +
                        if (
         
     | 
| 705 | 
         
            +
                            self.training
         
     | 
| 706 | 
         
            +
                            and self.gradient_checkpoint
         
     | 
| 707 | 
         
            +
                            and (self.gradient_checkpoint_layers == -1 or layer_num < self.gradient_checkpoint_layers)
         
     | 
| 708 | 
         
            +
                        ):
         
     | 
| 709 | 
         
            +
                            audio, cond, v_cond = torch.utils.checkpoint.checkpoint(
         
     | 
| 710 | 
         
            +
                                ckpt_wrapper(block), *triple_block_args, use_reentrant=False
         
     | 
| 711 | 
         
            +
                            )
         
     | 
| 712 | 
         
            +
                        else:
         
     | 
| 713 | 
         
            +
                            audio, cond, v_cond = block(*triple_block_args)
         
     | 
| 714 | 
         
            +
             
     | 
| 715 | 
         
            +
                    x = audio 
         
     | 
| 716 | 
         
            +
                    if sync_vec is not None:
         
     | 
| 717 | 
         
            +
                        vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1)
         
     | 
| 718 | 
         
            +
                        vec = torch.cat((vec, sync_vec), dim=1)
         
     | 
| 719 | 
         
            +
             
     | 
| 720 | 
         
            +
                    freqs_cos, freqs_sin, _, _ = self.build_rope_for_audio_visual(audio_seq_len, v_cond_seq_len)
         
     | 
| 721 | 
         
            +
                    if self.add_sync_feat_to_audio:
         
     | 
| 722 | 
         
            +
                        vec = add_sync_feat_to_audio + vec.unsqueeze(dim=1)
         
     | 
| 723 | 
         
            +
                    if len(self.single_blocks) > 0:
         
     | 
| 724 | 
         
            +
                        for layer_num, block in enumerate(self.single_blocks):
         
     | 
| 725 | 
         
            +
                            single_block_args = [
         
     | 
| 726 | 
         
            +
                                x,
         
     | 
| 727 | 
         
            +
                                vec,
         
     | 
| 728 | 
         
            +
                                (freqs_cos, freqs_sin),
         
     | 
| 729 | 
         
            +
                            ]
         
     | 
| 730 | 
         
            +
                            if (
         
     | 
| 731 | 
         
            +
                                self.training
         
     | 
| 732 | 
         
            +
                                and self.gradient_checkpoint
         
     | 
| 733 | 
         
            +
                                and (
         
     | 
| 734 | 
         
            +
                                    self.gradient_checkpoint_layers == -1
         
     | 
| 735 | 
         
            +
                                    or layer_num + len(self.triple_blocks) < self.gradient_checkpoint_layers
         
     | 
| 736 | 
         
            +
                                )
         
     | 
| 737 | 
         
            +
                            ):
         
     | 
| 738 | 
         
            +
                                x = torch.utils.checkpoint.checkpoint(ckpt_wrapper(block), *single_block_args, use_reentrant=False)
         
     | 
| 739 | 
         
            +
                            else:
         
     | 
| 740 | 
         
            +
                                x = block(*single_block_args)
         
     | 
| 741 | 
         
            +
             
     | 
| 742 | 
         
            +
                    audio = x
         
     | 
| 743 | 
         
            +
             
     | 
| 744 | 
         
            +
                    # ========================= Final layer =========================
         
     | 
| 745 | 
         
            +
                    if sync_vec is not None:
         
     | 
| 746 | 
         
            +
                        vec = sync_vec
         
     | 
| 747 | 
         
            +
                    audio = self.final_layer(audio, vec)  # (N, T, patch_size * out_channels)
         
     | 
| 748 | 
         
            +
                    audio = self.unpatchify1d(audio, tl)
         
     | 
| 749 | 
         
            +
             
     | 
| 750 | 
         
            +
                    if return_dict:
         
     | 
| 751 | 
         
            +
                        out["x"] = audio
         
     | 
| 752 | 
         
            +
                        return out
         
     | 
| 753 | 
         
            +
                    return audio
         
     | 
| 754 | 
         
            +
             
     | 
| 755 | 
         
            +
                def unpatchify1d(self, x, l):
         
     | 
| 756 | 
         
            +
                    # x: (N, L, patch_size * C)
         
     | 
| 757 | 
         
            +
                    # audio: (N, C, T), T == L * patch_size
         
     | 
| 758 | 
         
            +
                    c = self.unpatchify_channels
         
     | 
| 759 | 
         
            +
                    p = self.patch_size
         
     | 
| 760 | 
         
            +
                    assert l == x.shape[1]
         
     | 
| 761 | 
         
            +
             
     | 
| 762 | 
         
            +
                    x = x.reshape(shape=(x.shape[0], l, p, c))
         
     | 
| 763 | 
         
            +
                    x = torch.einsum("ntpc->nctp", x)
         
     | 
| 764 | 
         
            +
                    audio = x.reshape(shape=(x.shape[0], c, l * p))
         
     | 
| 765 | 
         
            +
                    return audio
         
     | 
| 766 | 
         
            +
             
     | 
| 767 | 
         
            +
                def params_count(self):
         
     | 
| 768 | 
         
            +
                    counts = {
         
     | 
| 769 | 
         
            +
                        "triple": sum(
         
     | 
| 770 | 
         
            +
                            [
         
     | 
| 771 | 
         
            +
                                sum(p.numel() for p in block.audio_cross_q.parameters())
         
     | 
| 772 | 
         
            +
                                + sum(p.numel() for p in block.v_cond_cross_q.parameters())
         
     | 
| 773 | 
         
            +
                                + sum(p.numel() for p in block.text_cross_kv.parameters())
         
     | 
| 774 | 
         
            +
                                + sum(p.numel() for p in block.audio_self_attn_qkv.parameters())
         
     | 
| 775 | 
         
            +
                                + sum(p.numel() for p in block.v_cond_attn_qkv.parameters())
         
     | 
| 776 | 
         
            +
                                + sum(p.numel() for p in block.audio_mlp.parameters())
         
     | 
| 777 | 
         
            +
                                + sum(p.numel() for p in block.audio_self_proj.parameters())
         
     | 
| 778 | 
         
            +
                                + sum(p.numel() for p in block.v_cond_self_proj.parameters())
         
     | 
| 779 | 
         
            +
                                + sum(p.numel() for p in block.v_cond_mlp.parameters())
         
     | 
| 780 | 
         
            +
                                for block in self.triple_blocks
         
     | 
| 781 | 
         
            +
                            ]
         
     | 
| 782 | 
         
            +
                        ),
         
     | 
| 783 | 
         
            +
                        "single": sum(
         
     | 
| 784 | 
         
            +
                            [
         
     | 
| 785 | 
         
            +
                                sum(p.numel() for p in block.linear1.parameters())
         
     | 
| 786 | 
         
            +
                                + sum(p.numel() for p in block.linear2.parameters())
         
     | 
| 787 | 
         
            +
                                for block in self.single_blocks
         
     | 
| 788 | 
         
            +
                            ]
         
     | 
| 789 | 
         
            +
                        ),
         
     | 
| 790 | 
         
            +
                        "total": sum(p.numel() for p in self.parameters()),
         
     | 
| 791 | 
         
            +
                    }
         
     | 
| 792 | 
         
            +
             
     | 
| 793 | 
         
            +
                    counts["attn+mlp"] = counts["triple"] + counts["single"]
         
     | 
| 794 | 
         
            +
                    return counts
         
     | 
    	
        hunyuanvideo_foley/models/nn/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        hunyuanvideo_foley/models/nn/activation_layers.py
    ADDED
    
    | 
         @@ -0,0 +1,44 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch.nn as nn
         
     | 
| 2 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            def get_activation_layer(act_type):
         
     | 
| 5 | 
         
            +
                if act_type == "gelu":
         
     | 
| 6 | 
         
            +
                    return lambda: nn.GELU()
         
     | 
| 7 | 
         
            +
                elif act_type == "gelu_tanh":
         
     | 
| 8 | 
         
            +
                    # Approximate `tanh` requires torch >= 1.13
         
     | 
| 9 | 
         
            +
                    return lambda: nn.GELU(approximate="tanh")
         
     | 
| 10 | 
         
            +
                elif act_type == "relu":
         
     | 
| 11 | 
         
            +
                    return nn.ReLU
         
     | 
| 12 | 
         
            +
                elif act_type == "silu":
         
     | 
| 13 | 
         
            +
                    return nn.SiLU
         
     | 
| 14 | 
         
            +
                else:
         
     | 
| 15 | 
         
            +
                    raise ValueError(f"Unknown activation type: {act_type}")
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            class SwiGLU(nn.Module):
         
     | 
| 18 | 
         
            +
                def __init__(
         
     | 
| 19 | 
         
            +
                    self,
         
     | 
| 20 | 
         
            +
                    dim: int,
         
     | 
| 21 | 
         
            +
                    hidden_dim: int,
         
     | 
| 22 | 
         
            +
                    out_dim: int,
         
     | 
| 23 | 
         
            +
                ):
         
     | 
| 24 | 
         
            +
                    """
         
     | 
| 25 | 
         
            +
                    Initialize the SwiGLU FeedForward module.
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    Args:
         
     | 
| 28 | 
         
            +
                        dim (int): Input dimension.
         
     | 
| 29 | 
         
            +
                        hidden_dim (int): Hidden dimension of the feedforward layer.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                    Attributes:
         
     | 
| 32 | 
         
            +
                        w1: Linear transformation for the first layer.
         
     | 
| 33 | 
         
            +
                        w2: Linear transformation for the second layer.
         
     | 
| 34 | 
         
            +
                        w3: Linear transformation for the third layer.
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    """
         
     | 
| 37 | 
         
            +
                    super().__init__()
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    self.w1 = nn.Linear(dim, hidden_dim, bias=False)
         
     | 
| 40 | 
         
            +
                    self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
         
     | 
| 41 | 
         
            +
                    self.w3 = nn.Linear(dim, hidden_dim, bias=False)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                def forward(self, x):
         
     | 
| 44 | 
         
            +
                    return self.w2(F.silu(self.w1(x)) * self.w3(x))
         
     | 
    	
        hunyuanvideo_foley/models/nn/attn_layers.py
    ADDED
    
    | 
         @@ -0,0 +1,546 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import importlib.metadata
         
     | 
| 2 | 
         
            +
            import math
         
     | 
| 3 | 
         
            +
            from typing import Tuple, Union
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 7 | 
         
            +
            from einops import rearrange
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            try:
         
     | 
| 10 | 
         
            +
                from flash_attn import (
         
     | 
| 11 | 
         
            +
                    flash_attn_qkvpacked_func,
         
     | 
| 12 | 
         
            +
                    flash_attn_kvpacked_func,
         
     | 
| 13 | 
         
            +
                    flash_attn_varlen_kvpacked_func,
         
     | 
| 14 | 
         
            +
                    flash_attn_varlen_qkvpacked_func,
         
     | 
| 15 | 
         
            +
                )
         
     | 
| 16 | 
         
            +
                from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
         
     | 
| 17 | 
         
            +
            except ImportError:
         
     | 
| 18 | 
         
            +
                flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func = None, None, None
         
     | 
| 19 | 
         
            +
                index_first_axis = None
         
     | 
| 20 | 
         
            +
            from packaging import version
         
     | 
| 21 | 
         
            +
            from transformers.utils.import_utils import _is_package_available
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            from .norm_layers import get_norm_layer
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
         
     | 
| 26 | 
         
            +
                """
         
     | 
| 27 | 
         
            +
                Reshape frequency tensor for broadcasting it with another tensor.
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
         
     | 
| 30 | 
         
            +
                for the purpose of broadcasting the frequency tensor during element-wise operations.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                Notes:
         
     | 
| 33 | 
         
            +
                    When using FlashMHAModified, head_first should be False.
         
     | 
| 34 | 
         
            +
                    When using Attention, head_first should be True.
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                Args:
         
     | 
| 37 | 
         
            +
                    freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
         
     | 
| 38 | 
         
            +
                    x (torch.Tensor): Target tensor for broadcasting compatibility.
         
     | 
| 39 | 
         
            +
                    head_first (bool): head dimension first (except batch dim) or not.
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                Returns:
         
     | 
| 42 | 
         
            +
                    torch.Tensor: Reshaped frequency tensor.
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                Raises:
         
     | 
| 45 | 
         
            +
                    AssertionError: If the frequency tensor doesn't match the expected shape.
         
     | 
| 46 | 
         
            +
                    AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
         
     | 
| 47 | 
         
            +
                """
         
     | 
| 48 | 
         
            +
                ndim = x.ndim
         
     | 
| 49 | 
         
            +
                assert 0 <= 1 < ndim
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                if isinstance(freqs_cis, tuple):
         
     | 
| 52 | 
         
            +
                    # freqs_cis: (cos, sin) in real space
         
     | 
| 53 | 
         
            +
                    if head_first:
         
     | 
| 54 | 
         
            +
                        assert freqs_cis[0].shape == (
         
     | 
| 55 | 
         
            +
                            x.shape[-2],
         
     | 
| 56 | 
         
            +
                            x.shape[-1],
         
     | 
| 57 | 
         
            +
                        ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
         
     | 
| 58 | 
         
            +
                        shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
         
     | 
| 59 | 
         
            +
                    else:
         
     | 
| 60 | 
         
            +
                        assert freqs_cis[0].shape == (
         
     | 
| 61 | 
         
            +
                            x.shape[1],
         
     | 
| 62 | 
         
            +
                            x.shape[-1],
         
     | 
| 63 | 
         
            +
                        ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
         
     | 
| 64 | 
         
            +
                        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
         
     | 
| 65 | 
         
            +
                    return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
         
     | 
| 66 | 
         
            +
                else:
         
     | 
| 67 | 
         
            +
                    # freqs_cis: values in complex space
         
     | 
| 68 | 
         
            +
                    if head_first:
         
     | 
| 69 | 
         
            +
                        assert freqs_cis.shape == (
         
     | 
| 70 | 
         
            +
                            x.shape[-2],
         
     | 
| 71 | 
         
            +
                            x.shape[-1],
         
     | 
| 72 | 
         
            +
                        ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
         
     | 
| 73 | 
         
            +
                        shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
         
     | 
| 74 | 
         
            +
                    else:
         
     | 
| 75 | 
         
            +
                        assert freqs_cis.shape == (
         
     | 
| 76 | 
         
            +
                            x.shape[1],
         
     | 
| 77 | 
         
            +
                            x.shape[-1],
         
     | 
| 78 | 
         
            +
                        ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
         
     | 
| 79 | 
         
            +
                        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
         
     | 
| 80 | 
         
            +
                    return freqs_cis.view(*shape)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            def rotate_half(x):
         
     | 
| 84 | 
         
            +
                x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)  # [B, S, H, D//2]
         
     | 
| 85 | 
         
            +
                return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            def apply_rotary_emb(
         
     | 
| 89 | 
         
            +
                xq: torch.Tensor,
         
     | 
| 90 | 
         
            +
                xk: torch.Tensor,
         
     | 
| 91 | 
         
            +
                freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
         
     | 
| 92 | 
         
            +
                head_first: bool = False,
         
     | 
| 93 | 
         
            +
            ) -> Tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 94 | 
         
            +
                """
         
     | 
| 95 | 
         
            +
                Apply rotary embeddings to input tensors using the given frequency tensor.
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
         
     | 
| 98 | 
         
            +
                frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
         
     | 
| 99 | 
         
            +
                is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
         
     | 
| 100 | 
         
            +
                returned as real tensors.
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                Args:
         
     | 
| 103 | 
         
            +
                    xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
         
     | 
| 104 | 
         
            +
                    xk (torch.Tensor): Key tensor to apply rotary embeddings.   [B, S, H, D]
         
     | 
| 105 | 
         
            +
                    freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
         
     | 
| 106 | 
         
            +
                    head_first (bool): head dimension first (except batch dim) or not.
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                Returns:
         
     | 
| 109 | 
         
            +
                    Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                """
         
     | 
| 112 | 
         
            +
                xk_out = None
         
     | 
| 113 | 
         
            +
                if isinstance(freqs_cis, tuple):
         
     | 
| 114 | 
         
            +
                    cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)  # [S, D]
         
     | 
| 115 | 
         
            +
                    cos, sin = cos.to(xq.device), sin.to(xq.device)
         
     | 
| 116 | 
         
            +
                    # real * cos - imag * sin
         
     | 
| 117 | 
         
            +
                    # imag * cos + real * sin
         
     | 
| 118 | 
         
            +
                    xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
         
     | 
| 119 | 
         
            +
                    xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
         
     | 
| 120 | 
         
            +
                else:
         
     | 
| 121 | 
         
            +
                    # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
         
     | 
| 122 | 
         
            +
                    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  # [B, S, H, D//2]
         
     | 
| 123 | 
         
            +
                    freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device)  # [S, D//2] --> [1, S, 1, D//2]
         
     | 
| 124 | 
         
            +
                    # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
         
     | 
| 125 | 
         
            +
                    # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
         
     | 
| 126 | 
         
            +
                    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
         
     | 
| 127 | 
         
            +
                    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  # [B, S, H, D//2]
         
     | 
| 128 | 
         
            +
                    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                return xq_out, xk_out
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
            class BasicAttentionLayer(nn.Module):
         
     | 
| 134 | 
         
            +
                def __init__(self, attn_mode="flash", deterministic=False):
         
     | 
| 135 | 
         
            +
                    super().__init__()
         
     | 
| 136 | 
         
            +
                    self.attn_mode = attn_mode
         
     | 
| 137 | 
         
            +
                    self.deterministic = deterministic
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                def set_attn_mode(self, new_mode):
         
     | 
| 140 | 
         
            +
                    self.attn_mode = new_mode
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                def enable_deterministic(self):
         
     | 
| 143 | 
         
            +
                    self.deterministic = True
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                def disable_deterministic(self):
         
     | 
| 146 | 
         
            +
                    self.deterministic = False
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
            MEMORY_LAYOUT = {
         
     | 
| 150 | 
         
            +
                "self_flash": (
         
     | 
| 151 | 
         
            +
                    lambda x: x,
         
     | 
| 152 | 
         
            +
                    lambda x: x,
         
     | 
| 153 | 
         
            +
                ),
         
     | 
| 154 | 
         
            +
                "cross_flash": (
         
     | 
| 155 | 
         
            +
                    lambda x: x,
         
     | 
| 156 | 
         
            +
                    lambda x: x,
         
     | 
| 157 | 
         
            +
                ),
         
     | 
| 158 | 
         
            +
                "flash_torch_sp": (
         
     | 
| 159 | 
         
            +
                    lambda x: x,
         
     | 
| 160 | 
         
            +
                    lambda x: x,
         
     | 
| 161 | 
         
            +
                ),
         
     | 
| 162 | 
         
            +
                "torch": (
         
     | 
| 163 | 
         
            +
                    lambda x: x.transpose(1, 2),
         
     | 
| 164 | 
         
            +
                    lambda x: x.transpose(1, 2),
         
     | 
| 165 | 
         
            +
                ),
         
     | 
| 166 | 
         
            +
                "vanilla": (
         
     | 
| 167 | 
         
            +
                    lambda x: x.transpose(1, 2),
         
     | 
| 168 | 
         
            +
                    lambda x: x.transpose(1, 2),
         
     | 
| 169 | 
         
            +
                ),
         
     | 
| 170 | 
         
            +
            }
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
            # Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/modeling_flash_attention_utils.py#L33C1-L57C6
         
     | 
| 174 | 
         
            +
            def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
         
     | 
| 175 | 
         
            +
                """
         
     | 
| 176 | 
         
            +
                Retrieves indexing data required to repad unpadded (ragged) tensors.
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                Arguments:
         
     | 
| 179 | 
         
            +
                    attention_mask (`torch.Tensor`):
         
     | 
| 180 | 
         
            +
                        Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                Return:
         
     | 
| 183 | 
         
            +
                    indices (`torch.Tensor):
         
     | 
| 184 | 
         
            +
                        The indices of non-masked tokens from the flattened input sequence.
         
     | 
| 185 | 
         
            +
                    cu_seqlens (`torch.Tensor`):
         
     | 
| 186 | 
         
            +
                        The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
         
     | 
| 187 | 
         
            +
                    max_seqlen_in_batch (`int`):
         
     | 
| 188 | 
         
            +
                        Maximum sequence length in batch.
         
     | 
| 189 | 
         
            +
                """
         
     | 
| 190 | 
         
            +
                seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
         
     | 
| 191 | 
         
            +
                indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
         
     | 
| 192 | 
         
            +
                max_seqlen_in_batch = seqlens_in_batch.max().item()
         
     | 
| 193 | 
         
            +
                cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
         
     | 
| 194 | 
         
            +
                return (
         
     | 
| 195 | 
         
            +
                    indices,
         
     | 
| 196 | 
         
            +
                    cu_seqlens,
         
     | 
| 197 | 
         
            +
                    max_seqlen_in_batch,
         
     | 
| 198 | 
         
            +
                )
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
            # Copyed from https://github.com/huggingface/transformers/blob/b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/utils/import_utils.py#L822
         
     | 
| 202 | 
         
            +
            def is_flash_attn_greater_or_equal(library_version: str):
         
     | 
| 203 | 
         
            +
                if not _is_package_available("flash_attn"):
         
     | 
| 204 | 
         
            +
                    return False
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
            def get_kv_seqlens_with_mask(attn_mask, k, v):
         
     | 
| 210 | 
         
            +
                indices_k, cu_seqlens_k, max_seqlen_k = _get_unpad_data(attn_mask)
         
     | 
| 211 | 
         
            +
                b, s1, a, d = k.shape
         
     | 
| 212 | 
         
            +
                k = index_first_axis(k.reshape(b * s1, a, d), indices_k)
         
     | 
| 213 | 
         
            +
                v = index_first_axis(v.reshape(b * s1, a, d), indices_k)
         
     | 
| 214 | 
         
            +
                kv = torch.stack([k, v], dim=1)
         
     | 
| 215 | 
         
            +
                return cu_seqlens_k, max_seqlen_k, kv
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
            def get_q_seqlens(q):
         
     | 
| 219 | 
         
            +
                bs, s, a, d = q.shape
         
     | 
| 220 | 
         
            +
                cu_seqlens_q = torch.arange(0, (bs + 1) * s, step=s, dtype=torch.int32, device=q.device)
         
     | 
| 221 | 
         
            +
                q = q.reshape(bs * s, a, d)
         
     | 
| 222 | 
         
            +
                return cu_seqlens_q, s, q
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
            def flash_attn_no_pad(
         
     | 
| 225 | 
         
            +
                qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None
         
     | 
| 226 | 
         
            +
            ):
         
     | 
| 227 | 
         
            +
                # adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27
         
     | 
| 228 | 
         
            +
                batch_size = qkv.shape[0]
         
     | 
| 229 | 
         
            +
                seqlen = qkv.shape[1]
         
     | 
| 230 | 
         
            +
                nheads = qkv.shape[-2]
         
     | 
| 231 | 
         
            +
                x = rearrange(qkv, "b s three h d -> b s (three h d)")
         
     | 
| 232 | 
         
            +
                # x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch
         
     | 
| 233 | 
         
            +
                # x_unpad, indices, cu_seqlens, max_s
         
     | 
| 234 | 
         
            +
                unpad_results = unpad_input(
         
     | 
| 235 | 
         
            +
                    x, key_padding_mask
         
     | 
| 236 | 
         
            +
                )
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                if len(unpad_results) == 4:
         
     | 
| 239 | 
         
            +
                    x_unpad, indices, cu_seqlens, max_s = unpad_results
         
     | 
| 240 | 
         
            +
                elif len(unpad_results) == 5:
         
     | 
| 241 | 
         
            +
                    x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_results
         
     | 
| 242 | 
         
            +
                else:
         
     | 
| 243 | 
         
            +
                    raise ValueError
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
         
     | 
| 246 | 
         
            +
                output_unpad = flash_attn_varlen_qkvpacked_func(
         
     | 
| 247 | 
         
            +
                    x_unpad,
         
     | 
| 248 | 
         
            +
                    cu_seqlens,
         
     | 
| 249 | 
         
            +
                    max_s,
         
     | 
| 250 | 
         
            +
                    dropout_p,
         
     | 
| 251 | 
         
            +
                    softmax_scale=softmax_scale,
         
     | 
| 252 | 
         
            +
                    causal=causal,
         
     | 
| 253 | 
         
            +
                )
         
     | 
| 254 | 
         
            +
                output = rearrange(
         
     | 
| 255 | 
         
            +
                    pad_input(
         
     | 
| 256 | 
         
            +
                        rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen
         
     | 
| 257 | 
         
            +
                    ),
         
     | 
| 258 | 
         
            +
                    "b s (h d) -> b s h d",
         
     | 
| 259 | 
         
            +
                    h=nheads,
         
     | 
| 260 | 
         
            +
                )
         
     | 
| 261 | 
         
            +
                return output
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
            def attention(
         
     | 
| 265 | 
         
            +
                q,
         
     | 
| 266 | 
         
            +
                k,
         
     | 
| 267 | 
         
            +
                v,
         
     | 
| 268 | 
         
            +
                mode,
         
     | 
| 269 | 
         
            +
                drop_rate=0,
         
     | 
| 270 | 
         
            +
                attn_mask=None,
         
     | 
| 271 | 
         
            +
                cond_mask=None,
         
     | 
| 272 | 
         
            +
                causal=False,
         
     | 
| 273 | 
         
            +
                deterministic=False,
         
     | 
| 274 | 
         
            +
                cu_seqlens=None,
         
     | 
| 275 | 
         
            +
                max_seqlen=None,
         
     | 
| 276 | 
         
            +
                cu_seqlens_k=None,
         
     | 
| 277 | 
         
            +
                max_seqlen_k=None,
         
     | 
| 278 | 
         
            +
                img_seq_len=None,
         
     | 
| 279 | 
         
            +
            ):
         
     | 
| 280 | 
         
            +
                """
         
     | 
| 281 | 
         
            +
                Perform QKV self attention.
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                Args:
         
     | 
| 284 | 
         
            +
                    q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
         
     | 
| 285 | 
         
            +
                    k (torch.Tensor): Key tensor with shape [b, s1, a, d]
         
     | 
| 286 | 
         
            +
                    v (torch.Tensor): Value tensor with shape [b, s1, a, d]
         
     | 
| 287 | 
         
            +
                    mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
         
     | 
| 288 | 
         
            +
                    drop_rate (float): Dropout rate in attention map. (default: 0)
         
     | 
| 289 | 
         
            +
                    attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
         
     | 
| 290 | 
         
            +
                        (default: None)
         
     | 
| 291 | 
         
            +
                    causal (bool): Whether to use causal attention. (default: False)
         
     | 
| 292 | 
         
            +
                    deterministic (bool): Whether to use deterministic attention. (default: False)
         
     | 
| 293 | 
         
            +
                    cu_seqlens (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
         
     | 
| 294 | 
         
            +
                        used to index into q.
         
     | 
| 295 | 
         
            +
                    max_seqlen (int): The maximum sequence length in the batch of q.
         
     | 
| 296 | 
         
            +
                    cu_seqlens_k (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
         
     | 
| 297 | 
         
            +
                        used to index into kv.
         
     | 
| 298 | 
         
            +
                    max_seqlen_k (int): The maximum sequence length in the batch of k and v.
         
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
                Returns:
         
     | 
| 301 | 
         
            +
                    torch.Tensor: Output tensor after self attention with shape [b, s, ad]
         
     | 
| 302 | 
         
            +
                """
         
     | 
| 303 | 
         
            +
                if mode in ["torch", "vanilla", "self_flash", "cross_flash"]:
         
     | 
| 304 | 
         
            +
                    if isinstance(q, tuple):
         
     | 
| 305 | 
         
            +
                        q = torch.cat(q, dim=1)
         
     | 
| 306 | 
         
            +
                    if isinstance(k, tuple):
         
     | 
| 307 | 
         
            +
                        k = torch.cat(k, dim=1)
         
     | 
| 308 | 
         
            +
                    if isinstance(v, tuple):
         
     | 
| 309 | 
         
            +
                        v = torch.cat(v, dim=1)
         
     | 
| 310 | 
         
            +
                    pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
         
     | 
| 311 | 
         
            +
                    q = pre_attn_layout(q)
         
     | 
| 312 | 
         
            +
                    k = pre_attn_layout(k)
         
     | 
| 313 | 
         
            +
                    v = pre_attn_layout(v)
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                if "flash" in mode:
         
     | 
| 316 | 
         
            +
                    assert (
         
     | 
| 317 | 
         
            +
                        flash_attn_qkvpacked_func is not None
         
     | 
| 318 | 
         
            +
                    ), "Flash attention is not available. Please install flash_attn first."
         
     | 
| 319 | 
         
            +
                    flash_kwargs = dict(dropout_p=drop_rate, causal=causal)
         
     | 
| 320 | 
         
            +
                    if deterministic:
         
     | 
| 321 | 
         
            +
                        if not is_flash_attn_greater_or_equal("2.4.1"):
         
     | 
| 322 | 
         
            +
                            raise ValueError(
         
     | 
| 323 | 
         
            +
                                "Flash attention deterministic mode requires flash_attn>=2.4.1. " "Please upgrade flash_attn"
         
     | 
| 324 | 
         
            +
                            )
         
     | 
| 325 | 
         
            +
                        flash_kwargs["deterministic"] = deterministic
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                    if mode == "self_flash":
         
     | 
| 328 | 
         
            +
                        qkv = torch.stack([q, k, v], dim=2)
         
     | 
| 329 | 
         
            +
                        if attn_mask is not None:
         
     | 
| 330 | 
         
            +
                            raise ValueError("Self attention does not support attention mask")
         
     | 
| 331 | 
         
            +
                        x = flash_attn_qkvpacked_func(qkv, **flash_kwargs)
         
     | 
| 332 | 
         
            +
             
     | 
| 333 | 
         
            +
                    elif mode == "cross_flash":
         
     | 
| 334 | 
         
            +
                        kv = torch.stack([k, v], dim=2)
         
     | 
| 335 | 
         
            +
                        if attn_mask is None:
         
     | 
| 336 | 
         
            +
                            x = flash_attn_kvpacked_func(q, kv, **flash_kwargs)
         
     | 
| 337 | 
         
            +
                        else:
         
     | 
| 338 | 
         
            +
                            b, s, a, h = q.shape
         
     | 
| 339 | 
         
            +
                            cu_seqlens_q, max_seqlen_q, q = get_q_seqlens(q)
         
     | 
| 340 | 
         
            +
                            cu_seqlens_k, max_seqlen_k, kv = get_kv_seqlens_with_mask(attn_mask, k, v)
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                            attn_output = flash_attn_varlen_kvpacked_func(
         
     | 
| 343 | 
         
            +
                                q,
         
     | 
| 344 | 
         
            +
                                kv,
         
     | 
| 345 | 
         
            +
                                cu_seqlens_q=cu_seqlens_q,
         
     | 
| 346 | 
         
            +
                                cu_seqlens_k=cu_seqlens_k,
         
     | 
| 347 | 
         
            +
                                max_seqlen_q=max_seqlen_q,
         
     | 
| 348 | 
         
            +
                                max_seqlen_k=max_seqlen_k,
         
     | 
| 349 | 
         
            +
                                **flash_kwargs,
         
     | 
| 350 | 
         
            +
                            )
         
     | 
| 351 | 
         
            +
                            x = attn_output.reshape(b, s, a, h)
         
     | 
| 352 | 
         
            +
                elif mode == 'torch':
         
     | 
| 353 | 
         
            +
                    if attn_mask is not None and attn_mask.dtype != torch.bool:
         
     | 
| 354 | 
         
            +
                        attn_mask = attn_mask.to(q.dtype)
         
     | 
| 355 | 
         
            +
                    x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                elif mode == "vanilla":
         
     | 
| 358 | 
         
            +
                    scale_factor = 1 / math.sqrt(q.size(-1))
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
                    b, a, s, _ = q.shape
         
     | 
| 361 | 
         
            +
                    s1 = k.size(2)
         
     | 
| 362 | 
         
            +
                    attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
         
     | 
| 363 | 
         
            +
                    if causal:
         
     | 
| 364 | 
         
            +
                        # Only applied to self attention
         
     | 
| 365 | 
         
            +
                        assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
         
     | 
| 366 | 
         
            +
                        temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
         
     | 
| 367 | 
         
            +
                        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
         
     | 
| 368 | 
         
            +
                        attn_bias.to(q.dtype)
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                    if attn_mask is not None:
         
     | 
| 371 | 
         
            +
                        if attn_mask.dtype == torch.bool:
         
     | 
| 372 | 
         
            +
                            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
         
     | 
| 373 | 
         
            +
                        else:
         
     | 
| 374 | 
         
            +
                            attn_bias += attn_mask
         
     | 
| 375 | 
         
            +
             
     | 
| 376 | 
         
            +
                    # TODO(jarvizhang): Maybe force q and k to be float32 to avoid numerical overflow
         
     | 
| 377 | 
         
            +
                    attn = (q @ k.transpose(-2, -1)) * scale_factor
         
     | 
| 378 | 
         
            +
                    attn += attn_bias
         
     | 
| 379 | 
         
            +
                    attn = attn.softmax(dim=-1)
         
     | 
| 380 | 
         
            +
                    attn = torch.dropout(attn, p=drop_rate, train=True)
         
     | 
| 381 | 
         
            +
                    x = attn @ v
         
     | 
| 382 | 
         
            +
                else:
         
     | 
| 383 | 
         
            +
                    raise NotImplementedError(f"Unsupported attention mode: {mode}")
         
     | 
| 384 | 
         
            +
             
     | 
| 385 | 
         
            +
                if mode in ["torch", "vanilla", "self_flash", "cross_flash"]:
         
     | 
| 386 | 
         
            +
                    x = post_attn_layout(x).contiguous()
         
     | 
| 387 | 
         
            +
                b, s, a, d = x.shape
         
     | 
| 388 | 
         
            +
                out = x.reshape(b, s, -1)
         
     | 
| 389 | 
         
            +
                return out
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
            class SelfAttentionLayer(BasicAttentionLayer):
         
     | 
| 393 | 
         
            +
                def __init__(
         
     | 
| 394 | 
         
            +
                    self,
         
     | 
| 395 | 
         
            +
                    dim,
         
     | 
| 396 | 
         
            +
                    num_heads,
         
     | 
| 397 | 
         
            +
                    qkv_bias=True,
         
     | 
| 398 | 
         
            +
                    qk_norm=True,
         
     | 
| 399 | 
         
            +
                    attn_drop=0,
         
     | 
| 400 | 
         
            +
                    proj_drop=0,
         
     | 
| 401 | 
         
            +
                    dtype=None,
         
     | 
| 402 | 
         
            +
                    device=None,
         
     | 
| 403 | 
         
            +
                    norm_type="layer",
         
     | 
| 404 | 
         
            +
                    attn_mode="self_flash",
         
     | 
| 405 | 
         
            +
                    deterministic=False,
         
     | 
| 406 | 
         
            +
                ) -> None:
         
     | 
| 407 | 
         
            +
                    factory_kwargs = {"device": device, "dtype": dtype}
         
     | 
| 408 | 
         
            +
                    super().__init__(attn_mode, deterministic)
         
     | 
| 409 | 
         
            +
                    self.dim = dim
         
     | 
| 410 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 411 | 
         
            +
                    assert self.dim % num_heads == 0, "dim must be divisible by num_heads"
         
     | 
| 412 | 
         
            +
                    self.head_dim = self.dim // num_heads
         
     | 
| 413 | 
         
            +
                    self.attn_drop = attn_drop
         
     | 
| 414 | 
         
            +
             
     | 
| 415 | 
         
            +
                    # This assertion is aligned with flash attention
         
     | 
| 416 | 
         
            +
                    assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                    self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **factory_kwargs)
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                    norm_layer = get_norm_layer(norm_type)
         
     | 
| 421 | 
         
            +
                    self.q_norm = (
         
     | 
| 422 | 
         
            +
                        norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
         
     | 
| 423 | 
         
            +
                    )
         
     | 
| 424 | 
         
            +
                    self.k_norm = (
         
     | 
| 425 | 
         
            +
                        norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
         
     | 
| 426 | 
         
            +
                    )
         
     | 
| 427 | 
         
            +
             
     | 
| 428 | 
         
            +
                    self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
         
     | 
| 429 | 
         
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         
     | 
| 430 | 
         
            +
             
     | 
| 431 | 
         
            +
                def forward(self, x, freqs_cis=None, attn_mask=None):
         
     | 
| 432 | 
         
            +
                    """
         
     | 
| 433 | 
         
            +
                    Args:
         
     | 
| 434 | 
         
            +
                        x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
         
     | 
| 435 | 
         
            +
                        freqs_cis (torch.Tensor, optional): (batch, hidden_dim // 2), RoPE for image
         
     | 
| 436 | 
         
            +
                        attn_mask (torch.Tensor, optional): (batch, seq_len, seq_len), mask for attention
         
     | 
| 437 | 
         
            +
                    """
         
     | 
| 438 | 
         
            +
                    b, s, d = x.shape
         
     | 
| 439 | 
         
            +
             
     | 
| 440 | 
         
            +
                    # Apply QKV projection
         
     | 
| 441 | 
         
            +
                    qkv = self.Wqkv(x)
         
     | 
| 442 | 
         
            +
                    qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim)  # [b, s, 3, a, d]
         
     | 
| 443 | 
         
            +
                    q, k, v = qkv.unbind(dim=2)  # [b, s, a, d]
         
     | 
| 444 | 
         
            +
             
     | 
| 445 | 
         
            +
                    # Apply QK-Norm if needed
         
     | 
| 446 | 
         
            +
                    q = self.q_norm(q)
         
     | 
| 447 | 
         
            +
                    k = self.k_norm(k)
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                    # Apply RoPE if needed
         
     | 
| 450 | 
         
            +
                    if freqs_cis is not None:
         
     | 
| 451 | 
         
            +
                        qq, kk = apply_rotary_emb(q, k, freqs_cis)
         
     | 
| 452 | 
         
            +
                        assert (
         
     | 
| 453 | 
         
            +
                            qq.shape == q.shape and kk.shape == k.shape
         
     | 
| 454 | 
         
            +
                        ), f"qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}"
         
     | 
| 455 | 
         
            +
                        q, k = qq, kk
         
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
                    # Apply self attention
         
     | 
| 458 | 
         
            +
                    context = attention(
         
     | 
| 459 | 
         
            +
                        q,
         
     | 
| 460 | 
         
            +
                        k,
         
     | 
| 461 | 
         
            +
                        v,
         
     | 
| 462 | 
         
            +
                        drop_rate=self.attn_drop if self.training else 0,
         
     | 
| 463 | 
         
            +
                        attn_mask=attn_mask,
         
     | 
| 464 | 
         
            +
                        mode=self.attn_mode,
         
     | 
| 465 | 
         
            +
                        deterministic=self.deterministic,
         
     | 
| 466 | 
         
            +
                    )
         
     | 
| 467 | 
         
            +
                    out = self.out_proj(context)
         
     | 
| 468 | 
         
            +
                    out = self.proj_drop(out)
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
                    return out
         
     | 
| 471 | 
         
            +
             
     | 
| 472 | 
         
            +
             
     | 
| 473 | 
         
            +
            class CrossAttentionLayer(BasicAttentionLayer):
         
     | 
| 474 | 
         
            +
                def __init__(
         
     | 
| 475 | 
         
            +
                    self,
         
     | 
| 476 | 
         
            +
                    qdim,
         
     | 
| 477 | 
         
            +
                    kdim,
         
     | 
| 478 | 
         
            +
                    num_heads,
         
     | 
| 479 | 
         
            +
                    qkv_bias=True,
         
     | 
| 480 | 
         
            +
                    qk_norm=True,
         
     | 
| 481 | 
         
            +
                    attn_drop=0,
         
     | 
| 482 | 
         
            +
                    proj_drop=0,
         
     | 
| 483 | 
         
            +
                    dtype=None,
         
     | 
| 484 | 
         
            +
                    device=None,
         
     | 
| 485 | 
         
            +
                    norm_type="layer",
         
     | 
| 486 | 
         
            +
                    attn_mode="cross_flash",
         
     | 
| 487 | 
         
            +
                    deterministic=False,
         
     | 
| 488 | 
         
            +
                ):
         
     | 
| 489 | 
         
            +
                    factory_kwargs = {"device": device, "dtype": dtype}
         
     | 
| 490 | 
         
            +
                    super().__init__(attn_mode, deterministic)
         
     | 
| 491 | 
         
            +
                    self.qdim = qdim
         
     | 
| 492 | 
         
            +
                    self.kdim = kdim
         
     | 
| 493 | 
         
            +
                    self.num_heads = num_heads
         
     | 
| 494 | 
         
            +
                    assert self.qdim % num_heads == 0, "qdim must be divisible by num_heads"
         
     | 
| 495 | 
         
            +
                    self.head_dim = self.qdim // num_heads
         
     | 
| 496 | 
         
            +
                    self.attn_drop = attn_drop
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
                    # This assertion is aligned with flash attention
         
     | 
| 499 | 
         
            +
                    assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
         
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
                    self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
         
     | 
| 502 | 
         
            +
                    self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
         
     | 
| 503 | 
         
            +
             
     | 
| 504 | 
         
            +
                    norm_layer = get_norm_layer(norm_type)
         
     | 
| 505 | 
         
            +
                    self.q_norm = (
         
     | 
| 506 | 
         
            +
                        norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
         
     | 
| 507 | 
         
            +
                    )
         
     | 
| 508 | 
         
            +
                    self.k_norm = (
         
     | 
| 509 | 
         
            +
                        norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
         
     | 
| 510 | 
         
            +
                    )
         
     | 
| 511 | 
         
            +
             
     | 
| 512 | 
         
            +
                    self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
         
     | 
| 513 | 
         
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         
     | 
| 514 | 
         
            +
             
     | 
| 515 | 
         
            +
                def forward(self, x, y, attn_mask=None):
         
     | 
| 516 | 
         
            +
                    """
         
     | 
| 517 | 
         
            +
                    Args:
         
     | 
| 518 | 
         
            +
                        x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
         
     | 
| 519 | 
         
            +
                        y (torch.Tensor): (batch, seq_len1, hidden_dim1)
         
     | 
| 520 | 
         
            +
                        attn_mask (torch.Tensor): (batch, seq_len1), mask for attention
         
     | 
| 521 | 
         
            +
                    """
         
     | 
| 522 | 
         
            +
                    b, s, d = x.shape
         
     | 
| 523 | 
         
            +
                    _, s1, d1 = y.shape
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
                    q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim)
         
     | 
| 526 | 
         
            +
                    kv = self.kv_proj(y).view(b, s1, 2, self.num_heads, self.head_dim)
         
     | 
| 527 | 
         
            +
                    k, v = kv.unbind(dim=2)
         
     | 
| 528 | 
         
            +
             
     | 
| 529 | 
         
            +
                    # Apply QK-Norm if needed
         
     | 
| 530 | 
         
            +
                    q = self.q_norm(q)
         
     | 
| 531 | 
         
            +
                    k = self.k_norm(k)
         
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
                    # Apply cross attention
         
     | 
| 534 | 
         
            +
                    context = attention(
         
     | 
| 535 | 
         
            +
                        q,
         
     | 
| 536 | 
         
            +
                        k,
         
     | 
| 537 | 
         
            +
                        v,
         
     | 
| 538 | 
         
            +
                        attn_mask=attn_mask,
         
     | 
| 539 | 
         
            +
                        drop_rate=self.attn_drop if self.training else 0,
         
     | 
| 540 | 
         
            +
                        mode=self.attn_mode,
         
     | 
| 541 | 
         
            +
                        deterministic=self.deterministic,
         
     | 
| 542 | 
         
            +
                    )
         
     | 
| 543 | 
         
            +
                    out = self.out_proj(context)
         
     | 
| 544 | 
         
            +
                    out = self.proj_drop(out)
         
     | 
| 545 | 
         
            +
             
     | 
| 546 | 
         
            +
                    return out
         
     | 
    	
        hunyuanvideo_foley/models/nn/embed_layers.py
    ADDED
    
    | 
         @@ -0,0 +1,136 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from ...utils.helper import to_2tuple, to_1tuple
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class PatchEmbed1D(nn.Module):
         
     | 
| 8 | 
         
            +
                """1D Audio to Patch Embedding
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                A convolution based approach to patchifying a 1D audio w/ embedding projection.
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                Based on the impl in https://github.com/google-research/vision_transformer
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                Hacked together by / Copyright 2020 Ross Wightman
         
     | 
| 15 | 
         
            +
                """
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def __init__(
         
     | 
| 18 | 
         
            +
                    self,
         
     | 
| 19 | 
         
            +
                    patch_size=1,
         
     | 
| 20 | 
         
            +
                    in_chans=768,
         
     | 
| 21 | 
         
            +
                    embed_dim=768,
         
     | 
| 22 | 
         
            +
                    norm_layer=None,
         
     | 
| 23 | 
         
            +
                    flatten=True,
         
     | 
| 24 | 
         
            +
                    bias=True,
         
     | 
| 25 | 
         
            +
                    dtype=None,
         
     | 
| 26 | 
         
            +
                    device=None,
         
     | 
| 27 | 
         
            +
                ):
         
     | 
| 28 | 
         
            +
                    factory_kwargs = {"dtype": dtype, "device": device}
         
     | 
| 29 | 
         
            +
                    super().__init__()
         
     | 
| 30 | 
         
            +
                    patch_size = to_1tuple(patch_size)
         
     | 
| 31 | 
         
            +
                    self.patch_size = patch_size
         
     | 
| 32 | 
         
            +
                    self.flatten = flatten
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    self.proj = nn.Conv1d(
         
     | 
| 35 | 
         
            +
                        in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs
         
     | 
| 36 | 
         
            +
                    )
         
     | 
| 37 | 
         
            +
                    nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
         
     | 
| 38 | 
         
            +
                    if bias:
         
     | 
| 39 | 
         
            +
                        nn.init.zeros_(self.proj.bias)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                def forward(self, x):
         
     | 
| 44 | 
         
            +
                    assert (
         
     | 
| 45 | 
         
            +
                        x.shape[2] % self.patch_size[0] == 0
         
     | 
| 46 | 
         
            +
                    ), f"The patch_size of {self.patch_size[0]} must be divisible by the token number ({x.shape[2]}) of x."
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    x = self.proj(x)
         
     | 
| 49 | 
         
            +
                    if self.flatten:
         
     | 
| 50 | 
         
            +
                        x = x.transpose(1, 2)  # BCN -> BNC
         
     | 
| 51 | 
         
            +
                    x = self.norm(x)
         
     | 
| 52 | 
         
            +
                    return x
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            class ConditionProjection(nn.Module):
         
     | 
| 56 | 
         
            +
                """
         
     | 
| 57 | 
         
            +
                Projects condition embeddings. Also handles dropout for classifier-free guidance.
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
         
     | 
| 60 | 
         
            +
                """
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
         
     | 
| 63 | 
         
            +
                    factory_kwargs = {'dtype': dtype, 'device': device}
         
     | 
| 64 | 
         
            +
                    super().__init__()
         
     | 
| 65 | 
         
            +
                    self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
         
     | 
| 66 | 
         
            +
                    self.act_1 = act_layer()
         
     | 
| 67 | 
         
            +
                    self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                def forward(self, caption):
         
     | 
| 70 | 
         
            +
                    hidden_states = self.linear_1(caption)
         
     | 
| 71 | 
         
            +
                    hidden_states = self.act_1(hidden_states)
         
     | 
| 72 | 
         
            +
                    hidden_states = self.linear_2(hidden_states)
         
     | 
| 73 | 
         
            +
                    return hidden_states
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            def timestep_embedding(t, dim, max_period=10000):
         
     | 
| 77 | 
         
            +
                """
         
     | 
| 78 | 
         
            +
                Create sinusoidal timestep embeddings.
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                Args:
         
     | 
| 81 | 
         
            +
                    t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
         
     | 
| 82 | 
         
            +
                    dim (int): the dimension of the output.
         
     | 
| 83 | 
         
            +
                    max_period (int): controls the minimum frequency of the embeddings.
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                Returns:
         
     | 
| 86 | 
         
            +
                    embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
         
     | 
| 89 | 
         
            +
                """
         
     | 
| 90 | 
         
            +
                half = dim // 2
         
     | 
| 91 | 
         
            +
                freqs = torch.exp(
         
     | 
| 92 | 
         
            +
                    -math.log(max_period)
         
     | 
| 93 | 
         
            +
                    * torch.arange(start=0, end=half, dtype=torch.float32)
         
     | 
| 94 | 
         
            +
                    / half
         
     | 
| 95 | 
         
            +
                ).to(device=t.device)
         
     | 
| 96 | 
         
            +
                args = t[:, None].float() * freqs[None]
         
     | 
| 97 | 
         
            +
                embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
         
     | 
| 98 | 
         
            +
                if dim % 2:
         
     | 
| 99 | 
         
            +
                    embedding = torch.cat(
         
     | 
| 100 | 
         
            +
                        [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
         
     | 
| 101 | 
         
            +
                    )
         
     | 
| 102 | 
         
            +
                return embedding
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
            class TimestepEmbedder(nn.Module):
         
     | 
| 106 | 
         
            +
                """
         
     | 
| 107 | 
         
            +
                Embeds scalar timesteps into vector representations.
         
     | 
| 108 | 
         
            +
                """
         
     | 
| 109 | 
         
            +
                def __init__(self,
         
     | 
| 110 | 
         
            +
                             hidden_size,
         
     | 
| 111 | 
         
            +
                             act_layer,
         
     | 
| 112 | 
         
            +
                             frequency_embedding_size=256,
         
     | 
| 113 | 
         
            +
                             max_period=10000,
         
     | 
| 114 | 
         
            +
                             out_size=None,
         
     | 
| 115 | 
         
            +
                             dtype=None,
         
     | 
| 116 | 
         
            +
                             device=None
         
     | 
| 117 | 
         
            +
                             ):
         
     | 
| 118 | 
         
            +
                    factory_kwargs = {'dtype': dtype, 'device': device}
         
     | 
| 119 | 
         
            +
                    super().__init__()
         
     | 
| 120 | 
         
            +
                    self.frequency_embedding_size = frequency_embedding_size
         
     | 
| 121 | 
         
            +
                    self.max_period = max_period
         
     | 
| 122 | 
         
            +
                    if out_size is None:
         
     | 
| 123 | 
         
            +
                        out_size = hidden_size
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    self.mlp = nn.Sequential(
         
     | 
| 126 | 
         
            +
                        nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
         
     | 
| 127 | 
         
            +
                        act_layer(),
         
     | 
| 128 | 
         
            +
                        nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
         
     | 
| 129 | 
         
            +
                    )
         
     | 
| 130 | 
         
            +
                    nn.init.normal_(self.mlp[0].weight, std=0.02)
         
     | 
| 131 | 
         
            +
                    nn.init.normal_(self.mlp[2].weight, std=0.02)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                def forward(self, t):
         
     | 
| 134 | 
         
            +
                    t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
         
     | 
| 135 | 
         
            +
                    t_emb = self.mlp(t_freq)
         
     | 
| 136 | 
         
            +
                    return t_emb
         
     | 
    	
        hunyuanvideo_foley/models/nn/mlp_layers.py
    ADDED
    
    | 
         @@ -0,0 +1,149 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Modified from timm library:
         
     | 
| 2 | 
         
            +
            # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from functools import partial
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import torch.nn as nn
         
     | 
| 8 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from .modulate_layers import modulate
         
     | 
| 11 | 
         
            +
            from ...utils.helper import to_2tuple
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class MLP(nn.Module):
         
     | 
| 14 | 
         
            +
                """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def __init__(
         
     | 
| 17 | 
         
            +
                    self,
         
     | 
| 18 | 
         
            +
                    in_channels,
         
     | 
| 19 | 
         
            +
                    hidden_channels=None,
         
     | 
| 20 | 
         
            +
                    out_features=None,
         
     | 
| 21 | 
         
            +
                    act_layer=nn.GELU,
         
     | 
| 22 | 
         
            +
                    norm_layer=None,
         
     | 
| 23 | 
         
            +
                    bias=True,
         
     | 
| 24 | 
         
            +
                    drop=0.0,
         
     | 
| 25 | 
         
            +
                    use_conv=False,
         
     | 
| 26 | 
         
            +
                    device=None,
         
     | 
| 27 | 
         
            +
                    dtype=None,
         
     | 
| 28 | 
         
            +
                ):
         
     | 
| 29 | 
         
            +
                    factory_kwargs = {"device": device, "dtype": dtype}
         
     | 
| 30 | 
         
            +
                    super().__init__()
         
     | 
| 31 | 
         
            +
                    out_features = out_features or in_channels
         
     | 
| 32 | 
         
            +
                    hidden_channels = hidden_channels or in_channels
         
     | 
| 33 | 
         
            +
                    bias = to_2tuple(bias)
         
     | 
| 34 | 
         
            +
                    drop_probs = to_2tuple(drop)
         
     | 
| 35 | 
         
            +
                    linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
         
     | 
| 38 | 
         
            +
                    self.act = act_layer()
         
     | 
| 39 | 
         
            +
                    self.drop1 = nn.Dropout(drop_probs[0])
         
     | 
| 40 | 
         
            +
                    self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
         
     | 
| 41 | 
         
            +
                    self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
         
     | 
| 42 | 
         
            +
                    self.drop2 = nn.Dropout(drop_probs[1])
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                def forward(self, x):
         
     | 
| 45 | 
         
            +
                    x = self.fc1(x)
         
     | 
| 46 | 
         
            +
                    x = self.act(x)
         
     | 
| 47 | 
         
            +
                    x = self.drop1(x)
         
     | 
| 48 | 
         
            +
                    x = self.norm(x)
         
     | 
| 49 | 
         
            +
                    x = self.fc2(x)
         
     | 
| 50 | 
         
            +
                    x = self.drop2(x)
         
     | 
| 51 | 
         
            +
                    return x
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            # copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
         
     | 
| 55 | 
         
            +
            # only used when use_vanilla is True
         
     | 
| 56 | 
         
            +
            class MLPEmbedder(nn.Module):
         
     | 
| 57 | 
         
            +
                def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
         
     | 
| 58 | 
         
            +
                    factory_kwargs = {"device": device, "dtype": dtype}
         
     | 
| 59 | 
         
            +
                    super().__init__()
         
     | 
| 60 | 
         
            +
                    self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
         
     | 
| 61 | 
         
            +
                    self.silu = nn.SiLU()
         
     | 
| 62 | 
         
            +
                    self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 65 | 
         
            +
                    return self.out_layer(self.silu(self.in_layer(x)))
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            class LinearWarpforSingle(nn.Module):
         
     | 
| 69 | 
         
            +
                def __init__(self, in_dim: int, out_dim: int, bias=True, device=None, dtype=None):
         
     | 
| 70 | 
         
            +
                    factory_kwargs = {"device": device, "dtype": dtype}
         
     | 
| 71 | 
         
            +
                    super().__init__()
         
     | 
| 72 | 
         
            +
                    self.fc = nn.Linear(in_dim, out_dim, bias=bias, **factory_kwargs)
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def forward(self, x, y):
         
     | 
| 75 | 
         
            +
                    z = torch.cat([x, y], dim=2)
         
     | 
| 76 | 
         
            +
                    return self.fc(z)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            class FinalLayer1D(nn.Module):
         
     | 
| 79 | 
         
            +
                def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
         
     | 
| 80 | 
         
            +
                    factory_kwargs = {"device": device, "dtype": dtype}
         
     | 
| 81 | 
         
            +
                    super().__init__()
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    # Just use LayerNorm for the final layer
         
     | 
| 84 | 
         
            +
                    self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
         
     | 
| 85 | 
         
            +
                    self.linear = nn.Linear(hidden_size, patch_size * out_channels, bias=True, **factory_kwargs)
         
     | 
| 86 | 
         
            +
                    nn.init.zeros_(self.linear.weight)
         
     | 
| 87 | 
         
            +
                    nn.init.zeros_(self.linear.bias)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    # Here we don't distinguish between the modulate types. Just use the simple one.
         
     | 
| 90 | 
         
            +
                    self.adaLN_modulation = nn.Sequential(
         
     | 
| 91 | 
         
            +
                        act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
         
     | 
| 92 | 
         
            +
                    )
         
     | 
| 93 | 
         
            +
                    # Zero-initialize the modulation
         
     | 
| 94 | 
         
            +
                    nn.init.zeros_(self.adaLN_modulation[1].weight)
         
     | 
| 95 | 
         
            +
                    nn.init.zeros_(self.adaLN_modulation[1].bias)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                def forward(self, x, c):
         
     | 
| 98 | 
         
            +
                    shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
         
     | 
| 99 | 
         
            +
                    x = modulate(self.norm_final(x), shift=shift, scale=scale)
         
     | 
| 100 | 
         
            +
                    x = self.linear(x)
         
     | 
| 101 | 
         
            +
                    return x
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            class ChannelLastConv1d(nn.Conv1d):
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 107 | 
         
            +
                    x = x.permute(0, 2, 1)
         
     | 
| 108 | 
         
            +
                    x = super().forward(x)
         
     | 
| 109 | 
         
            +
                    x = x.permute(0, 2, 1)
         
     | 
| 110 | 
         
            +
                    return x
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
            class ConvMLP(nn.Module):
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                def __init__(
         
     | 
| 116 | 
         
            +
                    self,
         
     | 
| 117 | 
         
            +
                    dim: int,
         
     | 
| 118 | 
         
            +
                    hidden_dim: int,
         
     | 
| 119 | 
         
            +
                    multiple_of: int = 256,
         
     | 
| 120 | 
         
            +
                    kernel_size: int = 3,
         
     | 
| 121 | 
         
            +
                    padding: int = 1,
         
     | 
| 122 | 
         
            +
                    device=None,
         
     | 
| 123 | 
         
            +
                    dtype=None,
         
     | 
| 124 | 
         
            +
                ):
         
     | 
| 125 | 
         
            +
                    """
         
     | 
| 126 | 
         
            +
                    Convolutional MLP module.
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    Args:
         
     | 
| 129 | 
         
            +
                        dim (int): Input dimension.
         
     | 
| 130 | 
         
            +
                        hidden_dim (int): Hidden dimension of the feedforward layer.
         
     | 
| 131 | 
         
            +
                        multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    Attributes:
         
     | 
| 134 | 
         
            +
                        w1: Linear transformation for the first layer.
         
     | 
| 135 | 
         
            +
                        w2: Linear transformation for the second layer.
         
     | 
| 136 | 
         
            +
                        w3: Linear transformation for the third layer.
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    """
         
     | 
| 139 | 
         
            +
                    factory_kwargs = {"device": device, "dtype": dtype}
         
     | 
| 140 | 
         
            +
                    super().__init__()
         
     | 
| 141 | 
         
            +
                    hidden_dim = int(2 * hidden_dim / 3)
         
     | 
| 142 | 
         
            +
                    hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    self.w1 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
         
     | 
| 145 | 
         
            +
                    self.w2 = ChannelLastConv1d(hidden_dim, dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
         
     | 
| 146 | 
         
            +
                    self.w3 = ChannelLastConv1d(dim, hidden_dim, bias=False, kernel_size=kernel_size, padding=padding, **factory_kwargs)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                def forward(self, x):
         
     | 
| 149 | 
         
            +
                    return self.w2(F.silu(self.w1(x)) * self.w3(x))
         
     | 
    	
        hunyuanvideo_foley/models/nn/modulate_layers.py
    ADDED
    
    | 
         @@ -0,0 +1,49 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Callable
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class ModulateDiT(nn.Module):
         
     | 
| 6 | 
         
            +
                def __init__(self, hidden_size: int, factor: int, act_layer: Callable, dtype=None, device=None):
         
     | 
| 7 | 
         
            +
                    factory_kwargs = {"dtype": dtype, "device": device}
         
     | 
| 8 | 
         
            +
                    super().__init__()
         
     | 
| 9 | 
         
            +
                    self.act = act_layer()
         
     | 
| 10 | 
         
            +
                    self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
         
     | 
| 11 | 
         
            +
                    # Zero-initialize the modulation
         
     | 
| 12 | 
         
            +
                    nn.init.zeros_(self.linear.weight)
         
     | 
| 13 | 
         
            +
                    nn.init.zeros_(self.linear.bias)
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         
     | 
| 16 | 
         
            +
                    return self.linear(self.act(x))
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def modulate(x, shift=None, scale=None):
         
     | 
| 20 | 
         
            +
                if x.ndim == 3:
         
     | 
| 21 | 
         
            +
                    shift = shift.unsqueeze(1) if shift is not None and shift.ndim == 2 else None
         
     | 
| 22 | 
         
            +
                    scale = scale.unsqueeze(1) if scale is not None and scale.ndim == 2 else None
         
     | 
| 23 | 
         
            +
                if scale is None and shift is None:
         
     | 
| 24 | 
         
            +
                    return x
         
     | 
| 25 | 
         
            +
                elif shift is None:
         
     | 
| 26 | 
         
            +
                    return x * (1 + scale)
         
     | 
| 27 | 
         
            +
                elif scale is None:
         
     | 
| 28 | 
         
            +
                    return x + shift
         
     | 
| 29 | 
         
            +
                else:
         
     | 
| 30 | 
         
            +
                    return x * (1 + scale) + shift
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
            def apply_gate(x, gate=None, tanh=False):
         
     | 
| 34 | 
         
            +
                if gate is None:
         
     | 
| 35 | 
         
            +
                    return x
         
     | 
| 36 | 
         
            +
                if gate.ndim == 2 and x.ndim == 3:
         
     | 
| 37 | 
         
            +
                    gate = gate.unsqueeze(1)
         
     | 
| 38 | 
         
            +
                if tanh:
         
     | 
| 39 | 
         
            +
                    return x * gate.tanh()
         
     | 
| 40 | 
         
            +
                else:
         
     | 
| 41 | 
         
            +
                    return x * gate
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            def ckpt_wrapper(module):
         
     | 
| 45 | 
         
            +
                def ckpt_forward(*inputs):
         
     | 
| 46 | 
         
            +
                    outputs = module(*inputs)
         
     | 
| 47 | 
         
            +
                    return outputs
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                return ckpt_forward
         
     | 
    	
        hunyuanvideo_foley/models/nn/norm_layers.py
    ADDED
    
    | 
         @@ -0,0 +1,70 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            class RMSNorm(nn.Module):
         
     | 
| 5 | 
         
            +
                def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6,
         
     | 
| 6 | 
         
            +
                             device=None, dtype=None):
         
     | 
| 7 | 
         
            +
                    """
         
     | 
| 8 | 
         
            +
                    Initialize the RMSNorm normalization layer.
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                    Args:
         
     | 
| 11 | 
         
            +
                        dim (int): The dimension of the input tensor.
         
     | 
| 12 | 
         
            +
                        eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                    Attributes:
         
     | 
| 15 | 
         
            +
                        eps (float): A small value added to the denominator for numerical stability.
         
     | 
| 16 | 
         
            +
                        weight (nn.Parameter): Learnable scaling parameter.
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
                    """
         
     | 
| 19 | 
         
            +
                    factory_kwargs = {'device': device, 'dtype': dtype}
         
     | 
| 20 | 
         
            +
                    super().__init__()
         
     | 
| 21 | 
         
            +
                    self.eps = eps
         
     | 
| 22 | 
         
            +
                    if elementwise_affine:
         
     | 
| 23 | 
         
            +
                        self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                def _norm(self, x):
         
     | 
| 26 | 
         
            +
                    """
         
     | 
| 27 | 
         
            +
                    Apply the RMSNorm normalization to the input tensor.
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                    Args:
         
     | 
| 30 | 
         
            +
                        x (torch.Tensor): The input tensor.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    Returns:
         
     | 
| 33 | 
         
            +
                        torch.Tensor: The normalized tensor.
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                    """
         
     | 
| 36 | 
         
            +
                    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                def forward(self, x):
         
     | 
| 39 | 
         
            +
                    """
         
     | 
| 40 | 
         
            +
                    Forward pass through the RMSNorm layer.
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    Args:
         
     | 
| 43 | 
         
            +
                        x (torch.Tensor): The input tensor.
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    Returns:
         
     | 
| 46 | 
         
            +
                        torch.Tensor: The output tensor after applying RMSNorm.
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    """
         
     | 
| 49 | 
         
            +
                    output = self._norm(x.float()).type_as(x)
         
     | 
| 50 | 
         
            +
                    if hasattr(self, "weight"):
         
     | 
| 51 | 
         
            +
                        output = output * self.weight
         
     | 
| 52 | 
         
            +
                    return output
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            def get_norm_layer(norm_layer):
         
     | 
| 56 | 
         
            +
                """
         
     | 
| 57 | 
         
            +
                Get the normalization layer.
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                Args:
         
     | 
| 60 | 
         
            +
                    norm_layer (str): The type of normalization layer.
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                Returns:
         
     | 
| 63 | 
         
            +
                    norm_layer (nn.Module): The normalization layer.
         
     | 
| 64 | 
         
            +
                """
         
     | 
| 65 | 
         
            +
                if norm_layer == "layer":
         
     | 
| 66 | 
         
            +
                    return nn.LayerNorm
         
     | 
| 67 | 
         
            +
                elif norm_layer == "rms":
         
     | 
| 68 | 
         
            +
                    return RMSNorm
         
     | 
| 69 | 
         
            +
                else:
         
     | 
| 70 | 
         
            +
                    raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
         
     |