import spaces import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, TorchAoConfig from threading import Thread import os, subprocess, torch from torchao.quantization import Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, Float8DynamicActivationFloat8WeightConfig from torchao.dtypes import Int4CPULayout #subprocess.run("pip list", shell=True) IS_COMPILE = False if torch.cuda.is_available() else True device = "cuda" if torch.cuda.is_available() else "cpu" # https://huggingface.co/docs/transformers/en/quantization/torchao?examples-CPU=int8-dynamic-and-weight-only if torch.cuda.is_available(): quant_config = Float8DynamicActivationFloat8WeightConfig() else: #quant_config = Int4WeightOnlyConfig(group_size=128, layout=Int4CPULayout()) quant_config = Int8DynamicActivationInt8WeightConfig() quantization_config = TorchAoConfig(quant_type=quant_config) #checkpoint = "HuggingFaceTB/SmolLM2-135M-Instruct" checkpoint = "unsloth/gemma-3-4b-it" tokenizer = AutoTokenizer.from_pretrained(checkpoint) #model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32).to(device) model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map=device, quantization_config=quantization_config).eval() if IS_COMPILE: model.generation_config.cache_implementation = "static" input_text = "Warming up." input_ids = tokenizer(input_text, return_tensors="pt").to(device) output = model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") def get_duration(message, history, system_message, max_tokens, temperature, top_p, duration): return duration @spaces.GPU(duration=get_duration) @torch.inference_mode() def respond_stream(message, history, system_message, max_tokens, temperature, top_p, duration): messages = [{"role": "system", "content": system_message}] + history + [{"role": "user", "content": message}] inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True, ).to(model.device) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) gen_kwargs = dict( input_ids=inputs["input_ids"], #attention_mask=inputs["attention_mask"], streamer=streamer, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_p=top_p, eos_token_id=tokenizer.eos_token_id, num_beams=1, output_scores=False, ) if IS_COMPILE: gen_kwargs["cache_implementation"] = "static" thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() partial = "" for piece in streamer: partial += piece yield partial @spaces.GPU(duration=get_duration) @torch.inference_mode() def respond(message, history, system_message, max_tokens, temperature, top_p, duration): messages = [{"role": "system", "content": system_message}] + history + [{"role": "user", "content": message}] inputs = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True, ).to(model.device) gen_kwargs = dict( input_ids=inputs["input_ids"], #attention_mask=inputs["attention_mask"], max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_p=top_p, eos_token_id=tokenizer.eos_token_id, num_beams=1, output_scores=False, ) if IS_COMPILE: gen_kwargs["cache_implementation"] = "static" outputs = model.generate(**gen_kwargs) gen_ids = outputs[0][inputs["input_ids"].shape[-1]:] return tokenizer.decode(gen_ids, skip_special_tokens=True) """ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface """ chatbot = gr.ChatInterface( respond, type="messages", additional_inputs=[ gr.Textbox(value="You are a friendly Chatbot.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), gr.Slider(minimum=1, maximum=360, value=30, step=1, label="Duration"), ], ) with gr.Blocks() as demo: chatbot.render() if __name__ == "__main__": demo.queue().launch()