Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload folder using huggingface_hub
Browse files- README.md +6 -7
- app.py +78 -0
- gpt_infer.py +190 -0
- requirements.txt +5 -0
    	
        README.md
    CHANGED
    
    | @@ -1,12 +1,11 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title:  | 
| 3 | 
            -
            emoji:  | 
| 4 | 
            -
            colorFrom:  | 
| 5 | 
            -
            colorTo:  | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version:  | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
            -
            pinned: false
         | 
| 10 | 
             
            ---
         | 
| 11 |  | 
| 12 | 
            -
             | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: nanochat ZeroGPU Demo
         | 
| 3 | 
            +
            emoji: 🤖
         | 
| 4 | 
            +
            colorFrom: indigo
         | 
| 5 | 
            +
            colorTo: blue
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 4.41.0
         | 
| 8 | 
             
            app_file: app.py
         | 
|  | |
| 9 | 
             
            ---
         | 
| 10 |  | 
| 11 | 
            +
            ZeroGPU demo Space serving nanochat SFT/MID/BASE models (CPU inference).
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,78 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
            import gradio as gr
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from huggingface_hub import hf_hub_download
         | 
| 7 | 
            +
            from tokenizers import Tokenizer as HFTokenizer
         | 
| 8 | 
            +
            from gpt_infer import GPT, GPTConfig
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            DEFAULT_MODEL = os.environ.get('NANOCHAT_DEFAULT_MODEL', 'loocorez/nanochat-sft-d20-step650')
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            def load_model(repo_id: str):
         | 
| 13 | 
            +
                cfg_path = hf_hub_download(repo_id, 'config.json')
         | 
| 14 | 
            +
                with open(cfg_path, 'r') as f:
         | 
| 15 | 
            +
                    cfg = json.load(f)
         | 
| 16 | 
            +
                model_config = GPTConfig(
         | 
| 17 | 
            +
                    sequence_len=cfg.get('n_ctx', 2048),
         | 
| 18 | 
            +
                    vocab_size=cfg['vocab_size'],
         | 
| 19 | 
            +
                    n_layer=cfg['n_layer'],
         | 
| 20 | 
            +
                    n_head=cfg['n_head'],
         | 
| 21 | 
            +
                    n_kv_head=cfg.get('n_kv_head', cfg['n_head']),
         | 
| 22 | 
            +
                    n_embd=cfg['n_embd'],
         | 
| 23 | 
            +
                )
         | 
| 24 | 
            +
                model = GPT(model_config)
         | 
| 25 | 
            +
                model.eval()
         | 
| 26 | 
            +
                try:
         | 
| 27 | 
            +
                    from safetensors.torch import load_file
         | 
| 28 | 
            +
                    weights_path = hf_hub_download(repo_id, 'model.safetensors')
         | 
| 29 | 
            +
                    sd = load_file(weights_path)
         | 
| 30 | 
            +
                except Exception:
         | 
| 31 | 
            +
                    weights_path = hf_hub_download(repo_id, 'pytorch_model.bin')
         | 
| 32 | 
            +
                    sd = torch.load(weights_path, map_location='cpu')
         | 
| 33 | 
            +
                model.load_state_dict(sd, strict=True, assign=True)
         | 
| 34 | 
            +
                tok_path = hf_hub_download(repo_id, 'tokenizer.json')
         | 
| 35 | 
            +
                tok = HFTokenizer.from_file(tok_path)
         | 
| 36 | 
            +
                return model, tok
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            model_cache = {}
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            def get_model(repo_id: str):
         | 
| 41 | 
            +
                if repo_id not in model_cache:
         | 
| 42 | 
            +
                    model_cache[repo_id] = load_model(repo_id)
         | 
| 43 | 
            +
                return model_cache[repo_id]
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            @torch.inference_mode()
         | 
| 46 | 
            +
            def generate(repo_id: str, prompt: str, max_tokens: int, temperature: float, top_k: int|None):
         | 
| 47 | 
            +
                model, tok = get_model(repo_id)
         | 
| 48 | 
            +
                bos_id = tok.token_to_id('<|bos|>')
         | 
| 49 | 
            +
                ids = tok.encode(prompt).ids
         | 
| 50 | 
            +
                if bos_id is not None:
         | 
| 51 | 
            +
                    ids = [bos_id] + ids
         | 
| 52 | 
            +
                out_tokens = []
         | 
| 53 | 
            +
                for token in model.generate(ids, max_tokens=max_tokens, temperature=temperature, top_k=top_k or None):
         | 
| 54 | 
            +
                    out_tokens.append(token)
         | 
