noequal commited on
Commit
b791513
·
1 Parent(s): 713110c

Minimalistic version

Browse files
Files changed (1) hide show
  1. app.py +12 -41
app.py CHANGED
@@ -1,47 +1,18 @@
1
- # Import necessary libraries
2
  import streamlit as st
3
- import pandas as pd
4
- import torch
5
- from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding, DataCollatorForLanguageModeling
6
- from text_processor import generate_text, classify_text
7
 
8
- # Step 1: Set Up Your Environment
9
- # Environment setup and package installations.
10
 
11
- # Step 2: Data Preparation
12
- # Load and preprocess your CSV dataset.
13
- df = pd.read_csv('stepkids_training_data.csv')
14
 
15
- # Filter out rows with missing label data
16
- df = df.dropna(subset=['Theme 1', 'Theme 2', 'Theme 3', 'Theme 4', 'Theme 5'])
17
 
18
- text_list = df['Post Text'].tolist()
19
- labels = df[['Theme 1', 'Theme 2', 'Theme 3', 'Theme 4', 'Theme 5']].values.tolist()
 
 
 
 
20
 
21
- # Step 3: Model Selection
22
- # Load your GPT-2 model for text generation.
23
- model_name = "gpt2" # Choose the appropriate GPT-2 model variant
24
- text_gen_model = GPT2LMHeadModel.from_pretrained(model_name)
25
- text_gen_tokenizer = GPT2Tokenizer.from_pretrained(model_name)
26
- text_gen_tokenizer.pad_token = text_gen_tokenizer.eos_token
27
-
28
- # Load your sequence classification model (e.g., BERT)
29
- seq_classifier_model = GPT2ForSequenceClassification.from_pretrained(model_name)
30
- seq_classifier_tokenizer = GPT2Tokenizer.from_pretrained(model_name)
31
- seq_classifier_tokenizer.pad_token = seq_classifier_tokenizer.eos_token
32
-
33
- # Create a title and a text input for the app
34
- st.title('Thematic Analysis with GPT-2 Large')
35
- text = st.text_area('Enter some text')
36
-
37
- # If the text is not empty, perform both text generation and sequence classification
38
- if text:
39
- # Perform text generation
40
- generated_text = generate_text(text, text_gen_model, text_gen_tokenizer)
41
- st.write('Generated Text:')
42
- st.write(generated_text)
43
-
44
- # Perform sequence classification
45
- labels = classify_text(text, seq_classifier_model, seq_classifier_tokenizer)
46
- st.write('Classified Labels:')
47
- st.write(labels)
 
 
1
  import streamlit as st
2
+ from transformers import pipeline
 
 
 
3
 
4
+ # Load the zero-shot classification pipeline
5
+ theme_detection = pipeline('zero-shot-classification')
6
 
7
+ st.title("Theme Detection App")
 
 
8
 
9
+ # Create a textarea for user input
10
+ user_text = st.text_area("Enter Text:", "Type here...")
11
 
12
+ if st.button("Detect Themes"):
13
+ # Perform theme detection
14
+ themes = theme_detection(user_text, ['Theme1', 'Theme2', 'Theme3'])
15
+
16
+ # Display the result
17
+ st.success(f"Detected Themes: {', '.join(themes['labels'])}")
18