Support this model device-independent. (#52)
Browse files- Support this model device-independent. (61ddce6553b8beffc5a859ba31726aab1f9b6979)
Co-authored-by: yejinglai <Eraa@users.noreply.huggingface.co>
    	
        speech_conformer_encoder.py
    CHANGED
    
    | @@ -2477,9 +2477,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module): | |
| 2477 | 
             
                        seq_len, batch_size, self.chunk_size, self.left_chunk
         | 
| 2478 | 
             
                    )
         | 
| 2479 |  | 
| 2480 | 
            -
                    if xs_pad. | 
| 2481 | 
            -
                        enc_streaming_mask = enc_streaming_mask. | 
| 2482 | 
            -
                        xs_pad = xs_pad.cuda()
         | 
| 2483 |  | 
| 2484 | 
             
                    input_tensor = xs_pad
         | 
| 2485 | 
             
                    input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
         | 
| @@ -2496,8 +2495,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module): | |
| 2496 | 
             
                        enc_streaming_mask_nc = self._streaming_mask(
         | 
| 2497 | 
             
                            seq_len, batch_size, chunk_size_nc, left_chunk_nc
         | 
| 2498 | 
             
                        )
         | 
| 2499 | 
            -
                        if xs_pad. | 
| 2500 | 
            -
                            enc_streaming_mask_nc = enc_streaming_mask_nc. | 
| 2501 | 
             
                        if masks is not None:
         | 
| 2502 | 
             
                            hs_mask_nc = masks & enc_streaming_mask_nc
         | 
| 2503 | 
             
                        else:
         | 
|  | |
| 2477 | 
             
                        seq_len, batch_size, self.chunk_size, self.left_chunk
         | 
| 2478 | 
             
                    )
         | 
| 2479 |  | 
| 2480 | 
            +
                    if xs_pad.device != "cpu":
         | 
| 2481 | 
            +
                        enc_streaming_mask = enc_streaming_mask.to(xs_pad.device)
         | 
|  | |
| 2482 |  | 
| 2483 | 
             
                    input_tensor = xs_pad
         | 
| 2484 | 
             
                    input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
         | 
|  | |
| 2495 | 
             
                        enc_streaming_mask_nc = self._streaming_mask(
         | 
| 2496 | 
             
                            seq_len, batch_size, chunk_size_nc, left_chunk_nc
         | 
| 2497 | 
             
                        )
         | 
| 2498 | 
            +
                        if xs_pad.device != "cpu":
         | 
| 2499 | 
            +
                            enc_streaming_mask_nc = enc_streaming_mask_nc.to(xs_pad.device)
         | 
| 2500 | 
             
                        if masks is not None:
         | 
| 2501 | 
             
                            hs_mask_nc = masks & enc_streaming_mask_nc
         | 
| 2502 | 
             
                        else:
         | 

