Upload 2 files
Browse files- configuration_olmo.py +44 -0
 - modeling_olmo.py +145 -0
 
    	
        configuration_olmo.py
    ADDED
    
    | 
         @@ -0,0 +1,44 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """
         
     | 
| 2 | 
         
            +
            OLMo configuration
         
     | 
| 3 | 
         
            +
            """
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from transformers import AutoConfig, PretrainedConfig
         
     | 
| 6 | 
         
            +
            from transformers.utils import logging
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from olmo.config import ModelConfig
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            logger = logging.get_logger(__name__)
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class OLMoConfig(PretrainedConfig):
         
     | 
| 14 | 
         
            +
                model_type = "olmo"
         
     | 
| 15 | 
         
            +
                keys_to_ignore_at_inference = ["past_key_values"]  # TODO: confirm
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def __init__(self, use_cache: bool = False, **kwargs):
         
     | 
| 18 | 
         
            +
                    model_config = ModelConfig()
         
     | 
| 19 | 
         
            +
                    all_kwargs = model_config.asdict()
         
     | 
| 20 | 
         
            +
                    all_kwargs.update(kwargs)
         
     | 
| 21 | 
         
            +
                    all_kwargs.update({"use_cache": use_cache})
         
     | 
| 22 | 
         
            +
                    all_kwargs.update(
         
     | 
| 23 | 
         
            +
                        {
         
     | 
| 24 | 
         
            +
                            "architectures": all_kwargs.get("architectures", ["OlmoModelForCausalLM"])
         
     | 
| 25 | 
         
            +
                            or ["OlmoModelForCausalLM"]
         
     | 
| 26 | 
         
            +
                        }
         
     | 
| 27 | 
         
            +
                    )
         
     | 
| 28 | 
         
            +
                    super().__init__(**all_kwargs)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                @property
         
     | 
| 31 | 
         
            +
                def num_attention_heads(self):
         
     | 
| 32 | 
         
            +
                    return self.n_heads
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                @property
         
     | 
| 35 | 
         
            +
                def num_hidden_layers(self):
         
     | 
| 36 | 
         
            +
                    return self.n_layers
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                @property
         
     | 
| 39 | 
         
            +
                def hidden_size(self):
         
     | 
| 40 | 
         
            +
                    return self.d_model
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            # Register the config class so that it is available for transformer pipelines, auto-loading etc.
         
     | 
| 44 | 
         
            +
            AutoConfig.register("olmo", OLMoConfig)
         
     | 
    	
        modeling_olmo.py
    ADDED
    
    | 
         @@ -0,0 +1,145 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import List, Optional, Tuple, Union
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from transformers import PreTrainedModel
         
     | 
| 5 | 
         
            +
            from transformers.modeling_outputs import CausalLMOutputWithPast
         
     | 
| 6 | 
         
            +
            from transformers.models.auto import AutoModelForCausalLM
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from olmo.config import ModelConfig
         
     | 
| 9 | 
         
            +
            from olmo.model import Olmo
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from .configuration_olmo import OLMoConfig
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            def create_model_config_from_pretrained_config(config: OLMoConfig):
         
     | 
| 15 | 
         
            +
                """
         
     | 
| 16 | 
         
            +
                Utility function
         
     | 
| 17 | 
         
            +
                """
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                kwargs = {}
         
     | 
| 20 | 
         
            +
                for key in ModelConfig.__match_args__:
         
     | 
| 21 | 
         
            +
                    kwargs[key] = getattr(config, key)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                model_config = ModelConfig(**kwargs)
         
     | 
| 24 | 
         
            +
                return model_config
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            class OLMoForCausalLM(PreTrainedModel):
         
     | 
| 28 | 
         
            +
                """
         
     | 
| 29 | 
         
            +
                Extremely barebones HF model wrapper.
         
     | 
| 30 | 
         
            +
                """
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                config_class = OLMoConfig
         
     | 
| 33 | 
         
            +
                base_model_prefix = "model"
         
     | 
| 34 | 
         
            +
                _no_split_modules = ["OLMoBlock"]
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                def __init__(self, config: OLMoConfig, model: Optional[Olmo] = None, init_params: bool = False):
         
     | 
| 37 | 
         
            +
                    super().__init__(config)
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    if not model:
         
     | 
| 40 | 
         
            +
                        model_config = create_model_config_from_pretrained_config(config)
         
     | 
| 41 | 
         
            +
                        # Initialize model (always on CPU to start with so we don't run out of GPU memory).
         
     | 
| 42 | 
         
            +
                        model_config.init_device = "cpu"
         
     | 
| 43 | 
         
            +
                        self.model = Olmo(model_config, init_params=init_params)
         
     | 
| 44 | 
         
            +
                    else:
         
     | 
| 45 | 
         
            +
                        self.model = model
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def forward(
         
     | 
| 48 | 
         
            +
                    self,
         
     | 
| 49 | 
         
            +
                    input_ids: torch.LongTensor = None,
         
     | 
| 50 | 
         
            +
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 51 | 
         
            +
                    past_key_values: Optional[List[torch.FloatTensor]] = None,
         
     | 
| 52 | 
         
            +
                    labels: Optional[torch.LongTensor] = None,
         
     | 
| 53 | 
         
            +
                    use_cache: Optional[bool] = None,
         
     | 
| 54 | 
         
            +
                    output_attentions: Optional[bool] = None,
         
     | 
| 55 | 
         
            +
                    output_hidden_states: Optional[bool] = None,
         
     | 
| 56 | 
         
            +
                    return_dict: Optional[bool] = None,
         
     | 
| 57 | 
         
            +
                ) -> Union[Tuple, CausalLMOutputWithPast]:
         
     | 
