Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import ( | |
| pipeline, | |
| AutoModelForTokenClassification, | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| AutoModelForMaskedLM | |
| ) | |
| import pandas as pd | |
| import spacy | |
| import csv | |
| from io import StringIO | |
| # -------------------- PAGE CONFIG -------------------- | |
| st.set_page_config( | |
| page_title="PuoBERTa Multi-Task Demo", | |
| page_icon="🔤", | |
| layout="wide" | |
| ) | |
| # -------------------- HEADER -------------------- | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col2: | |
| try: | |
| st.image("logo_transparent_small.png", width=300) | |
| except: | |
| st.write("🔤 PuoBERTa") | |
| st.title("PuoBERTa Multi-Task Demo") | |
| st.markdown(""" | |
| A comprehensive demo for Setswana language models including: | |
| - **Mask Filling**: Fill in missing words in sentences | |
| - **POS Tagging**: Identify parts of speech | |
| - **Named Entity Recognition**: Extract entities like people, places, organizations | |
| - **News Classification**: Classify news articles by category | |
| """) | |
| st.markdown("---") | |
| # -------------------- SIDEBAR -------------------- | |
| st.sidebar.header("Model Information") | |
| st.sidebar.markdown(""" | |
| **Authors**: Vukosi Marivate, Moseli Mots'Oehli, Valencia Wagner, Richard Lastrucci, Isheanesu Dzingirai | |
| **Paper**: [PuoBERTa: Training and evaluation of a curated language model for Setswana](https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17) | [Preprint/Arxiv](https://arxiv.org/abs/2310.09141) | |
| **Huggingface Space Creators**: Vukosi Marivate, Zion Van Wyk, Unarine Netshifhefhe, Thapelo Sindane | |
| **Models Used**: | |
| - [dsfsi/PuoBERTa (Mask Filling - Pretrained Model)](https://huggingface.co/dsfsi/PuoBERTa) | |
| - [dsfsi/PuoBERTa-POS (POS Tagging)](https://huggingface.co/dsfsi/PuoBERTa-POS) | |
| - [dsfsi/PuoBERTa-NER (Named Entity Recognition)](https://huggingface.co/dsfsi/PuoBERTa-NER) | |
| - [dsfsi/PuoBERTa-News (News Classification)](https://huggingface.co/dsfsi/PuoBERTa-News) | |
| """) | |
| # -------------------- CACHING FUNCTIONS -------------------- | |
| def load_mask_filling_model(): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa") | |
| model = AutoModelForMaskedLM.from_pretrained("dsfsi/PuoBERTa") | |
| # Create pipeline and verify mask token | |
| pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, top_k=5) | |
| # Debug: print mask token for verification | |
| print(f"Mask token being used: {tokenizer.mask_token}") | |
| return pipe | |
| except Exception as e: | |
| st.error(f"Failed to load mask filling model: {str(e)}") | |
| return None | |
| def load_pos_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-POS") | |
| model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-POS") | |
| return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple") | |
| def load_ner_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-NER") | |
| model = AutoModelForTokenClassification.from_pretrained("dsfsi/PuoBERTa-NER") | |
| return pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple") | |
| def load_news_classification_model(): | |
| tokenizer = AutoTokenizer.from_pretrained("dsfsi/PuoBERTa-News") | |
| model = AutoModelForSequenceClassification.from_pretrained("dsfsi/PuoBERTa-News") | |
| return pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True) | |
| # -------------------- UTILITY FUNCTIONS -------------------- | |
| def get_correct_mask_token(text, tokenizer): | |
| """Get the correct mask token format for the given tokenizer""" | |
| mask_token = tokenizer.mask_token | |
| # Replace common mask token formats with the correct one | |
| text = text.replace("[MASK]", mask_token) | |
| text = text.replace("<mask>", mask_token) | |
| text = text.replace("<mask>", mask_token) | |
| return text | |
| # Then in your mask filling section, use: | |
| # corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer) | |
| # results = mask_filler(corrected_input) | |
| def merge_entities(output): | |
| """Merge consecutive entities of the same type""" | |
| merged = [] | |
| for i, ent in enumerate(output): | |
| if i > 0 and ent["start"] == output[i-1]["end"] and ent["entity_group"] == output[i-1]["entity_group"]: | |
| merged[-1]["word"] += ent["word"] | |
| merged[-1]["end"] = ent["end"] | |
| else: | |
| merged.append(ent) | |
| return merged | |
| def create_spacy_display(text, entities, task_type="ner"): | |
| """Create spaCy-style display for entities""" | |
| spacy_display = {"text": text, "ents": [], "title": None} | |
| for ent in entities: | |
| label = ent["entity_group"] | |
| if task_type == "ner" and label == "PER": | |
| label = "PERSON" | |
| spacy_display["ents"].append({ | |
| "start": ent["start"], | |
| "end": ent["end"], | |
| "label": label | |
| }) | |
| # Define colors for different entity types | |
| colors = { | |
| # POS colors | |
| "PRON": "#FF9999", | |
| "VERB": "#99FF99", | |
| "DET": "#9999FF", | |
| "PROPN": "#FFFF99", | |
| "CCONJ": "#FFCC99", | |
| "PUNCT": "#CCCCCC", | |
| "NUM": "#FFCCFF", | |
| "NOUN": "#FFB366", | |
| "ADJ": "#B366FF", | |
| "ADP": "#66FFB3", | |
| # NER colors | |
| "PERSON": "#85DCDF", | |
| "PER": "#85DCDF", | |
| "LOC": "#DF85DC", | |
| "ORG": "#DCDF85", | |
| "MISC": "#85ABDF" | |
| } | |
| try: | |
| html = spacy.displacy.render(spacy_display, style="ent", manual=True, minify=True, | |
| options={"colors": colors}) | |
| styled_html = f""" | |
| <style>mark.entity {{ display: inline-block; }}</style> | |
| <div style='overflow-x:auto; border: 1px solid #e6e9ef; border-radius: 0.25rem; padding: 1rem;'> | |
| {html} | |
| </div> | |
| """ | |
| return styled_html | |
| except: | |
| return "<p>Error rendering visualization</p>" | |
| def get_input_text(tab_name, examples): | |
| """Get input text based on selected method""" | |
| input_method = st.radio( | |
| "Select Input Method", | |
| ['Example Text', 'Write Text', 'Upload File'], | |
| key=f"{tab_name}_input_method" | |
| ) | |
| if input_method == 'Example Text': | |
| return st.selectbox("Example Sentences", examples, key=f"{tab_name}_examples") | |
| elif input_method == 'Write Text': | |
| return st.text_area("Enter text", height=100, key=f"{tab_name}_text_input") | |
| elif input_method == 'Upload File': | |
| uploaded = st.file_uploader("Upload text or CSV file", type=["txt", "csv"], key=f"{tab_name}_file") | |
| if uploaded: | |
| if uploaded.name.endswith('.csv'): | |
| df = pd.read_csv(uploaded) | |
| st.write("CSV Preview:", df.head()) | |
| col = st.selectbox("Choose column with text", df.columns, key=f"{tab_name}_csv_col") | |
| return "\n".join(df[col].dropna().astype(str).tolist()) | |
| else: | |
| return str(uploaded.read(), "utf-8") | |
| return "" | |
| # -------------------- TABS -------------------- | |
| tab1, tab2, tab3, tab4 = st.tabs(["🎭 Mask Filling", "🏷️ POS Tagging", "🔍 Named Entity Recognition", "📰 News Classification"]) | |
| # -------------------- MASK FILLING TAB -------------------- | |
| with tab1: | |
| st.header("Mask Filling") | |
| st.write("Fill in the blanks in Setswana sentences using `<mask>` token.") | |
| mask_examples = [ | |
| "Ke rata go <mask> dijo tsa Batswana.", | |
| "Botswana ke naga e e <mask> mo Afrika Borwa.", | |
| "Bana ba <mask> sekolo ka Mosupologo.", | |
| "Re tshwanetse go <mask> tikologo ya rona." | |
| ] | |
| mask_input = get_input_text("mask", mask_examples) | |
| if st.button("Fill Masks", key="mask_button") and mask_input.strip(): | |
| # Check for both mask formats and convert if needed | |
| if "[MASK]" in mask_input: | |
| mask_input = mask_input.replace("[MASK]", "<mask>") | |
| st.info("Converted [MASK] to <mask> format") | |
| elif "<mask>" not in mask_input: | |
| st.warning("Please include <mask> token in your text.") | |
| else: | |
| with st.spinner("Filling masks..."): | |
| try: | |
| mask_filler = load_mask_filling_model() | |
| corrected_input = get_correct_mask_token(mask_input, mask_filler.tokenizer) | |
| results = mask_filler(corrected_input) | |
| # results = mask_filler(mask_input) | |
| st.subheader("Predictions") | |
| for i, result in enumerate(results, 1): | |
| confidence = result['score'] * 100 | |
| st.write(f"**{i}.** {result['sequence']} (confidence: {confidence:.1f}%)") | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| # Debug information | |
| st.info(f"Input text: {mask_input}") | |
| try: | |
| mask_filler = load_mask_filling_model() | |
| st.info(f"Model mask token: {mask_filler.tokenizer.mask_token}") | |
| except: | |
| pass | |
| # -------------------- POS TAGGING TAB -------------------- | |
| with tab2: | |
| st.header("Parts of Speech Tagging") | |
| st.write("Identify grammatical parts of speech in Setswana text.") | |
| pos_examples = [ | |
| "Moso ono mo dikgang tsa ura le ura, o tsoga le Oarabile Moamogwe go simolola ka 05:00 - 10:00", | |
| "Batho ba le bantsi ba rata go bala dikgang tsa Setswana.", | |
| "Ke ithutile Setswana kwa sekolong sa me.", | |
| "Dikgomo di ja bojang mo tshimong." | |
| ] | |
| pos_input = get_input_text("pos", pos_examples) | |
| if st.button("Run POS Tagging", key="pos_button") and pos_input.strip(): | |
| with st.spinner("Running POS tagging..."): | |
| try: | |
| pos_tagger = load_pos_model() | |
| output = pos_tagger(pos_input) | |
| entities = merge_entities(output) | |
| if entities: | |
| # Display results table | |
| df = pd.DataFrame(entities)[['word', 'entity_group', 'score', 'start', 'end']] | |
| df['score'] = df['score'].round(4) | |
| st.subheader("POS Tags") | |
| st.dataframe(df, use_container_width=True) | |
| # Visual display | |
| st.subheader("Visual Display") | |
| html = create_spacy_display(pos_input, entities, "pos") | |
| st.markdown(html, unsafe_allow_html=True) | |
| else: | |
| st.info("No POS tags identified.") | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| # -------------------- NER TAB -------------------- | |
| with tab3: | |
| st.header("Named Entity Recognition") | |
| st.write("Extract named entities like people, places, and organizations from Setswana text.") | |
| ner_examples = [ | |
| "Oarabile Moamogwe o tswa Gaborone mme o bereka kwa University of Botswana.", | |
| "Motswana yo o tumileng Mpho Balopi o ne a kopana le Rre Khama kwa Presidential Palace.", | |
| "Botswana Democratic Party e ne ya kopana le African National Congress.", | |
| "Bank of Botswana e mo Gaborone e laola economy ya naga." | |
| ] | |
| ner_input = get_input_text("ner", ner_examples) | |
| if st.button("Run NER", key="ner_button") and ner_input.strip(): | |
| with st.spinner("Running NER..."): | |
| try: | |
| ner_pipeline = load_ner_model() | |
| output = ner_pipeline(ner_input) | |
| entities = merge_entities(output) | |
| if entities: | |
| # Display results table | |
| df = pd.DataFrame(entities)[['word', 'entity_group', 'score', 'start', 'end']] | |
| df['score'] = df['score'].round(4) | |
| st.subheader("Named Entities") | |
| st.dataframe(df, use_container_width=True) | |
| # Visual display | |
| st.subheader("Visual Display") | |
| html = create_spacy_display(ner_input, entities, "ner") | |
| st.markdown(html, unsafe_allow_html=True) | |
| else: | |
| st.info("No named entities found.") | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| # -------------------- NEWS CLASSIFICATION TAB -------------------- | |
| with tab4: | |
| st.header("News Classification") | |
| st.write("Classify Setswana news articles into different categories.") | |
| # Category mapping | |
| categories = { | |
| "arts_culture_entertainment_and_media": "Botsweretshi, setso, boitapoloso le bobegakgang", | |
| "crime_law_and_justice": "Bosenyi, molao le bosiamisi", | |
| "disaster_accident_and_emergency_incident": "Masetlapelo, kotsi le tiragalo ya maemo a tshoganyetso", | |
| "economy_business_and_finance": "Ikonomi, tsa kgwebo le tsa ditšhelete", | |
| "education": "Thuto", | |
| "environment": "Tikologo", | |
| "health": "Boitekanelo", | |
| "politics": "Dipolotiki", | |
| "religion_and_belief": "Bodumedi le tumelo", | |
| "society": "Setšhaba" | |
| } | |
| news_examples = [ | |
| "Puso ya Botswana e solofeditse gore e tla oketsa dithuso tsa thuto mo dikolong tsa poraemari.", | |
| "Dipalo tsa bosenyi di oketsegile mo torong ya Gaborone ka pakeng tse di fetileng.", | |
| "Setšhaba sa Botswana se keteka matsalo a Rre le Mme ba ba ratanang thata.", | |
| "Boemelo jwa economy ya Botswana bo tsweletse sentle ka ngwaga ono." | |
| ] | |
| news_input = get_input_text("news", news_examples) | |
| if st.button("Classify News", key="news_button") and news_input.strip(): | |
| with st.spinner("Classifying news..."): | |
| try: | |
| classifier = load_news_classification_model() | |
| results = classifier(news_input) | |
| # Process results | |
| predictions = {} | |
| for pred in results[0]: | |
| category_en = pred['label'] | |
| category_tn = categories.get(category_en, category_en) | |
| predictions[category_tn] = round(pred['score'], 4) | |
| # Sort by confidence | |
| sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True)) | |
| st.subheader("Classification Results") | |
| # Display as progress bars | |
| for category, confidence in list(sorted_predictions.items())[:5]: | |
| st.write(f"**{category}**") | |
| st.progress(confidence) | |
| st.write(f"Confidence: {confidence:.1%}") | |
| st.write("") | |
| # Display full results table | |
| with st.expander("View All Categories"): | |
| results_df = pd.DataFrame([ | |
| {"Category": cat, "Confidence": conf} | |
| for cat, conf in sorted_predictions.items() | |
| ]) | |
| st.dataframe(results_df, use_container_width=True) | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| # -------------------- FOOTER -------------------- | |
| st.markdown("---") | |
| st.markdown(""" | |
| ### 📚 Citation | |
| ```bibtex | |
| @inproceedings{marivate2023puoberta, | |
| title = {PuoBERTa: Training and evaluation of a curated language model for Setswana}, | |
| author = {Vukosi Marivate and Moseli Mots'Oehli and Valencia Wagner and Richard Lastrucci and Isheanesu Dzingirai}, | |
| year = {2023}, | |
| booktitle= {Artificial Intelligence Research. SACAIR 2023. Communications in Computer and Information Science}, | |
| url= {https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17}, | |
| keywords = {NLP}, | |
| preprint_url = {https://arxiv.org/abs/2310.09141}, | |
| dataset_url = {https://github.com/dsfsi/PuoBERTa}, | |
| software_url = {https://huggingface.co/dsfsi/PuoBERTa} | |
| } | |
| ``` | |
| **Links**: [Paper](https://link.springer.com/chapter/10.1007/978-3-031-49002-6_17) | [Preprint/Arxiv](https://arxiv.org/abs/2310.09141) | [GitHub](https://github.com/dsfsi/PuoBERTa) | [HuggingFace](https://huggingface.co/dsfsi/PuoBERTa) | |
| """) |