Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Commit 
							
							·
						
						b6e0092
	
1
								Parent(s):
							
							e43437a
								
Add files
Browse files- app.py +964 -0
 - app_utils.py +102 -0
 - assets/.DS_Store +0 -0
 - assets/images/.DS_Store +0 -0
 - assets/images/editing/banana.png +0 -0
 - assets/images/editing/cake.png +0 -0
 - assets/images/editing/rabbit.png +0 -0
 - assets/images/interpolation/church1.png +0 -0
 - assets/images/interpolation/church2.png +0 -0
 - assets/images/interpolation/dog1.png +0 -0
 - assets/images/interpolation/dog2.png +0 -0
 - assets/images/interpolation/horse1.png +0 -0
 - assets/images/interpolation/horse2.png +0 -0
 - assets/images/interpolation/land1.png +0 -0
 - assets/images/interpolation/land2.png +0 -0
 - assets/images/interpolation/rabbit1.png +0 -0
 - assets/images/interpolation/rabbit2.png +0 -0
 - assets/images/interpolation/woman1.png +0 -0
 - assets/images/interpolation/woman2.png +0 -0
 - assets/images/inversion/000000029596.jpg +0 -0
 - assets/images/inversion/000000560011.jpg +0 -0
 - nulltxtinv_wrapper.py +450 -0
 - requirements.txt +16 -0
 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,964 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ################################################################################
         
     | 
| 2 | 
         
            +
            # Copyright (C) 2023 Jiayi Guo, Xingqian Xu, Manushree Vasu - All Rights Reserved                         #
         
     | 
| 3 | 
         
            +
            ################################################################################
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import gradio as gr
         
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
            import os.path as osp
         
     | 
| 8 | 
         
            +
            import PIL
         
     | 
| 9 | 
         
            +
            from PIL import Image
         
     | 
| 10 | 
         
            +
            import numpy as np
         
     | 
| 11 | 
         
            +
            from collections import OrderedDict
         
     | 
| 12 | 
         
            +
            from easydict import EasyDict as edict
         
     | 
| 13 | 
         
            +
            from functools import partial
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import torch
         
     | 
| 16 | 
         
            +
            import torchvision.transforms as tvtrans
         
     | 
| 17 | 
         
            +
            import time
         
     | 
| 18 | 
         
            +
            import argparse
         
     | 
| 19 | 
         
            +
            import json
         
     | 
| 20 | 
         
            +
            import hashlib
         
     | 
| 21 | 
         
            +
            import copy
         
     | 
| 22 | 
         
            +
            from tqdm import tqdm
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            from diffusers import StableDiffusionPipeline
         
     | 
| 25 | 
         
            +
            from diffusers import DDIMScheduler
         
     | 
| 26 | 
         
            +
            from app_utils import auto_dropdown
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            from huggingface_hub import hf_hub_download
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            version = "Smooth Diffusion Demo v1.0"
         
     | 
| 31 | 
         
            +
            refresh_symbol = "\U0001f504" # 🔄
         
     | 
| 32 | 
         
            +
            recycle_symbol = '\U0000267b' #
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            ##############
         
     | 
| 35 | 
         
            +
            # model_book #
         
     | 
| 36 | 
         
            +
            ##############
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            choices = edict()
         
     | 
| 39 | 
         
            +
            choices.diffuser = OrderedDict([
         
     | 
| 40 | 
         
            +
                ['SD-v1-5' , "runwayml/stable-diffusion-v1-5"],
         
     | 
| 41 | 
         
            +
                ['OJ-v4' , "prompthero/openjourney-v4"],
         
     | 
| 42 | 
         
            +
                ['RR-v2', "SG161222/Realistic_Vision_V2.0"],
         
     | 
| 43 | 
         
            +
            ])
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            choices.lora = OrderedDict([
         
     | 
| 46 | 
         
            +
                ['empty', ""],
         
     | 
| 47 | 
         
            +
                ['Smooth-LoRA-v1', hf_hub_download('shi-labs/smooth-diffusion-lora', 'pytorch_model.bin')],
         
     | 
| 48 | 
         
            +
            ])
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            choices.scheduler = OrderedDict([
         
     | 
| 51 | 
         
            +
                ['DDIM', DDIMScheduler],
         
     | 
| 52 | 
         
            +
            ])
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            choices.inversion = OrderedDict([
         
     | 
| 55 | 
         
            +
                ['NTI', 'NTI'],
         
     | 
| 56 | 
         
            +
                ['DDIM w/o text', 'DDIM w/o text'],
         
     | 
| 57 | 
         
            +
                ['DDIM', 'DDIM'], 
         
     | 
| 58 | 
         
            +
            ])
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            default = edict()
         
     | 
| 61 | 
         
            +
            default.diffuser = 'SD-v1-5'
         
     | 
| 62 | 
         
            +
            default.scheduler = 'DDIM'
         
     | 
| 63 | 
         
            +
            default.lora = 'Smooth-LoRA-v1'
         
     | 
| 64 | 
         
            +
            default.inversion = 'NTI'
         
     | 
| 65 | 
         
            +
            default.step = 50
         
     | 
| 66 | 
         
            +
            default.cfg_scale = 7.5
         
     | 
| 67 | 
         
            +
            default.framen = 24
         
     | 
| 68 | 
         
            +
            default.fps = 16
         
     | 
| 69 | 
         
            +
            default.nullinv_inner_step = 10
         
     | 
| 70 | 
         
            +
            default.threshold = 0.8
         
     | 
| 71 | 
         
            +
            default.variation = 0.8
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            ##########
         
     | 
| 74 | 
         
            +
            # helper #
         
     | 
| 75 | 
         
            +
            ##########
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            def lerp(t, v0, v1):
         
     | 
| 78 | 
         
            +
                if isinstance(t, float):
         
     | 
| 79 | 
         
            +
                    return v0*(1-t) + v1*t
         
     | 
| 80 | 
         
            +
                elif isinstance(t, (list, np.ndarray)):
         
     | 
| 81 | 
         
            +
                    return [v0*(1-ti) + v1*ti for ti in t]
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
         
     | 
| 84 | 
         
            +
                # mostly copied from
         
     | 
| 85 | 
         
            +
                # https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
         
     | 
| 86 | 
         
            +
                v0_unit = v0 / np.linalg.norm(v0)
         
     | 
| 87 | 
         
            +
                v1_unit = v1 / np.linalg.norm(v1)
         
     | 
| 88 | 
         
            +
                dot = np.sum(v0_unit * v1_unit)
         
     | 
| 89 | 
         
            +
                if np.abs(dot) > DOT_THRESHOLD:
         
     | 
| 90 | 
         
            +
                    return lerp(t, v0, v1)
         
     | 
| 91 | 
         
            +
                # Calculate initial angle between v0 and v1
         
     | 
| 92 | 
         
            +
                theta_0 = np.arccos(dot)
         
     | 
| 93 | 
         
            +
                sin_theta_0 = np.sin(theta_0)
         
     | 
| 94 | 
         
            +
                # Angle at timestep t
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                if isinstance(t, float):
         
     | 
| 97 | 
         
            +
                    tlist = [t]
         
     | 
| 98 | 
         
            +
                elif isinstance(t, (list, np.ndarray)):
         
     | 
| 99 | 
         
            +
                    tlist = t
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                v2_list = []
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                for ti in tlist:
         
     | 
| 104 | 
         
            +
                    theta_t = theta_0 * ti
         
     | 
| 105 | 
         
            +
                    sin_theta_t = np.sin(theta_t)
         
     | 
| 106 | 
         
            +
                    # Finish the slerp algorithm
         
     | 
| 107 | 
         
            +
                    s0 = np.sin(theta_0 - theta_t) / sin_theta_0
         
     | 
| 108 | 
         
            +
                    s1 = sin_theta_t / sin_theta_0
         
     | 
| 109 | 
         
            +
                    v2 = s0 * v0 + s1 * v1
         
     | 
| 110 | 
         
            +
                    v2_list.append(v2)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                if isinstance(t, float):
         
     | 
| 113 | 
         
            +
                    return v2_list[0]
         
     | 
| 114 | 
         
            +
                else:
         
     | 
| 115 | 
         
            +
                    return v2_list
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
            def offset_resize(image, width=512, height=512, left=0, right=0, top=0, bottom=0):
         
     | 
| 118 | 
         
            +
               
         
     | 
| 119 | 
         
            +
                image = np.array(image)[:, :, :3]
         
     | 
| 120 | 
         
            +
                h, w, c = image.shape
         
     | 
| 121 | 
         
            +
                left = min(left, w-1)
         
     | 
| 122 | 
         
            +
                right = min(right, w - left - 1)
         
     | 
| 123 | 
         
            +
                top = min(top, h - left - 1)
         
     | 
| 124 | 
         
            +
                bottom = min(bottom, h - top - 1)
         
     | 
| 125 | 
         
            +
                image = image[top:h-bottom, left:w-right]
         
     | 
| 126 | 
         
            +
                h, w, c = image.shape
         
     | 
| 127 | 
         
            +
                if h < w:
         
     | 
| 128 | 
         
            +
                    offset = (w - h) // 2
         
     | 
| 129 | 
         
            +
                    image = image[:, offset:offset + h]
         
     | 
| 130 | 
         
            +
                elif w < h:
         
     | 
| 131 | 
         
            +
                    offset = (h - w) // 2
         
     | 
| 132 | 
         
            +
                    image = image[offset:offset + w]
         
     | 
| 133 | 
         
            +
                image = Image.fromarray(image).resize((width, height))
         
     | 
| 134 | 
         
            +
                return image
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            def auto_dtype_device_shape(tlist, v0, v1, func,):
         
     | 
| 137 | 
         
            +
                vshape = v0.shape
         
     | 
| 138 | 
         
            +
                assert v0.shape == v1.shape
         
     | 
| 139 | 
         
            +
                assert isinstance(tlist, (list, np.ndarray))
         
     | 
| 140 | 
         
            +
                
         
     | 
| 141 | 
         
            +
                if isinstance(v0, torch.Tensor):
         
     | 
| 142 | 
         
            +
                    is_torch = True
         
     | 
| 143 | 
         
            +
                    dtype, device = v0.dtype, v0.device
         
     | 
| 144 | 
         
            +
                    v0 = v0.to('cpu').numpy().astype(float).flatten()
         
     | 
| 145 | 
         
            +
                    v1 = v1.to('cpu').numpy().astype(float).flatten()
         
     | 
| 146 | 
         
            +
                else:
         
     | 
| 147 | 
         
            +
                    is_torch = False
         
     | 
| 148 | 
         
            +
                    dtype = v0.dtype
         
     | 
| 149 | 
         
            +
                    assert isinstance(v0, np.ndarray)
         
     | 
| 150 | 
         
            +
                    assert isinstance(v1, np.ndarray)
         
     | 
| 151 | 
         
            +
                    v0 = v0.astype(float).flatten()
         
     | 
| 152 | 
         
            +
                    v1 = v1.astype(float).flatten()
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                r = func(tlist, v0, v1)
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                if is_torch:
         
     | 
| 157 | 
         
            +
                    r = [torch.Tensor(ri).view(*vshape).to(dtype).to(device) for ri in r]
         
     | 
| 158 | 
         
            +
                else:
         
     | 
| 159 | 
         
            +
                    r = [ri.astype(dtype) for ri in r]
         
     | 
| 160 | 
         
            +
                return r
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
            auto_lerp = partial(auto_dtype_device_shape, func=lerp)
         
     | 
| 163 | 
         
            +
            auto_slerp = partial(auto_dtype_device_shape, func=slerp)
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
            def frames2mp4(vpath, frames, fps):
         
     | 
| 166 | 
         
            +
                import moviepy.editor as mpy
         
     | 
| 167 | 
         
            +
                frames = [np.array(framei) for framei in frames]
         
     | 
| 168 | 
         
            +
                clip = mpy.ImageSequenceClip(frames, fps=fps)
         
     | 
| 169 | 
         
            +
                clip.write_videofile(vpath, fps=fps)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
            def negseed_to_rndseed(seed):
         
     | 
| 172 | 
         
            +
                if seed < 0:
         
     | 
| 173 | 
         
            +
                    seed = np.random.randint(0, np.iinfo(np.uint32).max-100)
         
     | 
| 174 | 
         
            +
                return seed
         
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            def regulate_image(pilim):
         
     | 
| 177 | 
         
            +
                w, h = pilim.size
         
     | 
| 178 | 
         
            +
                w = int(round(w/64)) * 64
         
     | 
| 179 | 
         
            +
                h = int(round(h/64)) * 64
         
     | 
| 180 | 
         
            +
                return pilim.resize([w, h], resample=PIL.Image.BILINEAR)
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
            def txt_to_emb(model, prompt):
         
     | 
| 183 | 
         
            +
                text_input = model.tokenizer(
         
     | 
| 184 | 
         
            +
                    prompt,
         
     | 
| 185 | 
         
            +
                    padding="max_length",
         
     | 
| 186 | 
         
            +
                    max_length=model.tokenizer.model_max_length,
         
     | 
| 187 | 
         
            +
                    truncation=True,
         
     | 
| 188 | 
         
            +
                    return_tensors="pt",)
         
     | 
| 189 | 
         
            +
                text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
         
     | 
| 190 | 
         
            +
                return text_embeddings
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
            def hash_pilim(pilim):
         
     | 
| 193 | 
         
            +
                hasha = hashlib.md5(pilim.tobytes()).hexdigest()
         
     | 
| 194 | 
         
            +
                return hasha
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
            def hash_cfgdict(cfgdict):
         
     | 
| 197 | 
         
            +
                hashb = hashlib.md5(json.dumps(cfgdict, sort_keys=True).encode('utf-8')).hexdigest()
         
     | 
| 198 | 
         
            +
                return hashb
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
            def remove_earliest_file(path, max_allowance=500, remove_ratio=0.1, ext=None):
         
     | 
| 201 | 
         
            +
                if len(os.listdir(path)) <= max_allowance:
         
     | 
| 202 | 
         
            +
                    return
         
     | 
| 203 | 
         
            +
                def get_mtime(fname):
         
     | 
| 204 | 
         
            +
                    return osp.getmtime(osp.join(path, fname))
         
     | 
| 205 | 
         
            +
                if ext is None:
         
     | 
| 206 | 
         
            +
                    flist = sorted(os.listdir(path), key=get_mtime)
         
     | 
| 207 | 
         
            +
                else:
         
     | 
| 208 | 
         
            +
                    flist = [fi for fi in os.listdir(path) if fi.endswith(ext)]
         
     | 
| 209 | 
         
            +
                    flist = sorted(flist, key=get_mtime)
         
     | 
