Spaces:
Running
Running
| # Copyright 2024 The YourMT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Please see the details in the LICENSE file. | |
| """ event_codec.py: Encodes and decodes events to/from indices | |
| 🚀 Improvements: | |
| • Encoding uses a precomputed dictionary in Python. This achieves a time | |
| complexity of O(1). | |
| • Decoding has time complexity of O(1), while the original code from MT3 | |
| (Gardner et al.) has a time complexity of O(n). | |
| In practice, the performance of this optimized code was 4x faster for encoding | |
| and decoding compared to the original code. | |
| """ | |
| from typing import List, Dict, Tuple, Optional | |
| from utils.note_event_dataclasses import Event, EventRange | |
| # from bisect import bisect_right | |
| class FastCodec: | |
| """ Fast Encoding and decoding Event. """ | |
| def __init__(self, | |
| special_tokens: List[str], | |
| max_shift_steps: int, | |
| event_ranges: List[EventRange], | |
| program_vocabulary: Optional[Dict] = None, | |
| drum_vocabulary: Optional[Dict] = None, | |
| extra_tokens: List[str] = [], | |
| name: Optional[str] = None): | |
| """ Initializes the FastCodec object. | |
| :param special_tokens: List of special tokens to include in the vocabulary. | |
| :param max_shift_steps: The maximum number of steps to shift. | |
| :param event_ranges: List of EventRange objects. | |
| :param instr_vocabulary: A dictionary of instrument groups. Please see config/vocabulary.py | |
| We apply vocabulary only for encoding in training. | |
| :param drum_vocabulary: A dictionary of drum mapping. Please see config/vocabulary.py | |
| We apply vocabulary only for encoding in training. | |
| :param name: Name of the codec. | |
| """ | |
| # Store the special tokens and event ranges. | |
| self.special_tokens = special_tokens | |
| self._special_token_ranges = [] | |
| self._extra_token_ranges = [] | |
| for token in special_tokens: | |
| self._special_token_ranges.append(EventRange(token, 0, 0)) | |
| for token in extra_tokens: | |
| self._extra_token_ranges.append(EventRange(token, 0, 0)) | |
| self._shift_range = EventRange(type='shift', min_value=0, max_value=max_shift_steps - 1) | |
| self._event_ranges = self._special_token_ranges + [self._shift_range] + event_ranges + self._extra_token_ranges | |
| # Ensure all event types have unique names. | |
| assert len(self._event_ranges) == len(set([er.type for er in self._event_ranges])) | |
| # Store the name of the codec, so that we can identify it in tokenizer. | |
| self._name = name | |
| # Create dictionary for decoding | |
| self._decode_dict = {} | |
| self._encode_dict = {} | |
| self._event_type_range_dict = {} | |
| idx = 0 | |
| for er in self._event_ranges: | |
| start_idx = idx | |
| for value in range(er.min_value, er.max_value + 1): | |
| self._decode_dict[idx] = Event(type=er.type, value=value) | |
| self._encode_dict[(er.type, value)] = idx | |
| idx += 1 | |
| end_idx = idx - 1 | |
| self._event_type_range_dict[er.type] = (start_idx, end_idx) | |
| self._num_classes = idx | |
| # Create inverse vocabulary for instrument groups | |
| if program_vocabulary is not None: | |
| self.inverse_vocab_program = {} | |
| self._create_inverse_vocab_program(program_vocabulary) | |
| else: | |
| self.inverse_vocab_program = None | |
| # Create inverse vocabulary for drum mapping | |
| if drum_vocabulary is not None: | |
| self.inverse_vocab_drum = {} | |
| self._create_inverse_vocab_drum(drum_vocabulary) | |
| else: | |
| self.inverse_vocab_drum = None | |
| def num_classes(self) -> int: | |
| return self._num_classes | |
| def _create_inverse_vocab_program(self, vocab): | |
| for key, values in vocab.items(): | |
| for value in values: | |
| self.inverse_vocab_program[value] = values[0] | |
| def _create_inverse_vocab_drum(self, vocab): | |
| for key, values in vocab.items(): | |
| for value in values: | |
| self.inverse_vocab_drum[value] = values[0] | |
| def encode_event(self, event: Event) -> int: | |
| """Encode an event to an index.""" | |
| if (event.type, event.value) not in self._encode_dict: | |
| raise ValueError(f'Unknown event type: {event.type} or value: {event.value}') | |
| if event.type == 'program' and self.inverse_vocab_program is not None: | |
| # If the event value is not in the vocabulary, use the original value | |
| _event_value = self.inverse_vocab_program.get(event.value, event.value) | |
| return self._encode_dict[(event.type, _event_value)] | |
| elif event.type == 'drum' and self.inverse_vocab_drum is not None: | |
| _event_value = self.inverse_vocab_drum.get(event.value, event.value) | |
| return self._encode_dict[(event.type, _event_value)] | |
| else: | |
| return self._encode_dict[(event.type, event.value)] | |
| def event_type_range(self, event_type: str) -> Tuple[int, int]: | |
| """Return [min_id, max_id] for an event type.""" | |
| if event_type not in self._event_type_range_dict: | |
| raise ValueError(f'Unknown event type: {event_type}') | |
| return self._event_type_range_dict[event_type] | |
| def decode_event_index(self, index: int) -> Event: | |
| """Decode an event index to an Event.""" | |
| if index < 0 or index >= self.num_classes: | |
| raise ValueError(f'Unknown event index: {index}') | |
| decoded_event = self._decode_dict[index] | |
| # Create a new event with the same type and value | |
| return Event(type=decoded_event.type, value=decoded_event.value) | |
| # class FastCodec: | |
| # """ Fast Encoding and decoding Event. """ | |
| # def __init__(self, | |
| # special_tokens: List[str], | |
| # max_shift_steps: int, | |
| # event_ranges: List[EventRange], | |
| # name: Optional[str] = None): | |
| # """ Initializes the FastCodec object. | |
| # :param special_tokens: List of special tokens to include in the vocabulary. | |
| # :param max_shift_steps: The maximum number of steps to shift. | |
| # :param event_ranges: List of EventRange objects. | |
| # """ | |
| # # Store the special tokens and event ranges. | |
| # self.special_tokens = special_tokens | |
| # self._special_token_ranges = [] | |
| # for token in special_tokens: | |
| # self._special_token_ranges.append(EventRange(token, 0, 0)) | |
| # self._shift_range = EventRange( | |
| # type='shift', min_value=0, max_value=max_shift_steps - 1) | |
| # self._event_ranges = self._special_token_ranges + [self._shift_range | |
| # ] + event_ranges | |
| # # Ensure all event types have unique names. | |
| # assert len(self._event_ranges) == len( | |
| # set([er.type for er in self._event_ranges])) | |
| # # Precompute cumulative offsets. | |
| # self._cumulative_offsets = [0] | |
| # for er in self._event_ranges: | |
| # self._cumulative_offsets.append(self._cumulative_offsets[-1] + | |
| # er.max_value - er.min_value + 1) | |
| # # Create event type to range and offset mapping. | |
| # self._event_type_to_range_offset = {} | |
| # for er, offset in zip(self._event_ranges, self._cumulative_offsets): | |
| # self._event_type_to_range_offset[er.type] = (er, offset) | |
| # # Store the name of the codec, so that we can identify it in tokenizer. | |
| # self._name = name | |
| # @property | |
| # def num_classes(self) -> int: | |
| # return self._cumulative_offsets[-1] | |
| # def encode_event(self, event: Event) -> int: | |
| # """Encode an event to an index.""" | |
| # if event.type not in self._event_type_to_range_offset: | |
| # raise ValueError(f'Unknown event type: {event.type}') | |
| # er, offset = self._event_type_to_range_offset[event.type] | |
| # if not er.min_value <= event.value <= er.max_value: | |
| # raise ValueError( | |
| # f'Event value {event.value} is not within valid range ' | |
| # f'[{er.min_value}, {er.max_value}] for type {event.type}') | |
| # return offset + event.value - er.min_value | |
| # def event_type_range(self, event_type: str) -> Tuple[int, int]: | |
| # """Return [min_id, max_id] for an event type.""" | |
| # offset = 0 | |
| # for er in self._event_ranges: | |
| # if event_type == er.type: | |
| # return offset, offset + (er.max_value - er.min_value) | |
| # offset += er.max_value - er.min_value + 1 | |
| # raise ValueError(f'Unknown event type: {event_type}') | |
| # def decode_event_index(self, index: int) -> Event: | |
| # """Decode an event index to an Event.""" | |
| # if index < 0 or index >= self.num_classes: | |
| # raise ValueError(f'Unknown event index: {index}') | |
| # # Find the event range using binary search. | |
| # range_idx = bisect_right(self._cumulative_offsets, index) - 1 | |
| # er = self._event_ranges[range_idx] | |
| # offset = self._cumulative_offsets[range_idx] | |
| # return Event(type=er.type, value=er.min_value + index - offset) | |
| # Original code | |
| # | |
| # https://github.com/magenta/mt3/blob/main/mt3/event_codec.py | |
| # Copyright 2022 The MT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # class Codec: | |
| # """Encode and decode events.""" | |
| # | |
| # def __init__(self, special_tokens: List[str], max_shift_steps: int, | |
| # event_ranges: List[EventRange]): | |
| # """Define Codec. | |
| # """ | |
| # self._special_token_ranges = [] | |
| # for token in special_tokens: | |
| # self._special_token_ranges.append(EventRange(token, 0, 0)) | |
| # self._shift_range = EventRange( | |
| # type='shift', min_value=0, max_value=max_shift_steps - 1) | |
| # self._event_ranges = self._special_token_ranges + [self._shift_range | |
| # ] + event_ranges | |
| # # Ensure all event types have unique names. | |
| # assert len(self._event_ranges) == len( | |
| # set([er.type for er in self._event_ranges])) | |
| # @property | |
| # def num_classes(self) -> int: | |
| # return sum(er.max_value - er.min_value + 1 for er in self._event_ranges) | |
| # def encode_event(self, event: Event) -> int: | |
| # """Encode an event to an index.""" | |
| # offset = 0 | |
| # for er in self._event_ranges: | |
| # if event.type == er.type: | |
| # if not er.min_value <= event.value <= er.max_value: | |
| # raise ValueError( | |
| # f'Event value {event.value} is not within valid range ' | |
| # f'[{er.min_value}, {er.max_value}] for type {event.type}' | |
| # ) | |
| # return offset + event.value - er.min_value | |
| # offset += er.max_value - er.min_value + 1 | |
| # raise ValueError(f'Unknown event type: {event.type}') | |
| # def event_type_range(self, event_type: str) -> Tuple[int, int]: | |
| # """Return [min_id, max_id] for an event type.""" | |
| # offset = 0 | |
| # for er in self._event_ranges: | |
| # if event_type == er.type: | |
| # return offset, offset + (er.max_value - er.min_value) | |
| # offset += er.max_value - er.min_value + 1 | |
| # raise ValueError(f'Unknown event type: {event_type}') | |
| # def decode_event_index(self, index: int) -> Event: | |
| # """Decode an event index to an Event.""" | |
| # offset = 0 | |
| # for er in self._event_ranges: | |
| # if offset <= index <= offset + er.max_value - er.min_value: | |
| # return Event(type=er.type, value=er.min_value + index - offset) | |
| # offset += er.max_value - er.min_value + 1 | |
| # raise ValueError(f'Unknown event index: {index}') | |