aapot
		
	commited on
		
		
					Commit 
							
							·
						
						947b4f4
	
1
								Parent(s):
							
							5f6eb94
								
Update optimizers
Browse files- EasyLM/data.py +7 -0
- EasyLM/optimizers.py +47 -3
- pretrain_llama_7b.sh +2 -1
    	
        EasyLM/data.py
    CHANGED
    
    | @@ -153,6 +153,7 @@ class HuggingfaceDataset(object): | |
| 153 | 
             
                    config.start_seek_loc = 0
         | 
| 154 | 
             
                    config.tokens_count_at_start = 0
         | 
| 155 | 
             
                    config.batch_token_dtype = 'i4'
         | 
|  | |
| 156 |  | 
| 157 | 
             
                    if updates is not None:
         | 
| 158 | 
             
                        config.update(ConfigDict(updates).copy_and_resolve_references())
         | 
| @@ -173,6 +174,8 @@ class HuggingfaceDataset(object): | |
| 173 | 
             
                    self._dataset_loc = self.config.start_seek_loc
         | 
| 174 | 
             
                    self._total_tokens = self.config.tokens_count_at_start
         | 
| 175 | 
             
                    self._index = 0
         | 
|  | |
|  | |
| 176 |  | 
| 177 | 
             
                def __iter__(self):
         | 
| 178 | 
             
                    if not self._eval_dataset and self._train_epochs > 0:
         | 
| @@ -236,6 +239,10 @@ class HuggingfaceDataset(object): | |
| 236 | 
             
                    self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
         | 
| 237 | 
             
                    self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
         | 
| 238 | 
             
                    self._train_epochs = state_dict.get('epochs', 0)
         | 
|  | |
|  | |
|  | |
|  | |
| 239 |  | 
| 240 | 
             
                @property
         | 
| 241 | 
             
                def seq_length(self):
         | 
|  | |
| 153 | 
             
                    config.start_seek_loc = 0
         | 
| 154 | 
             
                    config.tokens_count_at_start = 0
         | 
| 155 | 
             
                    config.batch_token_dtype = 'i4'
         | 
| 156 | 
            +
                    config.reset_dataset_loc = False
         | 
| 157 |  | 
| 158 | 
             
                    if updates is not None:
         | 
| 159 | 
             
                        config.update(ConfigDict(updates).copy_and_resolve_references())
         | 
|  | |
| 174 | 
             
                    self._dataset_loc = self.config.start_seek_loc
         | 
| 175 | 
             
                    self._total_tokens = self.config.tokens_count_at_start
         | 
| 176 | 
             
                    self._index = 0
         | 
| 177 | 
            +
                    self.reset_dataset_loc = self.config.reset_dataset_loc
         | 
| 178 | 
            +
             | 
| 179 |  | 
| 180 | 
             
                def __iter__(self):
         | 
| 181 | 
             
                    if not self._eval_dataset and self._train_epochs > 0:
         | 
|  | |
| 239 | 
             
                    self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
         | 
| 240 | 
             
                    self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
         | 
| 241 | 
             
                    self._train_epochs = state_dict.get('epochs', 0)
         | 
| 242 | 
            +
                    if self.reset_dataset_loc:
         | 
| 243 | 
            +
                        self._dataset_loc = 0
         | 
| 244 | 
            +
                        self._train_epochs = 0
         | 
| 245 | 
            +
             | 
| 246 |  | 
| 247 | 
             
                @property
         | 
| 248 | 
             
                def seq_length(self):
         | 
    	
        EasyLM/optimizers.py
    CHANGED
    
    | @@ -205,8 +205,9 @@ class LionOptimizerFactory(object): | |
| 205 | 
             
                    config.init_lr = 0.0
         | 
| 206 | 
             
                    config.end_lr = 0.0001
         | 
| 207 | 
             
                    config.lr = 0.001
         | 
| 208 | 
            -
                    config.lr_warmup_steps =  | 
| 209 | 
            -
                    config. | 
|  | |
| 210 | 
             
                    config.b1 = 0.9
         | 