| 210 | 
         
            +
                exceedn = max(len(flist)-max_allowance, 0)
         
     | 
| 211 | 
         
            +
                removen = int(max_allowance*remove_ratio)
         
     | 
| 212 | 
         
            +
                removen = max(1, removen) + exceedn
         
     | 
| 213 | 
         
            +
                for fi in flist[0:removen]:
         
     | 
| 214 | 
         
            +
                    os.remove(osp.join(path, fi))
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
            def remove_decoupled_file(path, exta='.mp4', extb='.json'):
         
     | 
| 217 | 
         
            +
                tag_a = [osp.splitext(fi)[0] for fi in os.listdir(path) if fi.endswith(exta)]
         
     | 
| 218 | 
         
            +
                tag_b = [osp.splitext(fi)[0] for fi in os.listdir(path) if fi.endswith(extb)]
         
     | 
| 219 | 
         
            +
                tag_a_extra = set(tag_a) - set(tag_b)
         
     | 
| 220 | 
         
            +
                tag_b_extra = set(tag_b) - set(tag_a)
         
     | 
| 221 | 
         
            +
                [os.remove(osp.join(path, tagi+exta)) for tagi in tag_a_extra]
         
     | 
| 222 | 
         
            +
                [os.remove(osp.join(path, tagi+extb)) for tagi in tag_b_extra]
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
            @torch.no_grad()
         
     | 
| 225 | 
         
            +
            def t2i_core(model, xt, emb, nemb, step=30, cfg_scale=7.5, return_list=False):
         
     | 
| 226 | 
         
            +
                from nulltxtinv_wrapper import diffusion_step, latent2image
         
     | 
| 227 | 
         
            +
                model.scheduler.set_timesteps(step)
         
     | 
| 228 | 
         
            +
                xi = xt
         
     | 
| 229 | 
         
            +
                emb = txt_to_emb(model, "") if emb is None else emb
         
     | 
| 230 | 
         
            +
                nemb = txt_to_emb(model, "") if nemb is None else nemb
         
     | 
| 231 | 
         
            +
                if return_list:
         
     | 
| 232 | 
         
            +
                    xi_list = [xi.clone()]
         
     | 
| 233 | 
         
            +
                for i, t in enumerate(tqdm(model.scheduler.timesteps)):
         
     | 
| 234 | 
         
            +
                    embi = emb[i] if isinstance(emb, list) else emb
         
     | 
| 235 | 
         
            +
                    nembi = nemb[i] if isinstance(nemb, list) else nemb
         
     | 
| 236 | 
         
            +
                    context = torch.cat([nembi, embi])
         
     | 
| 237 | 
         
            +
                    xi = diffusion_step(model, xi, context, t, cfg_scale, low_resource=False)
         
     | 
| 238 | 
         
            +
                    if return_list:
         
     | 
| 239 | 
         
            +
                        xi_list.append(xi.clone())
         
     | 
| 240 | 
         
            +
                x0 = xi
         
     | 
| 241 | 
         
            +
                im = latent2image(model.vae, x0, return_type='pil')
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                if return_list:
         
     | 
| 244 | 
         
            +
                    return im, xi_list
         
     | 
| 245 | 
         
            +
                else:
         
     | 
| 246 | 
         
            +
                    return im
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
            ########
         
     | 
| 249 | 
         
            +
            # main #
         
     | 
| 250 | 
         
            +
            ########
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
            class wrapper(object):
         
     | 
| 253 | 
         
            +
                def __init__(self, 
         
     | 
| 254 | 
         
            +
                             fp16=False, 
         
     | 
| 255 | 
         
            +
                             tag_diffuser=None, 
         
     | 
| 256 | 
         
            +
                             tag_lora=None,
         
     | 
| 257 | 
         
            +
                             tag_scheduler=None,):
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    self.device = "cuda"
         
     | 
| 260 | 
         
            +
                    if fp16:
         
     | 
| 261 | 
         
            +
                        self.torch_dtype = torch.float16
         
     | 
| 262 | 
         
            +
                    else:
         
     | 
| 263 | 
         
            +
                        self.torch_dtype = torch.float32
         
     | 
| 264 | 
         
            +
                    self.load_all(tag_diffuser, tag_lora, tag_scheduler)
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    self.image_latent_dim = 4
         
     | 
| 267 | 
         
            +
                    self.batchsize = 8
         
     | 
| 268 | 
         
            +
                    self.seed = {}
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                    self.cache_video_folder = "temp/video"
         
     | 
| 271 | 
         
            +
                    self.cache_video_maxn = 500
         
     | 
| 272 | 
         
            +
                    self.cache_image_folder = "temp/image"
         
     | 
| 273 | 
         
            +
                    self.cache_image_maxn = 500
         
     | 
| 274 | 
         
            +
                    self.cache_inverse_folder = "temp/inverse"
         
     | 
| 275 | 
         
            +
                    self.cache_inverse_maxn = 500
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                def load_all(self, tag_diffuser, tag_lora, tag_scheduler):
         
     | 
| 278 | 
         
            +
                    self.load_diffuser_lora(tag_diffuser, tag_lora)
         
     | 
| 279 | 
         
            +
                    self.load_scheduler(tag_scheduler)
         
     | 
| 280 | 
         
            +
                    return tag_diffuser, tag_lora, tag_scheduler
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                def load_diffuser_lora(self, tag_diffuser, tag_lora):
         
     | 
| 283 | 
         
            +
                    self.net = StableDiffusionPipeline.from_pretrained(
         
     | 
| 284 | 
         
            +
                        choices.diffuser[tag_diffuser], torch_dtype=self.torch_dtype).to(self.device)
         
     | 
| 285 | 
         
            +
                    self.net.safety_checker = None
         
     | 
| 286 | 
         
            +
                    if tag_lora != 'empty':
         
     | 
| 287 | 
         
            +
                        self.net.unet.load_attn_procs(
         
     | 
| 288 | 
         
            +
                            choices.lora[tag_lora], use_safetensors=False,)
         
     | 
| 289 | 
         
            +
                    self.tag_diffuser = tag_diffuser
         
     | 
| 290 | 
         
            +
                    self.tag_lora = tag_lora
         
     | 
| 291 | 
         
            +
                    return tag_diffuser, tag_lora
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                def load_scheduler(self, tag_scheduler):
         
     | 
| 294 | 
         
            +
                    self.net.scheduler = choices.scheduler[tag_scheduler].from_config(self.net.scheduler.config)
         
     | 
| 295 | 
         
            +
                    self.tag_scheduler = tag_scheduler
         
     | 
| 296 | 
         
            +
                    return tag_scheduler
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                def reset_seed(self, which='ltintp'):
         
     | 
| 299 | 
         
            +
                    return -1
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                def recycle_seed(self, which='ltintp'):
         
     | 
| 302 | 
         
            +
                    if which not in self.seed:
         
     | 
| 303 | 
         
            +
                        return self.reset_seed(which=which)
         
     | 
| 304 | 
         
            +
                    else:
         
     | 
| 305 | 
         
            +
                        return self.seed[which]
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                ##########
         
     | 
| 308 | 
         
            +
                # helper #
         
     | 
| 309 | 
         
            +
                ##########
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                def precheck_model(self, tag_diffuser, tag_lora, tag_scheduler):
         
     | 
| 312 | 
         
            +
                    if (tag_diffuser != self.tag_diffuser) or (tag_lora != self.tag_lora):
         
     | 
| 313 | 
         
            +
                        self.load_all(tag_diffuser, tag_lora, tag_scheduler)
         
     | 
| 314 | 
         
            +
                    if tag_scheduler != self.tag_scheduler:
         
     | 
| 315 | 
         
            +
                        self.load_scheduler(tag_scheduler)
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                ########
         
     | 
| 318 | 
         
            +
                # main #
         
     | 
| 319 | 
         
            +
                ########
         
     | 
| 320 | 
         
            +
             
     | 
| 321 | 
         
            +
                def ddiminv(self, img, cfgdict):
         
     | 
| 322 | 
         
            +
                    txt, step, cfg_scale = cfgdict['txt'], cfgdict['step'], cfgdict['cfg_scale']
         
     | 
| 323 | 
         
            +
                    from nulltxtinv_wrapper import NullInversion
         
     | 
| 324 | 
         
            +
                    null_inversion_model = NullInversion(self.net, step, cfg_scale)
         
     | 
| 325 | 
         
            +
                    with torch.no_grad():
         
     | 
| 326 | 
         
            +
                        emb = txt_to_emb(self.net, txt)
         
     | 
| 327 | 
         
            +
                        nemb = txt_to_emb(self.net, "")
         
     | 
| 328 | 
         
            +
                    xt = null_inversion_model.ddim_invert(img, txt)
         
     | 
| 329 | 
         
            +
                    data = {
         
     | 
| 330 | 
         
            +
                        'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt,
         
     | 
| 331 | 
         
            +
                        'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
         
     | 
| 332 | 
         
            +
                        'xt': xt, 'emb': emb, 'nemb': nemb,}
         
     | 
| 333 | 
         
            +
                    return data
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                def nullinv_or_loadcache(self, img, cfgdict, force_reinvert=False):
         
     | 
| 336 | 
         
            +
                    hash = hash_pilim(img) + "--" + hash_cfgdict(cfgdict)
         
     | 
| 337 | 
         
            +
                    cdir = self.cache_inverse_folder
         
     | 
| 338 | 
         
            +
                    cfname = osp.join(cdir, hash+'.pth')
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                    if osp.isfile(cfname) and (not force_reinvert):
         
     | 
| 341 | 
         
            +
                        cache_data = torch.load(cfname)
         
     | 
| 342 | 
         
            +
                        dtype = next(self.net.unet.parameters()).dtype
         
     | 
| 343 | 
         
            +
                        device = next(self.net.unet.parameters()).device
         
     | 
| 344 | 
         
            +
                        cache_data['xt'] = cache_data['xt'].to(device=device, dtype=dtype)
         
     | 
| 345 | 
         
            +
                        cache_data['emb'] = cache_data['emb'].to(device=device, dtype=dtype)
         
     | 
| 346 | 
         
            +
                        cache_data['nemb'] = [
         
     | 
| 347 | 
         
            +
                            nembi.to(device=device, dtype=dtype)
         
     | 
| 348 | 
         
            +
                                for nembi in cache_data['nemb']]
         
     | 
| 349 | 
         
            +
                        return cache_data
         
     | 
| 350 | 
         
            +
                    else:
         
     | 
| 351 | 
         
            +
                        txt, step, cfg_scale = cfgdict['txt'], cfgdict['step'], cfgdict['cfg_scale']
         
     | 
| 352 | 
         
            +
                        inner_step = cfgdict['inner_step']
         
     | 
| 353 | 
         
            +
                        from nulltxtinv_wrapper import NullInversion
         
     | 
| 354 | 
         
            +
                        null_inversion_model = NullInversion(self.net, step, cfg_scale)
         
     | 
| 355 | 
         
            +
                        with torch.no_grad():
         
     | 
| 356 | 
         
            +
                            emb = txt_to_emb(self.net, txt)
         
     | 
| 357 | 
         
            +
                        xt, nemb = null_inversion_model.null_invert(img, txt, num_inner_steps=inner_step)
         
     | 
| 358 | 
         
            +
                        cache_data = {
         
     | 
| 359 | 
         
            +
                            'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt,
         
     | 
| 360 | 
         
            +
                            'inner_step' : inner_step,
         
     | 
| 361 | 
         
            +
                            'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
         
     | 
| 362 | 
         
            +
                            'xt' : xt.to('cpu'),
         
     | 
| 363 | 
         
            +
                            'emb' : emb.to('cpu'),
         
     | 
| 364 | 
         
            +
                            'nemb' : [nembi.to('cpu') for nembi in nemb],}
         
     | 
| 365 | 
         
            +
                        os.makedirs(cdir, exist_ok=True)
         
     | 
| 366 | 
         
            +
                        remove_earliest_file(cdir, max_allowance=self.cache_inverse_maxn)
         
     | 
| 367 | 
         
            +
                        torch.save(cache_data, cfname)
         
     | 
| 368 | 
         
            +
                        data = {
         
     | 
| 369 | 
         
            +
                            'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt,
         
     | 
| 370 | 
         
            +
                            'inner_step' : inner_step,
         
     | 
| 371 | 
         
            +
                            'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
         
     | 
| 372 | 
         
            +
                            'xt' : xt, 'emb' : emb, 'nemb' : nemb,}
         
     | 
| 373 | 
         
            +
                        return data
         
     | 
| 374 | 
         
            +
             
     | 
| 375 | 
         
            +
                def nullinvdual_or_loadcachedual(self, img0, img1, cfgdict, force_reinvert=False):
         
     | 
| 376 | 
         
            +
                    hash = hash_pilim(img0) + "--" + hash_pilim(img1) + "--" + hash_cfgdict(cfgdict)
         
     | 
| 377 | 
         
            +
                    cdir = self.cache_inverse_folder
         
     | 
| 378 | 
         
            +
                    cfname = osp.join(cdir, hash+'.pth')
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                    if osp.isfile(cfname) and (not force_reinvert):
         
     | 
| 381 | 
         
            +
                        cache_data = torch.load(cfname)
         
     | 
| 382 | 
         
            +
                        dtype = next(self.net.unet.parameters()).dtype
         
     | 
| 383 | 
         
            +
                        device = next(self.net.unet.parameters()).device
         
     | 
| 384 | 
         
            +
                        cache_data['xt0'] = cache_data['xt0'].to(device=device, dtype=dtype)
         
     | 
| 385 | 
         
            +
                        cache_data['xt1'] = cache_data['xt1'].to(device=device, dtype=dtype)
         
     | 
| 386 | 
         
            +
                        cache_data['emb0'] = cache_data['emb0'].to(device=device, dtype=dtype)
         
     | 
| 387 | 
         
            +
                        cache_data['emb1'] = cache_data['emb1'].to(device=device, dtype=dtype)
         
     | 
| 388 | 
         
            +
                        cache_data['nemb'] = [
         
     | 
| 389 | 
         
            +
                            nembi.to(device=device, dtype=dtype)
         
     | 
| 390 | 
         
            +
                                for nembi in cache_data['nemb']]
         
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
                        cache_data_a = copy.deepcopy(cache_data)
         
     | 
| 393 | 
         
            +
                        cache_data_a['xt'] = cache_data_a.pop('xt0')
         
     | 
| 394 | 
         
            +
                        cache_data_a['emb'] = cache_data_a.pop('emb0')
         
     | 
| 395 | 
         
            +
                        cache_data_a.pop('xt1'); cache_data_a.pop('emb1')
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                        cache_data_b = cache_data
         
     | 