| 58 | 
         
            +
                    if use_cache is None:
         
     | 
| 59 | 
         
            +
                        use_cache = self.config.use_cache
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
         
     | 
| 64 | 
         
            +
                    outputs = self.model.forward(
         
     | 
| 65 | 
         
            +
                        input_ids=input_ids,
         
     | 
| 66 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 67 | 
         
            +
                        past_key_values=past_key_values,
         
     | 
| 68 | 
         
            +
                        use_cache=use_cache,
         
     | 
| 69 | 
         
            +
                    )
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    logits = outputs.logits
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                    loss = None
         
     | 
| 74 | 
         
            +
                    if labels is not None:
         
     | 
| 75 | 
         
            +
                        # Shift so that tokens < n predict n
         
     | 
| 76 | 
         
            +
                        shift_logits = logits[..., :-1, :].contiguous()
         
     | 
| 77 | 
         
            +
                        shift_labels = labels[..., 1:].contiguous()
         
     | 
| 78 | 
         
            +
                        # Flatten the tokens
         
     | 
| 79 | 
         
            +
                        loss_fct = torch.nn.CrossEntropyLoss()
         
     | 
| 80 | 
         
            +
                        shift_logits = shift_logits.view(-1, self.config.embedding_size)
         
     | 
| 81 | 
         
            +
                        shift_labels = shift_labels.view(-1)
         
     | 
| 82 | 
         
            +
                        # Enable model parallelism
         
     | 
| 83 | 
         
            +
                        shift_labels = shift_labels.to(shift_logits.device)
         
     | 
| 84 | 
         
            +
                        loss = loss_fct(shift_logits, shift_labels)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    if not return_dict:
         
     | 
| 87 | 
         
            +
                        output = (logits,) + outputs[1:]
         
     | 
| 88 | 
         
            +
                        return (loss,) + output if loss is not None else output
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    return CausalLMOutputWithPast(
         
     | 
| 91 | 
         
            +
                        loss=loss,
         
     | 
| 92 | 
         
            +
                        logits=logits,
         
     | 
| 93 | 
         
            +
                        past_key_values=outputs.attn_key_values,
         
     | 
| 94 | 
         
            +
                    )
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                def can_generate(self) -> bool:
         
     | 
| 97 | 
         
            +
                    return True
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                def prepare_inputs_for_generation(
         
     | 
| 100 | 
         
            +
                    self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
         
     | 
| 101 | 
         
            +
                ):
         
     | 
| 102 | 
         
            +
                    if past_key_values:
         
     | 
| 103 | 
         
            +
                        # This is because we want the model to only process the last generated token.
         
     | 
| 104 | 
         
            +
                        input_ids = input_ids[:, -1:]
         
     | 
| 105 | 
         
            +
                    model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    model_inputs.update(kwargs)
         
     | 
| 108 | 
         
            +
                    model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
         
     | 
| 109 | 
         
            +
                    return model_inputs
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                # TODO: these are required to make the implementation complete.
         
     | 
| 112 | 
         
            +
                # def resize_position_embeddings(self, new_num_position_embeddings: int):
         
     | 
| 113 | 
         
            +
                #     pass
         
     | 
| 114 | 
         
            +
                #
         
     | 
| 115 | 
         
            +
                # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
         
     | 
| 116 | 
         
            +
                #     pass
         
     | 
| 117 | 
         
            +
                #
         
     | 
| 118 | 
         
            +
                # def _reorder_cache(self, past_key_values, beam_idx):
         
     | 
| 119 | 
         
            +
                #     pass
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                def get_input_embeddings(self) -> torch.nn.Module:
         
     | 
| 122 | 
         
            +
                    return self.model.transformer.wte
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                def set_input_embeddings(self, value: torch.nn.Module):
         
     | 
| 125 | 
         
            +
                    self.model.transformer.wte = value
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                def get_output_embeddings(self):
         
     | 
| 128 | 
         
            +
                    if self.config.weight_tying:
         
     | 
| 129 | 
         
            +
                        return self.model.transformer.wte
         
     | 
| 130 | 
         
            +
                    else:
         
     | 
| 131 | 
         
            +
                        return self.model.transformer.ff_out
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                def set_output_embeddings(self, value: torch.nn.Module):
         
     | 
| 134 | 
         
            +
                    if self.config.weight_tying:
         
     | 
| 135 | 
         
            +
                        self.model.transformer.wte = value
         
     | 
| 136 | 
         
            +
                    else:
         
     | 
| 137 | 
         
            +
                        self.model.transformer.ff_out = value
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                def tie_weights(self):
         
     | 
| 140 | 
         
            +
                    if self.config.weight_tying:
         
     | 
| 141 | 
         
            +
                        self.model.transformer.ff_out = self.model.transformer.wte
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
            # Register the model so that it is available for transformer pipelines, auto-loading, etc.
         
     | 
| 145 | 
         
            +
            AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
         
     |