Spaces:
Runtime error
Runtime error
File size: 3,502 Bytes
8824f88 dd2b7a2 8824f88 cb975cd 749d210 8824f88 dd2b7a2 8824f88 cebc4e3 8824f88 89d85b1 8824f88 46f6bfb d027892 e39d67a 7cb6017 8824f88 bb0bba9 dd2b7a2 8824f88 cb975cd dd2b7a2 8824f88 749d210 8824f88 749d210 8824f88 dd2b7a2 0737a9d 34353a1 0737a9d 1e91ed1 dd2b7a2 8824f88 dd2b7a2 8824f88 dd2b7a2 8824f88 dd2b7a2 8824f88 cb975cd dd2b7a2 cb975cd dd2b7a2 749d210 dd2b7a2 89d85b1 e6d653e 89d85b1 8824f88 dd2b7a2 8824f88 dd2b7a2 8824f88 fc450cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
#!/usr/bin/env python
import os
import re
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
DESCRIPTION = "# ICONN Lite Chat"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p class='warning'>Running on CPU 🥶 This demo does not work on CPU.</p>"
top_k: int = 50
MAX_MAX_NEW_TOKENS = 100000000
DEFAULT_MAX_NEW_TOKENS = 10240
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
if torch.cuda.is_available():
model_id = "ICONNAI/ICONN-1-Mini-Beta"
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
def wrap_thinking_blocks(text: str) -> str:
def replacer(match):
content = match.group(1).strip()
return (
"<details class='thinking-block'>"
"<summary>💭 Thinking...</summary>"
f"<div class='thinking-content'><pre>{content}</pre></div>"
"</details>"
)
return re.sub(r"<think>\s*(.*?)\s*</think>", replacer, text, flags=re.DOTALL)
@spaces.GPU
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = [*chat_history, {"role": "user", "content": message}]
input_ids = tokenizer.apply_chat_template(
conversation, return_tensors="pt", enable_thinking=True
)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
wrapped = wrap_thinking_blocks("".join(outputs + [text]))
yield wrapped
outputs.append(text)
demo = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
],
stop_btn=None,
examples=[
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
type="messages",
description=DESCRIPTION,
css_paths="style.css",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|