Spaces:
Sleeping
Sleeping
| # coding=utf-8 | |
| # author: xusong <xusong28@jd.com> | |
| # time: 2022/8/23 17:08 | |
| import time | |
| import torch | |
| import gradio as gr | |
| from info import article | |
| from transformers import FillMaskPipeline | |
| from transformers import BertTokenizer | |
| from kplug.modeling_kplug import KplugForMaskedLM | |
| from pycorrector.bert.bert_corrector import BertCorrector | |
| from pycorrector import config | |
| from loguru import logger | |
| device_id = 0 if torch.cuda.is_available() else -1 | |
| css = """ | |
| .category-legend {display: none !important} | |
| """ | |
| class KplugCorrector(BertCorrector): | |
| def __init__(self, bert_model_dir=config.bert_model_dir, device=device_id): | |
| super(BertCorrector, self).__init__() | |
| self.name = 'kplug_corrector' | |
| t1 = time.time() | |
| tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-encoder") | |
| model = KplugForMaskedLM.from_pretrained("eson/kplug-base-encoder") | |
| self.model = FillMaskPipeline(model=model, tokenizer=tokenizer, device=device) | |
| if self.model: | |
| self.mask = self.model.tokenizer.mask_token | |
| logger.debug('Loaded bert model: %s, spend: %.3f s.' % (bert_model_dir, time.time() - t1)) | |
| corrector = KplugCorrector() | |
| error_sentences = [ | |
| '少先队员因该为老人让坐', | |
| '机七学习是人工智能领遇最能体现智能的一个分知', | |
| '今天心情很好', | |
| ] | |
| def mock_data(): | |
| corrected_sent = '机器学习是人工智能领域最能体现智能的一个分知' | |
| errs = [('七', '器', 1, 2), ('遇', '域', 10, 11)] | |
| return corrected_sent, errs | |
| def correct(sent): | |
| """ | |
| {"text": sent, "entities": [{}, {}] } 是 gradio 要求的格式,详见 https://www.gradio.app/docs/highlightedtext | |
| """ | |
| corrected_sent, errs = corrector.bert_correct(sent) | |
| # corrected_sent, errs = mock_data() | |
| print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs)) | |
| output = [{"entity": "纠错", "score": 0.5, "word": err[1], "start": err[2], "end": err[3]} for i, err in | |
| enumerate(errs)] | |
| return {"text": corrected_sent, "entities": output}, errs | |
| def test(): | |
| for sent in error_sentences: | |
| corrected_sent, err = corrector.bert_correct(sent) | |
| print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, err)) | |
| corr_iface = gr.Interface( | |
| fn=correct, | |
| inputs=gr.Textbox( | |
| label="输入文本", | |
| value="少先队员因该为老人让坐"), | |
| outputs=[ | |
| gr.HighlightedText( | |
| label="文本纠错", | |
| show_legend=True, | |
| ), | |
| gr.JSON( | |
| # label="JSON Output" | |
| ) | |
| ], | |
| examples=error_sentences, | |
| title="文本纠错(Corrector)", | |
| description='自动对汉语文本中的拼写、语法、标点等多种问题进行纠错校对,提示错误位置并返回修改建议', | |
| article=article, | |
| css=css | |
| ) | |
| if __name__ == "__main__": | |
| # test() | |
| # correct("少先队员因该为老人让坐") | |
| corr_iface.launch() | |