Spaces:
Sleeping
Sleeping
| # Import the necessary libraries | |
| import streamlit as st | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel, pipeline | |
| import torch | |
| # Load the gpt2-large model and tokenizer for text generation | |
| gen_model = GPT2LMHeadModel.from_pretrained('gpt2-large') | |
| gen_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large') | |
| # Load the zero-shot text classification pipeline from HuggingFace | |
| classifier = pipeline('zero-shot-classification') | |
| # Define a function that takes a text as input and returns a list of labels as output | |
| def generate_labels(text): | |
| # Append the special token [LABEL] to the text | |
| text = text + ' [LABEL]' | |
| # Convert the text to input ids and attention mask | |
| input_ids = gen_tokenizer.encode(text, return_tensors='pt') | |
| attention_mask = torch.ones_like(input_ids) | |
| # Generate up to 5 labels from the model | |
| outputs = gen_model.generate(input_ids, attention_mask=attention_mask, max_length=len(input_ids)+5, do_sample=True, top_p=0.95) | |
| # Decode the generated text | |
| generated = gen_tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| # Split the generated text by commas | |
| labels = generated.split(',') | |
| # Remove the special token and any whitespace from the labels | |
| labels = [label.replace('[LABEL]', '').strip() for label in labels] | |
| # Filter out any empty or duplicate labels | |
| labels = list(dict.fromkeys(filter(None, labels))) | |
| # Return the labels as a list | |
| return labels | |
| # Create a title and a text input for the app | |
| st.title('Thematic Analysis with GPT-2 Large') | |
| text = st.text_input('Enter some text to classify') | |
| # If the text is not empty, generate labels and classify the text | |
| if text: | |
| # Generate labels from the text | |
| labels = generate_labels(text) | |
| # Display the generated labels | |
| st.write(f'The generated labels are: {", ".join(labels)}') | |
| # Classify the text using the generated labels | |
| result = classifier(text, labels) | |
| # Get the label and the score with the highest probability | |
| label = result['labels'][0] | |
| score = result['scores'][0] | |
| # Display the label and the score | |
| st.write(f'The predicted label is: {label}') | |
| st.write(f'The probability is: {score:.4f}') | |