test_chatbot_2 / app.py
John6666's picture
Upload 2 files
c7c4c72 verified
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, TorchAoConfig
from threading import Thread
import os, subprocess, torch
from torchao.quantization import Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, Float8DynamicActivationFloat8WeightConfig
from torchao.dtypes import Int4CPULayout
#subprocess.run("pip list", shell=True)
IS_COMPILE = False if torch.cuda.is_available() else True
device = "cuda" if torch.cuda.is_available() else "cpu"
# https://huggingface.co/docs/transformers/en/quantization/torchao?examples-CPU=int8-dynamic-and-weight-only
if torch.cuda.is_available():
quant_config = Float8DynamicActivationFloat8WeightConfig()
else:
#quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout())
quant_config = Int8DynamicActivationInt8WeightConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)
#checkpoint = "HuggingFaceTB/SmolLM2-135M-Instruct"
checkpoint = "unsloth/gemma-3-4b-it"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
#model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32).to(device)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map=device, quantization_config=quantization_config).eval()
if IS_COMPILE:
model.generation_config.cache_implementation = "static"
input_text = "Warming up."
input_ids = tokenizer(input_text, return_tensors="pt").to(device)
output = model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
def get_duration(message, history, system_message, max_tokens, temperature, top_p, duration):
return duration
@spaces.GPU(duration=get_duration)
@torch.inference_mode()
def respond_stream(message, history, system_message, max_tokens, temperature, top_p, duration):
messages = [{"role": "system", "content": system_message}] + history + [{"role": "user", "content": message}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
gen_kwargs = dict(
input_ids=inputs["input_ids"],
#attention_mask=inputs["attention_mask"],
streamer=streamer,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
num_beams=1,
output_scores=False,
)
if IS_COMPILE: gen_kwargs["cache_implementation"] = "static"
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
partial = ""
for piece in streamer:
partial += piece
yield partial
@spaces.GPU(duration=get_duration)
@torch.inference_mode()
def respond(message, history, system_message, max_tokens, temperature, top_p, duration):
messages = [{"role": "system", "content": system_message}] + history + [{"role": "user", "content": message}]
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
gen_kwargs = dict(
input_ids=inputs["input_ids"],
#attention_mask=inputs["attention_mask"],
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
num_beams=1,
output_scores=False,
)
if IS_COMPILE: gen_kwargs["cache_implementation"] = "static"
outputs = model.generate(**gen_kwargs)
gen_ids = outputs[0][inputs["input_ids"].shape[-1]:]
return tokenizer.decode(gen_ids, skip_special_tokens=True)
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
gr.Slider(minimum=1, maximum=360, value=30, step=1, label="Duration"),
],
)
with gr.Blocks() as demo:
chatbot.render()
if __name__ == "__main__":
demo.queue().launch()