Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from utils import ( | |
| device, | |
| jina_tokenizer, | |
| jina_model, | |
| embeddings_predict_relevance, | |
| stsb_model, | |
| stsb_tokenizer, | |
| ms_model, | |
| ms_tokenizer, | |
| cross_encoder_predict_relevance | |
| ) | |
| def predict(system_prompt, user_prompt, selected_model): | |
| if selected_model == "jinaai/jina-embeddings-v2-small-en": | |
| predicted_label, probabilities = embeddings_predict_relevance(system_prompt, user_prompt, jina_model, jina_tokenizer, device) | |
| elif selected_model == "cross-encoder/stsb-roberta-base": | |
| predicted_label, probabilities = cross_encoder_predict_relevance(system_prompt, user_prompt, stsb_model, stsb_tokenizer, device) | |
| elif selected_model == "cross-encoder/ms-marco-MiniLM-L-6-v2": | |
| predicted_label, probabilities = cross_encoder_predict_relevance(system_prompt, user_prompt, ms_model, ms_tokenizer, device) | |
| probability_off_topic = probabilities[0][1] * 100 | |
| label = "Off-topic" if predicted_label==1 else "On-topic" | |
| result = f""" | |
| **Prediction Summary**: | |
| - **Predicted Label**: {label} | |
| - **Probability of Off-topic**: {probability_off_topic:.3f}% | |
| """ | |
| return result | |
| with gr.Blocks(theme=gr.themes.Soft(), fill_height=True) as app: | |
| gr.Markdown("# Off-Topic Classification using Fine-tuned Embeddings and Cross-Encoder Models") | |
| with gr.Row(): | |
| system_prompt = gr.Textbox(label="System Prompt") | |
| user_prompt = gr.Textbox(label="User Prompt") | |
| with gr.Row(): | |
| selected_model = gr.Dropdown( | |
| ["jinaai/jina-embeddings-v2-small-en", | |
| "cross-encoder/stsb-roberta-base", | |
| "cross-encoder/ms-marco-MiniLM-L-6-v2"], | |
| label="Select a model") | |
| # Button to run the prediction | |
| get_classfication = gr.Button("Check Content") | |
| output_result = gr.Markdown(label="Classification and Probabilities") | |
| get_classfication.click( | |
| fn=predict, | |
| inputs=[system_prompt, user_prompt, selected_model], | |
| outputs=output_result | |
| ) | |
| if __name__ == "__main__": | |
| app.launch() | |