John6666 commited on
Commit
c7c4c72
·
verified ·
1 Parent(s): 8b1e84b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +32 -14
  2. requirements.txt +1 -0
app.py CHANGED
@@ -2,21 +2,35 @@ import spaces
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, TorchAoConfig
4
  from threading import Thread
5
- import torch
6
- from torchao.quantization import Int8DynamicActivationInt8WeightConfig
7
- import subprocess
8
 
9
- subprocess.run("pip list", shell=True)
10
 
11
- quant_config = Int8DynamicActivationInt8WeightConfig()
 
 
 
 
 
 
 
 
 
 
 
12
  quantization_config = TorchAoConfig(quant_type=quant_config)
13
  #checkpoint = "HuggingFaceTB/SmolLM2-135M-Instruct"
14
  checkpoint = "unsloth/gemma-3-4b-it"
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
17
- model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
18
- device_map=device, quantization_config=quantization_config)
19
  #model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32).to(device)
 
 
 
 
 
 
 
20
 
21
  def get_duration(message, history, system_message, max_tokens, temperature, top_p, duration):
22
  return duration
@@ -32,7 +46,7 @@ def respond_stream(message, history, system_message, max_tokens, temperature, to
32
  add_generation_prompt=True,
33
  return_tensors="pt",
34
  return_dict=True,
35
- ).to(device)
36
 
37
  streamer = TextIteratorStreamer(
38
  tokenizer, skip_prompt=True, skip_special_tokens=True
@@ -46,8 +60,10 @@ def respond_stream(message, history, system_message, max_tokens, temperature, to
46
  temperature=temperature,
47
  top_p=top_p,
48
  eos_token_id=tokenizer.eos_token_id,
49
- cache_implementation="static",
 
50
  )
 
51
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
52
  thread.start()
53
 
@@ -65,12 +81,11 @@ def respond(message, history, system_message, max_tokens, temperature, top_p, du
65
  messages,
66
  tokenize=True,
67
  add_generation_prompt=True,
68
- padding=True,
69
  return_tensors="pt",
70
  return_dict=True,
71
- ).to(device)
72
 
73
- outputs = model.generate(
74
  input_ids=inputs["input_ids"],
75
  #attention_mask=inputs["attention_mask"],
76
  max_new_tokens=max_tokens,
@@ -78,8 +93,11 @@ def respond(message, history, system_message, max_tokens, temperature, top_p, du
78
  temperature=temperature,
79
  top_p=top_p,
80
  eos_token_id=tokenizer.eos_token_id,
81
- cache_implementation="static",
 
82
  )
 
 
83
 
84
  gen_ids = outputs[0][inputs["input_ids"].shape[-1]:]
85
  return tokenizer.decode(gen_ids, skip_special_tokens=True)
 
2
  import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, TorchAoConfig
4
  from threading import Thread
5
+ import os, subprocess, torch
6
+ from torchao.quantization import Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, Float8DynamicActivationFloat8WeightConfig
7
+ from torchao.dtypes import Int4CPULayout
8
 
 
9
 
10
+ #subprocess.run("pip list", shell=True)
11
+
12
+ IS_COMPILE = False if torch.cuda.is_available() else True
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ # https://huggingface.co/docs/transformers/en/quantization/torchao?examples-CPU=int8-dynamic-and-weight-only
16
+ if torch.cuda.is_available():
17
+ quant_config = Float8DynamicActivationFloat8WeightConfig()
18
+ else:
19
+ #quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout())
20
+ quant_config = Int8DynamicActivationInt8WeightConfig()
21
+
22
  quantization_config = TorchAoConfig(quant_type=quant_config)
23
  #checkpoint = "HuggingFaceTB/SmolLM2-135M-Instruct"
24
  checkpoint = "unsloth/gemma-3-4b-it"
 
25
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
 
 
26
  #model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32).to(device)
27
+ model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
28
+ device_map=device, quantization_config=quantization_config).eval()
29
+ if IS_COMPILE:
30
+ model.generation_config.cache_implementation = "static"
31
+ input_text = "Warming up."
32
+ input_ids = tokenizer(input_text, return_tensors="pt").to(device)
33
+ output = model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
34
 
35
  def get_duration(message, history, system_message, max_tokens, temperature, top_p, duration):
36
  return duration
 
46
  add_generation_prompt=True,
47
  return_tensors="pt",
48
  return_dict=True,
49
+ ).to(model.device)
50
 
51
  streamer = TextIteratorStreamer(
52
  tokenizer, skip_prompt=True, skip_special_tokens=True
 
60
  temperature=temperature,
61
  top_p=top_p,
62
  eos_token_id=tokenizer.eos_token_id,
63
+ num_beams=1,
64
+ output_scores=False,
65
  )
66
+ if IS_COMPILE: gen_kwargs["cache_implementation"] = "static"
67
  thread = Thread(target=model.generate, kwargs=gen_kwargs)
68
  thread.start()
69
 
 
81
  messages,
82
  tokenize=True,
83
  add_generation_prompt=True,
 
84
  return_tensors="pt",
85
  return_dict=True,
86
+ ).to(model.device)
87
 
88
+ gen_kwargs = dict(
89
  input_ids=inputs["input_ids"],
90
  #attention_mask=inputs["attention_mask"],
91
  max_new_tokens=max_tokens,
 
93
  temperature=temperature,
94
  top_p=top_p,
95
  eos_token_id=tokenizer.eos_token_id,
96
+ num_beams=1,
97
+ output_scores=False,
98
  )
99
+ if IS_COMPILE: gen_kwargs["cache_implementation"] = "static"
100
+ outputs = model.generate(**gen_kwargs)
101
 
102
  gen_ids = outputs[0][inputs["input_ids"].shape[-1]:]
103
  return tokenizer.decode(gen_ids, skip_special_tokens=True)
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  huggingface_hub[hf_xet]
2
  torch
3
  torchao
 
4
  transformers
5
  accelerate
6
  peft
 
1
  huggingface_hub[hf_xet]
2
  torch
3
  torchao
4
+ triton
5
  transformers
6
  accelerate
7
  peft