from typing import List from attr import dataclass import torch import numpy as np from transformers import AutoTokenizer, AutoModelForTokenClassification def batchify(lst, batch_size): last_item_shorter = False if len(lst[-1]) < len(lst[0]): last_item_shorter = True max_index = len(lst)-1 else: max_index = len(lst) for i in range(0, max_index, batch_size): yield lst[i : min(i + batch_size, max_index)] if last_item_shorter: yield lst[-1:] @dataclass class Token: index: int start: int end: int length: int decoded_str: str class ParagraphSplitter: def __init__(self, model_id="mamei16/chonky_distilbert_base_uncased_1.1", device="cpu", model_cache_dir: str = None): super().__init__() self.device = device self.is_modernbert = model_id.startswith("mirth/chonky_modernbert") or model_id == "mirth/chonky_mmbert_small_multilingual_1" id2label = { 0: "O", 1: "separator", } label2id = { "O": 0, "separator": 1, } if self.is_modernbert: tokenizer_kwargs = {"model_max_length": 1024} else: tokenizer_kwargs = {} self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=model_cache_dir, **tokenizer_kwargs) self.model = AutoModelForTokenClassification.from_pretrained( model_id, num_labels=2, id2label=id2label, label2id=label2id, cache_dir=model_cache_dir, torch_dtype=torch.float32 if device == "cpu" else torch.float16 ) self.model.eval() self.model.to(device) def split_into_semantic_chunks(self, text, separator_indices: List[int]): start_index = 0 for idx in separator_indices: yield text[start_index:idx].strip() start_index = idx if start_index < len(text): yield text[start_index:].strip() def __call__(self, text: str) -> List[str]: max_seq_len = self.tokenizer.model_max_length window_step_size = max_seq_len // 2 ids_plus = self.tokenizer(text, truncation=True, add_special_tokens=True, return_offsets_mapping=True, return_overflowing_tokens=True, stride=window_step_size) tokens = [[Token(i*max_seq_len+j, offset_tup[0], offset_tup[1], offset_tup[1]-offset_tup[0], text[offset_tup[0]:offset_tup[1]]) for j, offset_tup in enumerate(offset_list)] for i, offset_list in enumerate(ids_plus["offset_mapping"])] input_ids = ids_plus["input_ids"] all_separator_tokens = [] batch_size = 4 for input_id_batch, token_batch in zip(batchify(input_ids, batch_size), batchify(tokens, batch_size)): with torch.no_grad(): output = self.model(torch.tensor(input_id_batch).to(self.device)) logits = output.logits.cpu().numpy() maxes = np.max(logits, axis=-1, keepdims=True) shifted_exp = np.exp(logits - maxes) scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True) token_classes = scores.argmax(axis=-1) # Find last index of each sequence of ones in token class sequence separator_token_idx_tup = ((token_classes[:, :-1] - token_classes[:, 1:]) > 0).nonzero() separator_tokens = [token_batch[i][j] for i, j in zip(*separator_token_idx_tup)] all_separator_tokens.extend(separator_tokens) flat_tokens = [token for window in tokens for token in window] sorted_separator_tokens = sorted(all_separator_tokens, key=lambda x: x.start) separator_indices = [] for i in range(len(sorted_separator_tokens)-1): current_sep_token = sorted_separator_tokens[i] if current_sep_token.end == 0: continue next_sep_token = sorted_separator_tokens[i+1] # next_token is the token succeeding current_sep_token in the original text next_token = flat_tokens[current_sep_token.index+1] # If current separator token is part of a bigger contiguous token, move to the end of the bigger token while (current_sep_token.end == next_token.start and (not self.is_modernbert or (current_sep_token.decoded_str != '\n' and not next_token.decoded_str.startswith(' ')))): current_sep_token = next_token next_token = flat_tokens[current_sep_token.index+1] if ((current_sep_token.start + current_sep_token.length) > next_sep_token.start or ((next_sep_token.end - current_sep_token.end) <= 1)): continue separator_indices.append(current_sep_token.end) if sorted_separator_tokens: separator_indices.append(sorted_separator_tokens[-1].end) yield from self.split_into_semantic_chunks(text, separator_indices)