Commit 
							
							·
						
						a795b9b
	
1
								Parent(s):
							
							6ded867
								
add log to check whether chunking is working
Browse files
    	
        cosmos_transfer1/diffusion/model/model_v2w.py
    CHANGED
    
    | 
         @@ -249,6 +249,7 @@ class DiffusionV2WModel(DiffusionT2WModel): 
     | 
|
| 249 | 
         
             
                    assert condition_latent is not None, "condition_latent should be provided"
         
     | 
| 250 | 
         | 
| 251 | 
         
             
                    # try to add chunking here !!!
         
     | 
| 
         | 
|
| 252 | 
         
             
                    x0_fn = self.get_x0_fn_from_batch_with_condition_latent(
         
     | 
| 253 | 
         
             
                        data_batch,
         
     | 
| 254 | 
         
             
                        guidance,
         
     | 
| 
         @@ -312,6 +313,8 @@ class DiffusionV2WModel(DiffusionT2WModel): 
     | 
|
| 312 | 
         
             
                        Function that takes noisy input and noise level and returns denoised prediction
         
     | 
| 313 | 
         
             
                    """
         
     | 
| 314 | 
         
             
                    if chunking is None:
         
     | 
| 
         | 
|
| 
         | 
|
| 315 | 
         
             
                        if is_negative_prompt:
         
     | 
| 316 | 
         
             
                            condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
         
     | 
| 317 | 
         
             
                        else:
         
     | 
| 
         @@ -347,6 +350,8 @@ class DiffusionV2WModel(DiffusionT2WModel): 
     | 
|
| 347 | 
         | 
| 348 | 
         
             
                        return x0_fn
         
     | 
| 349 | 
         
             
                    else:
         
     | 
| 
         | 
|
| 
         | 
|
| 350 | 
         
             
                        def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
         
     | 
| 351 | 
         
             
                            if is_negative_prompt:
         
     | 
| 352 | 
         
             
                                condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
         
     | 
| 
         | 
|
| 249 | 
         
             
                    assert condition_latent is not None, "condition_latent should be provided"
         
     | 
| 250 | 
         | 
| 251 | 
         
             
                    # try to add chunking here !!!
         
     | 
| 252 | 
         
            +
                    log.info("x0_fn")
         
     | 
| 253 | 
         
             
                    x0_fn = self.get_x0_fn_from_batch_with_condition_latent(
         
     | 
| 254 | 
         
             
                        data_batch,
         
     | 
| 255 | 
         
             
                        guidance,
         
     | 
| 
         | 
|
| 313 | 
         
             
                        Function that takes noisy input and noise level and returns denoised prediction
         
     | 
| 314 | 
         
             
                    """
         
     | 
| 315 | 
         
             
                    if chunking is None:
         
     | 
| 316 | 
         
            +
                        log.info("no chunking")
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
             
                        if is_negative_prompt:
         
     | 
| 319 | 
         
             
                            condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
         
     | 
| 320 | 
         
             
                        else:
         
     | 
| 
         | 
|
| 350 | 
         | 
| 351 | 
         
             
                        return x0_fn
         
     | 
| 352 | 
         
             
                    else:
         
     | 
| 353 | 
         
            +
                        log.info("chunking !!!")
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
             
                        def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
         
     | 
| 356 | 
         
             
                            if is_negative_prompt:
         
     | 
| 357 | 
         
             
                                condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
         
     |