| 398 | 
         
            +
                        cache_data_b['xt'] = cache_data_b.pop('xt1')
         
     | 
| 399 | 
         
            +
                        cache_data_b['emb'] = cache_data_b.pop('emb1')
         
     | 
| 400 | 
         
            +
                        cache_data_b.pop('xt0'); cache_data_b.pop('emb0')
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                        return cache_data_a, cache_data_b
         
     | 
| 403 | 
         
            +
                    else:
         
     | 
| 404 | 
         
            +
                        txt0, txt1, step, cfg_scale, inner_step = \
         
     | 
| 405 | 
         
            +
                            cfgdict['txt0'], cfgdict['txt1'], cfgdict['step'], \
         
     | 
| 406 | 
         
            +
                            cfgdict['cfg_scale'], cfgdict['inner_step']
         
     | 
| 407 | 
         
            +
                        
         
     | 
| 408 | 
         
            +
                        from nulltxtinv_wrapper import NullInversion
         
     | 
| 409 | 
         
            +
                        null_inversion_model = NullInversion(self.net, step, cfg_scale)
         
     | 
| 410 | 
         
            +
                        with torch.no_grad():
         
     | 
| 411 | 
         
            +
                            emb0 = txt_to_emb(self.net, txt0)
         
     | 
| 412 | 
         
            +
                            emb1 = txt_to_emb(self.net, txt1)
         
     | 
| 413 | 
         
            +
                        
         
     | 
| 414 | 
         
            +
                        xt0, xt1, nemb = null_inversion_model.null_invert_dual(
         
     | 
| 415 | 
         
            +
                            img0, img1, txt0, txt1, num_inner_steps=inner_step)
         
     | 
| 416 | 
         
            +
                        cache_data = {
         
     | 
| 417 | 
         
            +
                            'step' : step, 'cfg_scale' : cfg_scale, 
         
     | 
| 418 | 
         
            +
                            'txt0' : txt0, 'txt1' : txt1,
         
     | 
| 419 | 
         
            +
                            'inner_step' : inner_step,
         
     | 
| 420 | 
         
            +
                            'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
         
     | 
| 421 | 
         
            +
                            'xt0' : xt0.to('cpu'), 'xt1' : xt1.to('cpu'),
         
     | 
| 422 | 
         
            +
                            'emb0' : emb0.to('cpu'), 'emb1' : emb1.to('cpu'),
         
     | 
| 423 | 
         
            +
                            'nemb' : [nembi.to('cpu') for nembi in nemb],}
         
     | 
| 424 | 
         
            +
                        os.makedirs(cdir, exist_ok=True)
         
     | 
| 425 | 
         
            +
                        remove_earliest_file(cdir, max_allowance=self.cache_inverse_maxn)
         
     | 
| 426 | 
         
            +
                        torch.save(cache_data, cfname)
         
     | 
| 427 | 
         
            +
                        data0 = {
         
     | 
| 428 | 
         
            +
                            'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt0,
         
     | 
| 429 | 
         
            +
                            'inner_step' : inner_step,
         
     | 
| 430 | 
         
            +
                            'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
         
     | 
| 431 | 
         
            +
                            'xt' : xt0, 'emb' : emb0, 'nemb' : nemb,}
         
     | 
| 432 | 
         
            +
                        data1 = {
         
     | 
| 433 | 
         
            +
                            'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt1,
         
     | 
| 434 | 
         
            +
                            'inner_step' : inner_step,
         
     | 
| 435 | 
         
            +
                            'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
         
     | 
| 436 | 
         
            +
                            'xt' : xt1, 'emb' : emb1, 'nemb' : nemb,}
         
     | 
| 437 | 
         
            +
                        return data0, data1
         
     | 
| 438 | 
         
            +
             
     | 
| 439 | 
         
            +
                def image_inversion(
         
     | 
| 440 | 
         
            +
                        self, img, txt, 
         
     | 
| 441 | 
         
            +
                        cfg_scale, step, 
         
     | 
| 442 | 
         
            +
                        inversion, inner_step, force_reinvert):
         
     | 
| 443 | 
         
            +
                    from nulltxtinv_wrapper import text2image_ldm
         
     | 
| 444 | 
         
            +
                    if inversion == 'DDIM w/o text':
         
     | 
| 445 | 
         
            +
                        txt = ''
         
     | 
| 446 | 
         
            +
                    if not inversion == 'NTI':
         
     | 
| 447 | 
         
            +
                        data = self.ddiminv(img, {'txt':txt, 'step':step, 'cfg_scale':cfg_scale,})
         
     | 
| 448 | 
         
            +
                    else:
         
     | 
| 449 | 
         
            +
                        data = self.nullinv_or_loadcache(
         
     | 
| 450 | 
         
            +
                            img, {'txt':txt, 'step':step,
         
     | 
| 451 | 
         
            +
                                  'cfg_scale':cfg_scale, 'inner_step':inner_step,
         
     | 
| 452 | 
         
            +
                                  'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,}, force_reinvert)
         
     | 
| 453 | 
         
            +
                    
         
     | 
| 454 | 
         
            +
                    if inversion == 'NTI':
         
     | 
| 455 | 
         
            +
                        img_inv, _ = text2image_ldm(
         
     | 
| 456 | 
         
            +
                            self.net, [txt], step, cfg_scale, 
         
     | 
| 457 | 
         
            +
                            latent=data['xt'], uncond_embeddings=data['nemb'])
         
     | 
| 458 | 
         
            +
                    else:
         
     | 
| 459 | 
         
            +
                        img_inv, _ = text2image_ldm(
         
     | 
| 460 | 
         
            +
                        self.net, [txt], step, cfg_scale,
         
     | 
| 461 | 
         
            +
                        latent=data['xt'], uncond_embeddings=None)
         
     | 
| 462 | 
         
            +
                        
         
     | 
| 463 | 
         
            +
                    return img_inv
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
                def image_editing(
         
     | 
| 466 | 
         
            +
                    self, img, txt_0, txt_1,
         
     | 
| 467 | 
         
            +
                    cfg_scale, step, thresh, 
         
     | 
| 468 | 
         
            +
                    inversion, inner_step, force_reinvert):
         
     | 
| 469 | 
         
            +
                    from nulltxtinv_wrapper import text2image_ldm_imedit
         
     | 
| 470 | 
         
            +
                    if inversion == 'DDIM w/o text':
         
     | 
| 471 | 
         
            +
                        txt_0 = ''
         
     | 
| 472 | 
         
            +
                    if not inversion == 'NTI':
         
     | 
| 473 | 
         
            +
                        data = self.ddiminv(img, {'txt':txt_0, 'step':step, 'cfg_scale':cfg_scale,})
         
     | 
| 474 | 
         
            +
                        img_edited, _ = text2image_ldm_imedit(
         
     | 
| 475 | 
         
            +
                            self.net, thresh, [txt_0], [txt_1], step, cfg_scale,
         
     | 
| 476 | 
         
            +
                            latent=data['xt'], uncond_embeddings=None)
         
     | 
| 477 | 
         
            +
                    else:
         
     | 
| 478 | 
         
            +
                        data = self.nullinv_or_loadcache(
         
     | 
| 479 | 
         
            +
                            img, {'txt':txt_0, 'step':step,
         
     | 
| 480 | 
         
            +
                                  'cfg_scale':cfg_scale, 'inner_step':inner_step,
         
     | 
| 481 | 
         
            +
                                  'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,}, force_reinvert)
         
     | 
| 482 | 
         
            +
                        img_edited, _ = text2image_ldm_imedit(
         
     | 
| 483 | 
         
            +
                            self.net, thresh, [txt_0], [txt_1], step, cfg_scale,
         
     | 
| 484 | 
         
            +
                            latent=data['xt'], uncond_embeddings=data['nemb'])
         
     | 
| 485 | 
         
            +
                    
         
     | 
| 486 | 
         
            +
                    return img_edited
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
                def general_interpolation(
         
     | 
| 489 | 
         
            +
                        self, xset0, xset1,
         
     | 
| 490 | 
         
            +
                        cfg_scale, step, tlist,):
         
     | 
| 491 | 
         
            +
             
     | 
| 492 | 
         
            +
                    xt0, emb0, nemb0 = xset0['xt'], xset0['emb'], xset0['nemb']
         
     | 
| 493 | 
         
            +
                    xt1, emb1, nemb1 = xset1['xt'], xset1['emb'], xset1['nemb']
         
     | 
| 494 | 
         
            +
                    framen = len(tlist)
         
     | 
| 495 | 
         
            +
             
     | 
| 496 | 
         
            +
                    xt_list = auto_slerp(tlist, xt0, xt1)
         
     | 
| 497 | 
         
            +
                    emb_list = auto_lerp(tlist, emb0, emb1)
         
     | 
| 498 | 
         
            +
                    
         
     | 
| 499 | 
         
            +
                    if isinstance(nemb0, list) and isinstance(nemb1, list):
         
     | 
| 500 | 
         
            +
                        assert len(nemb0) == len(nemb1)
         
     | 
| 501 | 
         
            +
                        nemb_list = [auto_lerp(tlist, e0, e1) for e0, e1 in zip(nemb0, nemb1)]
         
     | 
| 502 | 
         
            +
                        nemb_islist = True
         
     | 
| 503 | 
         
            +
                    else:
         
     | 
| 504 | 
         
            +
                        nemb_list = auto_lerp(tlist, nemb0, nemb1)
         
     | 
| 505 | 
         
            +
                        nemb_islist = False
         
     | 
| 506 | 
         
            +
             
     | 
| 507 | 
         
            +
                    im_list = []
         
     | 
| 508 | 
         
            +
                    for frameidx in range(0, len(xt_list), self.batchsize):
         
     | 
| 509 | 
         
            +
                        xt_batch = [xt_list[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))]
         
     | 
| 510 | 
         
            +
                        xt_batch = torch.cat(xt_batch, dim=0)
         
     | 
| 511 | 
         
            +
                        emb_batch = [emb_list[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))]
         
     | 
| 512 | 
         
            +
                        emb_batch = torch.cat(emb_batch, dim=0)
         
     | 
| 513 | 
         
            +
                        if nemb_islist:
         
     | 
| 514 | 
         
            +
                            nemb_batch = []
         
     | 
| 515 | 
         
            +
                            for nembi in nemb_list:
         
     | 
| 516 | 
         
            +
                                nembi_batch = [nembi[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))]
         
     | 
| 517 | 
         
            +
                                nembi_batch = torch.cat(nembi_batch, dim=0)
         
     | 
| 518 | 
         
            +
                                nemb_batch.append(nembi_batch)
         
     | 
| 519 | 
         
            +
                        else:
         
     | 
| 520 | 
         
            +
                            nemb_batch = [nemb_list[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))]
         
     | 
| 521 | 
         
            +
                            nemb_batch = torch.cat(nemb_batch, dim=0)
         
     | 
| 522 | 
         
            +
             
     | 
| 523 | 
         
            +
                        im = t2i_core(
         
     | 
| 524 | 
         
            +
                            self.net, xt_batch, emb_batch, nemb_batch, step, cfg_scale)
         
     | 
| 525 | 
         
            +
                        im_list += im if isinstance(im, list) else [im]
         
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
                    return im_list
         
     | 
| 528 | 
         
            +
             
     | 
| 529 | 
         
            +
                def run_iminvs(
         
     | 
| 530 | 
         
            +
                        self, img, text, 
         
     | 
| 531 | 
         
            +
                        cfg_scale, step, 
         
     | 
| 532 | 
         
            +
                        force_resize, width, height,
         
     | 
| 533 | 
         
            +
                        inversion, inner_step, force_reinvert,
         
     | 
| 534 | 
         
            +
                        tag_diffuser, tag_lora, tag_scheduler, ):
         
     | 
| 535 | 
         
            +
                    
         
     | 
| 536 | 
         
            +
                    self.precheck_model(tag_diffuser, tag_lora, tag_scheduler)
         
     | 
| 537 | 
         
            +
                    
         
     | 
| 538 | 
         
            +
                    if force_resize:
         
     | 
| 539 | 
         
            +
                        img = offset_resize(img, width, height)
         
     | 
| 540 | 
         
            +
                    else:
         
     | 
| 541 | 
         
            +
                        img = regulate_image(img)
         
     | 
| 542 | 
         
            +
             
     | 
| 543 | 
         
            +
                    recon_output = self.image_inversion(
         
     | 
| 544 | 
         
            +
                        img, text, cfg_scale, step, 
         
     | 
| 545 | 
         
            +
                        inversion, inner_step, force_reinvert)
         
     | 
| 546 | 
         
            +
             
     | 
| 547 | 
         
            +
                    idir = self.cache_image_folder
         
     | 
| 548 | 
         
            +
                    os.makedirs(idir, exist_ok=True)
         
     | 
| 549 | 
         
            +
                    remove_earliest_file(idir, max_allowance=self.cache_image_maxn)
         
     | 
| 550 | 
         
            +
                    sname = "time{}_iminvs_{}_{}".format(
         
     | 
| 551 | 
         
            +
                        int(time.time()), self.tag_diffuser, self.tag_lora,)
         
     | 
| 552 | 
         
            +
                    ipath = osp.join(idir, sname+'.png')
         
     | 
| 553 | 
         
            +
                    recon_output.save(ipath)
         
     | 
| 554 | 
         
            +
                    
         
     | 
| 555 | 
         
            +
                    return [recon_output]
         
     | 
| 556 | 
         
            +
             
     | 
| 557 | 
         
            +
                def run_imedit(
         
     | 
| 558 | 
         
            +
                        self, img, txt_0,txt_1, 
         
     | 
| 559 | 
         
            +
                        threshold, cfg_scale, step, 
         
     | 
| 560 | 
         
            +
                        force_resize, width, height,
         
     | 
| 561 | 
         
            +
                        inversion, inner_step, force_reinvert,
         
     | 
| 562 | 
         
            +
                        tag_diffuser, tag_lora, tag_scheduler, ):
         
     | 
| 563 | 
         
            +
                    
         
     | 
| 564 | 
         
            +
                    self.precheck_model(tag_diffuser, tag_lora, tag_scheduler)
         
     | 
| 565 | 
         
            +
                    if force_resize:
         
     | 
| 566 | 
         
            +
                        img = offset_resize(img, width, height)
         
     | 
| 567 | 
         
            +
                    else:
         
     | 
| 568 | 
         
            +
                        img = regulate_image(img)
         
     | 
| 569 | 
         
            +
             
     | 
| 570 | 
         
            +
                    edited_img= self.image_editing(
         
     | 
| 571 | 
         
            +
                        img, txt_0,txt_1, cfg_scale, step, threshold,
         
     | 
| 572 | 
         
            +
                        inversion, inner_step, force_reinvert)
         
     | 
