File size: 1,551 Bytes
1da6722
f9b90ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fed313
 
 
 
 
 
 
 
713110c
9fed313
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch

# Function for generating text based on input
def generate_text(input_text, model, tokenizer):
    # Append the special token to the input
    input_text = input_text + ' [LABEL]'
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    attention_mask = torch.ones_like(input_ids)
    outputs = model.generate(input_ids, attention_mask=attention_mask, max_length=len(input_ids) + 5, do_sample=True, top_p=0.95)
    generated = tokenizer.decode(outputs[0], skip_special_tokens=False)
    labels = generated.split(',')
    labels = [label.replace('[LABEL]', '').strip() for label in labels]
    return generated

# Function for sequence classification
def classify_text(input_text, model, tokenizer):
    # Tokenize the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    attention_mask = torch.ones_like(input_ids)
    # Perform sequence classification
    result = model(input_ids, attention_mask=attention_mask)
    # Post-process the results (e.g., select labels based on a threshold)
    labels = post_process_labels(result)
    return labels

# Post-process labels based on a threshold or confidence score
def post_process_labels(results):
    # Implement your logic to extract and filter labels
    # based on your sequence classification model's output
    # For example, you might use a threshold for each label's score
    # to determine whether it should be considered a valid theme.
    # Return the selected labels as a list.
    selected_labels = [results]
    return selected_labels