Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
593a8e7
1
Parent(s):
0226e6c
separate bitnet from generic pipeline
Browse files- utils/models.py +26 -2
utils/models.py
CHANGED
|
@@ -11,7 +11,7 @@ from transformers import (
|
|
| 11 |
AutoTokenizer,
|
| 12 |
AutoModelForCausalLM,
|
| 13 |
StoppingCriteria,
|
| 14 |
-
|
| 15 |
)
|
| 16 |
from .prompts import format_rag_prompt
|
| 17 |
from .shared import generation_interrupt
|
|
@@ -156,7 +156,14 @@ def run_inference(model_name, context, question):
|
|
| 156 |
|
| 157 |
print("REACHED HERE BEFORE pipe")
|
| 158 |
print(f"Loading model {model_name}...")
|
| 159 |
-
if "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
pipe = pipeline(
|
| 161 |
"text-generation",
|
| 162 |
model=model_name,
|
|
@@ -226,7 +233,24 @@ def run_inference(model_name, context, question):
|
|
| 226 |
|
| 227 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
| 228 |
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
|
|
|
| 230 |
else: # For other models
|
| 231 |
formatted = pipe.tokenizer.apply_chat_template(
|
| 232 |
text_input,
|
|
|
|
| 11 |
AutoTokenizer,
|
| 12 |
AutoModelForCausalLM,
|
| 13 |
StoppingCriteria,
|
| 14 |
+
BitNetForCausalLM
|
| 15 |
)
|
| 16 |
from .prompts import format_rag_prompt
|
| 17 |
from .shared import generation_interrupt
|
|
|
|
| 156 |
|
| 157 |
print("REACHED HERE BEFORE pipe")
|
| 158 |
print(f"Loading model {model_name}...")
|
| 159 |
+
if "bitnet" in model_name.lower():
|
| 160 |
+
bitnet_model = BitNetForCausalLM.from_pretrained(
|
| 161 |
+
model_name,
|
| 162 |
+
device_map="cuda",
|
| 163 |
+
torch_dtype=torch.bfloat16,
|
| 164 |
+
trust_remote_code=True,
|
| 165 |
+
)
|
| 166 |
+
elif "icecream" not in model_name.lower():
|
| 167 |
pipe = pipeline(
|
| 168 |
"text-generation",
|
| 169 |
model=model_name,
|
|
|
|
| 233 |
|
| 234 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
| 235 |
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
|
| 236 |
+
elif "bitnet" in model_name.lower():
|
| 237 |
+
formatted = tokenizer.apply_chat_template(
|
| 238 |
+
text_input,
|
| 239 |
+
tokenize=True,
|
| 240 |
+
return_tensors="pt",
|
| 241 |
+
return_dict=True,
|
| 242 |
+
**tokenizer_kwargs,
|
| 243 |
+
).to(device)
|
| 244 |
+
with torch.inference_mode():
|
| 245 |
+
# Check interrupt before generation
|
| 246 |
+
if generation_interrupt.is_set():
|
| 247 |
+
return ""
|
| 248 |
+
output_sequences = bitnet_model.generate(
|
| 249 |
+
**formatted,
|
| 250 |
+
max_new_tokens=512,
|
| 251 |
+
)
|
| 252 |
|
| 253 |
+
result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
|
| 254 |
else: # For other models
|
| 255 |
formatted = pipe.tokenizer.apply_chat_template(
|
| 256 |
text_input,
|