| 573 | 
         
            +
             
     | 
| 574 | 
         
            +
                    idir = self.cache_image_folder
         
     | 
| 575 | 
         
            +
                    os.makedirs(idir, exist_ok=True)
         
     | 
| 576 | 
         
            +
                    remove_earliest_file(idir, max_allowance=self.cache_image_maxn)
         
     | 
| 577 | 
         
            +
                    sname = "time{}_imedit_{}_{}".format(
         
     | 
| 578 | 
         
            +
                        int(time.time()), self.tag_diffuser, self.tag_lora,)
         
     | 
| 579 | 
         
            +
                    ipath = osp.join(idir, sname+'.png')
         
     | 
| 580 | 
         
            +
                    edited_img.save(ipath)
         
     | 
| 581 | 
         
            +
                 
         
     | 
| 582 | 
         
            +
                    return [edited_img]
         
     | 
| 583 | 
         
            +
             
     | 
| 584 | 
         
            +
             
     | 
| 585 | 
         
            +
                def run_imintp(
         
     | 
| 586 | 
         
            +
                        self, 
         
     | 
| 587 | 
         
            +
                        img0, img1, txt0, txt1,
         
     | 
| 588 | 
         
            +
                        cfg_scale, step, 
         
     | 
| 589 | 
         
            +
                        framen, fps, 
         
     | 
| 590 | 
         
            +
                        force_resize, width, height,
         
     | 
| 591 | 
         
            +
                        inversion, inner_step, force_reinvert,
         
     | 
| 592 | 
         
            +
                        tag_diffuser, tag_lora, tag_scheduler,):
         
     | 
| 593 | 
         
            +
                    
         
     | 
| 594 | 
         
            +
                    self.precheck_model(tag_diffuser, tag_lora, tag_scheduler)
         
     | 
| 595 | 
         
            +
                    if txt1 == '':
         
     | 
| 596 | 
         
            +
                        txt1 = txt0
         
     | 
| 597 | 
         
            +
                    if force_resize:
         
     | 
| 598 | 
         
            +
                        img0 = offset_resize(img0, width, height)
         
     | 
| 599 | 
         
            +
                        img1 = offset_resize(img1, width, height)
         
     | 
| 600 | 
         
            +
                    else:
         
     | 
| 601 | 
         
            +
                        img0 = regulate_image(img0)
         
     | 
| 602 | 
         
            +
                        img1 = regulate_image(img1)
         
     | 
| 603 | 
         
            +
             
     | 
| 604 | 
         
            +
                    if inversion == 'DDIM':
         
     | 
| 605 | 
         
            +
                        data0 = self.ddiminv(img0, {'txt':txt0, 'step':step, 'cfg_scale':cfg_scale,})
         
     | 
| 606 | 
         
            +
                        data1 = self.ddiminv(img1, {'txt':txt1, 'step':step, 'cfg_scale':cfg_scale,})
         
     | 
| 607 | 
         
            +
                    elif inversion == 'DDIM w/o text':
         
     | 
| 608 | 
         
            +
                        data0 = self.ddiminv(img0, {'txt':"", 'step':step, 'cfg_scale':cfg_scale,})
         
     | 
| 609 | 
         
            +
                        data1 = self.ddiminv(img1, {'txt':"", 'step':step, 'cfg_scale':cfg_scale,})
         
     | 
| 610 | 
         
            +
                    else:
         
     | 
| 611 | 
         
            +
                        data0, data1 = self.nullinvdual_or_loadcachedual(
         
     | 
| 612 | 
         
            +
                            img0, img1, {'txt0':txt0, 'txt1':txt1, 'step':step,
         
     | 
| 613 | 
         
            +
                                         'cfg_scale':cfg_scale, 'inner_step':inner_step,
         
     | 
| 614 | 
         
            +
                                         'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,}, force_reinvert)
         
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
                    tlist = np.linspace(0.0, 1.0, framen)
         
     | 
| 617 | 
         
            +
             
     | 
| 618 | 
         
            +
                    iminv0 = t2i_core(self.net, data0['xt'], data0['emb'], data0['nemb'], step, cfg_scale)
         
     | 
| 619 | 
         
            +
                    iminv1 = t2i_core(self.net, data1['xt'], data1['emb'], data1['nemb'], step, cfg_scale)
         
     | 
| 620 | 
         
            +
                    frames = self.general_interpolation(data0, data1, cfg_scale, step, tlist)
         
     | 
| 621 | 
         
            +
             
     | 
| 622 | 
         
            +
                    vdir = self.cache_video_folder
         
     | 
| 623 | 
         
            +
                    os.makedirs(vdir, exist_ok=True)
         
     | 
| 624 | 
         
            +
                    remove_earliest_file(vdir, max_allowance=self.cache_video_maxn)
         
     | 
| 625 | 
         
            +
                    sname = "time{}_imintp_{}_{}_framen{}_fps{}".format(
         
     | 
| 626 | 
         
            +
                        int(time.time()), self.tag_diffuser, self.tag_lora, framen, fps)
         
     | 
| 627 | 
         
            +
                    vpath = osp.join(vdir, sname+'.mp4')
         
     | 
| 628 | 
         
            +
                    frames2mp4(vpath, frames, fps)
         
     | 
| 629 | 
         
            +
                    jpath = osp.join(vdir, sname+'.json')
         
     | 
| 630 | 
         
            +
                    cfgdict = {
         
     | 
| 631 | 
         
            +
                        "method" : "image_interpolation",
         
     | 
| 632 | 
         
            +
                        "txt0" : txt0, "txt1" : txt1,
         
     | 
| 633 | 
         
            +
                        "cfg_scale" : cfg_scale, "step" : step, 
         
     | 
| 634 | 
         
            +
                        "framen" : framen, "fps" : fps,
         
     | 
| 635 | 
         
            +
                        "force_resize" : force_resize, "width" : width, "height" : height,
         
     | 
| 636 | 
         
            +
                        "inversion" : inversion, "inner_step" : inner_step, 
         
     | 
| 637 | 
         
            +
                        "force_reinvert" : force_reinvert, 
         
     | 
| 638 | 
         
            +
                        "tag_diffuser" : tag_diffuser, "tag_lora" : tag_lora, "tag_scheduler" : tag_scheduler,}
         
     | 
| 639 | 
         
            +
                    with open(jpath, 'w') as f:
         
     | 
| 640 | 
         
            +
                        json.dump(cfgdict, f, indent=4)
         
     | 
| 641 | 
         
            +
             
     | 
| 642 | 
         
            +
                    return frames, vpath, [iminv0, iminv1]
         
     | 
| 643 | 
         
            +
             
     | 
| 644 | 
         
            +
            #################
         
     | 
| 645 | 
         
            +
            # get examples #
         
     | 
| 646 | 
         
            +
            #################
         
     | 
| 647 | 
         
            +
            cache_examples = False
         
     | 
| 648 | 
         
            +
            def get_imintp_example():
         
     | 
| 649 | 
         
            +
                case = [
         
     | 
| 650 | 
         
            +
                    [
         
     | 
| 651 | 
         
            +
                        'assets/images/interpolation/cityview1.png', 
         
     | 
| 652 | 
         
            +
                        'assets/images/interpolation/cityview2.png', 
         
     | 
| 653 | 
         
            +
                        'A city view',],
         
     | 
| 654 | 
         
            +
                    [
         
     | 
| 655 | 
         
            +
                        'assets/images/interpolation/woman1.png', 
         
     | 
| 656 | 
         
            +
                        'assets/images/interpolation/woman2.png', 
         
     | 
| 657 | 
         
            +
                        'A woman face',],
         
     | 
| 658 | 
         
            +
                    [
         
     | 
| 659 | 
         
            +
                        'assets/images/interpolation/land1.png', 
         
     | 
| 660 | 
         
            +
                        'assets/images/interpolation/land2.png', 
         
     | 
| 661 | 
         
            +
                        'A beautiful landscape',],
         
     | 
| 662 | 
         
            +
                    [
         
     | 
| 663 | 
         
            +
                        'assets/images/interpolation/dog1.png', 
         
     | 
| 664 | 
         
            +
                        'assets/images/interpolation/dog2.png', 
         
     | 
| 665 | 
         
            +
                        'A realistic dog',],
         
     | 
| 666 | 
         
            +
                    [
         
     | 
| 667 | 
         
            +
                        'assets/images/interpolation/church1.png', 
         
     | 
| 668 | 
         
            +
                        'assets/images/interpolation/church2.png', 
         
     | 
| 669 | 
         
            +
                        'A church',],
         
     | 
| 670 | 
         
            +
                    [
         
     | 
| 671 | 
         
            +
                        'assets/images/interpolation/rabbit1.png', 
         
     | 
| 672 | 
         
            +
                        'assets/images/interpolation/rabbit2.png', 
         
     | 
| 673 | 
         
            +
                        'A cute rabbit',],
         
     | 
| 674 | 
         
            +
                    [
         
     | 
| 675 | 
         
            +
                        'assets/images/interpolation/horse1.png', 
         
     | 
| 676 | 
         
            +
                        'assets/images/interpolation/horse2.png', 
         
     | 
| 677 | 
         
            +
                        'A robot horse',],
         
     | 
| 678 | 
         
            +
                ]
         
     | 
| 679 | 
         
            +
                return case
         
     | 
| 680 | 
         
            +
             
     | 
| 681 | 
         
            +
            def get_iminvs_example():
         
     | 
| 682 | 
         
            +
                case = [
         
     | 
| 683 | 
         
            +
                    [
         
     | 
| 684 | 
         
            +
                        'assets/images/inversion/000000560011.jpg', 
         
     | 
| 685 | 
         
            +
                        'A mouse is next to a keyboard on a desk',],
         
     | 
| 686 | 
         
            +
                    [
         
     | 
| 687 | 
         
            +
                        'assets/images/inversion/000000029596.jpg', 
         
     | 
| 688 | 
         
            +
                        'A room with a couch, table set with dinnerware and a television.',],
         
     | 
| 689 | 
         
            +
                ]
         
     | 
| 690 | 
         
            +
                return case
         
     | 
| 691 | 
         
            +
             
     | 
| 692 | 
         
            +
             
     | 
| 693 | 
         
            +
            def get_imedit_example():
         
     | 
| 694 | 
         
            +
                case = [
         
     | 
| 695 | 
         
            +
                    [
         
     | 
| 696 | 
         
            +
                        'assets/images/editing/rabbit.png', 
         
     | 
| 697 | 
         
            +
                        'A rabbit is eating a watermelon on the table', 
         
     | 
| 698 | 
         
            +
                        'A cat is eating a watermelon on the table', 
         
     | 
| 699 | 
         
            +
                        0.7,],
         
     | 
| 700 | 
         
            +
                    [
         
     | 
| 701 | 
         
            +
                        'assets/images/editing/cake.png', 
         
     | 
| 702 | 
         
            +
                        'A chocolate cake with cream on it', 
         
     | 
| 703 | 
         
            +
                        'A chocolate cake with strawberries on it', 
         
     | 
| 704 | 
         
            +
                        0.9,],
         
     | 
| 705 | 
         
            +
                    [
         
     | 
| 706 | 
         
            +
                        'assets/images/editing/banana.png', 
         
     | 
| 707 | 
         
            +
                        'A banana on the table', 
         
     | 
| 708 | 
         
            +
                        'A banana and an apple on the table', 
         
     | 
| 709 | 
         
            +
                        0.8,],
         
     | 
| 710 | 
         
            +
                    
         
     | 
| 711 | 
         
            +
                ]
         
     | 
| 712 | 
         
            +
                return case
         
     | 
| 713 | 
         
            +
             
     | 
| 714 | 
         
            +
             
     | 
| 715 | 
         
            +
            #################
         
     | 
| 716 | 
         
            +
            # sub interface #
         
     | 
| 717 | 
         
            +
            #################
         
     | 
| 718 | 
         
            +
             
     | 
| 719 | 
         
            +
             
     | 
| 720 | 
         
            +
            def interface_imintp(wrapper_obj):
         
     | 
| 721 | 
         
            +
                with gr.Row():
         
     | 
| 722 | 
         
            +
                    with gr.Column():
         
     | 
| 723 | 
         
            +
                        img0 = gr.Image(label="Image Input 0", type='pil',  elem_id='customized_imbox')
         
     | 
| 724 | 
         
            +
                    with gr.Column():
         
     | 
| 725 | 
         
            +
                        img1 = gr.Image(label="Image Input 1", type='pil', elem_id='customized_imbox')
         
     | 
| 726 | 
         
            +
                    with gr.Column():
         
     | 
| 727 | 
         
            +
                        video_output = gr.Video(label="Video Result", format='mp4', elem_id='customized_imbox')
         
     | 
| 728 | 
         
            +
                with gr.Row(): 
         
     | 
| 729 | 
         
            +
                    with gr.Column():      
         
     | 
| 730 | 
         
            +
                        txt0 = gr.Textbox(label='Text Input', lines=1, placeholder="Input prompt...", )
         
     | 
| 731 | 
         
            +
                    with gr.Column(): 
         
     | 
| 732 | 
         
            +
                        with gr.Row():
         
     | 
| 733 | 
         
            +
                            inversion = auto_dropdown('Inversion', choices.inversion, default.inversion)
         
     | 
| 734 | 
         
            +
                            inner_step = gr.Slider(label="Inner Step (NTI)", value=default.nullinv_inner_step, minimum=1, maximum=10, step=1)
         
     | 
| 735 | 
         
            +
                            force_reinvert = gr.Checkbox(label="Force ReInvert (NTI)", value=False)
         
     | 
| 736 | 
         
            +
                                
         
     | 
| 737 | 
         
            +
             
     | 
| 738 | 
         
            +
                with gr.Row():
         
     | 
| 739 | 
         
            +
                    with gr.Column(): 
         
     | 
| 740 | 
         
            +
                        with gr.Row():
         
     | 
| 741 | 
         
            +
                            framen = gr.Slider(label="Frame Number", minimum=8, maximum=default.framen, value=default.framen, step=1)
         
     | 
| 742 | 
         
            +
                            fps = gr.Slider(label="Video FPS", minimum=4, maximum=default.fps, value=default.fps, step=4)
         
     | 
| 743 | 
         
            +
                        with gr.Row():
         
     | 
| 744 | 
         
            +
                            button_run = gr.Button("Run") 
         
     | 
| 745 | 
         
            +
                        
         
     | 
| 746 | 
         
            +
             
     | 
| 747 | 
         
            +
                    with gr.Column():
         
     | 
| 748 | 
         
            +
                        with gr.Accordion('Frame Results', open=False):
         
     | 
| 749 | 
         
            +
                            frame_output = gr.Gallery(label="Frames", elem_id='customized_imbox')
         
     | 
| 750 | 
         
            +
                        with gr.Accordion("Inversion Results", open=False):
         
     | 
| 751 | 
         
            +
                            inv_output = gr.Gallery(label="Inversion Results", elem_id='customized_imbox')
         
     | 
| 752 | 
         
            +
                        with gr.Accordion('Advanced Settings', open=False):
         
     | 
| 753 | 
         
            +
                            with gr.Row():
         
     | 
| 754 | 
         
            +
                                tag_diffuser = auto_dropdown('Diffuser', choices.diffuser, default.diffuser)
         
     | 
| 755 | 
         
            +
                                tag_lora = auto_dropdown('Use LoRA', choices.lora, default.lora)
         
     | 
| 756 | 
         
            +
                                tag_scheduler = auto_dropdown('Scheduler', choices.scheduler, default.scheduler)
         
     | 
| 757 | 
         
            +
                            with gr.Row():
         
     | 
| 758 | 
         
            +
                                cfg_scale = gr.Number(label="Scale", minimum=1, maximum=10, value=default.cfg_scale, step=0.5)
         
     | 
| 759 | 
         
            +
                                step = gr.Number(default.step, label="Step", precision=0)
         
     | 
| 760 | 
         
            +
                            with gr.Row():
         
     | 
| 761 | 
         
            +
                                force_resize = gr.Checkbox(label="Force Resize", value=True)
         
     | 
| 762 | 
         
            +
                                inp_width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
         
     | 
| 763 | 
         
            +
                                inp_height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64)
         
     | 
| 764 | 
         
            +
                            with gr.Row():
         
     | 
| 765 | 
         
            +
                                txt1 = gr.Textbox(label='Optional Different Text Input for Image Input 1', lines=1, placeholder="Input prompt...", )
         
     | 
| 766 | 
         
            +
                        
         
     | 
| 767 | 
         
            +
             
     | 
| 768 | 
         
            +
                tag_diffuser.change(
         
     | 
| 769 | 
         
            +
                    wrapper_obj.load_all,
         
     | 
| 770 | 
         
            +
                    inputs = [tag_diffuser, tag_lora, tag_scheduler],
         
     | 
| 771 | 
         
            +
                    outputs = [tag_diffuser, tag_lora, tag_scheduler],)
         
     | 
| 772 | 
         
            +
             
     | 
| 773 | 
         
            +
                tag_lora.change(
         
     | 
| 774 | 
         
            +
                    wrapper_obj.load_all,
         
     | 
| 775 | 
         
            +
                    inputs = [tag_diffuser, tag_lora, tag_scheduler],
         
     | 
| 776 | 
         
            +
                    outputs = [tag_diffuser, tag_lora, tag_scheduler],)
         
     | 
| 777 | 
         
            +
             
     | 
| 778 | 
         
            +
                tag_scheduler.change(
         
     | 
| 779 | 
         
            +
                    wrapper_obj.load_scheduler,
         
     | 
| 780 | 
         
            +
                    inputs = [tag_scheduler],
         
     | 
| 781 | 
         
            +
                    outputs = [tag_scheduler],)
         
     | 
| 782 | 
         
            +
             
     | 
| 783 | 
         
            +
                button_run.click(
         
     | 
| 784 | 
         
            +
                    wrapper_obj.run_imintp,
         
     | 
| 785 | 
         
            +
                    inputs=[img0, img1, txt0, txt1,
         
     | 
| 786 | 
         
            +
                            cfg_scale, step, 
         
     | 
| 787 | 
         
            +
                            framen, fps, 
         
     | 
| 788 | 
         
            +
                            force_resize, inp_width, inp_height,
         
     | 
| 789 | 
         
            +
                            inversion, inner_step, force_reinvert,
         
     | 
| 790 | 
         
            +
                            tag_diffuser, tag_lora, tag_scheduler,],
         
     | 
| 791 | 
         
            +
                    outputs=[frame_output, video_output, inv_output])
         
     | 
| 792 | 
         
            +
             
     | 
| 793 | 
         
            +
                gr.Examples(
         
     | 
| 794 | 
         
            +
                    label='Examples', 
         
     | 
| 795 | 
         
            +
                    examples=get_imintp_example(), 
         
     | 
| 796 | 
         
            +
                    fn=wrapper_obj.run_imintp,
         
     | 
| 797 | 
         
            +
                    inputs=[img0, img1, txt0,],
         
     | 
| 798 | 
         
            +
                    outputs=[frame_output, video_output, inv_output],
         
     | 
| 799 | 
         
            +
                    cache_examples=cache_examples,)
         
     | 
| 800 | 
         
            +
             
     | 
| 801 | 
         
            +
            def interface_iminvs(wrapper_obj):
         
     | 
| 802 | 
         
            +
                with gr.Row():
         
     | 
| 803 | 
         
            +
                    image_input = gr.Image(label="Image input", type='pil', elem_id='customized_imbox')
         
     | 
| 804 | 
         
            +
                    recon_output = gr.Gallery(label="Reconstruction output", elem_id='customized_imbox')
         
     | 
| 805 | 
         
            +
                with gr.Row():
         
     | 
| 806 | 
         
            +
                    with gr.Column():
         
     | 
| 807 | 
         
            +
                        prompt = gr.Textbox(label='Text Input', lines=1, placeholder="Input prompt...", )
         
     | 
| 808 | 
         
            +
                        with gr.Row():
         
     | 
| 809 | 
         
            +
                            button_run = gr.Button("Run")
         
     | 
| 810 | 
         
            +
                        
         
     | 
| 811 | 
         
            +
                        
         
     | 
| 812 | 
         
            +
                    with gr.Column():
         
     | 
| 813 | 
         
            +
                        with gr.Row():
         
     | 
| 814 | 
         
            +
                            inversion = auto_dropdown('Inversion', choices.inversion, default.inversion)
         
     | 
| 815 | 
         
            +
                            inner_step = gr.Slider(label="Inner Step (NTI)", value=default.nullinv_inner_step, minimum=1, maximum=10, step=1)
         
     | 
| 816 | 
         
            +
                            force_reinvert = gr.Checkbox(label="Force ReInvert (NTI)", value=False)
         
     | 
| 817 | 
         
            +
                        with gr.Accordion('Advanced Settings', open=False):
         
     | 
| 818 | 
         
            +
                            with gr.Row():
         
     | 
| 819 | 
         
            +
                                tag_diffuser = auto_dropdown('Diffuser', choices.diffuser, default.diffuser)
         
     | 
| 820 | 
         
            +
                                tag_lora = auto_dropdown('Use LoRA', choices.lora, default.lora)
         
     | 
| 821 | 
         
            +
                                tag_scheduler = auto_dropdown('Scheduler', choices.scheduler, default.scheduler)
         
     | 
| 822 | 
         
            +
                            with gr.Row():
         
     | 
| 823 | 
         
            +
                                cfg_scale = gr.Number(label="Scale", minimum=1, maximum=10, value=default.cfg_scale, step=0.5)
         
     | 
| 824 | 
         
            +
                                step = gr.Number(default.step, label="Step", precision=0)
         
     | 
| 825 | 
         
            +
                            with gr.Row():
         
     | 
| 826 | 
         
            +
                                force_resize = gr.Checkbox(label="Force Resize", value=True)
         
     | 
| 827 | 
         
            +
                                inp_width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
         
     | 
| 828 | 
         
            +
                                inp_height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64)
         
     | 
