hotmemeh commited on
Commit
144f336
·
verified ·
1 Parent(s): 741e6dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -16
app.py CHANGED
@@ -1,36 +1,43 @@
1
  import gradio as gr
2
- from transformers import pipeline
3
  import torch
 
4
 
5
- # Auto-select model based on device
6
- if torch.cuda.is_available():
7
  MODEL_NAME = "darkc0de/XortronCriminalComputingConfig"
8
- device = 0
9
- else:
10
- MODEL_NAME = "gpt2" # CPU fallback
11
- device = -1
12
 
13
- print(f"Loading model: {MODEL_NAME} on {'GPU' if device == 0 else 'CPU'}")
14
 
15
- # Load Hugging Face pipeline
16
- generator = pipeline("text-generation", model=MODEL_NAME, device=device)
17
 
18
- # Streaming response function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def respond(message, history):
20
  output = generator(
21
  message,
22
- max_new_tokens=256, # use this instead of max_length
23
- num_return_sequences=1,
24
  do_sample=True,
25
  temperature=0.7,
26
- truncation=True, # fixes truncation warning
27
  )[0]["generated_text"]
28
 
29
- # Stream output in chunks
30
  for i in range(0, len(output), 20):
31
  yield {"role": "assistant", "content": output[: i + 20]}
32
 
33
- # Build the Gradio chat
34
  chat = gr.ChatInterface(
35
  fn=respond,
36
  type="messages",
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
 
 
 
5
  MODEL_NAME = "darkc0de/XortronCriminalComputingConfig"
 
 
 
 
6
 
7
+ print(f"Loading model: {MODEL_NAME}")
8
 
9
+ # Load tokenizer & model
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
 
12
+ # device_map="auto" lets it use GPU if available, otherwise CPU (warning: very slow on CPU)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ MODEL_NAME,
15
+ device_map="auto",
16
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
17
+ low_cpu_mem_usage=True,
18
+ )
19
+
20
+ generator = pipeline(
21
+ "text-generation",
22
+ model=model,
23
+ tokenizer=tokenizer,
24
+ device=0 if torch.cuda.is_available() else -1,
25
+ )
26
+
27
+ # Streaming response
28
  def respond(message, history):
29
  output = generator(
30
  message,
31
+ max_new_tokens=256,
 
32
  do_sample=True,
33
  temperature=0.7,
34
+ truncation=True,
35
  )[0]["generated_text"]
36
 
 
37
  for i in range(0, len(output), 20):
38
  yield {"role": "assistant", "content": output[: i + 20]}
39
 
40
+ # Build Gradio chat
41
  chat = gr.ChatInterface(
42
  fn=respond,
43
  type="messages",