Spaces:
Enderchef
/
Runtime error

File size: 3,502 Bytes
8824f88
dd2b7a2
8824f88
cb975cd
749d210
8824f88
dd2b7a2
8824f88
 
 
 
 
cebc4e3
8824f88
 
89d85b1
8824f88
46f6bfb
d027892
e39d67a
7cb6017
8824f88
 
bb0bba9
dd2b7a2
 
 
 
 
8824f88
cb975cd
 
 
 
 
 
 
 
 
 
 
dd2b7a2
8824f88
 
 
749d210
8824f88
 
 
 
 
 
749d210
8824f88
dd2b7a2
 
 
0737a9d
 
34353a1
0737a9d
1e91ed1
dd2b7a2
8824f88
dd2b7a2
 
8824f88
dd2b7a2
 
 
 
 
 
8824f88
dd2b7a2
 
8824f88
cb975cd
 
dd2b7a2
 
cb975cd
dd2b7a2
 
749d210
dd2b7a2
 
89d85b1
 
e6d653e
89d85b1
8824f88
dd2b7a2
 
8824f88
 
 
 
 
dd2b7a2
 
 
8824f88
 
 
fc450cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python

import os
import re
from collections.abc import Iterator
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

DESCRIPTION = "# ICONN Lite Chat"

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p class='warning'>Running on CPU 🥶 This demo does not work on CPU.</p>"

top_k: int = 50
MAX_MAX_NEW_TOKENS = 100000000
DEFAULT_MAX_NEW_TOKENS = 10240
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

if torch.cuda.is_available():
    model_id = "ICONNAI/ICONN-1-Mini-Beta"
    model = AutoModelForCausalLM.from_pretrained(
        model_id, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id)


def wrap_thinking_blocks(text: str) -> str:
    def replacer(match):
        content = match.group(1).strip()
        return (
            "<details class='thinking-block'>"
            "<summary>💭 Thinking...</summary>"
            f"<div class='thinking-content'><pre>{content}</pre></div>"
            "</details>"
        )
    return re.sub(r"<think>\s*(.*?)\s*</think>", replacer, text, flags=re.DOTALL)


@spaces.GPU
def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = [*chat_history, {"role": "user", "content": message}]

    input_ids = tokenizer.apply_chat_template(
        conversation, return_tensors="pt", enable_thinking=True
    )
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        wrapped = wrap_thinking_blocks("".join(outputs + [text]))
        yield wrapped
        outputs.append(text)


demo = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6),
        gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
        gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50),
        gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
    ],
    stop_btn=None,
    examples=[
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
    ],
    type="messages",
    description=DESCRIPTION,
    css_paths="style.css",
)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()