Spaces:
Runtime error
Runtime error
| import torch | |
| # copy from:https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L84 | |
| def precompute_freqs_cis(dim: int, | |
| end: int, | |
| theta: float = 10000.0) -> torch.Tensor: | |
| """Precomputes the frequency cis.""" | |
| freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) | |
| t = torch.arange(end, device=freqs.device) | |
| freqs = torch.outer(t, freqs).float() | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
| return freqs_cis | |
| # modified from: | |
| # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L95 | |
| def google_apply_rotary_emb(x: torch.Tensor, | |
| freqs_cis: torch.Tensor) -> torch.Tensor: | |
| """Applies the rotary embedding to the query and key tensors.""" | |
| x_ = torch.view_as_complex( | |
| torch.stack(torch.chunk(x.float(), 2, dim=-1), dim=-1)) | |
| x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) | |
| x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) | |
| x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1) | |
| return x_out | |
| def llama_apply_rotary_emb(x: torch.Tensor, | |
| freqs_cis: torch.Tensor) -> torch.Tensor: | |
| x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) | |
| x_out = torch.view_as_real(x_ * freqs_cis).flatten(3) | |
| return x_out.type_as(x) | |
| WENET_APPLY_ROTARY_EMB = { | |
| 'google': google_apply_rotary_emb, | |
| 'llama': llama_apply_rotary_emb, | |
| } | |