| 829 | 
         
            +
                        
         
     | 
| 830 | 
         
            +
             
     | 
| 831 | 
         
            +
                tag_diffuser.change(
         
     | 
| 832 | 
         
            +
                    wrapper_obj.load_all,
         
     | 
| 833 | 
         
            +
                    inputs = [tag_diffuser, tag_lora, tag_scheduler],
         
     | 
| 834 | 
         
            +
                    outputs = [tag_diffuser, tag_lora, tag_scheduler],)
         
     | 
| 835 | 
         
            +
             
     | 
| 836 | 
         
            +
                tag_lora.change(
         
     | 
| 837 | 
         
            +
                    wrapper_obj.load_all,
         
     | 
| 838 | 
         
            +
                    inputs = [tag_diffuser, tag_lora, tag_scheduler],
         
     | 
| 839 | 
         
            +
                    outputs = [tag_diffuser, tag_lora, tag_scheduler],)
         
     | 
| 840 | 
         
            +
             
     | 
| 841 | 
         
            +
                tag_scheduler.change(
         
     | 
| 842 | 
         
            +
                    wrapper_obj.load_scheduler,
         
     | 
| 843 | 
         
            +
                    inputs = [tag_scheduler],
         
     | 
| 844 | 
         
            +
                    outputs = [tag_scheduler],)
         
     | 
| 845 | 
         
            +
             
     | 
| 846 | 
         
            +
                button_run.click(
         
     | 
| 847 | 
         
            +
                    wrapper_obj.run_iminvs,
         
     | 
| 848 | 
         
            +
                    inputs=[image_input, prompt,  
         
     | 
| 849 | 
         
            +
                            cfg_scale, step, 
         
     | 
| 850 | 
         
            +
                            force_resize, inp_width, inp_height,
         
     | 
| 851 | 
         
            +
                            inversion, inner_step, force_reinvert, 
         
     | 
| 852 | 
         
            +
                            tag_diffuser, tag_lora, tag_scheduler,],
         
     | 
| 853 | 
         
            +
                    outputs=[recon_output])
         
     | 
| 854 | 
         
            +
                
         
     | 
| 855 | 
         
            +
                gr.Examples(
         
     | 
| 856 | 
         
            +
                    label='Examples', 
         
     | 
| 857 | 
         
            +
                    examples=get_iminvs_example(), 
         
     | 
| 858 | 
         
            +
                    fn=wrapper_obj.run_iminvs,
         
     | 
| 859 | 
         
            +
                    inputs=[image_input, prompt,],
         
     | 
| 860 | 
         
            +
                    outputs=[recon_output],
         
     | 
| 861 | 
         
            +
                    cache_examples=cache_examples,)
         
     | 
| 862 | 
         
            +
             
     | 
| 863 | 
         
            +
             
     | 
| 864 | 
         
            +
            def interface_imedit(wrapper_obj):
         
     | 
| 865 | 
         
            +
                with gr.Row():
         
     | 
| 866 | 
         
            +
                    image_input = gr.Image(label="Image input", type='pil', elem_id='customized_imbox')
         
     | 
| 867 | 
         
            +
                    edited_output = gr.Gallery(label="Edited output", elem_id='customized_imbox')
         
     | 
| 868 | 
         
            +
                with gr.Row():
         
     | 
| 869 | 
         
            +
                    with gr.Column():
         
     | 
| 870 | 
         
            +
                        prompt_0 = gr.Textbox(label='Source Text', lines=1, placeholder="Source prompt...", )
         
     | 
| 871 | 
         
            +
                        prompt_1 = gr.Textbox(label='Target Text', lines=1, placeholder="Target prompt...", )
         
     | 
| 872 | 
         
            +
                        with gr.Row():
         
     | 
| 873 | 
         
            +
                            button_run = gr.Button("Run")
         
     | 
| 874 | 
         
            +
                        
         
     | 
| 875 | 
         
            +
                    with gr.Column():
         
     | 
| 876 | 
         
            +
                        with gr.Row():
         
     | 
| 877 | 
         
            +
                            inversion = auto_dropdown('Inversion', choices.inversion, default.inversion)
         
     | 
| 878 | 
         
            +
                            inner_step = gr.Slider(label="Inner Step (NTI)", value=default.nullinv_inner_step, minimum=1, maximum=10, step=1)
         
     | 
| 879 | 
         
            +
                            force_reinvert = gr.Checkbox(label="Force ReInvert (NTI)", value=False)
         
     | 
| 880 | 
         
            +
                            threshold = gr.Slider(label="Threshold", minimum=0, maximum=1, value=default.threshold, step=0.1)
         
     | 
| 881 | 
         
            +
                        with gr.Accordion('Advanced Settings', open=False):
         
     | 
| 882 | 
         
            +
                            with gr.Row():
         
     | 
| 883 | 
         
            +
                                tag_diffuser = auto_dropdown('Diffuser', choices.diffuser, default.diffuser)
         
     | 
| 884 | 
         
            +
                                tag_lora = auto_dropdown('Use LoRA', choices.lora, default.lora)
         
     | 
| 885 | 
         
            +
                                tag_scheduler = auto_dropdown('Scheduler', choices.scheduler, default.scheduler)
         
     | 
| 886 | 
         
            +
                            with gr.Row():
         
     | 
| 887 | 
         
            +
                                cfg_scale = gr.Number(label="Scale", minimum=1, maximum=10, value=default.cfg_scale, step=0.5)
         
     | 
| 888 | 
         
            +
                                step = gr.Number(default.step, label="Step", precision=0)
         
     | 
| 889 | 
         
            +
                            with gr.Row():
         
     | 
| 890 | 
         
            +
                                force_resize = gr.Checkbox(label="Force Resize", value=True)
         
     | 
| 891 | 
         
            +
                                inp_width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
         
     | 
| 892 | 
         
            +
                                inp_height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64)
         
     | 
| 893 | 
         
            +
                        
         
     | 
| 894 | 
         
            +
             
     | 
| 895 | 
         
            +
                tag_diffuser.change(
         
     | 
| 896 | 
         
            +
                    wrapper_obj.load_all,
         
     | 
| 897 | 
         
            +
                    inputs = [tag_diffuser, tag_lora, tag_scheduler],
         
     | 
| 898 | 
         
            +
                    outputs = [tag_diffuser, tag_lora, tag_scheduler],)
         
     | 
| 899 | 
         
            +
             
     | 
| 900 | 
         
            +
                tag_lora.change(
         
     | 
| 901 | 
         
            +
                    wrapper_obj.load_all,
         
     | 
| 902 | 
         
            +
                    inputs = [tag_diffuser, tag_lora, tag_scheduler],
         
     | 
| 903 | 
         
            +
                    outputs = [tag_diffuser, tag_lora, tag_scheduler],)
         
     | 
| 904 | 
         
            +
             
     | 
| 905 | 
         
            +
                tag_scheduler.change(
         
     | 
| 906 | 
         
            +
                    wrapper_obj.load_scheduler,
         
     | 
| 907 | 
         
            +
                    inputs = [tag_scheduler],
         
     | 
| 908 | 
         
            +
                    outputs = [tag_scheduler],)
         
     | 
| 909 | 
         
            +
             
     | 
