marme commited on
Commit
eb5fab4
·
1 Parent(s): bf12662
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from chonky import ParagraphSplitter
4
+
5
+ splitter = ParagraphSplitter()
6
+
7
+ with gr.Blocks() as demo:
8
+ gr.Markdown("# Semantic Chunking Demo\n **Note**: This Space runs on CPU only, so input is limited to max. 50000 characters.")
9
+ gr.HTML("""<footer>
10
+ <p>Powered by <a href="https://huggingface.co/mamei16/chonky_distilbert_base_uncased_1.1">mamei16/chonky_distilbert_base_uncased_1.1</a></p>
11
+ </footer>""")
12
+ button = gr.Button("Run", variant="primary")
13
+ text = gr.Textbox(label='Input Text', max_length=50000)
14
+ gr.Markdown("## Result chunks:")
15
+ chunks = gr.Markdown()
16
+
17
+ button.click(lambda x: "\n\n---\n\n".join(splitter(x)), text, chunks)
18
+
19
+
20
+
21
+ if __name__ == "__main__":
22
+ demo.queue(max_size=20)
23
+ demo.launch()
chonky/__init__.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from attr import dataclass
3
+
4
+ import torch
5
+ import numpy as np
6
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
7
+
8
+
9
+ def batchify(lst, batch_size):
10
+ last_item_shorter = False
11
+ if len(lst[-1]) < len(lst[0]):
12
+ last_item_shorter = True
13
+ max_index = len(lst)-1
14
+ else:
15
+ max_index = len(lst)
16
+
17
+ for i in range(0, max_index, batch_size):
18
+ yield lst[i : min(i + batch_size, max_index)]
19
+
20
+ if last_item_shorter:
21
+ yield lst[-1:]
22
+
23
+
24
+ @dataclass
25
+ class Token:
26
+ index: int
27
+ start: int
28
+ end: int
29
+ length: int
30
+ decoded_str: str
31
+
32
+
33
+ class ParagraphSplitter:
34
+ def __init__(self, model_id="mamei16/chonky_distilbert_base_uncased_1.1", device="cpu", model_cache_dir: str = None):
35
+ super().__init__()
36
+ self.device = device
37
+ self.is_modernbert = model_id.startswith("mirth/chonky_modernbert")
38
+
39
+ id2label = {
40
+ 0: "O",
41
+ 1: "separator",
42
+ }
43
+ label2id = {
44
+ "O": 0,
45
+ "separator": 1,
46
+ }
47
+
48
+ if self.is_modernbert:
49
+ tokenizer_kwargs = {"model_max_length": 1024}
50
+ else:
51
+ tokenizer_kwargs = {}
52
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=model_cache_dir, **tokenizer_kwargs)
53
+ self.model = AutoModelForTokenClassification.from_pretrained(
54
+ model_id,
55
+ num_labels=2,
56
+ id2label=id2label,
57
+ label2id=label2id,
58
+ cache_dir=model_cache_dir,
59
+ torch_dtype=torch.float32 if device == "cpu" else torch.float16
60
+ )
61
+ self.model.eval()
62
+ self.model.to(device)
63
+
64
+ def split_into_semantic_chunks(self, text, separator_indices: List[int]):
65
+ start_index = 0
66
+
67
+ for idx in separator_indices:
68
+ yield text[start_index:idx].strip()
69
+ start_index = idx
70
+
71
+ if start_index < len(text):
72
+ yield text[start_index:].strip()
73
+
74
+ def __call__(self, text: str) -> List[str]:
75
+ max_seq_len = self.tokenizer.model_max_length
76
+ window_step_size = max_seq_len // 2
77
+ ids_plus = self.tokenizer(text, truncation=True, add_special_tokens=True, return_offsets_mapping=True,
78
+ return_overflowing_tokens=True, stride=window_step_size)
79
+
80
+ tokens = [[Token(i*max_seq_len+j,
81
+ offset_tup[0], offset_tup[1],
82
+ offset_tup[1]-offset_tup[0],
83
+ text[offset_tup[0]:offset_tup[1]]) for j, offset_tup in enumerate(offset_list)]
84
+ for i, offset_list in enumerate(ids_plus["offset_mapping"])]
85
+
86
+ input_ids = ids_plus["input_ids"]
87
+ all_separator_tokens = []
88
+
89
+ batch_size = 4
90
+ for input_id_batch, token_batch in zip(batchify(input_ids, batch_size),
91
+ batchify(tokens, batch_size)):
92
+ with torch.no_grad():
93
+ output = self.model(torch.tensor(input_id_batch).to(self.device))
94
+
95
+ logits = output.logits.cpu().numpy()
96
+ maxes = np.max(logits, axis=-1, keepdims=True)
97
+ shifted_exp = np.exp(logits - maxes)
98
+ scores = shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
99
+ token_classes = scores.argmax(axis=-1)
100
+ # Find last index of each sequence of ones in token class sequence
101
+ separator_token_idx_tup = ((token_classes[:, :-1] - token_classes[:, 1:]) > 0).nonzero()
102
+
103
+ separator_tokens = [token_batch[i][j] for i, j in zip(*separator_token_idx_tup)]
104
+ all_separator_tokens.extend(separator_tokens)
105
+
106
+ flat_tokens = [token for window in tokens for token in window]
107
+ sorted_separator_tokens = sorted(all_separator_tokens, key=lambda x: x.start)
108
+ separator_indices = []
109
+ for i in range(len(sorted_separator_tokens)-1):
110
+ current_sep_token = sorted_separator_tokens[i]
111
+ if current_sep_token.end == 0:
112
+ continue
113
+ next_sep_token = sorted_separator_tokens[i+1]
114
+ # next_token is the token succeeding current_sep_token in the original text
115
+ next_token = flat_tokens[current_sep_token.index+1]
116
+
117
+ # If current separator token is part of a bigger contiguous token, move to the end of the bigger token
118
+ while (current_sep_token.end == next_token.start and
119
+ (not self.is_modernbert or (current_sep_token.decoded_str != '\n'
120
+ and not next_token.decoded_str.startswith(' ')))):
121
+ current_sep_token = next_token
122
+ next_token = flat_tokens[current_sep_token.index+1]
123
+
124
+ if ((current_sep_token.start + current_sep_token.length) > next_sep_token.start or
125
+ ((next_sep_token.end - current_sep_token.end) <= 1)):
126
+ continue
127
+
128
+ separator_indices.append(current_sep_token.end)
129
+
130
+ if sorted_separator_tokens:
131
+ separator_indices.append(sorted_separator_tokens[-1].end)
132
+
133
+ yield from self.split_into_semantic_chunks(text, separator_indices)
chonky/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (8.9 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ numpy
3
+ torch
4
+