Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
116a714
1
Parent(s):
83d0454
try using pipe but explicitly set model to be BitNetForCausalLM
Browse files- utils/models.py +29 -18
utils/models.py
CHANGED
|
@@ -163,6 +163,17 @@ def run_inference(model_name, context, question):
|
|
| 163 |
torch_dtype=torch.bfloat16,
|
| 164 |
trust_remote_code=True,
|
| 165 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
elif "icecream" not in model_name.lower():
|
| 167 |
pipe = pipeline(
|
| 168 |
"text-generation",
|
|
@@ -233,24 +244,24 @@ def run_inference(model_name, context, question):
|
|
| 233 |
|
| 234 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
| 235 |
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
|
| 236 |
-
elif "bitnet" in model_name.lower():
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
else: # For other models
|
| 255 |
formatted = pipe.tokenizer.apply_chat_template(
|
| 256 |
text_input,
|
|
|
|
| 163 |
torch_dtype=torch.bfloat16,
|
| 164 |
trust_remote_code=True,
|
| 165 |
)
|
| 166 |
+
pipe = pipeline(
|
| 167 |
+
"text-generation",
|
| 168 |
+
model=bitnet_model,
|
| 169 |
+
tokenizer=tokenizer,
|
| 170 |
+
device_map="cuda",
|
| 171 |
+
trust_remote_code=True,
|
| 172 |
+
torch_dtype=torch.bfloat16,
|
| 173 |
+
model_kwargs={
|
| 174 |
+
"attn_implementation": "eager",
|
| 175 |
+
},
|
| 176 |
+
)
|
| 177 |
elif "icecream" not in model_name.lower():
|
| 178 |
pipe = pipeline(
|
| 179 |
"text-generation",
|
|
|
|
| 244 |
|
| 245 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
| 246 |
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
|
| 247 |
+
# elif "bitnet" in model_name.lower():
|
| 248 |
+
# formatted = tokenizer.apply_chat_template(
|
| 249 |
+
# text_input,
|
| 250 |
+
# tokenize=True,
|
| 251 |
+
# return_tensors="pt",
|
| 252 |
+
# return_dict=True,
|
| 253 |
+
# **tokenizer_kwargs,
|
| 254 |
+
# ).to(bitnet_model.device)
|
| 255 |
+
# with torch.inference_mode():
|
| 256 |
+
# # Check interrupt before generation
|
| 257 |
+
# if generation_interrupt.is_set():
|
| 258 |
+
# return ""
|
| 259 |
+
# output_sequences = bitnet_model.generate(
|
| 260 |
+
# **formatted,
|
| 261 |
+
# max_new_tokens=512,
|
| 262 |
+
# )
|
| 263 |
+
|
| 264 |
+
# result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
|
| 265 |
else: # For other models
|
| 266 |
formatted = pipe.tokenizer.apply_chat_template(
|
| 267 |
text_input,
|