| 910 | 
         
            +
                button_run.click(
         
     | 
| 911 | 
         
            +
                    wrapper_obj.run_imedit,
         
     | 
| 912 | 
         
            +
                    inputs=[image_input, prompt_0, prompt_1, 
         
     | 
| 913 | 
         
            +
                            threshold, cfg_scale, step, 
         
     | 
| 914 | 
         
            +
                            force_resize, inp_width, inp_height,
         
     | 
| 915 | 
         
            +
                            inversion, inner_step, force_reinvert, 
         
     | 
| 916 | 
         
            +
                            tag_diffuser, tag_lora, tag_scheduler,],
         
     | 
| 917 | 
         
            +
                    outputs=[edited_output])
         
     | 
| 918 | 
         
            +
                
         
     | 
| 919 | 
         
            +
                gr.Examples(
         
     | 
| 920 | 
         
            +
                    label='Examples', 
         
     | 
| 921 | 
         
            +
                    examples=get_imedit_example(), 
         
     | 
| 922 | 
         
            +
                    fn=wrapper_obj.run_imedit,
         
     | 
| 923 | 
         
            +
                    inputs=[image_input, prompt_0, prompt_1, threshold,],
         
     | 
| 924 | 
         
            +
                    outputs=[edited_output],
         
     | 
| 925 | 
         
            +
                    cache_examples=cache_examples,)
         
     | 
| 926 | 
         
            +
                    
         
     | 
| 927 | 
         
            +
             
     | 
| 928 | 
         
            +
            #############
         
     | 
| 929 | 
         
            +
            # Interface #
         
     | 
| 930 | 
         
            +
            #############
         
     | 
| 931 | 
         
            +
             
     | 
| 932 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 933 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 934 | 
         
            +
                parser.add_argument('-p', '--port', type=int, default=None)
         
     | 
| 935 | 
         
            +
                args = parser.parse_args()
         
     | 
| 936 | 
         
            +
                from app_utils import css_empty, css_version_4_11_0
         
     | 
| 937 | 
         
            +
                # css = css_empty
         
     | 
| 938 | 
         
            +
                css = css_version_4_11_0
         
     | 
| 939 | 
         
            +
             
     | 
| 940 | 
         
            +
                wrapper_obj = wrapper(
         
     | 
| 941 | 
         
            +
                    fp16=False, 
         
     | 
| 942 | 
         
            +
                    tag_diffuser=default.diffuser,
         
     | 
| 943 | 
         
            +
                    tag_lora=default.lora,
         
     | 
| 944 | 
         
            +
                    tag_scheduler=default.scheduler)
         
     | 
| 945 | 
         
            +
             
     | 
| 946 | 
         
            +
                if True:
         
     | 
| 947 | 
         
            +
                    with gr.Blocks(css=css) as demo:
         
     | 
| 948 | 
         
            +
                        gr.HTML(
         
     | 
| 949 | 
         
            +
                            """
         
     | 
| 950 | 
         
            +
                            <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
         
     | 
| 951 | 
         
            +
                            <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
         
     | 
| 952 | 
         
            +
                                {}
         
     | 
| 953 | 
         
            +
                            </h1>
         
     | 
| 954 | 
         
            +
                            </div>
         
     | 
| 955 | 
         
            +
                            """.format(version))
         
     | 
| 956 | 
         
            +
             
     | 
| 957 | 
         
            +
                        with gr.Tab('Image Interpolation'):
         
     | 
| 958 | 
         
            +
                            interface_imintp(wrapper_obj)
         
     | 
| 959 | 
         
            +
                        with gr.Tab('Image Inversion'):
         
     | 
| 960 | 
         
            +
                            interface_iminvs(wrapper_obj)
         
     | 
| 961 | 
         
            +
                        with gr.Tab('Image Editing'):
         
     | 
| 962 | 
         
            +
                            interface_imedit(wrapper_obj)
         
     | 
| 963 | 
         
            +
             
     | 
| 964 | 
         
            +
                    demo.launch()
         
     | 
    	
        app_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,102 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import os.path as osp
         
     | 
| 3 | 
         
            +
            import cv2
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            import numpy.random as npr
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 8 | 
         
            +
            import torchvision.transforms as tvtrans
         
     | 
| 9 | 
         
            +
            import PIL.Image
         
     | 
| 10 | 
         
            +
            from tqdm import tqdm
         
     | 
| 11 | 
         
            +
            from PIL import Image
         
     | 
| 12 | 
         
            +
            import copy
         
     | 
| 13 | 
         
            +
            import json
         
     | 
| 14 | 
         
            +
            from collections import OrderedDict
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            #######
         
     | 
| 17 | 
         
            +
            # css #
         
     | 
| 18 | 
         
            +
            #######
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            css_empty = ""
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            css_version_4_11_0 = """
         
     | 
| 23 | 
         
            +
                #customized_imbox {
         
     | 
| 24 | 
         
            +
                    min-height: 450px;
         
     | 
| 25 | 
         
            +
                    max-height: 450px;
         
     | 
| 26 | 
         
            +
                }
         
     | 
| 27 | 
         
            +
                #customized_imbox>div[data-testid="image"] {
         
     | 
| 28 | 
         
            +
                    min-height: 450px;
         
     | 
| 29 | 
         
            +
                }
         
     | 
| 30 | 
         
            +
                #customized_imbox>div[data-testid="image"]>span[data-testid="source-select"] {
         
     | 
| 31 | 
         
            +
                    max-height: 0px;
         
     | 
| 32 | 
         
            +
                }
         
     | 
| 33 | 
         
            +
                #customized_imbox>div[data-testid="image"]>span[data-testid="source-select"]>button {
         
     | 
| 34 | 
         
            +
                    max-height: 0px;
         
     | 
| 35 | 
         
            +
                }
         
     | 
| 36 | 
         
            +
                #customized_imbox>div[data-testid="image"]>div.upload-container>div.image-frame>img {
         
     | 
| 37 | 
         
            +
                    position: absolute;
         
     | 
| 38 | 
         
            +
                    top: 50%;
         
     | 
| 39 | 
         
            +
                    left: 50%;
         
     | 
| 40 | 
         
            +
                    transform: translateX(-50%) translateY(-50%);
         
     | 
| 41 | 
         
            +
                    width: unset;
         
     | 
| 42 | 
         
            +
                    height: unset;
         
     | 
| 43 | 
         
            +
                    max-height: 450px;
         
     | 
| 44 | 
         
            +
                }        
         
     | 
| 45 | 
         
            +
                #customized_imbox>div.unpadded_box {
         
     | 
| 46 | 
         
            +
                    min-height: 450px;
         
     | 
| 47 | 
         
            +
                }
         
     | 
| 48 | 
         
            +
                #myinst {
         
     | 
| 49 | 
         
            +
                    font-size: 0.8rem; 
         
     | 
| 50 | 
         
            +
                    margin: 0rem;
         
     | 
| 51 | 
         
            +
                    color: #6B7280;
         
     | 
| 52 | 
         
            +
                }
         
     | 
| 53 | 
         
            +
                #maskinst {
         
     | 
| 54 | 
         
            +
                    text-align: justify;
         
     | 
| 55 | 
         
            +
                    min-width: 1200px;
         
     | 
| 56 | 
         
            +
                }
         
     | 
| 57 | 
         
            +
                #maskinst>img {
         
     | 
| 58 | 
         
            +
                    min-width:399px;
         
     | 
| 59 | 
         
            +
                    max-width:450px;
         
     | 
| 60 | 
         
            +
                    vertical-align: top;
         
     | 
| 61 | 
         
            +
                    display: inline-block;
         
     | 
| 62 | 
         
            +
                }
         
     | 
| 63 | 
         
            +
                #maskinst:after {
         
     | 
| 64 | 
         
            +
                    content: "";
         
     | 
| 65 | 
         
            +
                    width: 100%;
         
     | 
| 66 | 
         
            +
                    display: inline-block;
         
     | 
| 67 | 
         
            +
                }
         
     | 
| 68 | 
         
            +
            """
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            ##########
         
     | 
| 71 | 
         
            +
            # helper #
         
     | 
| 72 | 
         
            +
            ##########
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            def highlight_print(info):
         
     | 
| 75 | 
         
            +
                print('')
         
     | 
| 76 | 
         
            +
                print(''.join(['#']*(len(info)+4)))
         
     | 
| 77 | 
         
            +
                print('# '+info+' #')
         
     | 
| 78 | 
         
            +
                print(''.join(['#']*(len(info)+4)))
         
     | 
| 79 | 
         
            +
                print('')
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            def auto_dropdown(name, choices_od, value):
         
     | 
| 82 | 
         
            +
                import gradio as gr
         
     | 
| 83 | 
         
            +
                option_list = [pi for pi in choices_od.keys()]
         
     | 
| 84 | 
         
            +
                return gr.Dropdown(label=name, choices=option_list, value=value)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            def load_sd_from_file(target):
         
     | 
| 87 | 
         
            +
                if osp.splitext(target)[-1] == '.ckpt':
         
     | 
| 88 | 
         
            +
                    sd = torch.load(target, map_location='cpu')['state_dict']
         
     | 
| 89 | 
         
            +
                elif osp.splitext(target)[-1] == '.pth':
         
     | 
| 90 | 
         
            +
                    sd = torch.load(target, map_location='cpu')
         
     | 
| 91 | 
         
            +
                elif osp.splitext(target)[-1] == '.safetensors':
         
     | 
| 92 | 
         
            +
                    from safetensors.torch import load_file as stload
         
     | 
| 93 | 
         
            +
                    sd = OrderedDict(stload(target, device='cpu'))
         
     | 
| 94 | 
         
            +
                else:
         
     | 
| 95 | 
         
            +
                    assert False, "File type must be .ckpt or .pth or .safetensors"
         
     | 
| 96 | 
         
            +
                return sd
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
            def torch_to_numpy(x):
         
     | 
| 99 | 
         
            +
                return x.detach().to('cpu').numpy()
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 102 | 
         
            +
                pass
         
     | 
    	
        assets/.DS_Store
    ADDED
    
    | 
         Binary file (6.15 kB). View file 
     | 
| 
         | 
    	
        assets/images/.DS_Store
    ADDED
    
    | 
         Binary file (6.15 kB). View file 
     | 
| 
         | 
    	
        assets/images/editing/banana.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/editing/cake.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/editing/rabbit.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/interpolation/church1.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/interpolation/church2.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/interpolation/dog1.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/interpolation/dog2.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/interpolation/horse1.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/interpolation/horse2.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/interpolation/land1.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/interpolation/land2.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/interpolation/rabbit1.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/interpolation/rabbit2.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/interpolation/woman1.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/interpolation/woman2.png
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/inversion/000000029596.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        assets/images/inversion/000000560011.jpg
    ADDED
    
    
											 
									 | 
									
								
    	
        nulltxtinv_wrapper.py
    ADDED
    
    | 
         @@ -0,0 +1,450 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import PIL.Image
         
     | 
| 4 | 
         
            +
            from tqdm import tqdm
         
     | 
| 5 | 
         
            +
            from typing import Optional, Union, List
         
     | 
| 6 | 
         
            +
            import warnings
         
     | 
| 7 | 
         
            +
            warnings.filterwarnings('ignore')
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from torch.optim.adam import Adam
         
     | 
| 10 | 
         
            +
            import torch.nn.functional as nnf
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from diffusers import DDIMScheduler
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            ##########
         
     | 
| 15 | 
         
            +
            # helper #
         
     | 
| 16 | 
         
            +
            ##########
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def diffusion_step(model, latents, context, t, guidance_scale, low_resource=False):
         
     | 
| 19 | 
         
            +
                if low_resource:
         
     | 
| 20 | 
         
            +
                    noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
         
     | 
| 21 | 
         
            +
                    noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
         
     | 
| 22 | 
         
            +
                else:
         
     | 
| 23 | 
         
            +
                    latents_input = torch.cat([latents] * 2)
         
     | 
| 24 | 
         
            +
                    noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
         
     | 
| 25 | 
         
            +
                    noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
         
     | 
| 26 | 
         
            +
                noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
         
     | 
| 27 | 
         
            +
                latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
         
     | 
| 28 | 
         
            +
                return latents
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def image2latent(vae, image):
         
     | 
| 31 | 
         
            +
                with torch.no_grad():
         
     | 
| 32 | 
         
            +
                    if isinstance(image, PIL.Image.Image):
         
     | 
| 33 | 
         
            +
                        image = np.array(image)
         
     | 
| 34 | 
         
            +
                    if isinstance(image, np.ndarray):
         
     | 
| 35 | 
         
            +
                        dtype = next(vae.parameters()).dtype
         
     | 
| 36 | 
         
            +
                        device = next(vae.parameters()).device
         
     | 
| 37 | 
         
            +
                        image = torch.from_numpy(image).float() / 127.5 - 1
         
     | 
| 38 | 
         
            +
                        image = image.permute(2, 0, 1).unsqueeze(0).to(device=device, dtype=dtype)
         
     | 
| 39 | 
         
            +
                    latents = vae.encode(image)['latent_dist'].mean
         
     | 
| 40 | 
         
            +
                    latents = latents * 0.18215
         
     | 
| 41 | 
         
            +
                return latents
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            def latent2image(vae, latents, return_type='np'):
         
     | 
| 44 | 
         
            +
                assert isinstance(latents, torch.Tensor)
         
     | 
| 45 | 
         
            +
                latents = 1 / 0.18215 * latents.detach()
         
     | 
| 46 | 
         
            +
                image = vae.decode(latents)['sample']
         
     | 
| 47 | 
         
            +
                if return_type in ['np', 'pil']:
         
     | 
| 48 | 
         
            +
                    image = (image / 2 + 0.5).clamp(0, 1)
         
     | 
| 49 | 
         
            +
                    image = image.cpu().permute(0, 2, 3, 1).numpy()
         
     | 
| 50 | 
         
            +
                    image = (image * 255).astype(np.uint8)
         
     | 
| 51 | 
         
            +
                    if return_type == 'pil':
         
     | 
| 52 | 
         
            +
                        pilim = [PIL.Image.fromarray(imi) for imi in image]
         
     | 
| 53 | 
         
            +
                        pilim = pilim[0] if len(pilim)==1 else pilim
         
     | 
| 54 | 
         
            +
                        return pilim
         
     | 
| 55 | 
         
            +
                    else:
         
     | 
| 56 | 
         
            +
                        return image
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def init_latent(latent, model, height, width, generator, batch_size):
         
     | 
| 59 | 
         
            +
                if latent is None:
         
     | 
