Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
1898f4f
1
Parent(s):
0276240
fixed qwen, disabled gemma3-1b and minicpm, re-enabled cogito
Browse files- utils/models.py +31 -26
utils/models.py
CHANGED
|
@@ -8,24 +8,24 @@ from .prompts import format_rag_prompt
|
|
| 8 |
from .shared import generation_interrupt
|
| 9 |
|
| 10 |
models = {
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
"Gemma-3-1b-it": "google/gemma-3-1b-it",
|
| 16 |
#"Gemma-3-4b-it": "google/gemma-3-4b-it",
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
#"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T",
|
| 22 |
-
#"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA",
|
| 23 |
"Qwen3-0.6b": "qwen/qwen3-0.6b",
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
|
| 30 |
}
|
| 31 |
|
|
@@ -145,23 +145,28 @@ def run_inference(model_name, context, question):
|
|
| 145 |
device_map='cuda',
|
| 146 |
trust_remote_code=True,
|
| 147 |
torch_dtype=torch.bfloat16,
|
|
|
|
|
|
|
|
|
|
| 148 |
)
|
| 149 |
|
| 150 |
text_input = format_rag_prompt(question, context, accepts_sys)
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
|
|
|
| 156 |
|
| 157 |
-
|
| 158 |
# Check interrupt before generation
|
| 159 |
-
if generation_interrupt.is_set():
|
| 160 |
-
return ""
|
| 161 |
|
| 162 |
-
|
| 163 |
#print(outputs[0]['generated_text'])
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
except Exception as e:
|
| 167 |
print(f"Error in inference for {model_name}: {e}")
|
|
|
|
| 8 |
from .shared import generation_interrupt
|
| 9 |
|
| 10 |
models = {
|
| 11 |
+
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
|
| 12 |
+
"Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct",
|
| 13 |
+
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
|
| 14 |
+
"Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
|
| 15 |
+
#"Gemma-3-1b-it": "google/gemma-3-1b-it",
|
| 16 |
#"Gemma-3-4b-it": "google/gemma-3-4b-it",
|
| 17 |
+
"Gemma-2-2b-it": "google/gemma-2-2b-it",
|
| 18 |
+
"Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct",
|
| 19 |
+
"Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b",
|
| 20 |
+
"IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct",
|
| 21 |
+
# #"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T",
|
| 22 |
+
# #"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA",
|
| 23 |
"Qwen3-0.6b": "qwen/qwen3-0.6b",
|
| 24 |
+
"Qwen3-1.7b": "qwen/qwen3-1.7b",
|
| 25 |
+
"Qwen3-4b": "qwen/qwen3-4b",
|
| 26 |
+
"SmolLM2-1.7b-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
| 27 |
+
"EXAONE-3.5-2.4B-instruct": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
|
| 28 |
+
"OLMo-2-1B-Instruct": "allenai/OLMo-2-0425-1B-Instruct",
|
| 29 |
|
| 30 |
}
|
| 31 |
|
|
|
|
| 145 |
device_map='cuda',
|
| 146 |
trust_remote_code=True,
|
| 147 |
torch_dtype=torch.bfloat16,
|
| 148 |
+
model_kwargs={
|
| 149 |
+
"attn_implementation": "eager",
|
| 150 |
+
}
|
| 151 |
)
|
| 152 |
|
| 153 |
text_input = format_rag_prompt(question, context, accepts_sys)
|
| 154 |
+
if "Gemma-3".lower() not in model_name.lower():
|
| 155 |
+
formatted = pipe.tokenizer.apply_chat_template(
|
| 156 |
+
text_input,
|
| 157 |
+
tokenize=False,
|
| 158 |
+
**tokenizer_kwargs,
|
| 159 |
+
)
|
| 160 |
|
| 161 |
+
input_length = len(formatted)
|
| 162 |
# Check interrupt before generation
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
outputs = pipe(formatted, max_new_tokens=512, generation_kwargs={"skip_special_tokens": True})
|
| 165 |
#print(outputs[0]['generated_text'])
|
| 166 |
+
result = outputs[0]['generated_text'][input_length:]
|
| 167 |
+
else: # don't use apply chat template? I don't know why gemma keeps breaking
|
| 168 |
+
result = pipe(text_input, max_new_tokens=512, generation_kwargs={"skip_special_tokens": True})[0]['generated_text']
|
| 169 |
+
result = result[0]['generated_text'][-1]['content']
|
| 170 |
|
| 171 |
except Exception as e:
|
| 172 |
print(f"Error in inference for {model_name}: {e}")
|