Spaces:
Paused
Paused
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import pathlib | |
| import tempfile | |
| import unittest | |
| import numpy as np | |
| import torch | |
| from datasets import Dataset, Image, Sequence, load_dataset | |
| from parameterized import parameterized | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoProcessor, | |
| AutoTokenizer, | |
| LlavaForConditionalGeneration, | |
| is_vision_available, | |
| ) | |
| from transformers.testing_utils import require_flash_attn, require_peft, require_vision | |
| from transformers.utils import is_peft_available | |
| from trl import SFTConfig, SFTTrainer | |
| from trl.trainer.sft_trainer import DataCollatorForLanguageModeling | |
| if is_peft_available(): | |
| from peft import LoraConfig, PeftModel, get_peft_model | |
| if is_vision_available(): | |
| from PIL import Image as PILImage | |
| def formatting_prompts_func(example): | |
| text = f"### Question: {example['question']}\n ### Answer: {example['answer']}" | |
| return text | |
| def formatting_func_for_pretokenized(example): | |
| return example["input_ids"] | |
| class TestDataCollatorForLanguageModeling(unittest.TestCase): | |
| def test_basic_padding(self): | |
| """Test basic padding functionality without completion masks.""" | |
| self.collator = DataCollatorForLanguageModeling(pad_token_id=0) | |
| examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] | |
| result = self.collator(examples) | |
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) | |
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) | |
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) | |
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) | |
| def test_completion_mask(self): | |
| """Test completion mask functionality.""" | |
| self.collator = DataCollatorForLanguageModeling(pad_token_id=0) | |
| examples = [ | |
| {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, | |
| {"input_ids": [4, 5], "completion_mask": [0, 1]}, | |
| ] | |
| result = self.collator(examples) | |
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) | |
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) | |
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) | |
| torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]])) | |
| def test_completion_only_loss_disabled(self): | |
| """Test behavior when completion_only_loss is disabled.""" | |
| collator = DataCollatorForLanguageModeling(pad_token_id=0, completion_only_loss=False) | |
| examples = [ | |
| {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, | |
| {"input_ids": [4, 5], "completion_mask": [0, 1]}, | |
| ] | |
| result = collator(examples) | |
| # Labels should not be masked when completion_only_loss=False | |
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) | |
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) | |
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) | |
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) | |
| def test_padding_free_mode(self): | |
| """Test padding-free mode where sequences are concatenated.""" | |
| collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) | |
| examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] | |
| result = collator(examples) | |
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) | |
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]])) | |
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) | |
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5]])) | |
| def test_padding_free_with_completion_mask(self): | |
| """Test padding-free mode with completion masks.""" | |
| collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) | |
| examples = [ | |
| {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, | |
| {"input_ids": [4, 5], "completion_mask": [1, 1]}, | |
| ] | |
| result = collator(examples) | |
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) | |
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]])) | |
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) | |
| torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, 4, 5]])) | |
| def test_packing_drops_attention_mask_for_flash_attention(self): | |
| """Test that when using packing with position_ids, attention_mask is dropped with fa2.""" | |
| collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, return_position_ids=True) | |
| # Simulate packed sequences with position_ids that restart (typical of BFD packing) | |
| examples = [ | |
| { | |
| "input_ids": [1, 2, 3, 4, 5, 6, 7, 8], # Packed: [1,2,3] + [4,5] + [6,7,8] | |
| "seq_lengths": [3, 2, 3], | |
| } | |
| ] | |
| result = collator(examples) | |
| # Verify that attention_mask is NOT present - this allows FlashAttention to use position_ids | |
| self.assertNotIn("attention_mask", result, "attention_mask should be dropped for packing with position_ids") | |
| # Verify essential keys are present | |
| self.assertIn("input_ids", result) | |
| self.assertIn("position_ids", result) | |
| self.assertIn("labels", result) | |
| # Verify the data is correctly processed | |
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]])) | |
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 1, 2]])) | |
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]])) | |
| def test_padding_free_without_position_ids_keeps_attention_mask(self): | |
| """ | |
| Test that padding_free mode without explicit position_ids still creates attention_mask. | |
| """ | |
| collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, return_position_ids=True) | |
| # Examples without position_ids (not packed) | |
| examples = [{"input_ids": [1, 2, 3, 4, 5]}] | |
| result = collator(examples) | |
| # Should still have attention_mask since no packed position_ids | |
| self.assertIn("attention_mask", result, "attention_mask should be present when no packed position_ids") | |
| self.assertIn("position_ids", result) | |
| self.assertIn("input_ids", result) | |
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) | |
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]])) | |
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3, 4]])) | |
| def test_pad_to_multiple_of(self): | |
| """Test padding to multiple of specified value.""" | |
| collator = DataCollatorForLanguageModeling(pad_token_id=0, pad_to_multiple_of=4) | |
| examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] | |
| result = collator(examples) | |
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])) | |
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]])) | |
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]])) | |
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]])) | |
| def test_custom_position_ids(self): | |
| """Test handling of custom position IDs in examples.""" | |
| self.collator = DataCollatorForLanguageModeling(pad_token_id=0) | |
| examples = [{"input_ids": [1, 2, 3], "seq_lengths": [1, 2]}, {"input_ids": [4, 5], "seq_lengths": [2]}] | |
| result = self.collator(examples) | |
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) | |
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) | |
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 0, 1], [0, 1, 0]])) | |
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) | |
| def test_single_example(self): | |
| """Test collator with a single example.""" | |
| self.collator = DataCollatorForLanguageModeling(pad_token_id=0) | |
| examples = [{"input_ids": [1, 2, 3, 4]}] | |
| result = self.collator(examples) | |
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]])) | |
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]])) | |
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3]])) | |
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]])) | |
| def test_different_pad_token_id(self): | |
| """Test with different pad token ID.""" | |
| collator = DataCollatorForLanguageModeling(pad_token_id=999) | |
| examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] | |
| result = collator(examples) | |
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]])) | |
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) | |
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) | |
| torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) | |
| def test_assistant_masks(self): | |
| """Test handling of assistant masks in examples.""" | |
| self.collator = DataCollatorForLanguageModeling(pad_token_id=0) | |
| examples = [ | |
| {"input_ids": [1, 2, 3], "assistant_masks": [0, 1, 1]}, | |
| {"input_ids": [4, 5], "assistant_masks": [0, 1]}, | |
| ] | |
| result = self.collator(examples) | |
| torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) | |
| torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) | |
| torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) | |
| torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]])) | |
| class SFTTrainerTester(unittest.TestCase): | |
| r""" """ | |
| def setUp(self): | |
| self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | |
| self.model = AutoModelForCausalLM.from_pretrained(self.model_id) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
| self.dummy_dataset = Dataset.from_dict( | |
| { | |
| "question": [ | |
| "Does llamas know how to code?", | |
| "Does llamas know how to fly?", | |
| "Does llamas know how to talk?", | |
| "Does llamas know how to code?", | |
| "Does llamas know how to fly?", | |
| "Does llamas know how to talk?", | |
| "Does llamas know how to swim?", | |
| ], | |
| "answer": [ | |
| "Yes, llamas are very good at coding.", | |
| "No, llamas can't fly.", | |
| "Yes, llamas are very good at talking.", | |
| "Yes, llamas are very good at coding.", | |
| "No, llamas can't fly.", | |
| "Yes, llamas are very good at talking.", | |
| "No, llamas can't swim.", | |
| ], | |
| "text": [ | |
| "### Question: Does llamas know how to code?\n ### Answer: Yes, llamas are very good at coding.", | |
| "### Question: Does llamas know how to fly?\n ### Answer: No, llamas can't fly.", | |
| "### Question: Does llamas know how to talk?\n ### Answer: Yes, llamas are very good at talking.", | |
| "### Question: Does llamas know how to code?\n ### Answer: Yes, llamas are very good at coding.", | |
| "### Question: Does llamas know how to fly?\n ### Answer: No, llamas can't fly.", | |
| "### Question: Does llamas know how to talk?\n ### Answer: Yes, llamas are very good at talking.", | |
| "### Question: Does llamas know how to swim?\n ### Answer: No, llamas can't swim.", | |
| ], | |
| } | |
| ) | |
| self.conversational_lm_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") | |
| self.standard_prompt_completion_dataset = load_dataset( | |
| "trl-internal-testing/zen", "standard_prompt_completion" | |
| ) | |
| if is_vision_available(): | |
| self.dummy_vsft_instruction_dataset = Dataset.from_dict( | |
| { | |
| "messages": [ | |
| [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}], | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": [{"type": "text", "text": "It is random noise."}], | |
| }, | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": "Oh ye, you are right, what is 1+1"}], | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": [{"type": "text", "text": "2"}], | |
| }, | |
| ], | |
| [ | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}], | |
| }, | |
| { | |
| "role": "assistant", | |
| "content": [{"type": "text", "text": "It is random noise."}], | |
| }, | |
| ], | |
| ], | |
| "images": [ | |
| [PILImage.fromarray((np.random.rand(40, 50, 3) * 255).astype("uint8")).convert("RGBA")], | |
| [PILImage.fromarray((np.random.rand(50, 60, 3) * 255).astype("uint8")).convert("RGBA")], | |
| ], | |
| } | |
| ) | |
| self.dummy_vsft_instruction_dataset.cast_column("images", Sequence(Image())) | |
| self.dummy_vsft_instruction_dataset = self.dummy_vsft_instruction_dataset.cast_column( | |
| "images", Sequence(Image()) | |
| ) | |
| def test_uncorrect_data(self): | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Shoud work as SFTTrainer natively supports conversational lm dataset | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| max_length=32, # make sure there is at least 1 packed sequence | |
| packing=True, | |
| report_to="none", | |
| ) | |
| _ = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| args=training_args, | |
| train_dataset=self.conversational_lm_dataset["train"], | |
| ) | |
| # Same, but without packing | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| packing=False, | |
| report_to="none", | |
| ) | |
| _ = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| args=training_args, | |
| train_dataset=self.conversational_lm_dataset["train"], | |
| ) | |
| # Same, but with packing with `max_length` | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| max_length=16, # make sure there is at least 1 packed sequence | |
| packing=True, | |
| report_to="none", | |
| ) | |
| _ = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| args=training_args, | |
| train_dataset=self.standard_prompt_completion_dataset["train"], | |
| ) | |
| # Same but with prompt-completion dataset | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| packing=False, | |
| report_to="none", | |
| ) | |
| _ = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| args=training_args, | |
| train_dataset=self.standard_prompt_completion_dataset["train"], | |
| ) | |
| # Should work as dummy dataset are supported with a formatting function | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| max_length=32, # make sure there is at least 1 packed sequence | |
| packing=True, | |
| report_to="none", | |
| ) | |
| _ = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| args=training_args, | |
| train_dataset=self.dummy_dataset, | |
| formatting_func=formatting_prompts_func, | |
| ) | |
| def test_with_model_(self): | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| max_length=16, | |
| packing=True, | |
| report_to="none", | |
| ) | |
| trainer = SFTTrainer( | |
| model=self.model, | |
| args=training_args, | |
| train_dataset=self.dummy_dataset, | |
| ) | |
| trainer.train() | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # with formatting_func + packed | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| max_length=16, | |
| packing=True, | |
| report_to="none", | |
| ) | |
| trainer = SFTTrainer( | |
| model=self.model, | |
| args=training_args, | |
| train_dataset=self.dummy_dataset, | |
| formatting_func=formatting_prompts_func, | |
| ) | |
| trainer.train() | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| max_length=16, | |
| report_to="none", | |
| ) | |
| trainer = SFTTrainer( | |
| model=self.model, | |
| args=training_args, | |
| train_dataset=self.dummy_dataset, | |
| ) | |
| trainer.train() | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| def test_only_train_packing(self): | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| gradient_checkpointing=True, | |
| packing=True, | |
| max_length=128, # make sure there is at least 1 packed sequence | |
| eval_packing=False, | |
| report_to="none", | |
| ) | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| args=training_args, | |
| train_dataset=self.conversational_lm_dataset["train"], | |
| eval_dataset=self.conversational_lm_dataset["test"], | |
| ) | |
| self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs | |
| self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"])) | |
| def test_eval_packing(self): | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| max_length=128, # make sure there is at least 1 packed sequence | |
| packing=True, | |
| report_to="none", | |
| ) | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| args=training_args, | |
| train_dataset=self.conversational_lm_dataset["train"], | |
| eval_dataset=self.conversational_lm_dataset["test"], | |
| ) | |
| self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs | |
| self.assertEqual(len(trainer.eval_dataset["input_ids"]), 1) # w/ this dataset, we end up with 6 seqs | |
| def test_no_packing(self): | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| max_length=128, # make sure there is at least 1 packed sequence | |
| packing=False, | |
| report_to="none", | |
| ) | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| args=training_args, | |
| train_dataset=self.conversational_lm_dataset["train"], | |
| eval_dataset=self.conversational_lm_dataset["test"], | |
| ) | |
| self.assertEqual(len(trainer.train_dataset["input_ids"]), len(self.conversational_lm_dataset["train"])) | |
| self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"])) | |
| def test_skip_prepare_dataset(self): | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| remove_unused_columns=False, | |
| dataset_kwargs={"skip_prepare_dataset": True}, | |
| report_to="none", | |
| ) | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| args=training_args, | |
| train_dataset=self.dummy_vsft_instruction_dataset, | |
| ) | |
| self.assertEqual(trainer.train_dataset.features, self.dummy_vsft_instruction_dataset.features) | |
| def test_skip_prepare_dataset_with_no_packing(self): | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| per_device_train_batch_size=2, | |
| remove_unused_columns=False, | |
| packing=False, | |
| dataset_kwargs={"skip_prepare_dataset": True}, | |
| report_to="none", | |
| ) | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| args=training_args, | |
| train_dataset=self.dummy_dataset, | |
| ) | |
| self.assertEqual(trainer.train_dataset.features, self.dummy_dataset.features) | |
| def test_llava(self): | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| remove_unused_columns=False, | |
| dataset_kwargs={"skip_prepare_dataset": True}, | |
| report_to="none", | |
| ) | |
| tiny_llava = LlavaForConditionalGeneration.from_pretrained( | |
| "trl-internal-testing/tiny-LlavaForConditionalGeneration" | |
| ) | |
| processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlavaForConditionalGeneration") | |
| processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious | |
| user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's | |
| questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for | |
| item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' | |
| %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if | |
| add_generation_prompt %}ASSISTANT: {% endif %}""" | |
| def collate_fn(examples): | |
| # Get the texts and images, and apply the chat template | |
| texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples] | |
| images = [example["images"][0] for example in examples] | |
| # Tokenize the texts and process the images | |
| batch = processor(images=images, text=texts, return_tensors="pt", padding=True) | |
| # The labels are the input_ids, and we mask the padding tokens in the loss computation | |
| labels = batch["input_ids"].clone() | |
| labels[labels == processor.tokenizer.pad_token_id] = -100 | |
| batch["labels"] = labels | |
| return batch | |
| trainer = SFTTrainer( | |
| model=tiny_llava, | |
| args=training_args, | |
| data_collator=collate_fn, | |
| train_dataset=self.dummy_vsft_instruction_dataset, | |
| ) | |
| trainer.train() | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # This new tester aims to replace the first one at some point | |
| class SFTTrainerTester2(unittest.TestCase): | |
| def test_train(self, model_id): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
| trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| # Special case for harmony | |
| def test_train_gpt_oss(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/harmony", "language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-GptOssForCausalLM", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_model(self): | |
| # Instantiate the model | |
| model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
| trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_model_torch_dtype(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, model_init_kwargs={"torch_dtype": torch.float16}, report_to="none" | |
| ) | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| # Check the torch dtype | |
| self.assertEqual(new_param.dtype, torch.float16) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_dense_with_peft_config(self): | |
| # Get the base model parameter names | |
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
| trainer = SFTTrainer( | |
| model=model_id, | |
| args=training_args, | |
| train_dataset=dataset, | |
| peft_config=LoraConfig(), | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the peft params have changed and the base model params have not changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| if n in base_param_names: # We expect the base model parameters to be the same | |
| self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") | |
| elif ( | |
| "base_layer" not in n | |
| ): # We expect the peft parameters to be different (except for the base layer) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_moe_with_peft_config(self): | |
| # Get the base model parameter names | |
| model_id = "trl-internal-testing/tiny-GptOssForCausalLM" | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
| trainer = SFTTrainer( | |
| model=model_id, | |
| args=training_args, | |
| train_dataset=dataset, | |
| peft_config=LoraConfig(target_parameters=["mlp.experts.down_proj", "mlp.experts.gate_up_proj"]), | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the peft params have changed and the base model params have not changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| if n in base_param_names: # We expect the base model parameters to be the same | |
| self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") | |
| elif ( | |
| "base_layer" not in n | |
| ): # We expect the peft parameters to be different (except for the base layer) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_peft_model(self): | |
| # Get the base model | |
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| # Get the base model parameter names | |
| base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] | |
| # Turn the model into a peft model | |
| lora_config = LoraConfig() | |
| model = get_peft_model(model, lora_config) | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
| trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the peft params have changed and the base model params have not changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| if n in base_param_names: # We expect the base model parameters to be the same | |
| self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") | |
| elif ( | |
| "base_layer" not in n | |
| ): # We expect the peft parameters to be different (except for the base layer) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_dense_with_peft_config_and_gradient_checkpointing(self): | |
| # Get the base model parameter names | |
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, gradient_checkpointing=True, report_to="none") | |
| trainer = SFTTrainer( | |
| model=model_id, | |
| args=training_args, | |
| train_dataset=dataset, | |
| peft_config=LoraConfig(), | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the peft params have changed and the base model params have not changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| if n in base_param_names: # We expect the base model parameters to be the same | |
| self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") | |
| elif ( | |
| "base_layer" not in n | |
| ): # We expect the peft parameters to be different (except for the base layer) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_moe_with_peft_config_and_gradient_checkpointing(self): | |
| # Get the base model parameter names | |
| model_id = "trl-internal-testing/tiny-GptOssForCausalLM" | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, gradient_checkpointing=True, report_to="none") | |
| trainer = SFTTrainer( | |
| model=model_id, | |
| args=training_args, | |
| train_dataset=dataset, | |
| peft_config=LoraConfig(target_parameters=["mlp.experts.down_proj", "mlp.experts.gate_up_proj"]), | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the peft params have changed and the base model params have not changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| if n in base_param_names: # We expect the base model parameters to be the same | |
| self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") | |
| elif ( | |
| "base_layer" not in n | |
| ): # We expect the peft parameters to be different (except for the base layer) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_with_peft_model_and_gradient_checkpointing(self): | |
| # Get the base model parameter names | |
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] | |
| model = get_peft_model(model, LoraConfig()) | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, gradient_checkpointing=True, report_to="none") | |
| trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) | |
| # Verify model is a PeftModel | |
| self.assertIsInstance(trainer.model, PeftModel) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the peft params have changed and the base model params have not changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| if n in base_param_names: # We expect the base model parameters to be the same | |
| self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") | |
| elif ( | |
| "base_layer" not in n | |
| ): # We expect the peft parameters to be different (except for the base layer) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_with_non_chatml_conversational_data(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") | |
| # Rename role/content to from/value to ensure SFT works with non-chatML conversational data | |
| def rename_fields(example: list[dict]): | |
| return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]} | |
| dataset = dataset.map(rename_fields, remove_columns="messages") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_with_pretokenized_data(self): | |
| # Get the dataset | |
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| def tokenize_example(example): | |
| return tokenizer(example["text"]) | |
| # Apply tokenization | |
| tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"]) | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
| trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_with_iterable_dataset(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train", streaming=True) | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, max_steps=3, report_to="none") | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_padding_free(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| padding_free=True, | |
| model_init_kwargs={"attn_implementation": "flash_attention_2"}, | |
| bf16=True, # flash_attention_2 only supports bf16 and fp16 | |
| report_to="none", | |
| ) | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_packing(self, packing_strategy): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, packing=True, packing_strategy=packing_strategy, max_length=10, report_to="none" | |
| ) | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_with_chat_template_kwargs(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
| tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") | |
| # The following template is a simplified version of the Qwen chat template, where an additional argument | |
| # `role_capital` is used to control the capitalization of roles. | |
| tokenizer.chat_template = '{%- if messages[0]["role"] == "system" -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\n" + messages[0]["content"] + "<|im_end|>\\n" }}{%- else -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n" }}{%- endif -%}{%- for message in messages -%} {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) -%} {{ "<|im_start|>" + (message.role.upper() if role_capital else message.role) + "\\n" + message.content + "<|im_end|>\\n" }} {%- elif message.role == "assistant" -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") }} {%- if message.content -%} {{ "\\n" + message.content }} {%- endif -%} {{ "<|im_end|>\\n" }} {%- elif message.role == "tool" -%} {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") -%} {{ "<|im_start|>" + ("USER" if role_capital else "user") }} {%- endif -%} {{ "\\n<tool_response>\\n" + message.content + "\\n</tool_response>" }} {%- if loop.last or (messages[loop.index0 + 1].role != "tool") -%} {{ "<|im_end|>\\n" }} {%- endif -%} {%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") + "\\n" }}{%- endif -%}' | |
| dataset.add_column("chat_template_kwargs", [{"role_capital": bool(i % 2)} for i in range(len(dataset))]) | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_assistant_only(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, assistant_only_loss=True, report_to="none") | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_completion_only(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, completion_only_loss=True, report_to="none") | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_completion_only_harmony(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/harmony", "prompt_completion", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, completion_only_loss=True, report_to="none") | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-GptOssForCausalLM", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_assistant_only_and_completion_only(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train") | |
| # To test this case, we need to add user messages in the completion (they'll be masked in the loss) | |
| def add_to_completion(example): | |
| example["completion"].append(example["prompt"][0]) | |
| example["completion"].append(example["completion"][0]) | |
| return example | |
| dataset = dataset.map(add_to_completion) | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, assistant_only_loss=True, completion_only_loss=True, report_to="none" | |
| ) | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_assistant_only_iterable_dataset(self): | |
| # Get the dataset | |
| dataset = load_dataset( | |
| "trl-internal-testing/zen", "conversational_language_modeling", split="train", streaming=True | |
| ) | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, assistant_only_loss=True, max_steps=3, report_to="none") | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_with_set_chat_template_from_model(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none") | |
| # trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_with_set_chat_template_from_path(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, | |
| chat_template_path=str(pathlib.Path(__file__).parent / "data" / "template.jinja"), | |
| report_to="none", | |
| ) | |
| # trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| # Check that the template saved in the output directory is the same as the one used for training | |
| template_path = pathlib.Path(tmp_dir) / "checkpoint-9" / "chat_template.jinja" | |
| self.assertTrue(template_path.exists(), f"Chat template not found at {template_path}") | |
| with open(template_path) as f: | |
| template_content = f.read() | |
| with open(training_args.chat_template_path) as f: | |
| original_template_content = f.read() | |
| self.assertEqual( | |
| template_content, original_template_content, "Chat template content does not match the original" | |
| ) | |
| def test_train_toolcall_data(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/toolcall", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, report_to="none") | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_train_with_eval(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none") | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset=dataset["test"], | |
| ) | |
| # Train the model | |
| trainer.train() | |
| # Check that the eval loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) | |
| def test_train_with_multiple_eval_dataset(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none") | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| args=training_args, | |
| train_dataset=dataset["train"], | |
| eval_dataset={"data1": dataset["test"], "data2": dataset["test"]}, | |
| ) | |
| # Train the model | |
| trainer.train() | |
| # Check that the eval losses are not None | |
| self.assertIsNotNone(trainer.state.log_history[-3]["eval_data1_loss"]) | |
| self.assertIsNotNone(trainer.state.log_history[-2]["eval_data2_loss"]) | |
| def test_train_with_gradient_checkpointing(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig(output_dir=tmp_dir, gradient_checkpointing=True, report_to="none") | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |
| def test_tag_added(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| # Initialize the trainer | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| train_dataset=dataset, | |
| ) | |
| for tag in ["sft", "trl"]: | |
| self.assertIn(tag, trainer.model.model_tags) | |
| def test_tag_added_peft(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| # Initialize the trainer | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", | |
| train_dataset=dataset, | |
| peft_config=LoraConfig(), | |
| ) | |
| for tag in ["sft", "trl"]: | |
| self.assertIn(tag, trainer.model.model_tags) | |
| def test_train_with_torch_dtype(self): | |
| # Get the dataset | |
| dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| # Initialize the trainer | |
| training_args = SFTConfig( | |
| output_dir=tmp_dir, model_init_kwargs={"torch_dtype": torch.float16}, report_to="none" | |
| ) | |
| trainer = SFTTrainer( | |
| model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset | |
| ) | |
| # Save the initial parameters to compare them later | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| # Train the model | |
| trainer.train() | |
| # Check that the training loss is not None | |
| self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) | |
| # Check the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | |