akhaliq HF Staff commited on
Commit
e71ab18
·
verified ·
1 Parent(s): a5660ec

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +61 -81
models.py CHANGED
@@ -2,11 +2,10 @@ import spaces
2
  import torch
3
  import numpy as np
4
  from typing import Generator
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
6
  from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE
7
 
8
  # Global variables to store the model and tokenizer
9
- # These are loaded under the GPU context to minimize overhead on subsequent calls.
10
  tokenizer = None
11
  model = None
12
 
@@ -17,18 +16,18 @@ def initialize_model():
17
  try:
18
  print(f"Loading model {MODEL_NAME}...")
19
 
20
- # Use bfloat16 for efficiency on modern GPUs (e.g., H100, A100)
21
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
22
 
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
  model = AutoModelForCausalLM.from_pretrained(
25
  MODEL_NAME,
26
  torch_dtype=dtype,
27
- device_map="auto" # Automatically handles device placement (GPU)
28
  )
29
  model.eval()
30
 
31
- # Set padding token if not defined (common for Causal LMs)
32
  if tokenizer.pad_token_id is None:
33
  tokenizer.pad_token_id = tokenizer.eos_token_id
34
 
@@ -38,49 +37,33 @@ def initialize_model():
38
  raise
39
  return tokenizer, model
40
 
41
- # Call initialization immediately to ensure the model is ready when the worker starts up
42
- # Note: This runs in the global scope, relying on the worker environment managing the GPU context.
43
  try:
44
  initialize_model()
45
  except Exception as e:
46
- print(f"Warning: Global model initialization failed: {e}. It will be re-attempted during the first inference call.")
47
-
48
 
49
  @spaces.GPU(duration=120)
50
  def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]:
51
  """
52
- Generates a response from the KAT model, streaming output token by token.
53
-
54
- Args:
55
- prompt: The current user input.
56
- history: The accumulated chat history (list of [user_msg, bot_msg] tuples).
57
-
58
- Yields:
59
- str: Accumulated text response chunk.
60
  """
61
  global tokenizer, model
62
 
63
- # Fallback initialization in case global loading failed
64
  if model is None or tokenizer is None:
65
  initialize_model()
66
 
67
  # Convert Gradio history format to the model's chat template format
68
  messages = []
69
  for human, bot in history:
70
- # Add past exchanges
71
  if human:
72
- messages.append({
73
- "role": "user", "content": human
74
- })
75
  if bot:
76
- messages.append({
77
- "role": "assistant", "content": bot
78
- })
79
 
80
  # Add the current prompt
81
- messages.append({
82
- "role": "user", "content": prompt
83
- })
84
 
85
  # Apply chat template
