Spaces:
Runtime error
Runtime error
Commit
·
4c9f7ae
1
Parent(s):
5672f53
Do not use safetensors
Browse files
app.py
CHANGED
|
@@ -38,11 +38,12 @@ print("Loaded vardecoder model successfully.")
|
|
| 38 |
logger.info("Loading fielddecoder model...")
|
| 39 |
|
| 40 |
fielddecoder_model = None
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
|
| 47 |
make_gradio_client = lambda: Client("https://ejschwartz-resym-field-helper.hf.space/")
|
| 48 |
|
|
@@ -104,7 +105,7 @@ def infer(code):
|
|
| 104 |
|
| 105 |
print(f"Prompt:\n{repr(var_prompt)}")
|
| 106 |
|
| 107 |
-
var_input_ids = tokenizer.encode(var_prompt, return_tensors="pt").
|
| 108 |
:, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
|
| 109 |
]
|
| 110 |
var_output = vardecoder_model.generate(
|
|
@@ -131,7 +132,7 @@ def infer(code):
|
|
| 131 |
if len(fields) == 0:
|
| 132 |
field_output = "Failed to parse fields" if field_prompt_result is None else "No fields"
|
| 133 |
else:
|
| 134 |
-
field_input_ids = tokenizer.encode(field_prompt_result, return_tensors="pt").
|
| 135 |
:, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
|
| 136 |
]
|
| 137 |
|
|
|
|
| 38 |
logger.info("Loading fielddecoder model...")
|
| 39 |
|
| 40 |
fielddecoder_model = None
|
| 41 |
+
fielddecoder_model = AutoModelForCausalLM.from_pretrained(
|
| 42 |
+
"ejschwartz/resym-fielddecoder",
|
| 43 |
+
torch_dtype=torch.bfloat16,
|
| 44 |
+
use_safetensors=False
|
| 45 |
+
)
|
| 46 |
+
logger.info("Successfully loaded fielddecoder model")
|
| 47 |
|
| 48 |
make_gradio_client = lambda: Client("https://ejschwartz-resym-field-helper.hf.space/")
|
| 49 |
|
|
|
|
| 105 |
|
| 106 |
print(f"Prompt:\n{repr(var_prompt)}")
|
| 107 |
|
| 108 |
+
var_input_ids = tokenizer.encode(var_prompt, return_tensors="pt").to(vardecoder_model.device)[
|
| 109 |
:, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
|
| 110 |
]
|
| 111 |
var_output = vardecoder_model.generate(
|
|
|
|
| 132 |
if len(fields) == 0:
|
| 133 |
field_output = "Failed to parse fields" if field_prompt_result is None else "No fields"
|
| 134 |
else:
|
| 135 |
+
field_input_ids = tokenizer.encode(field_prompt_result, return_tensors="pt").to(fielddecoder_model.device)[
|
| 136 |
:, : MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS
|
| 137 |
]
|
| 138 |
|