Spaces:
Sleeping
Sleeping
| import torch | |
| # Function for generating text based on input | |
| def generate_text(input_text, model, tokenizer): | |
| # Append the special token to the input | |
| input_text = input_text + ' [LABEL]' | |
| input_ids = tokenizer.encode(input_text, return_tensors='pt') | |
| attention_mask = torch.ones_like(input_ids) | |
| outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=len(input_ids) + 5, do_sample=True, top_p=0.95) | |
| generated = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
| labels = generated.split(',') | |
| labels = [label.replace('[LABEL]', '').strip() for label in labels] | |
| return generated | |
| # Function for sequence classification | |
| def classify_text(input_text, model, tokenizer): | |
| # Tokenize the input text | |
| input_ids = tokenizer.encode(input_text, return_tensors='pt') | |
| attention_mask = torch.ones_like(input_ids) | |
| # Perform sequence classification | |
| result = model(input_ids, attention_mask=attention_mask) | |
| # Post-process the results (e.g., select labels based on a threshold) | |
| labels = post_process_labels(result) | |
| return labels | |
| # Post-process labels based on a threshold or confidence score | |
| def post_process_labels(results): | |
| # Implement your logic to extract and filter labels | |
| # based on your sequence classification model's output | |
| # For example, you might use a threshold for each label's score | |
| # to determine whether it should be considered a valid theme. | |
| # Return the selected labels as a list. | |
| selected_labels = [results] | |
| return selected_labels |