dylangamachefl's picture
Update app.py
eeb6ce3 verified
# app.py
import streamlit as st
import requests
import os
from dotenv import load_dotenv
dotenv_path = os.path.join(os.path.dirname(__file__), "..", ".env")
load_dotenv(dotenv_path=dotenv_path)
API_TOKEN = os.getenv("HF_TOKEN")
MODEL_ID = "facebook/bart-large-mnli"
API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
headers = {"Authorization": f"Bearer {API_TOKEN}"}
def query_hf_api(payload):
"""
Sends a request to the Hugging Face Inference API.
"""
response = requests.post(API_URL, headers=headers, json=payload)
response.raise_for_status()
return response.json()
st.set_page_config(layout="wide", page_title="Zero-Shot Text Classifier")
st.title("🎯 Zero-Shot Text Classifier")
st.markdown(
f"""
This app uses the `{MODEL_ID}` Zero-Shot Classification model from Hugging Face
to classify text into categories you define, even if the model hasn't been explicitly trained on those categories.
Enter a piece of text and a comma-separated list of potential labels.
"""
)
text_to_classify = st.text_area(
"Enter text to classify:",
"The new AI regulations will have a significant impact on technology startups.",
height=100,
)
candidate_labels_input = st.text_input(
"Enter candidate labels (comma-separated):",
"politics, business, technology, sports, education, finance",
)
if st.button("Classify Text"):
if not API_TOKEN:
st.error(
"Hugging Face API token not found. Please set HUGGING_FACE_API_TOKEN in your .env file."
)
elif not text_to_classify.strip():
st.warning("Please enter some text to classify.")
elif not candidate_labels_input.strip():
st.warning("Please enter some candidate labels.")
else:
candidate_labels = [
label.strip()
for label in candidate_labels_input.split(",")
if label.strip()
]
if not candidate_labels:
st.warning(
"No valid candidate labels provided after parsing. Please ensure labels are separated by commas and are not empty."
)
else:
payload = {
"inputs": text_to_classify,
"parameters": {"candidate_labels": candidate_labels},
}
with st.spinner(f"Asking the AI model ({MODEL_ID})..."):
try:
result = query_hf_api(payload)
st.subheader("Classification Results:")
if (
isinstance(result, dict)
and "labels" in result
and "scores" in result
):
results_with_scores = list(
zip(result["labels"], result["scores"])
)
for label, score in results_with_scores:
st.write(f"**Label:** {label} - **Score:** {score:.4f}")
elif isinstance(result, dict) and "error" in result:
st.error(f"API Error: {result['error']}")
if "estimated_time" in result:
st.info(
f"The model might be loading. Estimated time: {result['estimated_time']:.2f} seconds. Please try again shortly."
)
else:
st.error("Received an unexpected response format from the API.")
st.json(result)
except requests.exceptions.RequestException as e:
st.error(f"API Request Failed: {e}")
if e.response is not None:
st.error(f"Response content: {e.response.text}")
except Exception as e:
st.error(f"An unexpected error occurred: {e}")
st.markdown("---")