Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import os | |
| model_name = "meta-llama/Meta-Llama-3-8B-Instruct" | |
| device_map = 'cuda' | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| def load_model() -> AutoModelForCausalLM: | |
| return AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map) | |
| def load_tokenizer() -> AutoTokenizer: | |
| return AutoTokenizer.from_pretrained(model_name) | |
| def preprocess_messages(message: str, history: list, system_prompt: str) -> dict: | |
| messages = [{'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': message}] | |
| prompt = load_tokenizer().apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| return prompt | |
| def generate_text(prompt: str, max_new_tokens: int, temperature: float) -> str: | |
| model = load_model() | |
| terminators = [load_tokenizer().eos_token_id, load_tokenizer().convert_tokens_to_ids(['\n'])] | |
| temp = temperature + 0.1 | |
| outputs = model.generate( | |
| prompt, | |
| max_new_tokens=max_new_tokens, | |
| eos_token_id=terminators[0], | |
| do_sample=True, | |
| temperature=temp, | |
| top_p=0.9 | |
| ) | |
| return load_tokenizer().decode(outputs[0], skip_special_tokens=True) | |
| def chat_function( | |
| message: str, | |
| history: list, | |
| system_prompt: str, | |
| max_new_tokens: int, | |
| temperature: float | |
| ) -> str: | |
| prompt = preprocess_messages(message, history, system_prompt) | |
| return generate_text(prompt, max_new_tokens, temperature) | |
| gr.ChatInterface( | |
| chat_function, | |
| chatbot=gr.Chatbot(height=400), | |
| textbox=gr.Textbox(placeholder="Enter message here", container=False, scale=7), | |
| title="LLAMA3 Chat", | |
| description="""Chat with llama3""", | |
| theme="soft", | |
| additional_inputs=[ | |
| gr.Textbox("You shall answer to all the questions as very smart AI", label="System Prompt"), | |
| gr.Slider(512, 4096, label="Max New Tokens"), | |
| gr.Slider(0, 1, label="Temperature") | |
| ] | |
| ).launch(debug=True) |