Spaces:
Runtime error
Runtime error
Commit
Β·
ff079ce
1
Parent(s):
3f5ced9
initial commit
Browse files
gen.py
CHANGED
|
@@ -2,8 +2,8 @@ import torch
|
|
| 2 |
import sys
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 4 |
import json
|
| 5 |
-
import jsonschema
|
| 6 |
-
from jsonschema import validate
|
| 7 |
|
| 8 |
tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
|
| 9 |
|
|
@@ -182,20 +182,31 @@ def generate(event):
|
|
| 182 |
|
| 183 |
|
| 184 |
output_text = tokenizer.decode(tokens[0], skip_special_tokens=False)
|
| 185 |
-
|
| 186 |
-
user_prompt_length = len(f"<bos><start_of_turn>user\n{prompt}\n{event}<end_of_turn>\n<start_of_turn>model\n") # Calculate user prompt length
|
| 187 |
|
| 188 |
json_start_index = output_text.find("<json>")
|
| 189 |
json_end_index = output_text.find("</json>")
|
| 190 |
|
| 191 |
if json_start_index != -1 and json_end_index != -1:
|
| 192 |
-
json_string = output_text[max(json_start_index + 6, user_prompt_length):json_end_index].strip()
|
| 193 |
|
| 194 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
try:
|
| 196 |
validate(instance=json.loads(json_string), schema=your_json_schema)
|
| 197 |
return json_string
|
| 198 |
-
except ValidationError as e:
|
| 199 |
return f"Error: Invalid JSON - {e}"
|
| 200 |
|
| 201 |
else:
|
|
|
|
| 2 |
import sys
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 4 |
import json
|
| 5 |
+
import jsonschema
|
| 6 |
+
from jsonschema import validate, ValidationError
|
| 7 |
|
| 8 |
tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it')
|
| 9 |
|
|
|
|
| 182 |
|
| 183 |
|
| 184 |
output_text = tokenizer.decode(tokens[0], skip_special_tokens=False)
|
| 185 |
+
user_prompt_length = len(f"<bos><start_of_turn>user\n{prompt}\n{event}<end_of_turn>\n<start_of_turn>model\n")
|
|
|
|
| 186 |
|
| 187 |
json_start_index = output_text.find("<json>")
|
| 188 |
json_end_index = output_text.find("</json>")
|
| 189 |
|
| 190 |
if json_start_index != -1 and json_end_index != -1:
|
| 191 |
+
json_string = output_text[max(json_start_index + 6, user_prompt_length):json_end_index].strip()
|
| 192 |
|
| 193 |
+
# Remove any leading/trailing non-JSON characters (if present)
|
| 194 |
+
if not json_string.startswith("{"):
|
| 195 |
+
first_brace_index = json_string.find("{")
|
| 196 |
+
if first_brace_index != -1:
|
| 197 |
+
json_string = json_string[first_brace_index:]
|
| 198 |
+
|
| 199 |
+
if not json_string.endswith("}"):
|
| 200 |
+
last_brace_index = json_string.rfind("}")
|
| 201 |
+
if last_brace_index != -1:
|
| 202 |
+
json_string = json_string[:last_brace_index + 1]
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# Validate JSON
|
| 206 |
try:
|
| 207 |
validate(instance=json.loads(json_string), schema=your_json_schema)
|
| 208 |
return json_string
|
| 209 |
+
except jsonschema.exceptions.ValidationError as e:
|
| 210 |
return f"Error: Invalid JSON - {e}"
|
| 211 |
|
| 212 |
else:
|