Spaces:
Sleeping
Sleeping
| ## | |
| from transformers import AutoTokenizer, pipeline | |
| from transformers import T5ForConditionalGeneration | |
| from transformers import PegasusForConditionalGeneration | |
| from transformers import BartForConditionalGeneration | |
| import streamlit as st | |
| # T5 | |
| def get_tidy_tab_t5(): | |
| if 'tidy_tab_t5' not in st.session_state: | |
| st.session_state.tidy_tab_t5 = load_model_t5() | |
| return st.session_state.tidy_tab_t5 | |
| def load_model_t5(): | |
| model_name="wgcv/tidy-tab-model-t5-small" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = T5ForConditionalGeneration.from_pretrained(model_name) | |
| return pipeline('summarization', model=model, tokenizer=tokenizer) | |
| def predict_model_t5(text): | |
| tidy_tab_t5 = get_tidy_tab_t5() | |
| if(tidy_tab_t5): | |
| text = "summarize: " + text | |
| result = tidy_tab_t5(text, max_length=8, min_length=1) | |
| if(len(result)>0): | |
| return result[0]['summary_text'] | |
| else: | |
| return None | |
| else: | |
| return None | |
| # pegasus-xsum | |
| def get_tidy_tab_pegasus(): | |
| if 'tidy_tab_pegasus' not in st.session_state: | |
| st.session_state.tidy_tab_pegasus = load_model_pegasus() | |
| return st.session_state.tidy_tab_pegasus | |
| def load_model_pegasus(): | |
| model_name="wgcv/tidy-tab-model-pegasus-xsum" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = PegasusForConditionalGeneration.from_pretrained(model_name) | |
| return pipeline('summarization', model=model, tokenizer=tokenizer) | |
| def predict_model_pegasus(text): | |
| tidy_tab_pegasus = get_tidy_tab_pegasus() | |
| if(tidy_tab_pegasus): | |
| text = text | |
| result = tidy_tab_pegasus(text, max_length=8, min_length=1) | |
| if(len(result)>0): | |
| return result[0]['summary_text'] | |
| else: | |
| return None | |
| else: | |
| return None | |
| # Bart-Large | |
| def get_tidy_tab_bart(): | |
| if 'tidy_tab_bart' not in st.session_state: | |
| st.session_state.tidy_tab_bart = load_model_bart() | |
| return st.session_state.tidy_tab_bart | |
| def load_model_bart(): | |
| model_name="wgcv/tidy-tab-model-bart-large-cnn" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = BartForConditionalGeneration.from_pretrained(model_name) | |
| return pipeline('summarization', model=model, tokenizer=tokenizer) | |
| def predict_model_bart(text): | |
| tidy_tab_bart = get_tidy_tab_bart() | |
| if(tidy_tab_bart): | |
| text = text | |
| result = tidy_tab_bart(text, num_beams=4, max_length=12, min_length=1, do_sample=True ) | |
| if(len(result)>0): | |
| return result[0]['summary_text'] | |
| else: | |
| return None | |
| else: | |
| return None |