akhaliq HF Staff commited on
Commit
f36fe6f
·
verified ·
1 Parent(s): ce8ea41

Deploy Gradio app with multiple files

Browse files
Files changed (4) hide show
  1. app.py +47 -0
  2. config.py +11 -0
  3. models.py +132 -0
  4. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ import gradio as gr
3
+ from models import stream_generate_response
4
+
5
+ # Header Link
6
+ ANYCODER_LINK = "<a href='https://huggingface.co/spaces/akhaliq/anycoder' target='_blank'>Built with anycoder</a>"
7
+
8
+ with gr.Blocks(title="KAT-Dev Chat", theme=gr.themes.Soft()) as demo:
9
+ gr.HTML(
10
+ f"""
11
+ <div style="text-align: center; max-width: 800px; margin: 0 auto;">
12
+ <h1>💬 KAT-Dev LLM Chat</h1>
13
+ <p>Powered by Kwaipilot/KAT-Dev, a large language model. This application uses Hugging Face ZeroGPU for highly efficient inference.</p>
14
+ {ANYCODER_LINK}
15
+ </div>
16
+ """
17
+ )
18
+
19
+ # ChatInterface handles the full conversational UI, streaming, and history management
20
+ chat_interface = gr.ChatInterface(
21
+ fn=stream_generate_response,
22
+ title="", # Title moved to HTML block
23
+ chatbot=gr.Chatbot(
24
+ height=500,
25
+ show_copy_button=True,
26
+ layout="bubble"
27
+ ),
28
+ textbox=gr.Textbox(
29
+ placeholder="Ask the KAT model anything...",
30
+ container=False,
31
+ scale=7
32
+ ),
33
+ # Ensure streaming is active
34
+ # Setting stream_every to a small value ensures rapid updates
35
+ stream_every=0.1,
36
+
37
+ # Disable the default submit button text since we have an icon
38
+ submit_btn=True,
39
+ stop_btn=True,
40
+
41
+ # Concurrency limit handled by @spaces.GPU
42
+ concurrency_limit=10,
43
+ )
44
+
45
+ demo.queue()
46
+ demo.launch()
47
+ ```
config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ import torch
3
+
4
+ # Model Configuration
5
+ MODEL_NAME = "Kwaipilot/KAT-Dev"
6
+
7
+ # Generation Configuration
8
+ MAX_NEW_TOKENS = 1024
9
+ TEMPERATURE = 0.7
10
+ DO_SAMPLE = True
11
+ ```
models.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ import spaces
3
+ import torch
4
+ import numpy as np
5
+ from typing import Generator
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
7
+ from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE
8
+
9
+ # Global variables to store the model and tokenizer
10
+ # These are loaded under the GPU context to minimize overhead on subsequent calls.
11
+ tokenizer = None
12
+ model = None
13
+
14
+ def initialize_model():
15
+ """Initializes and loads the model and tokenizer once onto the GPU."""
16
+ global tokenizer, model
17
+ if model is None:
18
+ try:
19
+ print(f"Loading model {MODEL_NAME}...")
20
+
21
+ # Use bfloat16 for efficiency on modern GPUs (e.g., H100, A100)
22
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ MODEL_NAME,
27
+ torch_dtype=dtype,
28
+ device_map="auto" # Automatically handles device placement (GPU)
29
+ )
30
+ model.eval()
31
+
32
+ # Set padding token if not defined (common for Causal LMs)
33
+ if tokenizer.pad_token_id is None:
34
+ tokenizer.pad_token_id = tokenizer.eos_token_id
35
+
36
+ print("Model loaded successfully.")
37
+ except Exception as e:
38
+ print(f"Failed to load model: {e}")
39
+ raise
40
+ return tokenizer, model
41
+
42
+ # Call initialization immediately to ensure the model is ready when the worker starts up
43
+ # Note: This runs in the global scope, relying on the worker environment managing the GPU context.
44
+ try:
45
+ initialize_model()
46
+ except Exception as e:
47
+ print(f"Warning: Global model initialization failed: {e}. It will be re-attempted during the first inference call.")
48
+
49
+
50
+ @spaces.GPU(duration=120)
51
+ def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]:
52
+ """
53
+ Generates a response from the KAT model, streaming output token by token.
54
+
55
+ Args:
56
+ prompt: The current user input.
57
+ history: The accumulated chat history (list of [user_msg, bot_msg] tuples).
58
+
59
+ Yields:
60
+ str: Accumulated text response chunk.
61
+ """
62
+ global tokenizer, model
63
+
64
+ # Fallback initialization in case global loading failed
65
+ if model is None or tokenizer is None:
66
+ initialize_model()
67
+
68
+ # Convert Gradio history format to the model's chat template format
69
+ messages = []
70
+ for human, bot in history:
71
+ # Add past exchanges
72
+ if human:
73
+ messages.append({"role": "user", "content": human})
74
+ if bot:
75
+ messages.append({"role": "assistant", "content": bot})
76
+
77
+ # Add the current prompt
78
+ messages.append({"role": "user", "content": prompt})
79
+
80
+ # Apply chat template
81
+ text = tokenizer.apply_chat_template(
82
+ messages,
83
+ tokenize=False,
84
+ add_generation_prompt=True,
85
+ )
86
+
87
+ # Prepare inputs and move to model device
88
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
89
+
90
+ # Use TextStreamer for efficient token streaming
91
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
92
+
93
+ # Start generation in a separate thread (TextStreamer uses an internal blocking mechanism)
94
+ # Since Gradio's generator interface expects synchronous yields from the main thread
95
+ # within the @spaces.GPU context, we need to adapt the TextStreamer output.
96
+
97
+ # A cleaner approach for Gradio streaming is direct model generation without TextStreamer:
98
+
99
+ input_ids = model_inputs.input_ids
100
+
101
+ generated_ids = model.generate(
102
+ input_ids=input_ids,
103
+ max_new_tokens=MAX_NEW_TOKENS,
104
+ do_sample=DO_SAMPLE,
105
+ temperature=TEMPERATURE,
106
+ pad_token_id=tokenizer.eos_token_id,
107
+ return_dict_in_generate=True,
108
+ output_scores=True,
109
+ min_new_tokens=1,
110
+ # Enable iterative decoding
111
+ repetition_penalty=1.1,
112
+ )
113
+
114
+ full_response = ""
115
+ # Process output sequence token by token
116
+ for seq in generated_ids.sequences:
117
+ # Get the new tokens generated after the prompt
118
+ new_tokens = seq[input_ids.shape[-1]:]
119
+
120
+ # Decode only the newly generated part of the sequence so far
121
+ current_response = tokenizer.decode(new_tokens, skip_special_tokens=True)
122
+
123
+ # Yield only the difference from the previous chunk
124
+ if len(current_response) > len(full_response):
125
+ new_text = current_response[len(full_response):]
126
+ full_response = current_response
127
+ yield new_text
128
+
129
+ # Final cleanup (sometimes the model output is slightly messy)
130
+ if full_response:
131
+ yield full_response.strip()
132
+ ```
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ```
2
+ gradio>=4.0
3
+ torch
4
+ transformers
5
+ accelerate
6
+ numpy
7
+ huggingface-hub
8
+ bitsandbytes
9
+ ```