Spaces:
Runtime error
Runtime error
| ############ 1. IMPORTING LIBRARIES ############ | |
| # Import streamlit, requests for API calls, and pandas and numpy for data manipulation | |
| import streamlit as st | |
| import requests | |
| import pandas as pd | |
| import numpy as np | |
| from streamlit_tags import st_tags # to add labels on the fly! | |
| ############ 2. SETTING UP THE PAGE LAYOUT AND TITLE ############ | |
| # `st.set_page_config` is used to display the default layout width, the title of the app, and the emoticon in the browser tab. | |
| st.set_page_config( | |
| layout="centered", page_title="Zero-Shot Text Classifier", page_icon="βοΈ" | |
| ) | |
| ############ CREATE THE LOGO AND HEADING ############ | |
| # We create a set of columns to display the logo and the heading next to each other. | |
| c1, c2 = st.columns([0.32, 2]) | |
| # The snowflake logo will be displayed in the first column, on the left. | |
| with c1: | |
| st.image( | |
| "https://images.unsplash.com/photo-1508175800969-525c72a047dd?w=500&auto=format&fit=crop&q=60&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxzZWFyY2h8MTl8fGFmcm8lMjByb2JvdHxlbnwwfHwwfHx8MA%3D%3D", | |
| width=85, | |
| ) | |
| # The heading will be on the right. | |
| with c2: | |
| st.caption("") | |
| st.title("Zero-Shot Text Classifier") | |
| # We need to set up session state via st.session_state so that app interactions don't reset the app. | |
| if not "valid_inputs_received" in st.session_state: | |
| st.session_state["valid_inputs_received"] = False | |
| ############ SIDEBAR CONTENT ############ | |
| st.sidebar.write("") | |
| # For elements to be displayed in the sidebar, we need to add the sidebar element in the widget. | |
| # We create a text input field for users to enter their API key. | |
| API_KEY = st.sidebar.text_input( | |
| "Enter your HuggingFace API key", | |
| help="Once you created you HuggingFace account, you can get your free API token in your settings page: https://huggingface.co/settings/tokens", | |
| type="password", | |
| ) | |
| # Adding the HuggingFace API inference URL. | |
| API_URL = "https://api-inference.huggingface.co/models/valhalla/distilbart-mnli-12-3" | |
| # Now, let's create a Python dictionary to store the API headers. | |
| headers = {"Authorization": f"Bearer {API_KEY}"} | |
| st.sidebar.markdown("---") | |
| # Let's add some info about the app to the sidebar. | |
| st.sidebar.write( | |
| """ | |
| App created by [Charly Wargnier](https://twitter.com/DataChaz) using [Streamlit](https://streamlit.io/)π and [HuggingFace](https://huggingface.co/inference-api)'s [Distilbart-mnli-12-3](https://huggingface.co/valhalla/distilbart-mnli-12-3) model. | |
| """ | |
| ) | |
| ############ TABBED NAVIGATION ############ | |
| # First, we're going to create a tabbed navigation for the app via st.tabs() | |
| # tabInfo displays info about the app. | |
| # tabMain displays the main app. | |
| MainTab, InfoTab = st.tabs(["Main", "Info"]) | |
| with InfoTab: | |
| st.subheader("What is Streamlit?") | |
| st.markdown( | |
| "[Streamlit](https://streamlit.io) is a Python library that allows the creation of interactive, data-driven web applications in Python." | |
| ) | |
| st.subheader("Resources") | |
| st.markdown( | |
| """ | |
| - [Streamlit Documentation](https://docs.streamlit.io/) | |
| - [Cheat sheet](https://docs.streamlit.io/library/cheatsheet) | |
| - [Book](https://www.amazon.com/dp/180056550X) (Getting Started with Streamlit for Data Science) | |
| """ | |
| ) | |
| st.subheader("Deploy") | |
| st.markdown( | |
| "You can quickly deploy Streamlit apps using [Streamlit Community Cloud](https://streamlit.io/cloud) in just a few clicks." | |
| ) | |
| with MainTab: | |
| # Then, we create a intro text for the app, which we wrap in a st.markdown() widget. | |
| st.write("") | |
| st.markdown( | |
| """ | |
| Classify keyphrases on the fly with this mighty app. No training needed! | |
| """ | |
| ) | |
| st.write("") | |
| # Now, we create a form via `st.form` to collect the user inputs. | |
| # All widget values will be sent to Streamlit in batch. | |
| # It makes the app faster! | |
| with st.form(key="my_form"): | |
| ############ ST TAGS ############ | |
| # We initialize the st_tags component with default "labels" | |
| # Here, we want to classify the text into one of the following user intents: | |
| # Transactional | |
| # Informational | |
| # Navigational | |
| labels_from_st_tags = st_tags( | |
| value=["Transactional", "Informational", "Navigational"], | |
| maxtags=3, | |
| suggestions=["Transactional", "Informational", "Navigational"], | |
| label="", | |
| ) | |
| # The block of code below is to display some text samples to classify. | |
| # This can of course be replaced with your own text samples. | |
| # MAX_KEY_PHRASES is a variable that controls the number of phrases that can be pasted: | |
| # The default in this app is 50 phrases. This can be changed to any number you like. | |
| MAX_KEY_PHRASES = 50 | |
| new_line = "\n" | |
| pre_defined_keyphrases = [ | |
| "I want to buy something", | |
| "We have a question about a product", | |
| "I want a refund through the Google Play store", | |
| "Can I have a discount, please", | |
| "Can I have the link to the product page?", | |
| ] | |
| # Python list comprehension to create a string from the list of keyphrases. | |
| keyphrases_string = f"{new_line.join(map(str, pre_defined_keyphrases))}" | |
| # The block of code below displays a text area | |
| # So users can paste their phrases to classify | |
| text = st.text_area( | |
| # Instructions | |
| "Enter keyphrases to classify", | |
| # 'sample' variable that contains our keyphrases. | |
| keyphrases_string, | |
| # The height | |
| height=200, | |
| # The tooltip displayed when the user hovers over the text area. | |
| help="At least two keyphrases for the classifier to work, one per line, " | |
| + str(MAX_KEY_PHRASES) | |
| + " keyphrases max in 'unlocked mode'. You can tweak 'MAX_KEY_PHRASES' in the code to change this", | |
| key="1", | |
| ) | |
| # The block of code below: | |
| # 1. Converts the data st.text_area into a Python list. | |
| # 2. It also removes duplicates and empty lines. | |
| # 3. Raises an error if the user has entered more lines than in MAX_KEY_PHRASES. | |
| text = text.split("\n") # Converts the pasted text to a Python list | |
| linesList = [] # Creates an empty list | |
| for x in text: | |
| linesList.append(x) # Adds each line to the list | |
| linesList = list(dict.fromkeys(linesList)) # Removes dupes | |
| linesList = list(filter(None, linesList)) # Removes empty lines | |
| if len(linesList) > MAX_KEY_PHRASES: | |
| st.info( | |
| f"βοΈ Note that only the first " | |
| + str(MAX_KEY_PHRASES) | |
| + " keyphrases will be reviewed to preserve performance. Fork the repo and tweak 'MAX_KEY_PHRASES' in the code to increase that limit." | |
| ) | |
| linesList = linesList[:MAX_KEY_PHRASES] | |
| submit_button = st.form_submit_button(label="Submit") | |
| ############ CONDITIONAL STATEMENTS ############ | |
| # Now, let us add conditional statements to check if users have entered valid inputs. | |
| # E.g. If the user has pressed the 'submit button without text, without labels, and with only one label etc. | |
| # The app will display a warning message. | |
| if not submit_button and not st.session_state.valid_inputs_received: | |
| st.stop() | |
| elif submit_button and not text: | |
| st.warning("βοΈ There is no keyphrases to classify") | |
| st.session_state.valid_inputs_received = False | |
| st.stop() | |
| elif submit_button and not labels_from_st_tags: | |
| st.warning("βοΈ You have not added any labels, please add some! ") | |
| st.session_state.valid_inputs_received = False | |
| st.stop() | |
| elif submit_button and len(labels_from_st_tags) == 1: | |
| st.warning("βοΈ Please make sure to add at least two labels for classification") | |
| st.session_state.valid_inputs_received = False | |
| st.stop() | |
| elif submit_button or st.session_state.valid_inputs_received: | |
| if submit_button: | |
| # The block of code below if for our session state. | |
| # This is used to store the user's inputs so that they can be used later in the app. | |
| st.session_state.valid_inputs_received = True | |
| ############ MAKING THE API CALL ############ | |
| # First, we create a Python function to construct the API call. | |
| def query(payload): | |
| response = requests.post(API_URL, headers=headers, json=payload) | |
| return response.json() | |
| # The function will send an HTTP POST request to the API endpoint. | |
| # This function has one argument: the payload | |
| # The payload is the data we want to send to HugggingFace when we make an API request | |
| # We create a list to store the outputs of the API call | |
| list_for_api_output = [] | |
| # We create a 'for loop' that iterates through each keyphrase | |
| # An API call will be made every time, for each keyphrase | |
| # The payload is composed of: | |
| # 1. the keyphrase | |
| # 2. the labels | |
| # 3. the 'wait_for_model' parameter set to "True", to avoid timeouts! | |
| for row in linesList: | |
| api_json_output = query( | |
| { | |
| "inputs": row, | |
| "parameters": {"candidate_labels": labels_from_st_tags}, | |
| "options": {"wait_for_model": True}, | |
| } | |
| ) | |
| # Let's have a look at the output of the API call | |
| # st.write(api_json_output) | |
| # All the results are appended to the empty list we created earlier | |
| list_for_api_output.append(api_json_output) | |
| # then we'll convert the list to a dataframe | |
| df = pd.DataFrame.from_dict(list_for_api_output) | |
| st.success("β Done!") | |
| st.caption("") | |
| st.markdown("### Check the results!") | |
| st.caption("") | |
| # st.write(df) | |
| ############ DATA WRANGLING ON THE RESULTS ############ | |
| # Various data wrangling to get the data in the right format! | |
| # List comprehension to convert the score from decimals to percentages | |
| f = [[f"{x:.2%}" for x in row] for row in df["scores"]] | |
| # Join the classification scores to the dataframe | |
| df["classification scores"] = f | |
| # Rename the column 'sequence' to 'keyphrase' | |
| df.rename(columns={"sequence": "keyphrase"}, inplace=True) | |
| # The API returns a list of all labels sorted by score. We only want the top label. | |
| # For that, we need to select the first element in the 'labels' and 'classification scores' lists | |
| df["label"] = df["labels"].str[0] | |
| df["accuracy"] = df["classification scores"].str[0] | |
| # Drop the columns we don't need | |
| df.drop(["scores", "labels", "classification scores"], inplace=True, axis=1) | |
| # st.write(df) | |
| # We need to change the index. Index starts at 0, so we make it start at 1 | |
| df.index = np.arange(1, len(df) + 1) | |
| # Display the dataframe | |
| st.write(df) | |
| cs, c1 = st.columns([2, 2]) | |
| # The code below is for the download button | |
| # Cache the conversion to prevent computation on every rerun | |
| with cs: | |
| def convert_df(df): | |
| return df.to_csv().encode("utf-8") | |
| csv = convert_df(df) | |
| st.caption("") | |
| st.download_button( | |
| label="Download results", | |
| data=csv, | |
| file_name="classification_results.csv", | |
| mime="text/csv", | |
| ) | |