| 211 | 
             
                    config.b2 = 0.98
         | 
| 212 | 
             
                    config.clip_gradient = 1.0
         | 
| @@ -243,6 +244,43 @@ class LionOptimizerFactory(object): | |
| 243 | 
             
                            ],
         | 
| 244 | 
             
                            [config.lr_warmup_steps],
         | 
| 245 | 
             
                        )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 246 | 
             
                    elif config.lr_schedule_type == "exponential_decay":
         | 
| 247 | 
             
                        learning_rate_schedule = optax.exponential_decay(
         | 
| 248 | 
             
                                    init_value=config.lr, 
         | 
| @@ -252,8 +290,14 @@ class LionOptimizerFactory(object): | |
| 252 | 
             
                                    staircase=False, 
         | 
| 253 | 
             
                                    end_value=config.end_lr,
         | 
| 254 | 
             
                        )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 255 | 
             
                    else:
         | 
| 256 | 
            -
                        raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant",  | 
| 257 |  | 
| 258 | 
             
                    optimizer_info = dict(
         | 
| 259 | 
             
                        learning_rate_schedule=learning_rate_schedule,
         | 
|  | |
| 205 | 
             
                    config.init_lr = 0.0
         | 
| 206 | 
             
                    config.end_lr = 0.0001
         | 
| 207 | 
             
                    config.lr = 0.001
         | 
| 208 | 
            +
                    config.lr_warmup_steps = 60000
         | 
| 209 | 
            +
                    config.lr_constant_steps = 840000
         | 
| 210 | 
            +
                    config.lr_decay_steps = 100000
         | 
| 211 | 
             
                    config.b1 = 0.9
         | 
| 212 | 
             
                    config.b2 = 0.98
         | 
| 213 | 
             
                    config.clip_gradient = 1.0
         | 
|  | |
| 244 | 
             
                            ],
         | 
| 245 | 
             
                            [config.lr_warmup_steps],
         | 
| 246 | 
             
                        )
         | 
| 247 | 
            +
                    elif config.lr_schedule_type == "warmup_constant_linear_decay":
         | 
| 248 | 
            +
                        learning_rate_schedule = optax.join_schedules(
         | 
| 249 | 
            +
                            [
         | 
| 250 | 
            +
                                optax.linear_schedule(
         | 
| 251 | 
            +
                                    init_value=config.init_lr,
         | 
| 252 | 
            +
                                    end_value=config.lr,
         | 
| 253 | 
            +
                                    transition_steps=config.lr_warmup_steps,
         | 
| 254 | 
            +
                                ),
         | 
| 255 | 
            +
                                optax.constant_schedule(config.lr),
         | 
| 256 | 
            +
                                optax.linear_schedule(
         | 
| 257 | 
            +
                                    init_value=config.lr,
         | 
| 258 | 
            +
                                    end_value=config.end_lr,
         | 
| 259 | 
            +
                                    transition_steps=config.lr_decay_steps,
         | 
| 260 | 
            +
                                )
         | 
| 261 | 
            +
                            ],
         | 
| 262 | 
            +
                            [config.lr_warmup_steps, config.lr_constant_steps],
         | 
| 263 | 
            +
                        )
         | 
| 264 | 
            +
                    elif config.lr_schedule_type == "warmup_constant_exponential_decay":
         | 
| 265 | 
            +
                        learning_rate_schedule = optax.join_schedules(
         | 
| 266 | 
            +
                            [
         | 
| 267 | 
            +
                                optax.linear_schedule(
         | 
| 268 | 
            +
                                    init_value=config.init_lr,
         | 
| 269 | 
            +
                                    end_value=config.lr,
         | 
| 270 | 
            +
                                    transition_steps=config.lr_warmup_steps,
         | 
| 271 | 
            +
                                ),
         | 
| 272 | 
            +
                                optax.constant_schedule(config.lr),
         | 
| 273 | 
            +
                                optax.exponential_decay(
         | 
| 274 | 
            +
                                    init_value=config.lr, 
         | 
| 275 | 
            +
                                    transition_steps=config.lr_decay_steps, 
         | 
| 276 | 
            +
                                    decay_rate=config.lr_decay_rate, 
         | 
| 277 | 
            +
                                    transition_begin=0, 
         | 
| 278 | 
            +
                                    staircase=False, 
         | 
| 279 | 
            +
                                    end_value=config.end_lr,
         | 
| 280 | 
            +
                                )
         | 
| 281 | 
            +
                            ],
         | 
| 282 | 
            +
                            [config.lr_warmup_steps, config.lr_constant_steps],
         | 
| 283 | 
            +
                        )
         | 
