noequal commited on
Commit
fa93b6f
·
1 Parent(s): 02bfe1e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import the necessary libraries
2
+ import streamlit as st
3
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, pipeline
4
+ import torch
5
+
6
+ # Load the gpt2-large model and tokenizer for text generation
7
+ gen_model = GPT2LMHeadModel.from_pretrained('gpt2-large')
8
+ gen_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
9
+
10
+ # Load the zero-shot text classification pipeline from HuggingFace
11
+ classifier = pipeline('zero-shot-classification')
12
+
13
+ # Define a function that takes a text as input and returns a list of labels as output
14
+ def generate_labels(text):
15
+ # Append the special token [LABEL] to the text
16
+ text = text + ' [LABEL]'
17
+ # Convert the text to input ids and attention mask
18
+ input_ids = gen_tokenizer.encode(text, return_tensors='pt')
19
+ attention_mask = torch.ones_like(input_ids)
20
+ # Generate up to 5 labels from the model
21
+ outputs = gen_model.generate(input_ids, attention_mask=attention_mask, max_length=len(input_ids)+5, do_sample=True, top_p=0.95)
22
+ # Decode the generated text
23
+ generated = gen_tokenizer.decode(outputs[0], skip_special_tokens=False)
24
+ # Split the generated text by commas
25
+ labels = generated.split(',')
26
+ # Remove the special token and any whitespace from the labels
27
+ labels = [label.replace('[LABEL]', '').strip() for label in labels]
28
+ # Filter out any empty or duplicate labels
29
+ labels = list(dict.fromkeys(filter(None, labels)))
30
+ # Return the labels as a list
31
+ return labels
32
+
33
+ # Create a title and a text input for the app
34
+ st.title('Thematic Analysis with GPT-2 Large')
35
+ text = st.text_input('Enter some text to classify')
36
+
37
+ # If the text is not empty, generate labels and classify the text
38
+ if text:
39
+ # Generate labels from the text
40
+ labels = generate_labels(text)
41
+ # Display the generated labels
42
+ st.write(f'The generated labels are: {", ".join(labels)}')
43
+ # Classify the text using the generated labels
44
+ result = classifier(text, labels)
45
+ # Get the label and the score with the highest probability
46
+ label = result['labels'][0]
47
+ score = result['scores'][0]
48
+ # Display the label and the score
49
+ st.write(f'The predicted label is: {label}')
50
+ st.write(f'The probability is: {score:.4f}')