Utiric commited on
Commit
54ba978
·
verified ·
1 Parent(s): 4bdd945

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -43
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import time
2
  import threading
3
  import torch
@@ -6,15 +7,18 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStream
6
 
7
  MODEL_NAME = "daniel-dona/gemma-3-270m-it"
8
 
9
- # CPU optimizasyonları
10
- torch.set_num_threads(torch.get_num_threads()) # Tüm çekirdekleri kullan
11
- torch.set_float32_matmul_precision("high") # Matmul hızını artır
 
 
 
 
12
 
13
- # Model/Tokenizer global yükleme
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
  model = AutoModelForCausalLM.from_pretrained(
16
  MODEL_NAME,
17
- torch_dtype=torch.float32, # CPU'da float32
18
  device_map=None
19
  )
20
  model.eval()
@@ -27,71 +31,50 @@ def build_prompt(message, history, system_message, max_ctx_tokens=1024):
27
  if a:
28
  msgs.append({"role": "assistant", "content": a})
29
  msgs.append({"role": "user", "content": message})
30
-
31
- # Token bütçesi ile kırpma
32
  while True:
33
- text = tokenizer.apply_chat_template(
34
- msgs, tokenize=False, add_generation_prompt=True
35
- )
36
  if len(tokenizer(text, add_special_tokens=False).input_ids) <= max_ctx_tokens:
37
  return text
38
- # En eski user+assistant çiftini at (system'i koru)
39
  for i in range(1, len(msgs)):
40
  if msgs[i]["role"] != "system":
41
  del msgs[i:i+2]
42
  break
43
 
44
  def respond_stream(message, history, system_message, max_tokens, temperature, top_p):
45
- # İlk mesajda tüm prompt'u veriyoruz; sonraki turlarda da bu örnek basit tutularak aynı akış korunuyor.
46
- # (HF TextIteratorStreamer ile generate() sonrası past_key_values dışarı alınmadığı için
47
- # bu sürüm KV cache’i oturumlar arası taşımıyor; hız için streaming + bağlam kırpma kullanıyoruz.)
48
  text = build_prompt(message, history, system_message)
49
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
50
-
51
- do_sample = temperature > 0
52
  gen_kwargs = dict(
53
  max_new_tokens=max_tokens,
54
  do_sample=do_sample,
55
  top_p=top_p,
56
  temperature=temperature if do_sample else None,
57
- use_cache=True, # decode aşamasında KV cache'i etkin
58
  eos_token_id=tokenizer.eos_token_id,
59
  pad_token_id=tokenizer.eos_token_id,
60
  )
61
-
62
- # skip_prompt=True ile prompt’un ekrana yazılmasını engelleriz (Transformers >= 4.42 gerektirir)
63
  try:
64
- streamer = TextIteratorStreamer(
65
- tokenizer, skip_special_tokens=True, skip_prompt=True
66
- )
67
  except TypeError:
68
- # Eski sürüm uyumluluğu: skip_prompt yoksa, yine de çalışır ama ilk chunk'ta prompt kırıntısı gelebilir
69
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
70
-
71
- thread = threading.Thread(
72
- target=model.generate,
73
- kwargs={**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None}, "streamer": streamer}
74
- )
75
-
76
  partial_text = ""
77
- start_time = None # İlk token geldiği anı işaretler
78
  with torch.inference_mode():
79
  thread.start()
80
- for chunk in streamer:
81
- if start_time is None:
82
- start_time = time.time()
83
- partial_text += chunk
84
- yield partial_text # append streaming: önceki + yeni chunk
85
- thread.join()
86
-
 
87
  end_time = time.time() if start_time is not None else time.time()
88
-
89
- # Üretilen token sayısını final metinden hesapla
90
- gen_token_count = len(tokenizer(partial_text, add_special_tokens=False).input_ids)
91
  duration = max(1e-6, end_time - start_time) if start_time else 0.0
 