| 55 | 
            +
                text = tok.decode(out_tokens, skip_special_tokens=False)
         | 
| 56 | 
            +
                return text
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            with gr.Blocks() as demo:
         | 
| 59 | 
            +
                gr.Markdown('# nanochat (ZeroGPU)')
         | 
| 60 | 
            +
                gr.Markdown('Select a model and generate text.')
         | 
| 61 | 
            +
                repo = gr.Dropdown(choices=[
         | 
| 62 | 
            +
                    'loocorez/nanochat-sft-d20-step650',
         | 
| 63 | 
            +
                    'loocorez/nanochat-mid-d20-step765',
         | 
| 64 | 
            +
                    'loocorez/nanochat-base-d20-step21400',
         | 
| 65 | 
            +
                ], value=DEFAULT_MODEL, label='Model Repo')
         | 
| 66 | 
            +
                prompt = gr.Textbox(label='Prompt', lines=6)
         | 
| 67 | 
            +
                with gr.Row():
         | 
| 68 | 
            +
                    max_tokens = gr.Slider(1, 256, value=128, step=1, label='Max tokens')
         | 
| 69 | 
            +
                    temperature = gr.Slider(0.0, 1.5, value=0.8, step=0.05, label='Temperature')
         | 
| 70 | 
            +
                    top_k = gr.Slider(0, 100, value=40, step=1, label='Top-k (0=disabled)')
         | 
| 71 | 
            +
                btn = gr.Button('Generate')
         | 
| 72 | 
            +
                output = gr.Textbox(label='Output', lines=10)
         | 
| 73 | 
            +
                btn.click(fn=lambda r,p,m,t,k: generate(r,p,int(m),float(t),int(k) if int(k)>0 else None),
         | 
| 74 | 
            +
                          inputs=[repo, prompt, max_tokens, temperature, top_k],
         | 
| 75 | 
            +
                          outputs=output)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            if __name__ == '__main__':
         | 
| 78 | 
            +
                demo.launch()
         | 
    	
        gpt_infer.py
    ADDED
    
    | @@ -0,0 +1,190 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            from dataclasses import dataclass
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def norm(x: torch.Tensor) -> torch.Tensor:
         | 
| 10 | 
            +
                return F.rms_norm(x, (x.size(-1),))
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
         | 
| 14 | 
            +
                assert x.ndim == 4
         | 
| 15 | 
            +
                d = x.shape[3] // 2
         | 
| 16 | 
            +
                x1, x2 = x[..., :d], x[..., d:]
         | 
| 17 | 
            +
                y1 = x1 * cos + x2 * sin
         | 
| 18 | 
            +
                y2 = x1 * (-sin) + x2 * cos
         | 
| 19 | 
            +
                out = torch.cat([y1, y2], 3)
         | 
| 20 | 
            +
                out = out.to(x.dtype)
         | 
| 21 | 
            +
                return out
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
         | 
| 25 | 
            +
                if n_rep == 1:
         | 
| 26 | 
            +
                    return x
         | 
| 27 | 
            +
                bs, n_kv_heads, slen, head_dim = x.shape
         | 
| 28 | 
            +
                return (
         | 
| 29 | 
            +
                    x[:, :, None, :, :]
         | 
| 30 | 
            +
                    .expand(bs, n_kv_heads, n_rep, slen, head_dim)
         | 
| 31 | 
            +
                    .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
         | 
| 32 | 
            +
                )
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            @dataclass
         | 
| 36 | 
            +
            class GPTConfig:
         | 
| 37 | 
            +
                sequence_len: int = 1024
         | 
| 38 | 
            +
                vocab_size: int = 50304
         | 
| 39 | 
            +
                n_layer: int = 12
         | 
| 40 | 
            +
                n_head: int = 6
         | 
| 41 | 
            +
                n_kv_head: int = 6
         | 
| 42 | 
            +
                n_embd: int = 768
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            class CausalSelfAttention(nn.Module):
         | 
| 46 | 
            +
                def __init__(self, config: GPTConfig, layer_idx: int):
         | 
| 47 | 
            +
                    super().__init__()
         | 
| 48 | 
            +
                    self.layer_idx = layer_idx
         | 
| 49 | 
            +
                    self.n_head = config.n_head
         | 
| 50 | 
            +
                    self.n_kv_head = config.n_kv_head
         | 
| 51 | 
            +
                    self.n_embd = config.n_embd
         | 
| 52 | 
            +
                    self.head_dim = self.n_embd // self.n_head
         | 
| 53 | 
            +
                    assert self.n_embd % self.n_head == 0
         | 