| 60 | 
         
            +
                    latent = torch.randn(
         
     | 
| 61 | 
         
            +
                        (1, model.unet.in_channels, height // 8, width // 8),
         
     | 
| 62 | 
         
            +
                        generator=generator,
         
     | 
| 63 | 
         
            +
                    )
         
     | 
| 64 | 
         
            +
                latents = latent.expand(batch_size,  model.unet.in_channels, height // 8, width // 8).to(model.device)
         
     | 
| 65 | 
         
            +
                return latent, latents
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            def txt_to_emb(model, prompt):
         
     | 
| 68 | 
         
            +
                text_input = model.tokenizer(
         
     | 
| 69 | 
         
            +
                    prompt,
         
     | 
| 70 | 
         
            +
                    padding="max_length",
         
     | 
| 71 | 
         
            +
                    max_length=model.tokenizer.model_max_length,
         
     | 
| 72 | 
         
            +
                    truncation=True,
         
     | 
| 73 | 
         
            +
                    return_tensors="pt",)
         
     | 
| 74 | 
         
            +
                text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
         
     | 
| 75 | 
         
            +
                return text_embeddings
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            @torch.no_grad()
         
     | 
| 78 | 
         
            +
            def text2image_ldm(
         
     | 
| 79 | 
         
            +
                    model,
         
     | 
| 80 | 
         
            +
                    prompt:  List[str],
         
     | 
| 81 | 
         
            +
                    num_inference_steps: int = 50,
         
     | 
| 82 | 
         
            +
                    guidance_scale: Optional[float] = 7.5,
         
     | 
| 83 | 
         
            +
                    generator: Optional[torch.Generator] = None,
         
     | 
| 84 | 
         
            +
                    latent: Optional[torch.FloatTensor] = None,
         
     | 
| 85 | 
         
            +
                    uncond_embeddings=None,
         
     | 
| 86 | 
         
            +
                    start_time=50,
         
     | 
| 87 | 
         
            +
                    return_type='pil', ):
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                batch_size = len(prompt)
         
     | 
| 90 | 
         
            +
                height = width = 512
         
     | 
| 91 | 
         
            +
                if latent is not None:
         
     | 
| 92 | 
         
            +
                    height = latent.shape[-2] * 8
         
     | 
| 93 | 
         
            +
                    width = latent.shape[-1] * 8
         
     | 
| 94 | 
         
            +
                
         
     | 
| 95 | 
         
            +
                text_input = model.tokenizer(
         
     | 
| 96 | 
         
            +
                    prompt,
         
     | 
| 97 | 
         
            +
                    padding="max_length",
         
     | 
| 98 | 
         
            +
                    max_length=model.tokenizer.model_max_length,
         
     | 
| 99 | 
         
            +
                    truncation=True,
         
     | 
| 100 | 
         
            +
                    return_tensors="pt",)
         
     | 
| 101 | 
         
            +
                text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
         
     | 
| 102 | 
         
            +
                max_length = text_input.input_ids.shape[-1]
         
     | 
| 103 | 
         
            +
                if uncond_embeddings is None:
         
     | 
| 104 | 
         
            +
                    uncond_input = model.tokenizer(
         
     | 
| 105 | 
         
            +
                        [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt",)
         
     | 
| 106 | 
         
            +
                    uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
         
     | 
| 107 | 
         
            +
                else:
         
     | 
| 108 | 
         
            +
                    uncond_embeddings_ = None
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                latent, latents = init_latent(latent, model, height, width, generator, batch_size)
         
     | 
| 111 | 
         
            +
                model.scheduler.set_timesteps(num_inference_steps)
         
     | 
| 112 | 
         
            +
                for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])):
         
     | 
| 113 | 
         
            +
                    if uncond_embeddings_ is None:
         
     | 
| 114 | 
         
            +
                        context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings])
         
     | 
| 115 | 
         
            +
                    else:
         
     | 
| 116 | 
         
            +
                        context = torch.cat([uncond_embeddings_, text_embeddings])
         
     | 
| 117 | 
         
            +
                    latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource=False)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                if return_type in ['pil', 'np']:
         
     | 
| 120 | 
         
            +
                    image = latent2image(model.vae, latents, return_type=return_type)
         
     | 
| 121 | 
         
            +
                else:
         
     | 
| 122 | 
         
            +
                    image = latents
         
     | 
| 123 | 
         
            +
                return image, latent
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
            @torch.no_grad()
         
     | 
| 126 | 
         
            +
            def text2image_ldm_imedit(
         
     | 
| 127 | 
         
            +
                model,
         
     | 
| 128 | 
         
            +
                thresh,
         
     | 
| 129 | 
         
            +
                prompt:  List[str],
         
     | 
| 130 | 
         
            +
                target_prompt:  List[str],
         
     | 
| 131 | 
         
            +
                num_inference_steps: int = 50,
         
     | 
| 132 | 
         
            +
                guidance_scale: Optional[float] = 7.5,
         
     | 
| 133 | 
         
            +
                generator: Optional[torch.Generator] = None,
         
     | 
| 134 | 
         
            +
                latent: Optional[torch.FloatTensor] = None,
         
     | 
| 135 | 
         
            +
                uncond_embeddings=None,
         
     | 
| 136 | 
         
            +
                start_time=50,
         
     | 
| 137 | 
         
            +
                return_type='pil'
         
     | 
| 138 | 
         
            +
            ):
         
     | 
| 139 | 
         
            +
                batch_size = len(prompt)
         
     | 
| 140 | 
         
            +
                height = width = 512
         
     | 
| 141 | 
         
            +
                
         
     | 
| 142 | 
         
            +
                text_input = model.tokenizer(
         
     | 
| 143 | 
         
            +
                    prompt,
         
     | 
| 144 | 
         
            +
                    padding="max_length",
         
     | 
| 145 | 
         
            +
                    max_length=model.tokenizer.model_max_length,
         
     | 
| 146 | 
         
            +
                    truncation=True,
         
     | 
| 147 | 
         
            +
                    return_tensors="pt",
         
     | 
| 148 | 
         
            +
                )
         
     | 
| 149 | 
         
            +
                target_text_input = model.tokenizer(
         
     | 
| 150 | 
         
            +
                    target_prompt,
         
     | 
| 151 | 
         
            +
                    padding="max_length",
         
     | 
| 152 | 
         
            +
                    max_length=model.tokenizer.model_max_length,
         
     | 
| 153 | 
         
            +
                    truncation=True,
         
     | 
| 154 | 
         
            +
                    return_tensors="pt",
         
     | 
| 155 | 
         
            +
                )
         
     | 
| 156 | 
         
            +
                text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
         
     | 
| 157 | 
         
            +
                target_text_embeddings = model.text_encoder(target_text_input.input_ids.to(model.device))[0]
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                max_length = text_input.input_ids.shape[-1]
         
     | 
| 160 | 
         
            +
                if uncond_embeddings is None:
         
     | 
| 161 | 
         
            +
                    uncond_input = model.tokenizer(
         
     | 
| 162 | 
         
            +
                        [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
         
     | 
| 163 | 
         
            +
                    )
         
     | 
| 164 | 
         
            +
                    uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
         
     | 
| 165 | 
         
            +
                else:
         
     | 
| 166 | 
         
            +
                    uncond_embeddings_ = None
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                latent, latents = init_latent(latent, model, height, width, generator, batch_size)
         
     | 
| 169 | 
         
            +
                model.scheduler.set_timesteps(num_inference_steps)
         
     | 
| 170 | 
         
            +
                for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])):
         
     | 
| 171 | 
         
            +
                    if i < (1 - thresh) * num_inference_steps:
         
     | 
| 172 | 
         
            +
                        if uncond_embeddings_ is None:
         
     | 
| 173 | 
         
            +
                            context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings])
         
     | 
| 174 | 
         
            +
                        else:
         
     | 
| 175 | 
         
            +
                            context = torch.cat([uncond_embeddings_, text_embeddings])
         
     | 
| 176 | 
         
            +
                        latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource=False)
         
     | 
| 177 | 
         
            +
                    else:
         
     | 
| 178 | 
         
            +
                        if uncond_embeddings_ is None:
         
     | 
| 179 | 
         
            +
                            context = torch.cat([uncond_embeddings[i].expand(*target_text_embeddings.shape), target_text_embeddings])
         
     | 
| 180 | 
         
            +
                        else:
         
     | 
| 181 | 
         
            +
                            context = torch.cat([uncond_embeddings_, target_text_embeddings])
         
     | 
| 182 | 
         
            +
                        latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource=False)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                if return_type in ['pil', 'np']:
         
     | 
| 185 | 
         
            +
                    image = latent2image(model.vae, latents, return_type=return_type)
         
     | 
| 186 | 
         
            +
                else:
         
     | 
| 187 | 
         
            +
                    image = latents
         
     | 
| 188 | 
         
            +
                return image, latent
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
            ###########
         
     | 
| 192 | 
         
            +
            # wrapper #
         
     | 
| 193 | 
         
            +
            ###########
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
            class NullInversion(object):
         
     | 
| 196 | 
         
            +
                def __init__(self, model, num_ddim_steps, guidance_scale, device='cuda'):
         
     | 
| 197 | 
         
            +
                    self.model = model
         
     | 
| 198 | 
         
            +
                    self.device = device
         
     | 
| 199 | 
         
            +
                    self.num_ddim_steps=num_ddim_steps
         
     | 
| 200 | 
         
            +
                    self.guidance_scale = guidance_scale
         
     | 
| 201 | 
         
            +
                    self.tokenizer = self.model.tokenizer
         
     | 
| 202 | 
         
            +
                    self.prompt = None
         
     | 
| 203 | 
         
            +
                    self.context = None
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
         
     | 
| 206 | 
         
            +
                    prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
         
     | 
| 207 | 
         
            +
                    alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
         
     | 
| 208 | 
         
            +
                    alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
         
     | 
| 209 | 
         
            +
                    beta_prod_t = 1 - alpha_prod_t
         
     | 
| 210 | 
         
            +
                    pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
         
     | 
| 211 | 
         
            +
                    pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
         
     | 
| 212 | 
         
            +
                    prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
         
     | 
| 213 | 
         
            +
                    return prev_sample
         
     | 
| 214 | 
         
            +
                
         
     | 
| 215 | 
         
            +
                def next_step(self, noise_pred, timestep, sample):
         
     | 
| 216 | 
         
            +
                    timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
         
     | 
| 217 | 
         
            +
                    alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
         
     | 
| 218 | 
         
            +
                    alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
         
     | 
| 219 | 
         
            +
                    beta_prod_t = 1 - alpha_prod_t
         
     | 
| 220 | 
         
            +
                    next_original_sample = (sample - beta_prod_t ** 0.5 * noise_pred) / alpha_prod_t ** 0.5
         
     | 
| 221 | 
         
            +
                    next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * noise_pred
         
     | 
| 222 | 
         
            +
                    next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
         
     | 
| 223 | 
         
            +
                    return next_sample
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                def get_noise_pred_single(self, latents, t, context):
         
     | 
| 226 | 
         
            +
                    noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
         
     | 
| 227 | 
         
            +
                    return noise_pred
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                def get_noise_pred(self, latents, t, is_forward=True, context=None):
         
     | 
| 230 | 
         
            +
                    latents_input = torch.cat([latents] * 2)
         
     | 
| 231 | 
         
            +
                    if context is None:
         
     | 
| 232 | 
         
            +
                        context = self.context
         
     | 
| 233 | 
         
            +
                    guidance_scale = 1 if is_forward else self.guidance_scale
         
     | 
| 234 | 
         
            +
                    noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
         
     | 
| 235 | 
         
            +
                    noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
         
     | 
| 236 | 
         
            +
                    noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
         
     | 
| 237 | 
         
            +
                    if is_forward:
         
     | 
| 238 | 
         
            +
                        latents = self.next_step(noise_pred, t, latents)
         
     | 
| 239 | 
         
            +
                    else:
         
     | 
| 240 | 
         
            +
                        latents = self.prev_step(noise_pred, t, latents)
         
     | 
| 241 | 
         
            +
                    return latents
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                @torch.no_grad()
         
     | 
| 244 | 
         
            +
                def init_prompt(self, prompt: str):
         
     | 