86
  text = tokenizer.apply_chat_template(
@@ -89,60 +72,57 @@ def stream_generate_response(prompt: str, history: list) -> Generator[str, None,
89
  add_generation_prompt=True,
90
  )
91
 
92
- # Prepare inputs and move to model device
93
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
 
 
 
 
 
94
 
95
- # Create a custom streamer that works with Gradio
96
- class GradioStreamer:
97
- def __init__(self, tokenizer):
98
- self.tokenizer = tokenizer
99
- self.text_queue = []
100
- self.generated_text = ""
 
 
101
 
102
- def put(self, value):
103
- # Decode the new tokens and add to queue
104
- if isinstance(value, torch.Tensor):
105
- new_text = self.tokenizer.decode(value, skip_special_tokens=True)
106
- # Only yield the new part
107
- if new_text.startswith(self.generated_text):
108
- new_part = new_text[len(self.generated_text):]
109
- if new_part:
110
- self.text_queue.append(new_part)
111
- self.generated_text = new_text
112
- else:
113
- # Sometimes the decoding might not align perfectly
114
- self.text_queue.append(new_text)
115
- self.generated_text = new_text
116
-
117
- def end(self):
118
- pass
119
 
120
- def __iter__(self):
121
- return iter(self.text_queue)
122
-
123
- # Create our custom streamer
124
- gradio_streamer = GradioStreamer(tokenizer)
125
-
126
- # Generate with streaming
127
- input_ids = model_inputs.input_ids
128
-
129
- # Generate tokens one by one for true streaming
130
- generated_ids = model.generate(
131
- input_ids=input_ids,
132
- max_new_tokens=MAX_NEW_TOKENS,
133
- do_sample=DO_SAMPLE,
134
- temperature=TEMPERATURE,
135
- pad_token_id=tokenizer.eos_token_id,
136
- streamer=gradio_streamer,
137
- repetition_penalty=1.1,
138
- )
139
-
140
- # Yield the text as it's generated
141
- accumulated_text = ""
142
- for new_chunk in gradio_streamer.text_queue:
143
- accumulated_text += new_chunk
144
  yield accumulated_text
145
-
146
- # Final yield to ensure complete text is sent
147
- if accumulated_text:
148
- yield accumulated_text.strip()
 
 
 
 
 
 
 
 
2
  import torch
3
  import numpy as np
4
  from typing import Generator
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
  from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, DO_SAMPLE
7
 
8
  # Global variables to store the model and tokenizer
 
9
  tokenizer = None
10
  model = None
11
 
 
16
  try:
17
  print(f"Loading model {MODEL_NAME}...")
18
 
19
+ # Use bfloat16 for efficiency on modern GPUs
20
  dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
21
 
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
23
  model = AutoModelForCausalLM.from_pretrained(
24
  MODEL_NAME,
25
  torch_dtype=dtype,
26
+ device_map="auto"
27
  )
28
  model.eval()
29
 
30
+ # Set padding token if not defined
31
  if tokenizer.pad_token_id is None:
32
  tokenizer.pad_token_id = tokenizer.eos_token_id
33
 
 
37
  raise
38
  return tokenizer, model
39
 
40
+ # Call initialization
 
41
  try:
42
  initialize_model()
43
  except Exception as e:
44
+ print(f"Warning: Global model initialization failed: {e}")
 
45
 
46
  @spaces.GPU(duration=120)
47
  def stream_generate_response(prompt: str, history: list) -> Generator[str, None, None]:
48
  """
49
+ Generates a response from the KAT model with proper streaming.
 
 
 
 
 
 
 
50
  """
51
  global tokenizer, model
52
 
53
+ # Fallback initialization
54
  if model is None or tokenizer is None:
55
  initialize_model()
56
 
57
  # Convert Gradio history format to the model's chat template format
58
  messages = []
59
  for human, bot in history:
 
60
  if human:
61
+ messages.append({"role": "user", "content": human})
 
 
62
  if bot:
63
+ messages.append({"role": "assistant", "content": bot})
 
 
64
 
65
  # Add the current prompt
66
+ messages.append({"role": "user", "content": prompt})
 
 
67
 
68
  # Apply chat template
69
  text = tokenizer.apply_chat_template(
 
72
  add_generation_prompt=True,
73
  )
74
 
75
+ # Tokenize with attention mask
76
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
77
+ input_ids = inputs.input_ids.to(model.device)
78
+ attention_mask = inputs.attention_mask.to(model.device)
79
+
80
+ # Generate with streaming using yield-based approach
81
+ accumulated_text = ""
82
 
83
+ # Generate tokens incrementally
84
+ for _ in range(MAX_NEW_TOKENS):
85
+ with torch.no_grad():
86
+ outputs = model(
87
+ input_ids=input_ids,
88
+ attention_mask=attention_mask,
89
+ return_dict=True
90
+ )
91
 
92
+ # Get next token probabilities
93
+ next_token_logits = outputs.logits[:, -1, :]
94
+
95
+ # Apply temperature
96
+ if TEMPERATURE > 0:
97
+ next_token_logits = next_token_logits / TEMPERATURE
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # Apply softmax and sample
100
+ probs = torch.softmax(next_token_logits, dim=-1)
101
+ if DO_SAMPLE:
102
+ next_token = torch.multinomial(probs, num_samples=1)
103
+ else:
104
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
105
+
106
+ # Check for EOS token
107
+ if next_token.item() == tokenizer.eos_token_id:
108
+ break
109
+
110
+ # Decode the new token
111
+ new_token_text = tokenizer.decode(next_token[0], skip_special_tokens=True)
112
+
113
+ # Update accumulated text
114
+ accumulated_text += new_token_text
115
+
116
+ # Yield the current accumulated text
 
 
 
 
 
 
117
  yield accumulated_text
118
+
119
+ # Prepare for next iteration
120
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
121
+ attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
122
+
123
+ # Stop if we've reached max tokens
124
+ if input_ids.shape[-1] >= input_ids.shape[-1] + MAX_NEW_TOKENS:
125
+ break
126
+
127
+ # Final yield to ensure complete text
128
+ yield accumulated_text.strip()