Spaces:
Running
Running
| # Copyright 2024 the LlamaFactory team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Sequence | |
| import torch | |
| from transformers import DataCollatorForSeq2Seq | |
| class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq): | |
| r""" | |
| Data collator for pairwise data. | |
| """ | |
| def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | |
| r""" | |
| Pads batched data to the longest sequence in the batch. | |
| We generate 2 * n examples where the first n examples represent chosen examples and | |
| the last n examples represent rejected examples. | |
| """ | |
| concatenated_features = [] | |
| for key in ("chosen", "rejected"): | |
| for feature in features: | |
| target_feature = { | |
| "input_ids": feature["{}_input_ids".format(key)], | |
| "attention_mask": feature["{}_attention_mask".format(key)], | |
| "labels": feature["{}_labels".format(key)], | |
| } | |
| if "pixel_values" in feature: | |
| target_feature["pixel_values"] = feature["pixel_values"] | |
| if "{}_token_type_ids".format(key) in feature: | |
| target_feature["token_type_ids"] = feature["{}_token_type_ids".format(key)] | |
| concatenated_features.append(target_feature) | |
| return super().__call__(concatenated_features) | |
| class KTODataCollatorWithPadding(DataCollatorForSeq2Seq): | |
| r""" | |
| Data collator for KTO data. | |
| """ | |
| def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | |
| target_features = [] | |
| kl_features = [] | |
| kto_tags = [] | |
| for feature in features: | |
| target_feature = { | |
| "input_ids": feature["input_ids"], | |
| "attention_mask": feature["attention_mask"], | |
| "labels": feature["labels"], | |
| } | |
| kl_feature = { | |
| "input_ids": feature["kl_input_ids"], | |
| "attention_mask": feature["kl_attention_mask"], | |
| "labels": feature["kl_labels"], | |
| } | |
| if "pixel_values" in feature: | |
| target_feature["pixel_values"] = feature["pixel_values"] | |
| if "token_type_ids" in feature: | |
| target_feature["token_type_ids"] = feature["token_type_ids"] | |
| kl_feature["token_type_ids"] = feature["kl_token_type_ids"] | |
| target_features.append(target_feature) | |
| kl_features.append(kl_feature) | |
| kto_tags.append(feature["kto_tags"]) | |
| batch = super().__call__(target_features) | |
| kl_batch = super().__call__(kl_features) | |
| batch["kl_input_ids"] = kl_batch["input_ids"] | |
| batch["kl_attention_mask"] = kl_batch["attention_mask"] | |
| batch["kl_labels"] = kl_batch["labels"] | |
| if "token_type_ids" in batch: | |
| batch["kl_token_type_ids"] = kl_batch["token_type_ids"] | |
| batch["kto_tags"] = torch.tensor(kto_tags) | |
| return batch | |