Spaces:
Runtime error
Runtime error
| import datetime | |
| import gradio as gr | |
| import torch | |
| from cache_system import CacheHandler | |
| from header import article, header | |
| from newspaper import Article | |
| from prompts import summarize_clickbait_short_prompt | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| GenerationConfig, | |
| LogitsProcessorList, | |
| TextStreamer, | |
| ) | |
| from utils import StopAfterTokenIsGenerated | |
| total_runs = 0 | |
| # Cargar el tokenizador | |
| tokenizer = AutoTokenizer.from_pretrained("somosnlp/NoticIA-7B") | |
| # Cargamos el modelo en 4 bits para usar menos VRAM | |
| # Usamos bitsandbytes por que es lo más sencillo de implementar para la demo aunque no es ni lo más rápido ni lo más eficiente | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "somosnlp/NoticIA-7B", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| #quantization_config=quantization_config, | |
| ) | |
| print(f"Model loaded in {model.device}") | |
| # Parámetros de generación. | |
| generation_config = GenerationConfig( | |
| max_new_tokens=128, # Los resúmenes son cortos, no necesitamos más tokens | |
| min_new_tokens=1, # No queremos resúmenes vacíos | |
| do_sample=True, # Un poquito mejor que greedy sampling | |
| num_beams=1, | |
| use_cache=True, # Eficiencia | |
| top_k=40, | |
| top_p=0.1, | |
| repetition_penalty=1.1, # Ayuda a evitar que el modelo entre en bucles | |
| encoder_repetition_penalty=1.1, # Favorecemos que el modelo cite el texto original | |
| temperature=0.15, # temperature baja para evitar que el modelo genere texto muy creativo. | |
| ) | |
| # Stop words, para evitar que el modelo genere tokens que no queremos. | |
| stop_words = [ | |
| "<s>", | |
| "</s>", | |
| "\\n", | |
| "[/INST]", | |
| "[INST]", | |
| "### User:", | |
| "### Assistant:", | |
| "###", | |
| "<start_of_turn>", | |
| "<end_of_turn>", | |
| "<end_of_turn>\\n", | |
| "<eos>", | |
| ] | |
| # Creamos un logits processor para detener la generación cuando el modelo genere un stop word | |
| stop_criteria = LogitsProcessorList( | |
| [ | |
| StopAfterTokenIsGenerated( | |
| stops=[ | |
| torch.tensor(tokenizer.encode(stop_word, add_special_tokens=False)) | |
| for stop_word in stop_words.copy() | |
| ], | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| ] | |
| ) | |
| def generate_text(url: str) -> (str, str): | |
| """ | |
| Dada una URL de una noticia, genera un resumen de una sola frase que revela la verdad detrás del titular. | |
| Args: | |
| url (str): URL de la noticia. | |
| Returns: | |
| str: Titular de la noticia. | |
| str: Resumen de la noticia. | |
| """ | |
| global cache_handler | |
| global total_runs | |
| total_runs += 1 | |
| print(f"Total runs: {total_runs}. Last run: {datetime.datetime.now()}") | |
| url = url.strip() | |
| if url.startswith("https://twitter.com/") or url.startswith("https://x.com/"): | |
| yield ( | |
| "🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.", | |
| "❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌", | |
| "Error", | |
| ) | |
| return ( | |
| "🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.", | |
| "❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌", | |
| "Error", | |
| ) | |
| # 1) Download the article | |
| # progress(0, desc="🤖 Accediendo a la noticia") | |
| # First, check if the URL is in the cache | |
| headline, text, resumen = cache_handler.get_from_cache(url, 0) | |
| if headline is not None and text is not None and resumen is not None: | |
| yield headline, resumen | |
| return headline, resumen | |
| else: | |
| try: | |
| article = Article(url) | |
| article.download() | |
| article.parse() | |
| headline = article.title | |
| text = article.text | |
| except Exception as e: | |
| print(e) | |
| headline = None | |
| text = None | |
| if headline is None or text is None: | |
| yield ( | |
| "🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.", | |
| "❌❌❌ Inténtalo de nuevo ❌❌❌", | |
| "Error", | |
| ) | |
| return ( | |
| "🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.", | |
| "❌❌❌ Inténtalo de nuevo ❌❌❌", | |
| "Error", | |
| ) | |
| # progress(0.5, desc="🤖 Leyendo noticia") | |
| try: | |
| prompt = summarize_clickbait_short_prompt(headline=headline, body=text) | |
| formatted_prompt = tokenizer.apply_chat_template( | |
| [{"role": "user", "content": prompt}], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| model_inputs = tokenizer( | |
| [formatted_prompt], return_tensors="pt", add_special_tokens=False | |
| ) | |
| streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True) | |
| model_output = model.generate( | |
| **model_inputs.to(model.device), | |
| streamer=streamer, | |
| generation_config=generation_config, | |
| logits_processor=stop_criteria, | |
| ) | |
| yield headline, streamer | |
| resumen = tokenizer.batch_decode( | |
| model_output, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| )[0].replace("<|end_of_turn|>", "") | |
| resumen = resumen.split("GPT4 Correct Assistant:")[-1] | |
| except Exception as e: | |
| print(e) | |
| yield ( | |
| "🤖 Error en la generación.", | |
| "❌❌❌ Inténtalo de nuevo más tarde ❌❌❌", | |
| "Error", | |
| ) | |
| return ( | |
| "🤖 Error en la generación.", | |
| "❌❌❌ Inténtalo de nuevo más tarde ❌❌❌", | |
| "Error", | |
| ) | |
| cache_handler.add_to_cache( | |
| url=url, title=headline, text=text, summary_type=0, summary=resumen | |
| ) | |
| yield headline, resumen | |
| hits, misses, cache_len = cache_handler.get_cache_stats() | |
| print( | |
| f"Hits: {hits}, misses: {misses}, cache length: {cache_len}. Percent hits: {round(hits/(hits+misses)*100,2)}%." | |
| ) | |
| return headline, resumen | |
| # Usamos una cache para guardar las últimas URL procesadas | |
| # Los usuarios seguramente introducirán en un mismo día la misma URL varias veces, por que | |
| # diferentes personas querrán ver el resumen de la misma noticia. | |
| # La cache se encarga de guardar los resúmenes de las noticias para que no tengamos que volver a generarlos. | |
| # La cache tiene un tamaño máximo de 1000 elementos, cuando se llena, se elimina el elemento más antiguo. | |
| cache_handler = CacheHandler(max_cache_size=1000) | |
| demo = gr.Interface( | |
| generate_text, | |
| inputs=[ | |
| gr.Textbox( | |
| label="🌐 URL de la noticia", | |
| info="Introduce la URL de la noticia que deseas resumir.", | |
| value="https://somosnlp.org/", | |
| interactive=True, | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Textbox( | |
| label="📰 Titular de la noticia", | |
| interactive=False, | |
| placeholder="Aquí aparecerá el título de la noticia", | |
| ), | |
| gr.Textbox( | |
| label="🗒️ Resumen", | |
| interactive=False, | |
| placeholder="Aquí aparecerá el resumen de la noticia.", | |
| ), | |
| ], | |
| # headline="⚔️ Clickbait Fighter! ⚔️", | |
| thumbnail="https://huggingface.co/datasets/Iker/NoticIA/resolve/main/assets/logo.png", | |
| theme="JohnSmith9982/small_and_pretty", | |
| description=header, | |
| article=article, | |
| cache_examples=False, | |
| concurrency_limit=1, | |
| examples=[ | |
| "https://www.huffingtonpost.es/virales/le-compra-abrigo-abuela-97nos-reaccion-fantasia.html", | |
| "https://emisorasunidas.com/2023/12/29/que-pasara-el-15-de-enero-de-2024/", | |
| "https://www.huffingtonpost.es/virales/llega-espana-le-llama-atencion-nombres-propios-persona.html", | |
| "https://www.infobae.com/que-puedo-ver/2023/11/19/la-comedia-familiar-y-navidena-que-ya-esta-en-netflix-y-puedes-ver-en-estas-fiestas/", | |
| "https://www.cope.es/n/1610984", | |
| ], | |
| submit_btn="Generar resumen", | |
| stop_btn="Detener generación", | |
| clear_btn="Limpiar", | |
| allow_flagging=False, | |
| ) | |
| demo.queue(max_size=None) | |
| demo.launch(share=False) | |