ThematicAnalysis / text_processor.py
noequal's picture
Update text_processor.py
713110c
raw
history blame
1.55 kB
import torch
# Function for generating text based on input
def generate_text(input_text, model, tokenizer):
# Append the special token to the input
input_text = input_text + ' [LABEL]'
input_ids = tokenizer.encode(input_text, return_tensors='pt')
attention_mask = torch.ones_like(input_ids)
outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=len(input_ids) + 5, do_sample=True, top_p=0.95)
generated = tokenizer.decode(outputs[0], skip_special_tokens=False)
labels = generated.split(',')
labels = [label.replace('[LABEL]', '').strip() for label in labels]
return generated
# Function for sequence classification
def classify_text(input_text, model, tokenizer):
# Tokenize the input text
input_ids = tokenizer.encode(input_text, return_tensors='pt')
attention_mask = torch.ones_like(input_ids)
# Perform sequence classification
result = model(input_ids, attention_mask=attention_mask)
# Post-process the results (e.g., select labels based on a threshold)
labels = post_process_labels(result)
return labels
# Post-process labels based on a threshold or confidence score
def post_process_labels(results):
# Implement your logic to extract and filter labels
# based on your sequence classification model's output
# For example, you might use a threshold for each label's score
# to determine whether it should be considered a valid theme.
# Return the selected labels as a list.
selected_labels = [results]
return selected_labels