| 284 | 
             
                    elif config.lr_schedule_type == "exponential_decay":
         | 
| 285 | 
             
                        learning_rate_schedule = optax.exponential_decay(
         | 
| 286 | 
             
                                    init_value=config.lr, 
         | 
|  | |
| 290 | 
             
                                    staircase=False, 
         | 
| 291 | 
             
                                    end_value=config.end_lr,
         | 
| 292 | 
             
                        )
         | 
| 293 | 
            +
                    elif config.lr_schedule_type == "linear_decay":
         | 
| 294 | 
            +
                        learning_rate_schedule = optax.linear_schedule(
         | 
| 295 | 
            +
                                    init_value=config.lr,
         | 
| 296 | 
            +
                                    end_value=config.end_lr,
         | 
| 297 | 
            +
                                    transition_steps=config.lr_decay_steps,
         | 
| 298 | 
            +
                        )
         | 
| 299 | 
             
                    else:
         | 
| 300 | 
            +
                        raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", "warmup_constant_linear_decay", "warmup_constant_exponential_decay", "exponential_decay" or "linear_decay"')
         | 
| 301 |  | 
| 302 | 
             
                    optimizer_info = dict(
         | 
| 303 | 
             
                        learning_rate_schedule=learning_rate_schedule,
         | 
    	
        pretrain_llama_7b.sh
    CHANGED
    
    | @@ -23,10 +23,11 @@ python3 -m EasyLM.models.llama.llama_train \ | |
| 23 | 
             
                --tokenizer.vocab_file='tokenizer.model' \
         | 
| 24 | 
             
                --optimizer.type='lion' \
         | 
| 25 | 
             
                --optimizer.lion_optimizer.weight_decay=1.0 \
         | 
| 26 | 
            -
                --optimizer.lion_optimizer.lr_schedule_type=' | 
| 27 | 
             
                --optimizer.lion_optimizer.lr=1e-4 \
         | 
| 28 | 
             
                --optimizer.lion_optimizer.end_lr=1e-5 \
         | 
| 29 | 
             
                --optimizer.lion_optimizer.lr_warmup_steps=60000 \
         | 
|  | |
| 30 | 
             
                --optimizer.lion_optimizer.lr_decay_steps=100000 \
         | 
| 31 | 
             
                --optimizer.lion_optimizer.bf16_momentum=True \
         | 
| 32 | 
             
                --train_dataset.type='huggingface' \
         | 
|  | |
| 23 | 
             
                --tokenizer.vocab_file='tokenizer.model' \
         | 
| 24 | 
             
                --optimizer.type='lion' \
         | 
| 25 | 
             
                --optimizer.lion_optimizer.weight_decay=1.0 \
         | 
| 26 | 
            +
                --optimizer.lion_optimizer.lr_schedule_type='warmup_constant_linear_decay' \
         | 
| 27 | 
             
                --optimizer.lion_optimizer.lr=1e-4 \
         | 
| 28 | 
             
                --optimizer.lion_optimizer.end_lr=1e-5 \
         | 
| 29 | 
             
                --optimizer.lion_optimizer.lr_warmup_steps=60000 \
         | 
| 30 | 
            +
                --optimizer.lion_optimizer.lr_constant_steps=900000 \
         | 
| 31 | 
             
                --optimizer.lion_optimizer.lr_decay_steps=100000 \
         | 
| 32 | 
             
                --optimizer.lion_optimizer.bf16_momentum=True \
         | 
| 33 | 
             
                --train_dataset.type='huggingface' \
         | 