| 245 | 
         
            +
                    uncond_input = self.model.tokenizer(
         
     | 
| 246 | 
         
            +
                        [""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
         
     | 
| 247 | 
         
            +
                        return_tensors="pt"
         
     | 
| 248 | 
         
            +
                    )
         
     | 
| 249 | 
         
            +
                    uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
         
     | 
| 250 | 
         
            +
                    text_input = self.model.tokenizer(
         
     | 
| 251 | 
         
            +
                        [prompt],
         
     | 
| 252 | 
         
            +
                        padding="max_length",
         
     | 
| 253 | 
         
            +
                        max_length=self.model.tokenizer.model_max_length,
         
     | 
| 254 | 
         
            +
                        truncation=True,
         
     | 
| 255 | 
         
            +
                        return_tensors="pt",
         
     | 
| 256 | 
         
            +
                    )
         
     | 
| 257 | 
         
            +
                    text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
         
     | 
| 258 | 
         
            +
                    self.context = torch.cat([uncond_embeddings, text_embeddings])
         
     | 
| 259 | 
         
            +
                    self.prompt = prompt
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                @torch.no_grad()
         
     | 
| 262 | 
         
            +
                def ddim_loop(self, latent, emb):
         
     | 
| 263 | 
         
            +
                    # uncond_embeddings, cond_embeddings = self.context.chunk(2)
         
     | 
| 264 | 
         
            +
                    all_latent = [latent]
         
     | 
| 265 | 
         
            +
                    latent = latent.clone().detach()
         
     | 
| 266 | 
         
            +
                    for i in range(self.num_ddim_steps):
         
     | 
| 267 | 
         
            +
                        t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
         
     | 
| 268 | 
         
            +
                        noise_pred = self.get_noise_pred_single(latent, t, emb)
         
     | 
| 269 | 
         
            +
                        latent = self.next_step(noise_pred, t, latent)
         
     | 
| 270 | 
         
            +
                        all_latent.append(latent)
         
     | 
| 271 | 
         
            +
                    return all_latent
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                @property
         
     | 
| 274 | 
         
            +
                def scheduler(self):
         
     | 
| 275 | 
         
            +
                    return self.model.scheduler
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                @torch.no_grad()
         
     | 
| 278 | 
         
            +
                def ddim_invert(self, image, prompt):
         
     | 
| 279 | 
         
            +
                    assert isinstance(image, PIL.Image.Image)
         
     | 
| 280 | 
         
            +
             
     | 
| 281 | 
         
            +
                    scheduler_save = self.model.scheduler
         
     | 
| 282 | 
         
            +
                    self.model.scheduler = DDIMScheduler.from_config(self.model.scheduler.config)
         
     | 
| 283 | 
         
            +
                    self.model.scheduler.set_timesteps(self.num_ddim_steps)
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    with torch.no_grad():
         
     | 
| 286 | 
         
            +
                        emb = txt_to_emb(self.model, prompt)
         
     | 
| 287 | 
         
            +
                        latent = image2latent(self.model.vae, image)
         
     | 
| 288 | 
         
            +
                    ddim_latents = self.ddim_loop(latent, emb)
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                    self.model.scheduler = scheduler_save
         
     | 
| 291 | 
         
            +
                    return ddim_latents[-1]
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                def null_optimization(self, latents, emb, nemb=None, num_inner_steps=10, epsilon=1e-5):
         
     | 
| 294 | 
         
            +
                    # force fp32
         
     | 
| 295 | 
         
            +
                    dtype = latents[0].dtype
         
     | 
| 296 | 
         
            +
                    uncond_embeddings = nemb.float() if nemb is not None else txt_to_emb(self.model, "").float()
         
     | 
| 297 | 
         
            +
                    cond_embeddings = emb.float()
         
     | 
| 298 | 
         
            +
                    latents = [li.float() for li in latents]
         
     | 
| 299 | 
         
            +
                    self.model.unet.to(torch.float32)
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    uncond_embeddings_list = []
         
     | 
| 302 | 
         
            +
                    latent_cur = latents[-1]
         
     | 
| 303 | 
         
            +
                    bar = tqdm(total=num_inner_steps * self.num_ddim_steps)
         
     | 
| 304 | 
         
            +
                    for i in range(self.num_ddim_steps):
         
     | 
| 305 | 
         
            +
                        uncond_embeddings = uncond_embeddings.clone().detach()
         
     | 
| 306 | 
         
            +
                        uncond_embeddings.requires_grad = True
         
     | 
| 307 | 
         
            +
                        optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
         
     | 
| 308 | 
         
            +
                        latent_prev = latents[len(latents) - i - 2]
         
     | 
| 309 | 
         
            +
                        t = self.model.scheduler.timesteps[i]
         
     | 
| 310 | 
         
            +
                        with torch.no_grad():
         
     | 
| 311 | 
         
            +
                            noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
         
     | 
| 312 | 
         
            +
                        for j in range(num_inner_steps):
         
     | 
| 313 | 
         
            +
                            noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
         
     | 
| 314 | 
         
            +
                            noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
         
     | 
| 315 | 
         
            +
                            latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
         
     | 
| 316 | 
         
            +
                            loss = nnf.mse_loss(latents_prev_rec, latent_prev)
         
     | 
| 317 | 
         
            +
                            optimizer.zero_grad()
         
     | 
| 318 | 
         
            +
                            loss.backward()
         
     | 
| 319 | 
         
            +
                            optimizer.step()
         
     | 
| 320 | 
         
            +
                            loss_item = loss.item()
         
     | 
| 321 | 
         
            +
                            bar.update()
         
     | 
| 322 | 
         
            +
                            if loss_item < epsilon + i * 2e-5:
         
     | 
| 323 | 
         
            +
                                break
         
     | 
| 324 | 
         
            +
                        for j in range(j + 1, num_inner_steps):
         
     | 
| 325 | 
         
            +
                            bar.update()
         
     | 
| 326 | 
         
            +
                        uncond_embeddings_list.append(uncond_embeddings[:1].detach())
         
     | 
| 327 | 
         
            +
                        with torch.no_grad():
         
     | 
| 328 | 
         
            +
                            context = torch.cat([uncond_embeddings, cond_embeddings])
         
     | 
| 329 | 
         
            +
                            latent_cur = self.get_noise_pred(latent_cur, t, False, context)
         
     | 
| 330 | 
         
            +
                    bar.close()
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                    uncond_embeddings_list = [ui.to(dtype) for ui in uncond_embeddings_list]
         
     | 
| 333 | 
         
            +
                    self.model.unet.to(dtype)
         
     | 
| 334 | 
         
            +
                    return uncond_embeddings_list
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                def null_invert(self, im, txt, ntxt=None, num_inner_steps=10, early_stop_epsilon=1e-5):
         
     | 
| 337 | 
         
            +
                    assert isinstance(im, PIL.Image.Image)
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                    scheduler_save = self.model.scheduler
         
     | 
| 340 | 
         
            +
                    self.model.scheduler = DDIMScheduler.from_config(self.model.scheduler.config)
         
     | 
| 341 | 
         
            +
                    self.model.scheduler.set_timesteps(self.num_ddim_steps)
         
     | 
| 342 | 
         
            +
             
     | 
| 343 | 
         
            +
                    with torch.no_grad():
         
     | 
| 344 | 
         
            +
                        nemb = txt_to_emb(self.model, ntxt) \
         
     | 
| 345 | 
         
            +
                            if ntxt is not None else txt_to_emb(self.model, "")
         
     | 
| 346 | 
         
            +
                        emb  = txt_to_emb(self.model, txt) 
         
     | 
| 347 | 
         
            +
                        latent = image2latent(self.model.vae, im)
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                    # ddim inversion
         
     | 
| 350 | 
         
            +
                    ddim_latents = self.ddim_loop(latent, emb)
         
     | 
| 351 | 
         
            +
                    # nulltext inversion
         
     | 
| 352 | 
         
            +
                    uncond_embeddings = self.null_optimization(
         
     | 
| 353 | 
         
            +
                        ddim_latents, emb, nemb, num_inner_steps, early_stop_epsilon)
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                    self.model.scheduler = scheduler_save
         
     | 
| 356 | 
         
            +
                    return ddim_latents[-1], uncond_embeddings
         
     | 
| 357 | 
         
            +
             
     | 
| 358 | 
         
            +
                def null_optimization_dual(
         
     | 
| 359 | 
         
            +
                        self, latents0, latents1, emb0, emb1, nemb=None, 
         
     | 
| 360 | 
         
            +
                        num_inner_steps=10, epsilon=1e-5):
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
                    # force fp32
         
     | 
| 363 | 
         
            +
                    dtype = latents0[0].dtype
         
     | 
| 364 | 
         
            +
                    uncond_embeddings = nemb.float() if nemb is not None else txt_to_emb(self.model, "").float()
         
     | 
| 365 | 
         
            +
                    cond_embeddings0, cond_embeddings1 = emb0.float(), emb1.float()
         
     | 
| 366 | 
         
            +
                    latents0 = [li.float() for li in latents0]
         
     | 
| 367 | 
         
            +
                    latents1 = [li.float() for li in latents1]
         
     | 
| 368 | 
         
            +
                    self.model.unet.to(torch.float32)
         
     | 
| 369 | 
         
            +
                    
         
     | 
| 370 | 
         
            +
                    uncond_embeddings_list = []
         
     | 
| 371 | 
         
            +
                    latent_cur0 = latents0[-1]
         
     | 
| 372 | 
         
            +
                    latent_cur1 = latents1[-1]
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                    bar = tqdm(total=num_inner_steps * self.num_ddim_steps)
         
     | 
| 375 | 
         
            +
                    for i in range(self.num_ddim_steps):
         
     | 
| 376 | 
         
            +
                        uncond_embeddings = uncond_embeddings.clone().detach()
         
     | 
| 377 | 
         
            +
                        uncond_embeddings.requires_grad = True
         
     | 
| 378 | 
         
            +
                        optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                        latent_prev0 = latents0[len(latents0) - i - 2]
         
     | 
| 381 | 
         
            +
                        latent_prev1 = latents1[len(latents1) - i - 2]
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                        t = self.model.scheduler.timesteps[i]
         
     | 
| 384 | 
         
            +
                        with torch.no_grad():
         
     | 
| 385 | 
         
            +
                            noise_pred_cond0 = self.get_noise_pred_single(latent_cur0, t, cond_embeddings0)
         
     | 
| 386 | 
         
            +
                            noise_pred_cond1 = self.get_noise_pred_single(latent_cur1, t, cond_embeddings1)
         
     | 
| 387 | 
         
            +
                        for j in range(num_inner_steps):
         
     | 
| 388 | 
         
            +
                            noise_pred_uncond0 = self.get_noise_pred_single(latent_cur0, t, uncond_embeddings)
         
     | 
| 389 | 
         
            +
                            noise_pred_uncond1 = self.get_noise_pred_single(latent_cur1, t, uncond_embeddings)
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
                            noise_pred0 = noise_pred_uncond0 + self.guidance_scale*(noise_pred_cond0-noise_pred_uncond0)
         
     | 
| 392 | 
         
            +
                            noise_pred1 = noise_pred_uncond1 + self.guidance_scale*(noise_pred_cond1-noise_pred_uncond1)
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                            latents_prev_rec0 = self.prev_step(noise_pred0, t, latent_cur0)
         
     | 
| 395 | 
         
            +
                            latents_prev_rec1 = self.prev_step(noise_pred1, t, latent_cur1)
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                            loss = nnf.mse_loss(latents_prev_rec0, latent_prev0) + \
         
     | 
| 398 | 
         
            +
                                   nnf.mse_loss(latents_prev_rec1, latent_prev1)
         
     | 
| 399 | 
         
            +
             
     | 
| 400 | 
         
            +
                            optimizer.zero_grad()
         
     | 
| 401 | 
         
            +
                            loss.backward()
         
     | 
| 402 | 
         
            +
                            optimizer.step()
         
     | 
| 403 | 
         
            +
                            loss_item = loss.item()
         
     | 
| 404 | 
         
            +
                            bar.update()
         
     | 
| 405 | 
         
            +
                            if loss_item < epsilon + i * 2e-5:
         
     | 
| 406 | 
         
            +
                                break
         
     | 
| 407 | 
         
            +
                        for j in range(j + 1, num_inner_steps):
         
     | 
| 408 | 
         
            +
                            bar.update()
         
     | 
| 409 | 
         
            +
                        uncond_embeddings_list.append(uncond_embeddings[:1].detach())
         
     | 
| 410 | 
         
            +
             
     | 
| 411 | 
         
            +
                        with torch.no_grad():
         
     | 
| 412 | 
         
            +
                            context0 = torch.cat([uncond_embeddings, cond_embeddings0])
         
     | 
| 413 | 
         
            +
                            context1 = torch.cat([uncond_embeddings, cond_embeddings1])
         
     | 
| 414 | 
         
            +
                            latent_cur0 = self.get_noise_pred(latent_cur0, t, False, context0)
         
     | 
| 415 | 
         
            +
                            latent_cur1 = self.get_noise_pred(latent_cur1, t, False, context1)
         
     | 
| 416 | 
         
            +
             
     | 
| 417 | 
         
            +
                    bar.close()
         
     | 
| 418 | 
         
            +
             
     | 
| 419 | 
         
            +
                    uncond_embeddings_list = [ui.to(dtype) for ui in uncond_embeddings_list]
         
     | 
| 420 | 
         
            +
                    self.model.unet.to(dtype)
         
     | 
| 421 | 
         
            +
                    return uncond_embeddings_list
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                def null_invert_dual(
         
     | 
| 424 | 
         
            +
                        self, im0, im1, txt0, txt1, ntxt=None, 
         
     | 
| 425 | 
         
            +
                        num_inner_steps=10, early_stop_epsilon=1e-5, ):
         
     | 
| 426 | 
         
            +
                    assert isinstance(im0, PIL.Image.Image)
         
     | 
| 427 | 
         
            +
                    assert isinstance(im1, PIL.Image.Image)
         
     | 
| 428 | 
         
            +
             
     | 
| 429 | 
         
            +
                    scheduler_save = self.model.scheduler
         
     | 
| 430 | 
         
            +
                    self.model.scheduler = DDIMScheduler.from_config(self.model.scheduler.config)
         
     | 
| 431 | 
         
            +
                    self.model.scheduler.set_timesteps(self.num_ddim_steps)
         
     | 
| 432 | 
         
            +
             
     | 
| 433 | 
         
            +
                    with torch.no_grad():
         
     | 
| 434 | 
         
            +
                        nemb = txt_to_emb(self.model, ntxt) \
         
     | 
| 435 | 
         
            +
                            if ntxt is not None else txt_to_emb(self.model, "")
         
     | 
| 436 | 
         
            +
                        latent0 = image2latent(self.model.vae, im0)
         
     | 
| 437 | 
         
            +
                        latent1 = image2latent(self.model.vae, im1)
         
     | 
| 438 | 
         
            +
                        emb0 = txt_to_emb(self.model, txt0)
         
     | 
| 439 | 
         
            +
                        emb1 = txt_to_emb(self.model, txt1)
         
     | 
| 440 | 
         
            +
             
     | 
| 441 | 
         
            +
                    # ddim inversion
         
     | 
| 442 | 
         
            +
                    ddim_latents_0 = self.ddim_loop(latent0, emb0)
         
     | 
| 443 | 
         
            +
                    ddim_latents_1 = self.ddim_loop(latent1, emb1)
         
     | 
| 444 | 
         
            +
             
     | 
| 445 | 
         
            +
                    # nulltext inversion
         
     | 
| 446 | 
         
            +
                    nembs = self.null_optimization_dual(
         
     | 
| 447 | 
         
            +
                        ddim_latents_0, ddim_latents_1, emb0, emb1, nemb, num_inner_steps, early_stop_epsilon)
         
     | 
| 448 | 
         
            +
             
     | 
| 449 | 
         
            +
                    self.model.scheduler = scheduler_save
         
     | 
| 450 | 
         
            +
                    return ddim_latents_0[-1], ddim_latents_1[-1], nembs
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,16 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            accelerate==0.20.3
         
     | 
| 2 | 
         
            +
            bitsandbytes==0.42.0
         
     | 
| 3 | 
         
            +
            datasets==2.14.4
         
     | 
| 4 | 
         
            +
            diffusers==0.20.1
         
     | 
| 5 | 
         
            +
            easydict==1.11
         
     | 
| 6 | 
         
            +
            gradio==4.19.2
         
     | 
| 7 | 
         
            +
            huggingface_hub==0.19.3
         
     | 
| 8 | 
         
            +
            moviepy==1.0.3
         
     | 
| 9 | 
         
            +
            opencv_python==4.7.0.72
         
     | 
| 10 | 
         
            +
            packaging==23.2
         
     | 
| 11 | 
         
            +
            pypatchify==0.1.4
         
     | 
| 12 | 
         
            +
            safetensors==0.3.1
         
     | 
| 13 | 
         
            +
            tqdm==4.65.0
         
     | 
| 14 | 
         
            +
            transformers==4.30.1
         
     | 
| 15 | 
         
            +
            wandb==0.16.3
         
     | 
| 16 | 
         
            +
            xformers==0.0.17
         
     |