92
  tps = (gen_token_count / duration) if duration > 0 else 0.0
93
-
94
- yield partial_text + f"\n\n⚡ **Hız:** {tps:.2f} token/sn"
95
 
96
  demo = gr.ChatInterface(
97
  respond_stream,
@@ -104,5 +87,6 @@ demo = gr.ChatInterface(
104
  )
105
 
106
  if __name__ == "__main__":
107
- # Gradio’nun stream buffer hatalarını azaltmak için queue iyi sonuç verir
 
108
  demo.queue().launch()
 
1
+ import os
2
  import time
3
  import threading
4
  import torch
 
7
 
8
  MODEL_NAME = "daniel-dona/gemma-3-270m-it"
9
 
10
+ os.environ.setdefault("OMP_NUM_THREADS", str(os.cpu_count() or 1))
11
+ os.environ.setdefault("MKL_NUM_THREADS", os.environ["OMP_NUM_THREADS"])
12
+ os.environ.setdefault("OMP_PROC_BIND", "TRUE")
13
+
14
+ torch.set_num_threads(int(os.environ["OMP_NUM_THREADS"]))
15
+ torch.set_num_interop_threads(1)
16
+ torch.set_float32_matmul_precision("high")
17
 
 
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
  model = AutoModelForCausalLM.from_pretrained(
20
  MODEL_NAME,
21
+ torch_dtype=torch.float32,
22
  device_map=None
23
  )
24
  model.eval()
 
31
  if a:
32
  msgs.append({"role": "assistant", "content": a})
33
  msgs.append({"role": "user", "content": message})
 
 
34
  while True:
35
+ text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
 
 
36
  if len(tokenizer(text, add_special_tokens=False).input_ids) <= max_ctx_tokens:
37
  return text
 
38
  for i in range(1, len(msgs)):
39
  if msgs[i]["role"] != "system":
40
  del msgs[i:i+2]
41
  break
42
 
43
  def respond_stream(message, history, system_message, max_tokens, temperature, top_p):
 
 
 
44
  text = build_prompt(message, history, system_message)
45
  inputs = tokenizer([text], return_tensors="pt").to(model.device)
46
+ do_sample = bool(temperature and temperature > 0.0)
 
47
  gen_kwargs = dict(
48
  max_new_tokens=max_tokens,
49
  do_sample=do_sample,
50
  top_p=top_p,
51
  temperature=temperature if do_sample else None,
52
+ use_cache=True,
53
  eos_token_id=tokenizer.eos_token_id,
54
  pad_token_id=tokenizer.eos_token_id,
55
  )
 
 
56
  try:
57
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
 
 
58
  except TypeError:
 
59
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
60
+ thread = threading.Thread(target=model.generate, kwargs={**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None}, "streamer": streamer})
 
 
 
 
 
61
  partial_text = ""
62
+ start_time = None
63
  with torch.inference_mode():
64
  thread.start()
65
+ try:
66
+ for chunk in streamer:
67
+ if start_time is None:
68
+ start_time = time.time()
69
+ partial_text += chunk
70
+ yield partial_text
71
+ finally:
72
+ thread.join()
73
  end_time = time.time() if start_time is not None else time.time()
 
 
 
74
  duration = max(1e-6, end_time - start_time) if start_time else 0.0
75
+ gen_token_count = len(tokenizer(partial_text, add_special_tokens=False).input_ids)
76
  tps = (gen_token_count / duration) if duration > 0 else 0.0
77
+ yield partial_text + f"\n\n⚡ Hız: {tps:.2f} token/sn"
 
78
 
79
  demo = gr.ChatInterface(
80
  respond_stream,
 
87
  )
88
 
89
  if __name__ == "__main__":
90
+ with torch.inference_mode():
91
+ _ = model.generate(**tokenizer(["Hi"], return_tensors="pt").to(model.device), max_new_tokens=1, do_sample=False, use_cache=True)
92
  demo.queue().launch()