| 54 | 
            +
                    assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
         | 
| 55 | 
            +
                    self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
         | 
| 56 | 
            +
                    self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
         | 
| 57 | 
            +
                    self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
         | 
| 58 | 
            +
                    self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def forward(self, x: torch.Tensor, cos_sin, kv_cache=None):
         | 
| 61 | 
            +
                    B, T, C = x.size()
         | 
| 62 | 
            +
                    q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
         | 
| 63 | 
            +
                    k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
         | 
| 64 | 
            +
                    v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
         | 
| 65 | 
            +
                    cos, sin = cos_sin
         | 
| 66 | 
            +
                    q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
         | 
| 67 | 
            +
                    q, k = norm(q), norm(k)
         | 
| 68 | 
            +
                    q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
         | 
| 69 | 
            +
                    if kv_cache is not None:
         | 
| 70 | 
            +
                        k, v = kv_cache.insert_kv(self.layer_idx, k, v)
         | 
| 71 | 
            +
                    Tq = q.size(2)
         | 
| 72 | 
            +
                    Tk = k.size(2)
         | 
| 73 | 
            +
                    nrep = self.n_head // self.n_kv_head
         | 
| 74 | 
            +
                    k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
         | 
| 75 | 
            +
                    if kv_cache is None or Tq == Tk:
         | 
| 76 | 
            +
                        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
         | 
| 77 | 
            +
                    elif Tq == 1:
         | 
| 78 | 
            +
                        y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
         | 
| 79 | 
            +
                    else:
         | 
| 80 | 
            +
                        attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device)
         | 
| 81 | 
            +
                        prefix_len = Tk - Tq
         | 
| 82 | 
            +
                        if prefix_len > 0:
         | 
| 83 | 
            +
                            attn_mask[:, :prefix_len] = True
         | 
| 84 | 
            +
                        attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
         | 
| 85 | 
            +
                        y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
         | 
| 86 | 
            +
                    y = y.transpose(1, 2).contiguous().view(B, T, -1)
         | 
| 87 | 
            +
                    y = self.c_proj(y)
         | 
| 88 | 
            +
                    return y
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            class MLP(nn.Module):
         | 
| 92 | 
            +
                def __init__(self, config: GPTConfig):
         | 
| 93 | 
            +
                    super().__init__()
         | 
| 94 | 
            +
                    self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
         | 
| 95 | 
            +
                    self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                def forward(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 98 | 
            +
                    x = self.c_fc(x)
         | 
| 99 | 
            +
                    x = F.relu(x).square()
         | 
| 100 | 
            +
                    x = self.c_proj(x)
         | 
| 101 | 
            +
                    return x
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            class Block(nn.Module):
         | 
| 105 | 
            +
                def __init__(self, config: GPTConfig, layer_idx: int):
         | 
| 106 | 
            +
                    super().__init__()
         | 
| 107 | 
            +
                    self.attn = CausalSelfAttention(config, layer_idx)
         | 
| 108 | 
            +
                    self.mlp = MLP(config)
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def forward(self, x: torch.Tensor, cos_sin, kv_cache=None) -> torch.Tensor:
         | 
| 111 | 
            +
                    x = x + self.attn(norm(x), cos_sin, kv_cache)
         | 
| 112 | 
            +
                    x = x + self.mlp(norm(x))
         | 
| 113 | 
            +
                    return x
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            class GPT(nn.Module):
         | 
| 117 | 
            +
                def __init__(self, config: GPTConfig):
         | 
| 118 | 
            +
                    super().__init__()
         | 
| 119 | 
            +
                    self.config = config
         | 
| 120 | 
            +
                    self.transformer = nn.ModuleDict({
         | 
| 121 | 
            +
                        'wte': nn.Embedding(config.vocab_size, config.n_embd),
         | 
| 122 | 
            +
                        'h': nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
         | 
| 123 | 
            +
                    })
         | 
| 124 | 
            +
                    self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
         | 
| 125 | 
            +
                    self.rotary_seq_len = config.sequence_len * 10
         | 
| 126 | 
            +
                    head_dim = config.n_embd // config.n_head
         | 
| 127 | 
            +
                    cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
         | 
| 128 | 
            +
                    self.register_buffer('cos', cos, persistent=False)
         | 
| 129 | 
            +
                    self.register_buffer('sin', sin, persistent=False)
         | 
| 130 | 
            +
                    self.transformer.wte.to(dtype=torch.bfloat16)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def _precompute_rotary_embeddings(self, seq_len: int, head_dim: int, base: int = 10000, device=None):
         | 
| 133 | 
            +
                    if device is None:
         | 
| 134 | 
            +
                        device = self.transformer.wte.weight.device
         | 
| 135 | 
            +
                    channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
         | 
| 136 | 
            +
                    inv_freq = 1.0 / (base ** (channel_range / head_dim))
         | 
| 137 | 
            +
                    t = torch.arange(seq_len, dtype=torch.float32, device=device)
         | 
| 138 | 
            +
                    freqs = torch.outer(t, inv_freq)
         | 
| 139 | 
            +
                    cos, sin = freqs.cos(), freqs.sin()
         | 
| 140 | 
            +
                    cos, sin = cos.bfloat16(), sin.bfloat16()
         | 
| 141 | 
            +
                    cos, sin = cos[None, :, None, :], sin[None, :, None, :]
         | 
| 142 | 
            +
                    return cos, sin
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache=None, loss_reduction: str = 'mean'):
         | 
