Spaces:
Sleeping
Sleeping
| # Import necessary libraries | |
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding, DataCollatorForLanguageModeling | |
| from text_processor import generate_text, classify_text | |
| # Step 1: Set Up Your Environment | |
| # Environment setup and package installations. | |
| # Step 2: Data Preparation | |
| # Load and preprocess your CSV dataset. | |
| df = pd.read_csv('stepkids_training_data.csv') | |
| # Filter out rows with missing label data | |
| df = df.dropna(subset=['Theme 1', 'Theme 2', 'Theme 3', 'Theme 4', 'Theme 5']) | |
| text_list = df['Post Text'].tolist() | |
| labels = df[['Theme 1', 'Theme 2', 'Theme 3', 'Theme 4', 'Theme 5']].values.tolist() | |
| # Step 3: Model Selection | |
| # Load your GPT-2 model for text generation. | |
| model_name = "gpt2" # Choose the appropriate GPT-2 model variant | |
| text_gen_model = GPT2LMHeadModel.from_pretrained(model_name) | |
| text_gen_tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| text_gen_tokenizer.pad_token = text_gen_tokenizer.eos_token | |
| # Load your sequence classification model (e.g., BERT) | |
| seq_classifier_model = GPT2ForSequenceClassification.from_pretrained(model_name) | |
| seq_classifier_tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| seq_classifier_tokenizer.pad_token = seq_classifier_tokenizer.eos_token | |
| # Create a title and a text input for the app | |
| st.title('Thematic Analysis with GPT-2 Large') | |
| text = st.text_area('Enter some text') | |
| # If the text is not empty, perform both text generation and sequence classification | |
| if text: | |
| # Perform text generation | |
| generated_text = generate_text(text, text_gen_model, text_gen_tokenizer) | |
| st.write('Generated Text:') | |
| st.write(generated_text) | |
| # Perform sequence classification | |
| labels = classify_text(text, seq_classifier_model, seq_classifier_tokenizer) | |
| st.write('Classified Labels:') | |
| st.write(labels) | |