File size: 3,677 Bytes
67cfd82
 
82840fc
 
67cfd82
82840fc
67cfd82
 
 
 
 
 
 
 
 
 
 
 
 
 
82840fc
 
62b94a6
 
 
 
 
 
82840fc
 
 
 
67cfd82
 
 
 
 
82840fc
67cfd82
 
 
 
 
 
 
 
82840fc
67cfd82
 
82840fc
67cfd82
 
 
 
 
 
 
 
82840fc
67cfd82
 
 
 
82840fc
67cfd82
 
 
 
 
 
74c69fa
67cfd82
 
82840fc
 
 
 
 
 
 
 
 
 
 
8a6ec08
82840fc
 
 
 
67cfd82
82840fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67cfd82
82840fc
 
67cfd82
 
934e970
67cfd82
b484597
67cfd82
537d369
67cfd82
 
537d369
67cfd82
 
 
 
 
82840fc
74c69fa
67cfd82
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import logging
from huggingface_hub import InferenceClient
import gradio as gr
from requests.exceptions import ConnectionError

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize the Hugging Face Inference Client
try:
    client = InferenceClient(
        model="mistralai/Mistral-7B-Instruct-v0.3",
        token=os.getenv("HF_TOKEN"),  # Ensure HF_TOKEN is set in your environment
        timeout=30,
    )
except Exception as e:
    logger.error(f"Failed to initialize InferenceClient: {e}")
    raise

def format_prompt(message, history):
    prompt = "<s>"
    for user_prompt, bot_response in history:
        prompt += f"[INST] {user_prompt} [/INST]"
        prompt += f" {bot_response}</s> "
    prompt += f"[INST] {message} [/INST]"
    return prompt

def generate(
    prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    try:
        temperature = float(temperature)
        if temperature < 1e-2:
            temperature = 1e-2
        top_p = float(top_p)

        generate_kwargs = dict(
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            seed=42,
        )

        formatted_prompt = format_prompt(prompt, history)
        logger.info("Sending request to Hugging Face API")

        stream = client.text_generation(
            formatted_prompt,
            **generate_kwargs,
            stream=True,
            details=True,
            return_full_text=False,
        )
        output = ""

        for response in stream:
            output += response.token.text
            yield output
        return output

    except ConnectionError as e:
        logger.error(f"Network error: {e}")
        yield "Error: Unable to connect to the Hugging Face API. Please check your internet connection and try again."
    except Exception as e:
        logger.error(f"Error during text generation: {e}")
        yield f"Error: {str(e)}"

# Define additional inputs for Gradio interface
additional_inputs = [
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=512,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum number of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    ),
]

# Create a Chatbot object
chatbot = gr.Chatbot(height=450, layout="bubble")

# Build the Gradio interface
with gr.Blocks() as demo:
    gr.HTML("<h1><center>🤖 Mistral-7B-Chat 💬</center></h1>")
    gr.ChatInterface(
        fn=generate,
        chatbot=chatbot,
        additional_inputs=additional_inputs,
        examples=[
            ["Give me the code for Binary Search in C++"],
            ["Explain the chapter of The Grand Inquisitor from The Brothers Karamazov."],
            ["Explain Newton's second law."],
        ],
    )

if __name__ == "__main__":
    logger.info("Starting Gradio application")
    demo.launch()