geekyrakshit's picture
Update app.py
a120af4 verified
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
model_id = "PhysicsWallahAI/Aryabhata-1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
stop_strings = [
"<|im_end|>",
"<|end|>",
"<im_start|>",
"⁠```python\n",
"⁠<|im_start|>",
"]}}]}}]",
]
def strip_bad_tokens(s, stop_strings):
for suffix in stop_strings:
if s.endswith(suffix):
return s[: -len(suffix)]
return s
generation_config = GenerationConfig(max_new_tokens=4096, stop_strings=stop_strings)
@spaces.GPU
def greet(prompt: str):
messages = [
{
"role": "system",
"content": "Think step-by-step; put only the final answer inside \\boxed{}.",
},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt")
outputs = model.generate(
**inputs, generation_config=generation_config, tokenizer=tokenizer
)
return strip_bad_tokens(
tokenizer.decode(outputs[0], skip_special_tokens=True), stop_strings
)
demo = gr.Interface(fn=greet, inputs="text", outputs="text", title="Aryabhatta Demo")
demo.launch()