雷娃
add system prompt
a03b7d1
raw
history blame
2.85 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
import gradio as gr
import json
# load model and tokenizer
model_name = "inclusionAI/Ling-mini-2.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
).eval()
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p
):
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
#client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
if len(system_message) == 0:
system_message = "## 你是谁\n\n我是百灵(Ling),一个由蚂蚁集团(Ant Group) 开发的AI智能助手"
messages = [{"role": "system", "content": system_message}]
messages.extend(history)
messages.append({"role": "user", "content": message})
print(f"system_prompt: {json.dumps(messages, ensure_ascii=False, indent=2)}")
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
model_inputs = tokenizer([text], return_tensors="pt", return_token_type_ids=False).to(model.device)
model_inputs.update(dict(max_new_tokens=512,streamer=streamer))
# Start a separate thread for model generation to allow streaming output
thread = Thread(
target=model.generate,
kwargs=model_inputs,
)
thread.start()
# Accumulate and yield text tokens as they are generated
acc_text = ""
for text_token in streamer:
acc_text += text_token # Append the generated token to the accumulated text
yield acc_text # Yield the accumulated text
# Ensure the generation thread completes
thread.join()
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="", label="System message"),
gr.Slider(minimum=1, maximum=32000, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
with gr.Blocks() as demo:
with gr.Sidebar():
gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch()