Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| """ | |
| @Author : Jiangjie Chen | |
| @Time : 2021/12/13 17:17 | |
| @Contact : jjchen19@fudan.edu.cn | |
| @Description: | |
| """ | |
| import os | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| from prettytable import PrettyTable | |
| import pandas as pd | |
| import torch | |
| import traceback | |
| config = { | |
| "model_type": "roberta", | |
| "model_name_or_path": "roberta-large", | |
| "logic_lambda": 0.5, | |
| "prior": "random", | |
| "mask_rate": 0.0, | |
| "cand_k": 1, | |
| "max_seq1_length": 256, | |
| "max_seq2_length": 128, | |
| "max_num_questions": 8, | |
| "do_lower_case": False, | |
| "seed": 42, | |
| "n_gpu": torch.cuda.device_count(), | |
| } | |
| os.system('git clone https://github.com/jiangjiechen/LOREN/') | |
| os.system('rm -r LOREN/data/') | |
| os.system('rm -r LOREN/results/') | |
| os.system('rm -r LOREN/models/') | |
| os.system('mv LOREN/* ./') | |
| model_dir = snapshot_download('Jiangjie/loren') | |
| config['fc_dir'] = os.path.join(model_dir, 'fact_checking/roberta-large/') | |
| config['mrc_dir'] = os.path.join(model_dir, 'mrc_seq2seq/bart-base/') | |
| config['er_dir'] = os.path.join(model_dir, 'evidence_retrieval/') | |
| from src.loren import Loren | |
| loren = Loren(config, verbose=False) | |
| try: | |
| js = loren.check('Donald Trump won the 2020 U.S. presidential election.') | |
| except Exception as e: | |
| raise ValueError(e) | |
| def highlight_phrase(text, phrase): | |
| text = loren.fc_client.tokenizer.clean_up_tokenization(text) | |
| return text.replace('<mask>', f'<i><b>{phrase}</b></i>') | |
| def highlight_entity(text, entity): | |
| return text.replace(entity, f'<i><b>{entity}</b></i>') | |
| def gradio_formatter(js, output_type): | |
| zebra_css = ''' | |
| tr:nth-child(even) { | |
| background: #f1f1f1; | |
| } | |
| thead{ | |
| background: #f1f1f1; | |
| }''' | |
| if output_type == 'e': | |
| data = {'Evidence': [highlight_entity(x, e) for x, e in zip(js['evidence'], js['entities'])]} | |
| elif output_type == 'z': | |
| p_sup, p_ref, p_nei = [], [], [] | |
| for x in js['phrase_veracity']: | |
| max_idx = torch.argmax(torch.tensor(x)).tolist() | |
| x = ['%.4f' % xx for xx in x] | |
| x[max_idx] = f'<i><b>{x[max_idx]}</b></i>' | |
| p_sup.append(x[2]) | |
| p_ref.append(x[0]) | |
| p_nei.append(x[1]) | |
| data = { | |
| 'Claim Phrase': js['claim_phrases'], | |
| 'Local Premise': [highlight_phrase(q, x[0]) for q, x in zip(js['cloze_qs'], js['evidential'])], | |
| 'p_SUP': p_sup, | |
| 'p_REF': p_ref, | |
| 'p_NEI': p_nei, | |
| } | |
| else: | |
| raise NotImplementedError | |
| data = pd.DataFrame(data) | |
| pt = PrettyTable(field_names=list(data.columns), | |
| align='l', border=True, hrules=1, vrules=1) | |
| for v in data.values: | |
| pt.add_row(v) | |
| html = pt.get_html_string(attributes={ | |
| 'style': 'border-width: 2px; bordercolor: black' | |
| }, format=True) | |
| html = f'<head> <style type="text/css"> {zebra_css} </style> </head>\n' + html | |
| html = html.replace('<', '<').replace('>', '>') | |
| return html | |
| def run(claim): | |
| try: | |
| js = loren.check(claim) | |
| except Exception as error_msg: | |
| exc = traceback.format_exc() | |
| msg = f'[Error]: {error_msg}.\n[Traceback]: {exc}' | |
| loren.logger.error(claim) | |
| loren.logger.error(msg) | |
| return 'Oops, something went wrong.', '', '' | |
| label = js['claim_veracity'] | |
| loren.logger.warning(label + str(js)) | |
| ev_html = gradio_formatter(js, 'e') | |
| z_html = gradio_formatter(js, 'z') | |
| return label, z_html, ev_html | |
| iface = gr.Interface( | |
| fn=run, | |
| inputs="text", | |
| outputs=[ | |
| 'text', | |
| 'html', | |
| 'html', | |
| ], | |
| examples=['Donald Trump won the U.S. 2020 presidential election.', | |
| 'The first inauguration of Bill Clinton was in the United States.', | |
| 'The Cry of the Owl is based on a book by an American.', | |
| 'Smriti Mandhana is an Indian woman.'], | |
| title="LOREN", | |
| layout='horizontal', | |
| description="LOREN is an interpretable Fact Verification model using Wikipedia as its knowledge source. " | |
| "This is a demo system for the AAAI 2022 paper: \"LOREN: Logic-Regularized Reasoning for Interpretable Fact Verification\"(https://jiangjiechen.github.io/publication/loren/). " | |
| "See the paper for more details. You can add a *FLAG* on the bottom to record interesting or bad cases! " | |
| "(Note that the demo system directly retrieves evidence from an up-to-date Wikipedia, which is different from the evidence used in the paper.)", | |
| flagging_dir='results/flagged/', | |
| allow_flagging=True, | |
| flagging_options=['Interesting!', 'Error: Claim Phrase Parsing', 'Error: Local Premise', | |
| 'Error: Require Commonsense', 'Error: Evidence Retrieval'], | |
| enable_queue=True | |
| ) | |
| iface.launch() | |