akhaliq HF Staff commited on
Commit
0763b5e
·
verified ·
1 Parent(s): eae8d97

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +138 -38
app.py CHANGED
@@ -1,41 +1,141 @@
1
- import gradio as gr
2
- from models import stream_generate_response
3
-
4
- # Header Link
5
- ANYCODER_LINK = "<a href='https://huggingface.co/spaces/akhaliq/anycoder' target='_blank'>Built with anycoder</a>"
6
-
7
- with gr.Blocks(title="KAT-Dev Chat", theme=gr.themes.Soft()) as demo:
8
- gr.HTML(
9
- f"""
10
- <div style="text-align: center; max-width: 800px; margin: 0 auto;">
11
- <h1>💬 KAT-Dev LLM Chat</h1>
12
- <p>Powered by Kwaipilot/KAT-Dev, a large language model. This application uses Hugging Face ZeroGPU for highly efficient inference.</p>
13
- {ANYCODER_LINK}
14
- </div>
15
- """
16
- )
17
 
18
- # ChatInterface handles the full conversational UI, streaming, and history management
19
- chat_interface = gr.ChatInterface(
20
- fn=stream_generate_response,
21
- title="", # Title moved to HTML block
22
- chatbot=gr.Chatbot(
23
- height=500,
24
- show_copy_button=True,
25
- layout="bubble"
26
- ),
27
- textbox=gr.Textbox(
28
- placeholder="Ask the KAT model anything...",
29
- container=False,
30
- scale=7
31
- ),
32
- # Disable the default submit button text since we have an icon
33
- submit_btn=True,
34
- stop_btn=True,
35
-
36
- # Concurrency limit handled by @spaces.GPU
37
- concurrency_limit=10,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- demo.queue()
41
- demo.launch()
 
1
+ I see the issues! The error is happening because the custom streamer isn't handling the input correctly, and we're not properly setting the attention mask. Let me fix the streaming implementation:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ === models.py ===
4
+ import spaces
5
+ import torch
6
+ import numpy as np
7
+ from typing import Generator
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE
10
+
11
+ # Global variables to store the model and tokenizer
12
+ tokenizer = None
13
+ model = None
14
+
15
+ def initialize_model():
16
+ """Initializes and loads the model and tokenizer once onto the GPU."""
17
+ global tokenizer, model
18
+ if model is None:
19
+ try:
20
+ print(f"Loading model {MODEL_NAME}...")
21
+
22
+ # Use bfloat16 for efficiency on modern GPUs
23
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ MODEL_NAME,
28
+ torch_dtype=dtype,
29
+ device_map="auto"
30
+ )
31
+ model.eval()
32
+
33
+ # Set padding token if not defined
34
+ if tokenizer.pad_token_id is None:
35
+ tokenizer.pad_token_id = tokenizer.eos_token_id
36
+
37
+ print("Model loaded successfully.")
38
+ except Exception as e:
39
+ print(f"Failed to load model: {e}")
40
+ raise
41
+ return tokenizer, model
42
+
43
+ # Call initialization
44
+ try:
45
+ initialize_model()
46
+ except Exception as e:
47
+ print(f"Warning: Global model initialization failed: {e}")
48
+
49
+ @spaces.GPU(duration=120)
50
+ def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]:
51
+ """
52
+ Generates a response from the KAT model with proper streaming.
53
+ """
54
+ global tokenizer, model
55
+
56
+ # Fallback initialization
57
+ if model is None or tokenizer is None:
58
+ initialize_model()
59
+
60
+ # Convert Gradio history format to the model's chat template format
61
+ messages = []
62
+ for human, bot in history:
63
+ if human:
64
+ messages.append({"role": "user", "content": human})
65
+ if bot:
66
+ messages.append({"role": "assistant", "content": bot})
67
+
68
+ # Add the current prompt
69
+ messages.append({"role": "user", "content": prompt})
70
+
71
+ # Apply chat template
72
+ text = tokenizer.apply_chat_template(
73
+ messages,
74
+ tokenize=False,
75
+ add_generation_prompt=True,
76
  )
77
+
78
+ # Tokenize with attention mask
79
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
80
+ input_ids = inputs.input_ids.to(model.device)
81
+ attention_mask = inputs.attention_mask.to(model.device)
82
+
83
+ # Generate with streaming using yield-based approach
84
+ accumulated_text = ""
85
+
86
+ # Generate tokens incrementally
87
+ for _ in range(MAX_NEW_TOKENS):
88
+ with torch.no_grad():
89
+ outputs = model(
90
+ input_ids=input_ids,
91
+ attention_mask=attention_mask,
92
+ return_dict=True
93
+ )
94
+
95
+ # Get next token probabilities
96
+ next_token_logits = outputs.logits[:, -1, :]
97
+
98
+ # Apply temperature
99
+ if TEMPERATURE > 0:
100
+ next_token_logits = next_token_logits / TEMPERATURE
101
+
102
+ # Apply softmax and sample
103
+ probs = torch.softmax(next_token_logits, dim=-1)
104
+ if DO_SAMPLE:
105
+ next_token = torch.multinomial(probs, num_samples=1)
106
+ else:
107
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
108
+
109
+ # Check for EOS token
110
+ if next_token.item() == tokenizer.eos_token_id:
111
+ break
112
+
113
+ # Decode the new token
114
+ new_token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
115
+
116
+ # Update accumulated text
117
+ accumulated_text += new_token_text
118
+
119
+ # Yield the current accumulated text
120
+ yield accumulated_text
121
+
122
+ # Prepare for next iteration
123
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
124
+ attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
125
+
126
+ # Stop if we've reached max tokens
127
+ if input_ids.shape[-1] >= input_ids.shape[-1] + MAX_NEW_TOKENS:
128
+ break
129
+
130
+ # Final yield to ensure complete text
131
+ yield accumulated_text.strip()
132
+
133
+ The key changes I made:
134
+
135
+ 1. **Fixed attention mask**: Now properly sets `attention_mask` when tokenizing to avoid the warning
136
+ 2. **Simplified streaming**: Using a manual token-by-token generation loop instead of the complex custom streamer
137
+ 3. **Proper tensor handling**: Correctly handles token tensors and decoding
138
+ 4. **EOS handling**: Properly stops generation when end-of-sequence token is encountered
139
+ 5. **Memory efficiency**: Uses `torch.no_grad()` for inference to save memory
140
 
141
+ This implementation should now properly stream tokens one by one and yield the accumulated text to the Gradio interface for real-time display.