Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import time | |
| import json | |
| from itertools import cycle | |
| import torch | |
| import gradio as gr | |
| from urllib.parse import unquote | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList | |
| from data import extract_leaves, split_document, handle_broken_output, clean_json_text, sync_empty_fields | |
| from examples import examples as input_examples | |
| from nuextract_logging import log_event | |
| MAX_INPUT_SIZE = 10_000 | |
| MAX_NEW_TOKENS = 4_000 | |
| MAX_WINDOW_SIZE = 4_000 | |
| markdown_description = """ | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
| <title>NuExtract</title> | |
| </head> | |
| <body> | |
| <img src="https://cdn.prod.website-files.com/638364a4e52e440048a9529c/64188f405afcf42d0b85b926_logo_numind_final.png" alt="NuMind Logo" style="vertical-align: middle;width: 200px; height: 50px;"> | |
| <p>We are a startup developing custom information extraction models. NuExtract is a zero-shot model.</p> | |
| <p>If you want the best performance on your problem, please contact us :).</p> | |
| <br> | |
| <ul> | |
| <li><strong>Webpage</strong>: <a href="https://www.numind.ai/">https://www.numind.ai/</a></li> | |
| </ul> | |
| <br> | |
| <h1>NuExtract-v1.5</h1> | |
| <p>NuExtract-v1.5 is a fine-tuning of Phi-3.5-mini-instruct, trained on a private high-quality dataset for structured information extraction. | |
| It supports long documents and several languages (English, French, Spanish, German, Portuguese, and Italian). | |
| To use the model, provide an input text and a JSON template describing the information you need to extract.</p> | |
| <ul> | |
| <li><strong>Model</strong>: <a href="https://huggingface.co/numind/NuExtract-v1.5">numind/NuExtract-v1.5</a></li> | |
| </ul> | |
| <i>NOTE: in this space we restrict the model inputs to a maximum length of 10k tokens, with anything over 4k being processed in a sliding window. For full model performance, self-host the model or contact us.</i> | |
| </body> | |
| </html> | |
| """ | |
| def highlight_words(input_text, json_output): | |
| colors = cycle(["#90ee90", "#add8e6", "#ffb6c1", "#ffff99", "#ffa07a", "#20b2aa", "#87cefa", "#b0e0e6", "#dda0dd", "#ffdead"]) | |
| color_map = {} | |
| highlighted_text = input_text | |
| leaves = extract_leaves(json_output) | |
| for path, value in leaves: | |
| path_key = tuple(path) | |
| if path_key not in color_map: | |
| color_map[path_key] = next(colors) | |
| color = color_map[path_key] | |
| # highlighted_text = highlighted_text.replace(f" {value}", f" <span style='background-color: {color};'>{unquote(f'{value}')}</span>") | |
| pattern = rf"( |\n|\t){value}( |\n|\t)" | |
| replacement = f" <span style='background-color: {color};'>{unquote(value)}</span> " | |
| highlighted_text = re.sub(pattern, replacement, highlighted_text, flags=re.IGNORECASE) | |
| return highlighted_text | |
| def predict_chunk(text, template, current, model, tokenizer): | |
| current = clean_json_text(current) | |
| input_llm = f"<|input|>\n### Template:\n{template}\n### Current:\n{current}\n### Text:\n{text}\n\n<|output|>" + "{" | |
| input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda") | |
| output = tokenizer.decode(model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS)[0], skip_special_tokens=True) | |
| return clean_json_text(output.split("<|output|>")[1]) | |
| def sliding_window_prediction(template, text, model, tokenizer, window_size=4000, overlap=128): | |
| # Split text into chunks of n tokens | |
| tokens = tokenizer.tokenize(text) | |
| chunks = split_document(text, window_size, overlap, tokenizer) | |
| # Iterate over text chunks | |
| prev = template | |
| full_pred = "" | |
| for i, chunk in enumerate(chunks): | |
| print(f"Processing chunk {i}...") | |
| pred = predict_chunk(chunk, template, prev, model, tokenizer) | |
| # Handle broken output | |
| pred = handle_broken_output(pred, prev) | |
| # create highlighted text | |
| highlighted_pred = highlight_words(text, json.loads(pred)) | |
| # Sync empty fields | |
| synced_pred = sync_empty_fields(json.loads(pred), json.loads(template)) | |
| synced_pred = json.dumps(synced_pred, indent=4) | |
| # Return progress, current prediction, and updated HTML | |
| yield f"Processed chunk {i+1}/{len(chunks)}", synced_pred, highlighted_pred | |
| # Iterate | |
| prev = pred | |
| ###### | |
| # Load the model and tokenizer | |
| model_name = "numind/NuExtract-v1.5" | |
| auth_token = os.environ.get("HF_TOKEN") or True | |
| model = AutoModelForCausalLM.from_pretrained(model_name, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", use_auth_token=auth_token) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token) | |
| model.eval() | |
| def gradio_interface_function(template, text): | |
| # reject invalid JSON | |
| try: | |
| template_json = json.loads(template) | |
| except: | |
| yield "", "Invalid JSON template", "" | |
| return # End the function since there was an error | |
| if len(tokenizer.tokenize(text)) > MAX_INPUT_SIZE: | |
| yield "", "Input text too long for space. Download model to use unrestricted.", "" | |
| return # End the function since there was an error | |
| # Initialize the sliding window prediction process | |
| prediction_generator = sliding_window_prediction(template, text, model, tokenizer, window_size=MAX_WINDOW_SIZE) | |
| # Iterate over the generator to return values at each step | |
| for progress, full_pred, html_content in prediction_generator: | |
| # yield gr.update(value=chunk_info), gr.update(value=progress), gr.update(value=full_pred), gr.update(value=html_content) | |
| yield progress, full_pred, html_content | |
| log_event(text, template, full_pred) | |
| # Set up the Gradio interface | |
| iface = gr.Interface( | |
| description=markdown_description, | |
| fn=gradio_interface_function, | |
| inputs=[ | |
| gr.Textbox(lines=2, placeholder="Enter Template here...", label="Template"), | |
| gr.Textbox(lines=2, placeholder="Enter input Text here...", label="Input Text"), | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Progress"), | |
| gr.Textbox(label="Model Output"), | |
| gr.HTML(label="Model Output with Highlighted Words"), | |
| ], | |
| examples=input_examples, | |
| # live=True # Enable real-time updates | |
| ) | |
| iface.launch(debug=True, share=True) |