translate_tl / utils /chunking.py
asasasaasasa's picture
init
da8d2e4
# utils/chunking.py
import logging
from pysbd import Segmenter
import re
def chunk_text_with_separators(text, tokenizer, max_tokens, lang):
"""
Splits the input text into chunks with preserved separators, optimized for handling lists and tables.
Args:
text (str): The input text to be chunked.
tokenizer: Tokenizer object used to encode text into tokens.
max_tokens (int): Maximum number of tokens allowed per chunk.
lang (str): Language of the text, used for sentence segmentation.
Returns:
list: A list of tuples, each containing a chunk of text and its corresponding separator.
"""
# Split text into sentences while preserving separators
sentences_with_seps = _split_technical_sentences(text, lang)
chunks = []
current_chunk = []
current_length = 0
current_separators = []
for sentence, sep in sentences_with_seps:
sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
sentence_len = len(sentence_tokens)
if sentence_len == 0:
continue
# Handle special cases like lists and tables
if _is_list_item(sentence) or _is_table_header(sentence):
if current_chunk:
# Finalize the current chunk before processing special items
chunks.append((' '.join(current_chunk), ''.join(current_separators)))
current_chunk = []
current_length = 0
current_separators = []
# Process list items as separate chunks
chunks.extend(_process_special_item(sentence, sep, tokenizer, max_tokens))
continue
# Add sentence to the current chunk if it fits
if current_length + sentence_len <= max_tokens:
current_chunk.append(sentence)
current_length += sentence_len
current_separators.append(sep)
else:
# Finalize the current chunk and start a new one
if current_chunk:
chunks.append((' '.join(current_chunk), ''.join(current_separators)))
current_chunk = [sentence]
current_length = sentence_len
current_separators = [sep]
# Add any remaining text to the final chunk
if current_chunk:
chunks.append((' '.join(current_chunk), ''.join(current_separators)))
return chunks
def _split_technical_sentences(text, lang):
"""Enhanced splitting for technical documents with lists and tables"""
# Handle numbered lists and bullet points
text = re.sub(r'(\n\s*\d+\.)', r'\n§§§\1', text)
# Handle colon-terminated headers
text = re.sub(r'(:\s*\n)', r'\1§§§', text)
sentences = []
separators = []
if lang == 'russian':
segmenter = Segmenter(language='ru', clean=False)
raw_sentences = segmenter.segment(text)
else:
raw_sentences = re.split(r'([.!?])(\s*)', text)
buffer = ''
current_sep = ''
for part in raw_sentences:
if '§§§' in part:
parts = part.split('§§§')
for p in parts[:-1]:
if p.strip():
sentences.append(p.strip())
separators.append(current_sep)
current_sep = ''
buffer = parts[-1]
else:
buffer += part
# Process buffer when we hit sentence boundaries
if lang == 'russian':
if buffer.strip() and any(buffer.endswith(c) for c in ['.', '!', '?', ':']):
sentences.append(buffer.strip())
separators.append(current_sep)
buffer = ''
current_sep = ''
else:
if re.search(r'[.!?:]$', buffer):
sentences.append(buffer.strip())
separators.append(current_sep)
buffer = ''
current_sep = ''
if buffer.strip():
sentences.append(buffer.strip())
separators.append(current_sep)
return list(zip(sentences, separators))
def _is_list_item(text):
return re.match(r'^\s*(\d+\.|\-|\*)\s', text)
def _is_table_header(text):
return re.search(r':\s*$', text) and re.search(r'[A-ZА-Я]{3,}', text)
def _process_special_item(text, separator, tokenizer, max_tokens):
"""Process list items and table headers as atomic units"""
chunks = []
current_chunk = []
current_length = 0
sentences = re.split(r'(\n+)', text)
for sentence in sentences:
if not sentence.strip():
continue
tokens = tokenizer.encode(sentence, add_special_tokens=False)
token_count = len(tokens)
if token_count > max_tokens:
# Handle oversized items with careful splitting
parts = re.split(r'([,;])', sentence)
for part in parts:
if not part.strip():
continue
part_tokens = tokenizer.encode(part, add_special_tokens=False)
part_len = len(part_tokens)
if current_length + part_len > max_tokens:
chunks.append((' '.join(current_chunk), separator))
current_chunk = [part]
current_length = part_len
else:
current_chunk.append(part)
current_length += part_len
else:
if current_length + token_count > max_tokens:
chunks.append((' '.join(current_chunk), separator))
current_chunk = [sentence]
current_length = token_count
else:
current_chunk.append(sentence)
current_length += token_count
if current_chunk:
chunks.append((' '.join(current_chunk), separator))
return chunks