| 145 | 
            +
                    B, T = idx.size()
         | 
| 146 | 
            +
                    assert T <= self.cos.size(1)
         | 
| 147 | 
            +
                    assert idx.device == self.cos.device
         | 
| 148 | 
            +
                    assert self.cos.dtype == torch.bfloat16
         | 
| 149 | 
            +
                    T0 = 0 if kv_cache is None else kv_cache.get_pos()
         | 
| 150 | 
            +
                    cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
         | 
| 151 | 
            +
                    x = self.transformer.wte(idx)
         | 
| 152 | 
            +
                    x = norm(x)
         | 
| 153 | 
            +
                    for block in self.transformer.h:
         | 
| 154 | 
            +
                        x = block(x, cos_sin, kv_cache)
         | 
| 155 | 
            +
                    x = norm(x)
         | 
| 156 | 
            +
                    softcap = 15
         | 
| 157 | 
            +
                    if targets is not None:
         | 
| 158 | 
            +
                        logits = self.lm_head(x)
         | 
| 159 | 
            +
                        logits = softcap * torch.tanh(logits / softcap)
         | 
| 160 | 
            +
                        logits = logits.float()
         | 
| 161 | 
            +
                        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
         | 
| 162 | 
            +
                        return loss
         | 
| 163 | 
            +
                    else:
         | 
| 164 | 
            +
                        logits = self.lm_head(x)
         | 
| 165 | 
            +
                        logits = softcap * torch.tanh(logits / softcap)
         | 
| 166 | 
            +
                        return logits
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                @torch.inference_mode()
         | 
| 169 | 
            +
                def generate(self, tokens: list[int], max_tokens: int, temperature: float = 1.0, top_k: int | None = None, seed: int = 42):
         | 
| 170 | 
            +
                    assert isinstance(tokens, list)
         | 
| 171 | 
            +
                    device = self.transformer.wte.weight.device
         | 
| 172 | 
            +
                    rng = None
         | 
| 173 | 
            +
                    if temperature > 0:
         | 
| 174 | 
            +
                        rng = torch.Generator(device=device)
         | 
| 175 | 
            +
                        rng.manual_seed(seed)
         | 
| 176 | 
            +
                    ids = torch.tensor([tokens], dtype=torch.long, device=device)
         | 
| 177 | 
            +
                    for _ in range(max_tokens):
         | 
| 178 | 
            +
                        logits = self.forward(ids)
         | 
| 179 | 
            +
                        logits = logits[:, -1, :]
         | 
| 180 | 
            +
                        if top_k is not None:
         | 
| 181 | 
            +
                            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
         | 
| 182 | 
            +
                            logits[logits < v[:, [-1]]] = -float('Inf')
         | 
| 183 | 
            +
                        if temperature > 0:
         | 
| 184 | 
            +
                            logits = logits / max(temperature, 1e-6)
         | 
| 185 | 
            +
                            probs = F.softmax(logits, dim=-1)
         | 
| 186 | 
            +
                            next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
         | 
| 187 | 
            +
                        else:
         | 
| 188 | 
            +
                            next_ids = torch.argmax(logits, dim=-1, keepdim=True)
         | 
| 189 | 
            +
                        ids = torch.cat((ids, next_ids), dim=1)
         | 
| 190 | 
            +
                        yield next_ids.item()
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            gradio==4.41.0
         | 
| 2 | 
            +
            huggingface_hub>=0.24.0
         | 
| 3 | 
            +
            tokenizers>=0.14.0
         | 
| 4 | 
            +
            safetensors>=0.4.3
         | 
| 5 | 
            +
            torch
         |