Zubiiiiiii294 commited on
Commit
22f1de6
Β·
verified Β·
1 Parent(s): 3cfc12b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -1,33 +1,42 @@
1
  import os
 
 
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
4
 
5
- model_id = os.getenv("MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.3")
6
 
7
- tokenizer = AutoTokenizer.from_pretrained(model_id)
 
8
  model = AutoModelForCausalLM.from_pretrained(
9
  model_id,
10
- device_map="auto",
11
- load_in_4bit=True,
12
- trust_remote_code=True
13
  )
14
 
15
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
16
-
17
- def generate_response(message, history):
18
- messages = history + [[message, ""]]
19
- prompt = ""
20
- for user, bot in messages:
21
- prompt += f"<|user|>\n{user}\n<|assistant|>\n{bot}\n"
22
- prompt += f"<|user|>\n{message}\n<|assistant|>\n"
23
-
24
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
25
- output = model.generate(**inputs, max_new_tokens=512, streamer=streamer)
26
- response = tokenizer.decode(output[0], skip_special_tokens=True).split("<|assistant|>")[-1].strip()
27
 
28
- return response
 
 
 
 
29
 
30
- demo = gr.ChatInterface(fn=generate_response, title="Vynix AI", theme="soft")
 
 
 
 
31
 
32
- if _name_ == "_main_":
33
- demo.launch()
 
1
  import os
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
  import gradio as gr
 
5
 
6
+ model_id = "mistralai/Mistral-7B-Instruct-v0.3"
7
 
8
+ # Load tokenizer and model with correct settings
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) # Important fix
10
  model = AutoModelForCausalLM.from_pretrained(
11
  model_id,
12
+ torch_dtype=torch.float16,
13
+ device_map="auto"
 
14
  )
15
 
16
+ # Create generation pipeline
17
+ pipe = pipeline(
18
+ "text-generation",
19
+ model=model,
20
+ tokenizer=tokenizer,
21
+ max_new_tokens=512,
22
+ do_sample=True,
23
+ top_k=50,
24
+ top_p=0.95,
25
+ temperature=0.7,
26
+ repetition_penalty=1.1
27
+ )
28
 
29
+ # Define Gradio UI
30
+ def chat_fn(message, history):
31
+ prompt = f"[INST] {message.strip()} [/INST]"
32
+ output = pipe(prompt)[0]['generated_text']
33
+ return output.replace(prompt, "").strip()
34
 
35
+ chatbot = gr.ChatInterface(
36
+ fn=chat_fn,
37
+ title="πŸ€– Vynix AI - Powered by Mistral",
38
+ description="Ask anything! Built using Mistral-7B-Instruct-v0.3.",
39
+ )
40
 
41
+ # Launch the app
42
+ chatbot.launch()