Spaces:
Runtime error
Runtime error
| import transformers | |
| import re | |
| from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline | |
| import torch | |
| import gradio as gr | |
| import json | |
| import os | |
| import shutil | |
| import requests | |
| import pandas as pd | |
| # Define the device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| editorial_model = "PleIAs/Estienne" | |
| token_classifier = pipeline( | |
| "token-classification", model=editorial_model, aggregation_strategy="simple", device=device | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512) | |
| css = """ | |
| <style> | |
| .manuscript { | |
| display: flex; | |
| margin-bottom: 10px; | |
| align-items: baseline; | |
| } | |
| .annotation { | |
| width: 15%; | |
| padding-right: 20px; | |
| color: grey !important; | |
| font-style: italic; | |
| text-align: right; | |
| } | |
| .content { | |
| width: 80%; | |
| } | |
| h2 { | |
| margin: 0; | |
| font-size: 1.5em; | |
| } | |
| .title-content h2 { | |
| font-weight: bold; | |
| } | |
| .bibliography-content { | |
| color:darkgreen !important; | |
| margin-top: -5px; /* Adjust if needed to align with annotation */ | |
| } | |
| .paratext-content { | |
| color:#a4a4a4 !important; | |
| margin-top: -5px; /* Adjust if needed to align with annotation */ | |
| } | |
| </style> | |
| """ | |
| # Preprocess the 'word' column | |
| def preprocess_text(text): | |
| # Remove HTML tags | |
| text = re.sub(r'<[^>]+>', '', text) | |
| # Replace newlines with spaces | |
| text = re.sub(r'\n', ' ', text) | |
| # Replace multiple spaces with a single space | |
| text = re.sub(r'\s+', ' ', text) | |
| # Strip leading and trailing whitespace | |
| return text.strip() | |
| def split_text(text, max_tokens=500): | |
| # Split the text by newline characters | |
| parts = text.split("\n") | |
| chunks = [] | |
| current_chunk = "" | |
| for part in parts: | |
| # Add part to current chunk | |
| if current_chunk: | |
| temp_chunk = current_chunk + "\n" + part | |
| else: | |
| temp_chunk = part | |
| # Tokenize the temporary chunk | |
| num_tokens = len(tokenizer.tokenize(temp_chunk)) | |
| if num_tokens <= max_tokens: | |
| current_chunk = temp_chunk | |
| else: | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| current_chunk = part | |
| if current_chunk: | |
| chunks.append(current_chunk) | |
| # If no newlines were found and still exceeding max_tokens, split further | |
| if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: | |
| long_text = chunks[0] | |
| chunks = [] | |
| while len(tokenizer.tokenize(long_text)) > max_tokens: | |
| split_point = len(long_text) // 2 | |
| while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): | |
| split_point += 1 | |
| # Ensure split_point does not go out of range | |
| if split_point >= len(long_text): | |
| split_point = len(long_text) - 1 | |
| chunks.append(long_text[:split_point].strip()) | |
| long_text = long_text[split_point:].strip() | |
| if long_text: | |
| chunks.append(long_text) | |
| return chunks | |
| def transform_chunks(marianne_segmentation): | |
| marianne_segmentation = pd.DataFrame(marianne_segmentation) | |
| marianne_segmentation = marianne_segmentation[marianne_segmentation['entity_group'] != 'separator'] | |
| marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).str.replace('¶', '\n', regex=False) | |
| marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).apply(preprocess_text) | |
| marianne_segmentation = marianne_segmentation[marianne_segmentation['word'].notna() & (marianne_segmentation['word'] != '') & (marianne_segmentation['word'] != ' ')] | |
| html_output = [] | |
| for _, row in marianne_segmentation.iterrows(): | |
| entity_group = row['entity_group'] | |
| result_entity = "[" + entity_group.capitalize() + "]" | |
| word = row['word'] | |
| if entity_group == 'title': | |
| html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content title-content"><h2>{word}</h2></div></div>') | |
| elif entity_group == 'bibliography': | |
| html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content bibliography-content">{word}</div></div>') | |
| elif entity_group == 'paratext': | |
| html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content paratext-content">{word}</div></div>') | |
| else: | |
| html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content">{word}</div></div>') | |
| final_html = '\n'.join(html_output) | |
| return final_html | |
| # Class to encapsulate the Falcon chatbot | |
| class MistralChatBot: | |
| def __init__(self, system_prompt="Le dialogue suivant est une conversation"): | |
| self.system_prompt = system_prompt | |
| def predict(self, user_message): | |
| editorial_text = re.sub("\n", " ¶ ", user_message) | |
| num_tokens = len(tokenizer.tokenize(editorial_text)) | |
| if num_tokens > 500: | |
| batch_prompts = split_text(editorial_text, max_tokens=500) | |
| else: | |
| batch_prompts = [editorial_text] | |
| out = token_classifier(batch_prompts) | |
| classified_list = [] | |
| for classification in out: | |
| df = pd.DataFrame(classification) | |
| classified_list.append(df) | |
| classified_list = pd.concat(classified_list) | |
| out = transform_chunks(classified_list) | |
| generated_text = f'{css}<h2 style="text-align:center">Edited text</h2>\n<div class="generation">{out}</div>' | |
| return generated_text | |
| # Create the Falcon chatbot instance | |
| mistral_bot = MistralChatBot() | |
| # Define the Gradio interface | |
| title = "Éditorialisation" | |
| description = "Un outil expérimental d'identification de la structure du texte à partir d'un encoder (Deberta)" | |
| examples = [ | |
| [ | |
| "Qui peut bénéficier de l'AIP?", # user_message | |
| 0.7 # temperature | |
| ] | |
| ] | |
| demo = gr.Blocks() | |
| with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo: | |
| gr.HTML("""<h1 style="text-align:center">Editorialize your text</h1>""") | |
| text_input = gr.Textbox(label="Your text", type="text", lines=1) | |
| text_button = gr.Button("Identify editorial structures") | |
| text_output = gr.HTML(label="Corrected text") | |
| text_button.click(mistral_bot.predict, inputs=text_input, outputs=[text_output]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |