|
|
import streamlit as st |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSequenceClassification, |
|
|
AutoModelForSeq2SeqLM, |
|
|
) |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
MODEL_MAPPING = { |
|
|
"text2shellcommands": "t5-small", |
|
|
"pentest_ai": "bert-base-uncased", |
|
|
} |
|
|
|
|
|
|
|
|
def select_model(): |
|
|
""" |
|
|
Adds a dropdown to the Streamlit sidebar for selecting a model. |
|
|
Returns: |
|
|
str: The selected model key from MODEL_MAPPING. |
|
|
""" |
|
|
st.sidebar.header("Model Configuration") |
|
|
selected_model = st.sidebar.selectbox("Select a model", list(MODEL_MAPPING.keys())) |
|
|
return selected_model |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model_and_tokenizer(model_name): |
|
|
""" |
|
|
Loads the tokenizer and model for the specified Hugging Face model name. |
|
|
Uses caching to optimize performance. |
|
|
|
|
|
Args: |
|
|
model_name (str): The name of the Hugging Face model to load. |
|
|
|
|
|
Returns: |
|
|
tuple: A tokenizer and model instance. |
|
|
""" |
|
|
try: |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
if "t5" in model_name or "seq2seq" in model_name: |
|
|
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
else: |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
return tokenizer, model |
|
|
except Exception as e: |
|
|
|
|
|
st.error(f"An error occurred while loading the model or tokenizer: {str(e)}") |
|
|
return None, None |
|
|
|
|
|
|
|
|
|
|
|
def predict_with_model(user_input, model, tokenizer, model_choice): |
|
|
""" |
|
|
Handles predictions using the loaded model and tokenizer. |
|
|
|
|
|
Args: |
|
|
user_input (str): Text input from the user. |
|
|
model: Loaded Hugging Face model. |
|
|
tokenizer: Loaded Hugging Face tokenizer. |
|
|
model_choice (str): Selected model key from MODEL_MAPPING. |
|
|
|
|
|
Returns: |
|
|
dict: A dictionary containing the prediction results. |
|
|
""" |
|
|
if model_choice == "text2shellcommands": |
|
|
|
|
|
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True) |
|
|
with torch.no_grad(): |
|
|
outputs = model.generate(**inputs) |
|
|
generated_command = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
return {"Generated Shell Command": generated_command} |
|
|
else: |
|
|
|
|
|
inputs = tokenizer(user_input, return_tensors="pt", padding=True, truncation=True) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
predicted_class = torch.argmax(logits, dim=-1).item() |
|
|
return { |
|
|
"Predicted Class": predicted_class, |
|
|
"Logits": logits.tolist(), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def process_uploaded_file(uploaded_file): |
|
|
""" |
|
|
Reads and processes the uploaded file. Supports text and CSV files. |
|
|
|
|
|
Args: |
|
|
uploaded_file: The uploaded file. |
|
|
|
|
|
Returns: |
|
|
str: The content of the file as a string. |
|
|
""" |
|
|
try: |
|
|
if uploaded_file is not None: |
|
|
file_type = uploaded_file.type |
|
|
|
|
|
|
|
|
if "text" in file_type: |
|
|
content = uploaded_file.read().decode("utf-8") |
|
|
return content |
|
|
|
|
|
elif "csv" in file_type: |
|
|
import pandas as pd |
|
|
df = pd.read_csv(uploaded_file) |
|
|
return df.to_string() |
|
|
else: |
|
|
st.error("Unsupported file type. Please upload a text or CSV file.") |
|
|
return None |
|
|
except Exception as e: |
|
|
st.error(f"Error processing file: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
st.title("AI Model Inference Dashboard") |
|
|
st.markdown( |
|
|
""" |
|
|
This dashboard allows you to interact with different AI models for inference tasks, |
|
|
such as generating shell commands or performing text classification. |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
model_choice = select_model() |
|
|
model_name = MODEL_MAPPING.get(model_choice) |
|
|
tokenizer, model = load_model_and_tokenizer(model_name) |
|
|
|
|
|
|
|
|
input_choice = st.radio("Choose Input Method", ("Text Input", "Upload File")) |
|
|
|
|
|
if input_choice == "Text Input": |
|
|
user_input = st.text_area("Enter your text input:", placeholder="Type your text here...") |
|
|
|
|
|
|
|
|
submit_button = st.button("Submit") |
|
|
|
|
|
if submit_button and user_input: |
|
|
st.write("### Prediction Results:") |
|
|
result = predict_with_model(user_input, model, tokenizer, model_choice) |
|
|
for key, value in result.items(): |
|
|
st.write(f"**{key}:** {value}") |
|
|
|
|
|
elif input_choice == "Upload File": |
|
|
uploaded_file = st.file_uploader("Choose a text or CSV file", type=["txt", "csv"]) |
|
|
|
|
|
|
|
|
submit_button = st.button("Submit") |
|
|
|
|
|
if submit_button and uploaded_file: |
|
|
file_content = process_uploaded_file(uploaded_file) |
|
|
if file_content: |
|
|
st.write("### File Content:") |
|
|
st.write(file_content) |
|
|
result = predict_with_model(file_content, model, tokenizer, model_choice) |
|
|
st.write("### Prediction Results:") |
|
|
for key, value in result.items(): |
|
|
st.write(f"**{key}:** {value}") |
|
|
else: |
|
|
st.info("No valid content found in the file.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|