Upload ContextualDocumentEmbeddingTransformer
Browse files- README.md +199 -0
- config.json +28 -0
- misc.py +518 -0
- model.py +622 -0
- model.safetensors +3 -0
README.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags: []
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "/fsx-checkpoints/jxm/cde/2024-08-06-transductive-pretrain-transductive-long-10node-3/checkpoint-7176",
|
| 3 |
+
"architecture": "transductive",
|
| 4 |
+
"architectures": [
|
| 5 |
+
"ContextualDocumentEmbeddingTransformer"
|
| 6 |
+
],
|
| 7 |
+
"attn_implementation": null,
|
| 8 |
+
"auto_map": {
|
| 9 |
+
"AutoConfig": "misc.ContextualModelConfig",
|
| 10 |
+
"AutoModel": "model.ContextualDocumentEmbeddingTransformer"
|
| 11 |
+
},
|
| 12 |
+
"cache_dir": null,
|
| 13 |
+
"config_name": null,
|
| 14 |
+
"disable_dropout": true,
|
| 15 |
+
"disable_transductive_rotary_embedding": true,
|
| 16 |
+
"embedder": "nomic-ai/nomic-bert-2048",
|
| 17 |
+
"embedder_rerank": "sentence-transformers/gtr-t5-base",
|
| 18 |
+
"embedding_output_dim": null,
|
| 19 |
+
"limit_layers": null,
|
| 20 |
+
"logit_scale": 50.0,
|
| 21 |
+
"max_seq_length": 512,
|
| 22 |
+
"model_revision": "main",
|
| 23 |
+
"tokenizer_name": null,
|
| 24 |
+
"torch_dtype": "float32",
|
| 25 |
+
"transductive_corpus_size": 512,
|
| 26 |
+
"transductive_sequence_dropout_prob": 0.0,
|
| 27 |
+
"transformers_version": "4.48.0.dev0"
|
| 28 |
+
}
|
misc.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import collections
|
| 4 |
+
import glob
|
| 5 |
+
import json
|
| 6 |
+
import hashlib
|
| 7 |
+
import itertools
|
| 8 |
+
import logging
|
| 9 |
+
import multiprocessing
|
| 10 |
+
import os
|
| 11 |
+
import pickle
|
| 12 |
+
import random
|
| 13 |
+
import requests
|
| 14 |
+
import sys
|
| 15 |
+
import zipfile
|
| 16 |
+
|
| 17 |
+
import datasets
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import tqdm
|
| 21 |
+
import transformers
|
| 22 |
+
|
| 23 |
+
from cde.lib.dist import get_num_proc, get_rank
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_cde_cache_dir() -> str:
|
| 27 |
+
script_directory = os.path.normpath(
|
| 28 |
+
os.path.join(
|
| 29 |
+
os.path.dirname(os.path.abspath(__file__)),
|
| 30 |
+
os.pardir, os.pardir,
|
| 31 |
+
)
|
| 32 |
+
)
|
| 33 |
+
return os.path.join(script_directory, "data")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_cache_location_from_kwargs(**kwargs):
|
| 37 |
+
cache_location = os.path.join(
|
| 38 |
+
get_cde_cache_dir(), "cluster"
|
| 39 |
+
)
|
| 40 |
+
os.makedirs(cache_location, exist_ok=True)
|
| 41 |
+
return os.path.join(cache_location, md5_hash_kwargs(**kwargs))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def process_qrels_uncached(corpus: datasets.Dataset, qrels: datasets.Dataset) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
| 45 |
+
qrels_idxs = collections.defaultdict(list)
|
| 46 |
+
qrels_scores = collections.defaultdict(list)
|
| 47 |
+
corpus_ids = np.array(corpus['_id'])
|
| 48 |
+
skipped_qrels = 0
|
| 49 |
+
|
| 50 |
+
for ex in tqdm.tqdm(qrels, desc='processing qrels', colour='#964B00', leave=False):
|
| 51 |
+
#
|
| 52 |
+
# example:
|
| 53 |
+
# {
|
| 54 |
+
# 'query-id': 1,
|
| 55 |
+
# 'corpus-id': 'b0680508-2019-04-18T13:48:51Z-00002-000',
|
| 56 |
+
# 'score': 2
|
| 57 |
+
# }
|
| 58 |
+
#
|
| 59 |
+
q_id = str(ex['query-id'])
|
| 60 |
+
c_idxs = (corpus_ids == str(ex['corpus-id'])).nonzero()[0]
|
| 61 |
+
#
|
| 62 |
+
assert len(c_idxs) <= 1, f"error - duplicate corpus ID? (found {len(c_idxs)} matches)"
|
| 63 |
+
#
|
| 64 |
+
if len(c_idxs):
|
| 65 |
+
qrels_idxs[q_id].append(c_idxs[0])
|
| 66 |
+
qrels_scores[q_id].append(ex['score'])
|
| 67 |
+
else:
|
| 68 |
+
skipped_qrels += 1
|
| 69 |
+
#
|
| 70 |
+
|
| 71 |
+
if skipped_qrels > 0:
|
| 72 |
+
logging.warning(f'Warning: Skipped {skipped_qrels}/{len(qrels)} qrels.')
|
| 73 |
+
|
| 74 |
+
return qrels_idxs, qrels_scores
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def process_qrels(
|
| 78 |
+
corpus: datasets.Dataset, qrels: datasets.Dataset,
|
| 79 |
+
use_cache: bool = True
|
| 80 |
+
) -> Tuple[Dict[str, List[float]], Dict[str, List[str]]]:
|
| 81 |
+
dataset_cache_file = '_'.join(
|
| 82 |
+
(corpus.cache_files[0]['filename'], qrels.cache_files[0]['filename'])
|
| 83 |
+
)
|
| 84 |
+
cache_file = strip_extension(dataset_cache_file) + '_processed_qrels.p'
|
| 85 |
+
os.makedirs(os.path.dirname(cache_file), exist_ok=True)
|
| 86 |
+
|
| 87 |
+
if not (use_cache and os.path.exists(cache_file)):
|
| 88 |
+
qrels_idxs, qrels_scores = process_qrels_uncached(
|
| 89 |
+
corpus=corpus, qrels=qrels
|
| 90 |
+
)
|
| 91 |
+
if use_cache:
|
| 92 |
+
pickle.dump((qrels_idxs, qrels_scores), open(cache_file, 'wb'))
|
| 93 |
+
else:
|
| 94 |
+
qrels_idxs, qrels_scores = pickle.load(open(cache_file, 'rb'))
|
| 95 |
+
|
| 96 |
+
return qrels_idxs, qrels_scores
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def strip_extension(filename: str) -> str:
|
| 100 |
+
"""Strips file extension.
|
| 101 |
+
|
| 102 |
+
Ex:
|
| 103 |
+
>> strip_extension('/root/dir/sub/file.ext')
|
| 104 |
+
'/root/dir/sub/file'
|
| 105 |
+
"""
|
| 106 |
+
return os.path.splitext(filename)[0]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def md5_hash(t: Tuple[str]) -> str:
|
| 110 |
+
return hashlib.md5('__'.join(t).encode()).hexdigest()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def md5_hash_kwargs(**kwargs) -> str:
|
| 114 |
+
# We ignore special hf args that start with _ like '__cached__setup_devices'.
|
| 115 |
+
safe_kwargs = {k: str(v) for k,v in kwargs.items() if not k.startswith('_')}
|
| 116 |
+
s = json.dumps(safe_kwargs, sort_keys=True)
|
| 117 |
+
return hashlib.md5(s.encode()).hexdigest()
|
| 118 |
+
|
| 119 |
+
def download_url(url: str, save_path: str, chunk_size: int = 1024):
|
| 120 |
+
"""Download url with progress bar using tqdm
|
| 121 |
+
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
|
| 122 |
+
Args:
|
| 123 |
+
url (str): downloadable url
|
| 124 |
+
save_path (str): local path to save the downloaded file
|
| 125 |
+
chunk_size (int, optional): chunking of files. Defaults to 1024.
|
| 126 |
+
"""
|
| 127 |
+
r = requests.get(url, stream=True)
|
| 128 |
+
total = int(r.headers.get('Content-Length', 0))
|
| 129 |
+
with open(save_path, 'wb') as fd, tqdm.tqdm(
|
| 130 |
+
desc=save_path,
|
| 131 |
+
total=total,
|
| 132 |
+
unit='iB',
|
| 133 |
+
unit_scale=True,
|
| 134 |
+
unit_divisor=chunk_size,
|
| 135 |
+
) as bar:
|
| 136 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
| 137 |
+
size = fd.write(data)
|
| 138 |
+
bar.update(size)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def unzip(zip_file: str, out_dir: str):
|
| 142 |
+
print("unzipping =>", zip_file)
|
| 143 |
+
zip_ = zipfile.ZipFile(zip_file, "r")
|
| 144 |
+
zip_.extractall(path=out_dir)
|
| 145 |
+
zip_.close()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def download_url_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str:
|
| 149 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 150 |
+
dataset = url.split("/")[-1]
|
| 151 |
+
zip_file = os.path.join(out_dir, dataset)
|
| 152 |
+
|
| 153 |
+
if not os.path.isfile(zip_file):
|
| 154 |
+
logging.info("Downloading {} ...".format(dataset))
|
| 155 |
+
download_url(url, zip_file, chunk_size)
|
| 156 |
+
|
| 157 |
+
if not os.path.isdir(zip_file.replace(".zip", "")):
|
| 158 |
+
logging.info("Unzipping {} ...".format(dataset))
|
| 159 |
+
unzip(zip_file, out_dir)
|
| 160 |
+
|
| 161 |
+
return os.path.join(out_dir, dataset.replace(".zip", ""))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def tqdm_if_main_worker(iterable: Iterable, **kwargs) -> Iterable:
|
| 165 |
+
if get_rank() == 0:
|
| 166 |
+
return tqdm.tqdm(iterable, **kwargs)
|
| 167 |
+
else:
|
| 168 |
+
return iterable
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class ContextualModelConfig(transformers.configuration_utils.PretrainedConfig):
|
| 172 |
+
"""We create a dummy configuration class that will just set properties
|
| 173 |
+
based on whatever kwargs we pass in.
|
| 174 |
+
|
| 175 |
+
When this class is initialized (see experiments.py) we pass in the
|
| 176 |
+
union of all data, model, and training args, all of which should
|
| 177 |
+
get saved to the config json.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
def __init__(self, **kwargs):
|
| 181 |
+
for key, value in kwargs.items():
|
| 182 |
+
try:
|
| 183 |
+
json.dumps(value)
|
| 184 |
+
setattr(self, key, value)
|
| 185 |
+
except TypeError:
|
| 186 |
+
# value was not JSON-serializable, skip
|
| 187 |
+
continue
|
| 188 |
+
super().__init__()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def independent_crop(
|
| 192 |
+
input_ids: torch.Tensor, pad_token_id: int,
|
| 193 |
+
l1: int = 256, l2: int = 256) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 194 |
+
"""Returns two independent crops from input_ids.
|
| 195 |
+
|
| 196 |
+
Assumes input_ids has a beginning and end token, like
|
| 197 |
+
[101, ..., 102, 0, 0, 0].
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
input_ids: tensor of IDs
|
| 201 |
+
pad_token_id: ID of pad tokens in input_ids
|
| 202 |
+
l1: length of span 1, cropped
|
| 203 |
+
l2: length of span 2, cropped
|
| 204 |
+
Returns:
|
| 205 |
+
span1: first crop (of length l1)
|
| 206 |
+
span2: second crop (of length l2)
|
| 207 |
+
"""
|
| 208 |
+
# Count tokens until pad.
|
| 209 |
+
if (input_ids == pad_token_id).sum() == 0:
|
| 210 |
+
N = len(input_ids)
|
| 211 |
+
else:
|
| 212 |
+
N = (input_ids == pad_token_id).int().argmax().item()
|
| 213 |
+
|
| 214 |
+
####
|
| 215 |
+
###
|
| 216 |
+
##
|
| 217 |
+
## Contriever: We use the random cropping data
|
| 218 |
+
## augmentation, with documents of 256 tokens and span
|
| 219 |
+
## sizes sampled between 5% and 50% of the document
|
| 220 |
+
## length
|
| 221 |
+
##
|
| 222 |
+
###
|
| 223 |
+
#####
|
| 224 |
+
####### LaPraDor: The maximum lengths set for queries and
|
| 225 |
+
####### documents are 64 and 350...
|
| 226 |
+
#####
|
| 227 |
+
# TODO is this divide-by-two a good idea? (Don't want s1=s2 ever..)
|
| 228 |
+
nl1 = min(N//2, l1)
|
| 229 |
+
nl2 = min(N//2, l2)
|
| 230 |
+
|
| 231 |
+
s1_start = random.randint(1, N-nl1)
|
| 232 |
+
s2_start = random.randint(1, N-nl2)
|
| 233 |
+
|
| 234 |
+
s1_idxs = itertools.chain(
|
| 235 |
+
[0], range(s1_start, s1_start+nl1), [N-1]
|
| 236 |
+
)
|
| 237 |
+
s1 = input_ids[torch.tensor(list(s1_idxs))]
|
| 238 |
+
s2_idxs = itertools.chain(
|
| 239 |
+
[0], range(s2_start, s2_start+nl2), [N-1]
|
| 240 |
+
)
|
| 241 |
+
s2 = input_ids[torch.tensor(list(s2_idxs))]
|
| 242 |
+
return (s1, s2)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def load_dataset_tables(
|
| 246 |
+
files: Iterable[str], num_workers: int = 16
|
| 247 |
+
) -> Iterable[datasets.table.MemoryMappedTable]:
|
| 248 |
+
import concurrent
|
| 249 |
+
from multiprocessing import Pool
|
| 250 |
+
|
| 251 |
+
# num_workers = min(num_workers, len(files))
|
| 252 |
+
num_workers = min(32, len(files))
|
| 253 |
+
|
| 254 |
+
use_threads = True
|
| 255 |
+
if use_threads:
|
| 256 |
+
pool_cls = concurrent.futures.ThreadPoolExecutor
|
| 257 |
+
pool_kwargs = {"max_workers": num_workers}
|
| 258 |
+
else:
|
| 259 |
+
pool_cls = Pool
|
| 260 |
+
pool_kwargs = {"processes": num_workers}
|
| 261 |
+
|
| 262 |
+
with pool_cls(**pool_kwargs) as pool:
|
| 263 |
+
if len(files) > 10:
|
| 264 |
+
files = tqdm_if_main_worker(
|
| 265 |
+
files,
|
| 266 |
+
desc=f"Loading {len(files)} files with {num_workers} workers",
|
| 267 |
+
total=len(files),
|
| 268 |
+
colour="#ffbd88"
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
result = list(
|
| 272 |
+
pool.map(datasets.table.MemoryMappedTable.from_file, files)
|
| 273 |
+
)
|
| 274 |
+
return result
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def datasets_fast_load_from_disk(cache_path: str) -> datasets.Dataset:
|
| 278 |
+
logging.info(f"fast_load_from_disk called with path:", cache_path)
|
| 279 |
+
dataset_info_path = os.path.join(cache_path, "dataset_info.json")
|
| 280 |
+
with open(dataset_info_path, encoding="utf-8") as dataset_info_file:
|
| 281 |
+
dataset_info = datasets.DatasetInfo.from_dict(json.load(dataset_info_file))
|
| 282 |
+
|
| 283 |
+
dataset_state_path = os.path.join(cache_path, "state.json")
|
| 284 |
+
with open(dataset_state_path, encoding="utf-8") as state_file:
|
| 285 |
+
state = json.load(state_file)
|
| 286 |
+
|
| 287 |
+
files = glob.glob(os.path.join(cache_path, "data-*.arrow"))
|
| 288 |
+
files = sorted(files)
|
| 289 |
+
num_workers = get_num_proc()
|
| 290 |
+
ds_tables = load_dataset_tables(
|
| 291 |
+
files=files,
|
| 292 |
+
num_workers=num_workers
|
| 293 |
+
)
|
| 294 |
+
arrow_table = datasets.table.concat_tables(ds_tables)
|
| 295 |
+
|
| 296 |
+
split = state["_split"]
|
| 297 |
+
split = datasets.splits.Split(split) if split is not None else split
|
| 298 |
+
|
| 299 |
+
# print("returning dataset")
|
| 300 |
+
return datasets.Dataset(
|
| 301 |
+
arrow_table=arrow_table,
|
| 302 |
+
info=dataset_info,
|
| 303 |
+
split=split,
|
| 304 |
+
fingerprint=state["_fingerprint"],
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def tokenize_dataset(
|
| 309 |
+
dataset: datasets.Dataset,
|
| 310 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 311 |
+
max_length: int,
|
| 312 |
+
text_key: str,
|
| 313 |
+
padding_strategy: str
|
| 314 |
+
) -> datasets.Dataset:
|
| 315 |
+
def tokenize_text(ex: Dict) -> Dict:
|
| 316 |
+
tt = tokenizer(
|
| 317 |
+
ex[text_key],
|
| 318 |
+
max_length=max_length,
|
| 319 |
+
truncation=True,
|
| 320 |
+
padding=padding_strategy,
|
| 321 |
+
)
|
| 322 |
+
for k,v in tt.items():
|
| 323 |
+
ex[f"{text_key}_{k}"] = v
|
| 324 |
+
ex["length"] = [len(tt) for tt in ex[f"{text_key}_input_ids"]]
|
| 325 |
+
return ex
|
| 326 |
+
|
| 327 |
+
# generate unique hash for tokenizer
|
| 328 |
+
vocab = tokenizer.vocab
|
| 329 |
+
vocab_words = tuple(sorted(vocab.keys(), key=lambda word: vocab[word]))
|
| 330 |
+
vocab_hash = md5_hash(vocab_words)
|
| 331 |
+
|
| 332 |
+
data_fingerprint = '__'.join((
|
| 333 |
+
dataset._fingerprint, str(vocab_hash), str(max_length),
|
| 334 |
+
text_key, padding_strategy
|
| 335 |
+
))
|
| 336 |
+
data_fingerprint = md5_hash(data_fingerprint)
|
| 337 |
+
dataset = dataset.map(
|
| 338 |
+
tokenize_text,
|
| 339 |
+
new_fingerprint=data_fingerprint,
|
| 340 |
+
batched=True,
|
| 341 |
+
load_from_cache_file=True,
|
| 342 |
+
)
|
| 343 |
+
return dataset
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class TensorRunningAverages:
|
| 347 |
+
_store_sum: Dict[str, torch.Tensor]
|
| 348 |
+
_store_total: Dict[str, torch.Tensor]
|
| 349 |
+
|
| 350 |
+
def __init__(self):
|
| 351 |
+
self._store_sum = {}
|
| 352 |
+
self._store_total = {}
|
| 353 |
+
|
| 354 |
+
def __iter__(self) -> Iterable[str]:
|
| 355 |
+
return iter(self._store_sum.keys())
|
| 356 |
+
|
| 357 |
+
def update(self, key: str, val: Union[int, float, torch.Tensor]) -> None:
|
| 358 |
+
if key not in self._store_sum:
|
| 359 |
+
self.clear(key)
|
| 360 |
+
if isinstance(val, torch.Tensor):
|
| 361 |
+
val = val.item() # tensor -> num
|
| 362 |
+
self._store_sum[key] += val
|
| 363 |
+
self._store_total[key] += 1
|
| 364 |
+
|
| 365 |
+
def get(self, key: str) -> float:
|
| 366 |
+
total = max(self._store_total.get(key).item(), 1.0)
|
| 367 |
+
return (self._store_sum[key] / float(total)).item() or 0.0
|
| 368 |
+
|
| 369 |
+
def clear(self, key: str) -> None:
|
| 370 |
+
self._store_sum[key] = torch.tensor(0.0, dtype=torch.float32)
|
| 371 |
+
self._store_total[key] = torch.tensor(0, dtype=torch.int32)
|
| 372 |
+
|
| 373 |
+
def clear_all(self) -> None:
|
| 374 |
+
for key in self._store_sum:
|
| 375 |
+
self.clear(key)
|
| 376 |
+
|
| 377 |
+
def get_and_clear_all(self) -> Dict[str, float]:
|
| 378 |
+
metrics = {}
|
| 379 |
+
for key in self:
|
| 380 |
+
metrics[key] = self.get(key)
|
| 381 |
+
self.clear(key)
|
| 382 |
+
return metrics
|
| 383 |
+
|
| 384 |
+
def load_embedder_and_tokenizer(name: str) -> Tuple[
|
| 385 |
+
transformers.PreTrainedModel,
|
| 386 |
+
transformers.PreTrainedTokenizer
|
| 387 |
+
]:
|
| 388 |
+
if name.startswith("nomic") or (name == "bert-base-uncased"):
|
| 389 |
+
from cde.lib.nomic_bert import NomicBertModel
|
| 390 |
+
if name.endswith("--from-scratch"):
|
| 391 |
+
name = name.replace("--from-scratch", "")
|
| 392 |
+
config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
|
| 393 |
+
model = NomicBertModel._from_config(config)
|
| 394 |
+
else:
|
| 395 |
+
model = NomicBertModel.from_pretrained(
|
| 396 |
+
name, add_pooling_layer=False
|
| 397 |
+
)
|
| 398 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
| 399 |
+
elif name in ["gtr-base", "gtr_base"]:
|
| 400 |
+
model = transformers.AutoModel.from_pretrained(
|
| 401 |
+
"sentence-transformers/gtr-t5-base"
|
| 402 |
+
).encoder
|
| 403 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 404 |
+
"sentence-transformers/gtr-t5-base"
|
| 405 |
+
)
|
| 406 |
+
elif name == "pile-t5-base-encoder":
|
| 407 |
+
model = transformers.AutoModel.from_pretrained(
|
| 408 |
+
"EleutherAI/pile-t5-base"
|
| 409 |
+
).encoder
|
| 410 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 411 |
+
"EleutherAI/pile-t5-base"
|
| 412 |
+
)
|
| 413 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 414 |
+
elif name == "pile-t5-base-decoder":
|
| 415 |
+
model = transformers.AutoModel.from_pretrained(
|
| 416 |
+
"EleutherAI/pile-t5-base"
|
| 417 |
+
).decoder
|
| 418 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 419 |
+
"EleutherAI/pile-t5-base"
|
| 420 |
+
)
|
| 421 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 422 |
+
elif name.startswith("gpt2") or name.startswith("meta-llama") or ("Llama" in name):
|
| 423 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 424 |
+
name,
|
| 425 |
+
# torch_dtype=torch.bfloat16,
|
| 426 |
+
attn_implementation="flash_attention_2" if torch.cuda.is_available() else "sdpa",
|
| 427 |
+
low_cpu_mem_usage=True,
|
| 428 |
+
# device_map="auto",
|
| 429 |
+
)
|
| 430 |
+
model.padding_side = "right"
|
| 431 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
| 432 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 433 |
+
tokenizer.add_eos_token = True
|
| 434 |
+
elif "Modern" in name:
|
| 435 |
+
print("special loading for ModernBERT!")
|
| 436 |
+
# [1] needed for faster training
|
| 437 |
+
# model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True, reference_compile=True)
|
| 438 |
+
# [2] needed for non-breaking inference
|
| 439 |
+
model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True, reference_compile=False)
|
| 440 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
| 441 |
+
else:
|
| 442 |
+
model = transformers.AutoModel.from_pretrained(name, trust_remote_code=True)
|
| 443 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(name)
|
| 444 |
+
return model, tokenizer
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def inputs_for_key(inputs: Dict[str, torch.Tensor], key: str):
|
| 448 |
+
key += "_"
|
| 449 |
+
return {k.replace(key, ""): v for k,v in inputs.items() if k.startswith(key)}
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def count_cpus() -> int:
|
| 453 |
+
try:
|
| 454 |
+
return len(os.sched_getaffinity(0))
|
| 455 |
+
except AttributeError:
|
| 456 |
+
return multiprocessing.cpu_count()
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def shuffle_batches(g: torch.Generator, list_of_tensors: List[torch.Tensor]) -> List[int]:
|
| 460 |
+
all_indices = []
|
| 461 |
+
for batch_tensor in tqdm_if_main_worker(list_of_tensors, colour="green", desc="Sampler shuffling per-batch"):
|
| 462 |
+
rand_perm = torch.randperm(len(batch_tensor), generator=g)
|
| 463 |
+
batch_list = batch_tensor[rand_perm].tolist()
|
| 464 |
+
all_indices.extend(batch_list)
|
| 465 |
+
return all_indices
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
# def shuffle_batches_multiproc(g: torch.Generator, list_of_tensors: List[torch.Tensor], num_processes: int = 8) -> List[int]:
|
| 469 |
+
# all_indices = []
|
| 470 |
+
# print(f"Shuffling {len(list_of_tensors)} tensors with {num_processes} workers.")
|
| 471 |
+
# pbar = tqdm_if_main_worker(list_of_tensors, colour="orange", desc=f"Sampler shuffling per-batch (nproc={num_processes})")
|
| 472 |
+
# pool = multiprocessing.Pool(processes=num_processes)
|
| 473 |
+
# chunk_size = len(list_of_tensors) // num_processes
|
| 474 |
+
# chunks = [list_of_tensors[i:i + chunk_size] for i in range(0, len(list_of_tensors), chunk_size)]
|
| 475 |
+
# worker_func = functools.partial(shuffle_batches, g=g)
|
| 476 |
+
# results = pool.map(worker_func, chunks)
|
| 477 |
+
# all_indices = []
|
| 478 |
+
# for result in results:
|
| 479 |
+
# all_indices.extend(result)
|
| 480 |
+
# pbar.update()
|
| 481 |
+
# return all_indices
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def exit_if_running_or_finished_wandb(
|
| 485 |
+
project_name: str,
|
| 486 |
+
exp_group: str, exp_name: str
|
| 487 |
+
) -> None:
|
| 488 |
+
print("Checking if experiment is already running...")
|
| 489 |
+
import wandb
|
| 490 |
+
|
| 491 |
+
api = wandb.Api()
|
| 492 |
+
running_runs = api.runs(
|
| 493 |
+
path="cde-0",
|
| 494 |
+
filters={
|
| 495 |
+
"display_name": exp_name,
|
| 496 |
+
"state": {"$regex": "Running|Finished"},
|
| 497 |
+
"config.exp_group": exp_group,
|
| 498 |
+
}
|
| 499 |
+
)
|
| 500 |
+
print("Found", len(running_runs), f"runs with name {exp_name} and group {exp_group} in {project_name}.")
|
| 501 |
+
|
| 502 |
+
if len(running_runs) > 0:
|
| 503 |
+
print("Exiting because experiment is already running or completed.")
|
| 504 |
+
sys.exit(0)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
HN_FILTER_TOKENIZER_MAP = {
|
| 508 |
+
"nomic": "nomic-ai/nomic-embed-text-v1",
|
| 509 |
+
"stella": "dunzhang/stella_en_400M_v5",
|
| 510 |
+
"sbert": "sentence-transformers/all-MiniLM-L6-v2",
|
| 511 |
+
"sentence_t5": "sentence-transformers/sentence-t5-base",
|
| 512 |
+
"gte": "Alibaba-NLP/gte-large-en-v1.5",
|
| 513 |
+
}
|
| 514 |
+
def load_hn_filter_tokenizer(tokenizer_name: str) -> Optional[transformers.PreTrainedTokenizer]:
|
| 515 |
+
if tokenizer_name in HN_FILTER_TOKENIZER_MAP:
|
| 516 |
+
return transformers.AutoTokenizer.from_pretrained(HN_FILTER_TOKENIZER_MAP[tokenizer_name])
|
| 517 |
+
else:
|
| 518 |
+
return None
|
model.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import transformers
|
| 7 |
+
|
| 8 |
+
from cde.lib.dist import print0
|
| 9 |
+
from cde.lib.tensor import mean_pool, mean_pool_3d, mean_pool_weighted, last_token_pool
|
| 10 |
+
|
| 11 |
+
from cde.lib import load_embedder_and_tokenizer, ContextualModelConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
gpt_tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
|
| 15 |
+
|
| 16 |
+
def limit_layers(model: transformers.PreTrainedModel, n_layers: int) -> None:
|
| 17 |
+
if hasattr(model, 'transformer'):
|
| 18 |
+
if hasattr(model.transformer, 'h'):
|
| 19 |
+
# gpt2
|
| 20 |
+
model.transformer.h = model.transformer.h[:n_layers]
|
| 21 |
+
else:
|
| 22 |
+
model.transformer.layer = model.transformer.layer[:n_layers]
|
| 23 |
+
elif hasattr(model, 'encoder'):
|
| 24 |
+
if hasattr(model.encoder, 'layers'):
|
| 25 |
+
model.encoder.layers = model.encoder.layers[:n_layers]
|
| 26 |
+
else:
|
| 27 |
+
model.encoder.layer = model.encoder.layer[:n_layers]
|
| 28 |
+
else:
|
| 29 |
+
raise RuntimeError(f"unknown how to limit layers of model {type(model)}")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def disable_dropout(model: torch.nn.Module):
|
| 33 |
+
dropout_modules = [m for m in model.modules() if isinstance(m, torch.nn.Dropout)]
|
| 34 |
+
for m in dropout_modules:
|
| 35 |
+
m.p = 0.0
|
| 36 |
+
print0(
|
| 37 |
+
f"Disabled {len(dropout_modules)} dropout modules from model type {type(model)}"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def disable_causality(model: torch.nn.Module):
|
| 42 |
+
disabled_modules = 0
|
| 43 |
+
for m in model.modules():
|
| 44 |
+
if hasattr(m, "is_causal"):
|
| 45 |
+
m.is_causal = False
|
| 46 |
+
disabled_modules += 1
|
| 47 |
+
print0(
|
| 48 |
+
f"Set is_causal=False in {disabled_modules} modules from model type {type(model)}"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ContextualModelMixin(nn.Module):
|
| 53 |
+
@property
|
| 54 |
+
def num_corpus_tokens(self) -> int:
|
| 55 |
+
return self.transductive_corpus_size * self.transductive_tokens_per_document
|
| 56 |
+
|
| 57 |
+
def contextual_init(self):
|
| 58 |
+
self.n_soft_prompt = 8
|
| 59 |
+
self.prompt_projection = torch.nn.Sequential(
|
| 60 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size),
|
| 61 |
+
torch.nn.ReLU(),
|
| 62 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size * self.n_soft_prompt)
|
| 63 |
+
)
|
| 64 |
+
self.transductive_corpus_size = vars(self.config).get("transductive_corpus_size", 1)
|
| 65 |
+
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1)
|
| 66 |
+
self.randomize_dataset_sequence_order = True
|
| 67 |
+
self.sequence_dropout_prob = vars(self.config).get("transductive_sequence_dropout_prob", 0.0)
|
| 68 |
+
if self.sequence_dropout_prob > 0.0:
|
| 69 |
+
self.sequence_dropout_null_embedding = torch.nn.Parameter(
|
| 70 |
+
torch.randn(self.hidden_size) * 0.01,
|
| 71 |
+
requires_grad = True
|
| 72 |
+
)
|
| 73 |
+
self.output_projection = torch.nn.Sequential(
|
| 74 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size),
|
| 75 |
+
torch.nn.ReLU(),
|
| 76 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size)
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def _prepare_dataset_embeddings(
|
| 80 |
+
self,
|
| 81 |
+
input_ids: torch.Tensor,
|
| 82 |
+
dataset_embeddings: torch.Tensor,
|
| 83 |
+
null_dataset_embedding: bool = False,
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
if not isinstance(dataset_embeddings, torch.Tensor):
|
| 86 |
+
dataset_embeddings = torch.tensor(dataset_embeddings)
|
| 87 |
+
|
| 88 |
+
if len(dataset_embeddings.shape) == 2:
|
| 89 |
+
# Auto-expand for a batch.
|
| 90 |
+
dataset_embeddings = dataset_embeddings[None, :, :] # (b, d) -> (1, b, d)
|
| 91 |
+
dataset_embeddings = dataset_embeddings.to(input_ids.device)
|
| 92 |
+
|
| 93 |
+
batch_size = input_ids.shape[0]
|
| 94 |
+
if (self.transductive_tokens_per_document > 1):
|
| 95 |
+
if self.training:
|
| 96 |
+
# Choose N random documents to fill our context window with.
|
| 97 |
+
# This logic is a little confusing but allows us to sample a
|
| 98 |
+
# different batch *per-document*
|
| 99 |
+
assert dataset_embeddings.shape[1] == self.transductive_tokens_per_document
|
| 100 |
+
R = torch.randint(
|
| 101 |
+
low=0,
|
| 102 |
+
high=len(dataset_embeddings),
|
| 103 |
+
size=(batch_size, self.config.transductive_corpus_size),
|
| 104 |
+
device=dataset_embeddings.device
|
| 105 |
+
)
|
| 106 |
+
# TODO make this deterministic somehow for evaluation?
|
| 107 |
+
dataset_embeddings = dataset_embeddings[R].reshape((batch_size, self.num_corpus_tokens, self.hidden_size))
|
| 108 |
+
else:
|
| 109 |
+
dataset_embeddings = dataset_embeddings.reshape((1, self.num_corpus_tokens, self.hidden_size))
|
| 110 |
+
# print("reshaped to dataset_embeddings.shape =", dataset_embeddings.shape)
|
| 111 |
+
|
| 112 |
+
if dataset_embeddings.shape[1] > self.num_corpus_tokens:
|
| 113 |
+
# If too many dataset embeddings are passed in, just take the first N until
|
| 114 |
+
# we have the proper number.
|
| 115 |
+
dataset_embeddings = dataset_embeddings[:, :self.num_corpus_tokens, :]
|
| 116 |
+
|
| 117 |
+
_, corpus_size, _hidden_size = dataset_embeddings.shape
|
| 118 |
+
if _ == 1:
|
| 119 |
+
# Auto-expand for a batch.
|
| 120 |
+
dataset_embeddings = dataset_embeddings.expand((batch_size, -1, -1))
|
| 121 |
+
|
| 122 |
+
if self.training and self.sequence_dropout_prob > 0.0:
|
| 123 |
+
sequence_dropout_mask = (
|
| 124 |
+
torch.rand((batch_size, corpus_size), device=dataset_embeddings.device) < self.sequence_dropout_prob
|
| 125 |
+
)
|
| 126 |
+
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
|
| 127 |
+
dataset_embeddings = torch.where(
|
| 128 |
+
sequence_dropout_mask[..., None], null_embeddings, dataset_embeddings
|
| 129 |
+
)
|
| 130 |
+
elif null_dataset_embedding:
|
| 131 |
+
null_embeddings = self.sequence_dropout_null_embedding[None, None].expand(batch_size, corpus_size, -1)
|
| 132 |
+
dataset_embeddings = null_embeddings
|
| 133 |
+
|
| 134 |
+
# backbone_max_seq_length = self.backbone.config.max_trained_positions
|
| 135 |
+
# assert batch_size + (2 * self.n_soft_prompt + corpus_size) <= backbone_max_seq_length, "too many hard negatives for backbone model"
|
| 136 |
+
soft_prompt = torch.ones((1, self.hidden_size), device=dataset_embeddings.device, dtype=dataset_embeddings.dtype)
|
| 137 |
+
soft_prompt = self.prompt_projection(soft_prompt).reshape((1, self.n_soft_prompt, self.hidden_size))
|
| 138 |
+
soft_prompt = soft_prompt.expand((len(dataset_embeddings), -1, -1)) # -> (b, 4+b, d) # soft_prompt.repeat((len(input_ids), 1, 1))
|
| 139 |
+
soft_prompt = torch.cat((dataset_embeddings, soft_prompt), dim=1)
|
| 140 |
+
|
| 141 |
+
return soft_prompt
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class BiEncoder(transformers.PreTrainedModel):
|
| 145 |
+
config_class = ContextualModelConfig
|
| 146 |
+
embedder: transformers.PreTrainedModel
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
config, #: transformers.PreTrainedConfig,
|
| 150 |
+
):
|
| 151 |
+
super().__init__(config=config)
|
| 152 |
+
embedder, _ = load_embedder_and_tokenizer(
|
| 153 |
+
config.embedder,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if config.limit_layers:
|
| 157 |
+
print0(f"Limiting layers to {config.limit_layers}")
|
| 158 |
+
limit_layers(embedder, config.limit_layers)
|
| 159 |
+
|
| 160 |
+
self.embedder = embedder
|
| 161 |
+
# if ("t5" in embedder.config.model_type):
|
| 162 |
+
# print0(f"using torch.compile() on embedder of type `{embedder.config.model_type}`")
|
| 163 |
+
# self.embedder = torch.compile(self.embedder)
|
| 164 |
+
self.hidden_size = self.embedder.config.hidden_size
|
| 165 |
+
# Allow pooling to multiple tokens per document
|
| 166 |
+
self.transductive_tokens_per_document = vars(self.config).get("transductive_tokens_per_document", 1)
|
| 167 |
+
self.mlp = torch.nn.Sequential(
|
| 168 |
+
torch.nn.Linear(self.hidden_size, self.hidden_size),
|
| 169 |
+
torch.nn.GELU(),
|
| 170 |
+
torch.nn.Linear(self.hidden_size, self.config.embedding_output_dim or self.hidden_size),
|
| 171 |
+
)
|
| 172 |
+
self.temp = config.logit_scale
|
| 173 |
+
|
| 174 |
+
if config.disable_dropout:
|
| 175 |
+
disable_dropout(self)
|
| 176 |
+
self.pooling_strategy = vars(config).get("pooling_strategy", "mean")
|
| 177 |
+
|
| 178 |
+
def forward(
|
| 179 |
+
self,
|
| 180 |
+
input_ids: torch.Tensor,
|
| 181 |
+
attention_mask: torch.Tensor,
|
| 182 |
+
dataset_input_ids: Optional[torch.Tensor] = None,
|
| 183 |
+
dataset_attention_mask: Optional[torch.Tensor] = None,
|
| 184 |
+
token_type_ids = None,
|
| 185 |
+
output_hidden_states: bool = False,
|
| 186 |
+
) -> torch.Tensor:
|
| 187 |
+
"""
|
| 188 |
+
query_embedding (float torch.Tensor) - shape (batch_size, embedding_dim)
|
| 189 |
+
document_embeddings (float torch.Tensor) - shape (corpus_size, embedding_dim)
|
| 190 |
+
where the corpus_size >= batch_size and is structured like this:
|
| 191 |
+
[d1, d2, d3, hn1_1, hn1_2, hn2_1, hn2_2, hn3_1, hn3_2]
|
| 192 |
+
for a corpus with three documents and two hard negatives per document
|
| 193 |
+
"""
|
| 194 |
+
del token_type_ids
|
| 195 |
+
|
| 196 |
+
outputs = (
|
| 197 |
+
self.embedder(
|
| 198 |
+
input_ids=input_ids,
|
| 199 |
+
attention_mask=attention_mask,
|
| 200 |
+
).last_hidden_state
|
| 201 |
+
)
|
| 202 |
+
if self.transductive_tokens_per_document > 1:
|
| 203 |
+
document_embeddings = None
|
| 204 |
+
batch_size, seq_length, output_dim = outputs.shape
|
| 205 |
+
|
| 206 |
+
if seq_length % self.transductive_tokens_per_document != 0:
|
| 207 |
+
# Pad to nearest multiple
|
| 208 |
+
n_extra_embeds = self.transductive_tokens_per_document - (seq_length % self.transductive_tokens_per_document)
|
| 209 |
+
outputs = torch.cat(
|
| 210 |
+
(outputs, torch.zeros((batch_size, n_extra_embeds, output_dim), device=outputs.device)),
|
| 211 |
+
dim=1
|
| 212 |
+
)
|
| 213 |
+
attention_mask = torch.cat(
|
| 214 |
+
(attention_mask, torch.zeros((batch_size, n_extra_embeds), device=attention_mask.device)),
|
| 215 |
+
dim=1
|
| 216 |
+
)
|
| 217 |
+
seq_length += n_extra_embeds
|
| 218 |
+
print(f"Added {n_extra_embeds} padding tokens to input_ids and attention_mask")
|
| 219 |
+
|
| 220 |
+
# print("ftransductive_tokens_per_document {self.transductive_tokens_per_document} outputs.shape =", outputs.shape)
|
| 221 |
+
|
| 222 |
+
outputs = outputs.reshape(
|
| 223 |
+
(batch_size, self.transductive_tokens_per_document, seq_length // self.transductive_tokens_per_document, output_dim)
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
attention_mask = attention_mask.reshape((batch_size, self.transductive_tokens_per_document, -1))
|
| 227 |
+
document_embeddings = mean_pool_3d(outputs, attention_mask)
|
| 228 |
+
|
| 229 |
+
document_embeddings = document_embeddings.reshape((batch_size, self.transductive_tokens_per_document, output_dim))
|
| 230 |
+
else:
|
| 231 |
+
if self.pooling_strategy == "mean":
|
| 232 |
+
document_embeddings = mean_pool(outputs, attention_mask)
|
| 233 |
+
else:
|
| 234 |
+
document_embeddings = document_embeddings.max(dim=1)
|
| 235 |
+
output = self.mlp(document_embeddings)
|
| 236 |
+
# breakpoint()
|
| 237 |
+
|
| 238 |
+
if output_hidden_states:
|
| 239 |
+
return {
|
| 240 |
+
"hidden_states": outputs,
|
| 241 |
+
"pooled": output,
|
| 242 |
+
}
|
| 243 |
+
else:
|
| 244 |
+
return output
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class DatasetConditionedAutoregressive(transformers.PreTrainedModel, ContextualModelMixin):
|
| 248 |
+
def __init__(
|
| 249 |
+
self,
|
| 250 |
+
config,
|
| 251 |
+
dataset_backbone: transformers.PreTrainedModel,
|
| 252 |
+
first_stage_hidden_size: int,
|
| 253 |
+
):
|
| 254 |
+
super().__init__(config=config)
|
| 255 |
+
self.backbone = dataset_backbone
|
| 256 |
+
self.backbone_hidden_size = self.backbone.config.hidden_size
|
| 257 |
+
self.hidden_size = first_stage_hidden_size # Input token size
|
| 258 |
+
self.contextual_init()
|
| 259 |
+
disable_causality(self.backbone)
|
| 260 |
+
|
| 261 |
+
self.pool_ignore_contextual_tokens = vars(self.config).get("pool_ignore_contextual_tokens", False)
|
| 262 |
+
self.pool_ignore_instruction_tokens = vars(self.config).get("pool_ignore_instruction_tokens", False)
|
| 263 |
+
self.pool_instruction_end_id = self.backbone.config.bos_token_id
|
| 264 |
+
|
| 265 |
+
# Override contextual init
|
| 266 |
+
self.output_projection = torch.nn.Sequential(
|
| 267 |
+
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size),
|
| 268 |
+
torch.nn.ReLU(),
|
| 269 |
+
torch.nn.Linear(self.backbone_hidden_size, self.backbone_hidden_size)
|
| 270 |
+
)
|
| 271 |
+
self._shift_rotary_embedding()
|
| 272 |
+
|
| 273 |
+
@property
|
| 274 |
+
def num_corpus_tokens(self) -> int:
|
| 275 |
+
return self.config.transductive_corpus_size * self.transductive_tokens_per_document
|
| 276 |
+
|
| 277 |
+
@property
|
| 278 |
+
def corpus_token_ratio(self) -> float:
|
| 279 |
+
# How many tokens from the first stage make one token in the second
|
| 280 |
+
# stage?
|
| 281 |
+
return self.backbone_hidden_size / self.hidden_size
|
| 282 |
+
|
| 283 |
+
def corpus_token_pad_size(self, n_tokens: int) -> int:
|
| 284 |
+
return self.hidden_size % self.backbone_hidden_size
|
| 285 |
+
|
| 286 |
+
def _shift_rotary_embedding(self) -> None:
|
| 287 |
+
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True)
|
| 288 |
+
# TODO: Can we do this for LLAMA?
|
| 289 |
+
print0("Warning: Positional embedding disabling not implemented for LLAMA.")
|
| 290 |
+
|
| 291 |
+
def forward(
|
| 292 |
+
self,
|
| 293 |
+
input_ids: torch.Tensor,
|
| 294 |
+
attention_mask: torch.Tensor,
|
| 295 |
+
dataset_embeddings: torch.Tensor,
|
| 296 |
+
output_hidden_states: bool = False,
|
| 297 |
+
null_dataset_embedding: bool = False,
|
| 298 |
+
) -> torch.Tensor:
|
| 299 |
+
soft_prompt = self._prepare_dataset_embeddings(
|
| 300 |
+
input_ids=input_ids,
|
| 301 |
+
dataset_embeddings=dataset_embeddings,
|
| 302 |
+
null_dataset_embedding=null_dataset_embedding,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# Reshape for this model.
|
| 306 |
+
# print("[DatasetConditionedAutoregressive] 1 -> soft_prompt.shape =", soft_prompt.shape)
|
| 307 |
+
num_soft_elements = torch.prod(torch.tensor(soft_prompt.shape[1:])).item()
|
| 308 |
+
soft_prompt = soft_prompt.reshape((soft_prompt.shape[0], num_soft_elements))
|
| 309 |
+
num_padding_elements = self.backbone_hidden_size - (num_soft_elements % self.backbone_hidden_size)
|
| 310 |
+
padding = torch.ones((soft_prompt.shape[0], num_padding_elements), device=soft_prompt.device)
|
| 311 |
+
soft_prompt = torch.cat((soft_prompt, padding), dim=1)
|
| 312 |
+
soft_prompt = soft_prompt.reshape(
|
| 313 |
+
(soft_prompt.shape[0], -1, self.backbone_hidden_size)
|
| 314 |
+
)
|
| 315 |
+
# print("[DatasetConditionedAutoregressive] 2 -> soft_prompt.shape =", soft_prompt.shape)
|
| 316 |
+
|
| 317 |
+
backbone_attention_mask = torch.ones(
|
| 318 |
+
soft_prompt.shape[0:2],
|
| 319 |
+
dtype=torch.long,
|
| 320 |
+
device=soft_prompt.device,
|
| 321 |
+
)
|
| 322 |
+
token_embeddings = self.backbone.get_input_embeddings()
|
| 323 |
+
inputs_embeds = token_embeddings(input_ids) # (b, s) -> (b, s, d)
|
| 324 |
+
# print("[2] inputs_embeds.shape =", inputs_embeds.shape)
|
| 325 |
+
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
|
| 326 |
+
# print("[3.a] inputs_embeds.shape =", inputs_embeds.shape)
|
| 327 |
+
input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
|
| 328 |
+
# print("[3.b] attention_mask.shape =", attention_mask.shape)
|
| 329 |
+
|
| 330 |
+
output = self.backbone(
|
| 331 |
+
inputs_embeds=inputs_embeds,
|
| 332 |
+
attention_mask=input_attention_mask,
|
| 333 |
+
output_hidden_states=True,
|
| 334 |
+
) # (1, 4 + b + s, d)
|
| 335 |
+
# trim soft prompt
|
| 336 |
+
output_vectors = output.hidden_states[-1]
|
| 337 |
+
n_soft_prompt_tokens = soft_prompt.shape[1]
|
| 338 |
+
|
| 339 |
+
if self.pool_ignore_instruction_tokens:
|
| 340 |
+
# Denote the end of an instruction with an extra BOS token.
|
| 341 |
+
# This is a bit arcane but relies on the fact that there will be a BOS token after the
|
| 342 |
+
# instruction, but also there may or may not be a BOS token at the beginning.
|
| 343 |
+
instruction_end_idx = (
|
| 344 |
+
(input_ids == self.pool_instruction_end_id) &
|
| 345 |
+
attention_mask &
|
| 346 |
+
(torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] > 0)
|
| 347 |
+
).int().argmax(1)
|
| 348 |
+
is_instruction_token_mask = (
|
| 349 |
+
torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] <= instruction_end_idx[:, None]
|
| 350 |
+
)
|
| 351 |
+
# catch edge case where there is no instruction
|
| 352 |
+
is_instruction_token_mask = is_instruction_token_mask.where(
|
| 353 |
+
(instruction_end_idx > 0)[:, None], torch.zeros_like(is_instruction_token_mask)
|
| 354 |
+
)
|
| 355 |
+
input_attention_mask = torch.cat((
|
| 356 |
+
backbone_attention_mask,
|
| 357 |
+
attention_mask & ~is_instruction_token_mask), dim=1
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
output_attention_mask = input_attention_mask
|
| 361 |
+
if self.pool_ignore_contextual_tokens:
|
| 362 |
+
output_vectors = output_vectors[:, n_soft_prompt_tokens:, :]
|
| 363 |
+
output_attention_mask = output_attention_mask[:, n_soft_prompt_tokens:]
|
| 364 |
+
|
| 365 |
+
# Take last token position
|
| 366 |
+
if vars(self.config).get("pooling_strategy") == "last_token":
|
| 367 |
+
output_pooled = last_token_pool(output_vectors, output_attention_mask)
|
| 368 |
+
elif vars(self.config).get("pooling_strategy") == "mean":
|
| 369 |
+
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
| 370 |
+
else:
|
| 371 |
+
output_pooled = mean_pool_weighted(output_vectors, output_attention_mask)
|
| 372 |
+
|
| 373 |
+
# average with original vectors
|
| 374 |
+
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
|
| 375 |
+
|
| 376 |
+
if output_hidden_states:
|
| 377 |
+
return {
|
| 378 |
+
"hidden_states": output_vectors,
|
| 379 |
+
"pooled": output,
|
| 380 |
+
}
|
| 381 |
+
else:
|
| 382 |
+
return output
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class DatasetConditionedBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
|
| 386 |
+
def __init__(
|
| 387 |
+
self,
|
| 388 |
+
config,
|
| 389 |
+
dataset_backbone: transformers.PreTrainedModel,
|
| 390 |
+
):
|
| 391 |
+
super().__init__(config=config)
|
| 392 |
+
self.backbone = dataset_backbone
|
| 393 |
+
self.hidden_size = self.backbone.config.hidden_size
|
| 394 |
+
self.hidden_size = dataset_backbone.config.hidden_size
|
| 395 |
+
self.contextual_init()
|
| 396 |
+
self._shift_rotary_embedding()
|
| 397 |
+
|
| 398 |
+
self.pool_ignore_contextual_tokens = vars(self.config).get("pool_ignore_contextual_tokens", True)
|
| 399 |
+
self.pool_ignore_instruction_tokens = vars(self.config).get("pool_ignore_instruction_tokens", False)
|
| 400 |
+
|
| 401 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(self.config.embedder)
|
| 402 |
+
self.pool_instruction_end_id = tokenizer.encode(": ", add_special_tokens=False)[0] # Hardcoded for colon-ending prefixes.
|
| 403 |
+
|
| 404 |
+
@property
|
| 405 |
+
def num_corpus_tokens(self) -> int:
|
| 406 |
+
return self.config.transductive_corpus_size * self.transductive_tokens_per_document
|
| 407 |
+
|
| 408 |
+
def _shift_rotary_embedding(self) -> None:
|
| 409 |
+
disable_transductive_rotary_embedding = vars(self.config).get("disable_transductive_rotary_embedding", True)
|
| 410 |
+
if self.backbone.config.model_type.startswith("nomic") and disable_transductive_rotary_embedding:
|
| 411 |
+
# We only want to apply positional embeddings to the
|
| 412 |
+
# *text* portion of the backbone network.
|
| 413 |
+
self.backbone.config.rotary_start_pos = 0.0
|
| 414 |
+
rotary_disabled = 0
|
| 415 |
+
|
| 416 |
+
rotary_start_pos = self.num_corpus_tokens
|
| 417 |
+
for module in self.backbone.modules():
|
| 418 |
+
if hasattr(module, "rotary_emb_dim"):
|
| 419 |
+
module.rotary_start_pos = rotary_start_pos
|
| 420 |
+
rotary_disabled += 1
|
| 421 |
+
print0(f"modified {rotary_disabled} rotary modules – set rotary_start_pos to {rotary_start_pos}")
|
| 422 |
+
|
| 423 |
+
def forward(
|
| 424 |
+
self,
|
| 425 |
+
input_ids: torch.Tensor,
|
| 426 |
+
attention_mask: torch.Tensor,
|
| 427 |
+
dataset_embeddings: torch.Tensor,
|
| 428 |
+
output_hidden_states: bool = False,
|
| 429 |
+
null_dataset_embedding: bool = False,
|
| 430 |
+
) -> torch.Tensor:
|
| 431 |
+
soft_prompt = self._prepare_dataset_embeddings(
|
| 432 |
+
input_ids=input_ids,
|
| 433 |
+
dataset_embeddings=dataset_embeddings,
|
| 434 |
+
null_dataset_embedding=null_dataset_embedding,
|
| 435 |
+
)
|
| 436 |
+
backbone_attention_mask = torch.ones(
|
| 437 |
+
soft_prompt.shape[0:2],
|
| 438 |
+
dtype=torch.long,
|
| 439 |
+
device=soft_prompt.device,
|
| 440 |
+
)
|
| 441 |
+
inputs_embeds = self.backbone.embeddings(input_ids) # (b, s) -> (b, s, d)
|
| 442 |
+
inputs_embeds = torch.cat((soft_prompt, inputs_embeds), dim=1) # (v, 4+b+s, d)
|
| 443 |
+
input_attention_mask = torch.cat((backbone_attention_mask, attention_mask), dim=1)
|
| 444 |
+
output = self.backbone(
|
| 445 |
+
inputs_embeds=inputs_embeds,
|
| 446 |
+
attention_mask=input_attention_mask,
|
| 447 |
+
) # (1, 4 + b + s, d)
|
| 448 |
+
# trim soft prompt
|
| 449 |
+
output_vectors = output.last_hidden_state
|
| 450 |
+
|
| 451 |
+
# use only these tokens
|
| 452 |
+
n_soft_prompt_tokens = soft_prompt.shape[1]
|
| 453 |
+
|
| 454 |
+
if self.pool_ignore_instruction_tokens:
|
| 455 |
+
# Denote the end of an instruction with an extra BOS token.
|
| 456 |
+
# This is a bit arcane but relies on the fact that there will be a BOS token after the
|
| 457 |
+
# instruction, but also there may or may not be a BOS token at the beginning.
|
| 458 |
+
instruction_end_idx = (
|
| 459 |
+
(input_ids == self.pool_instruction_end_id) &
|
| 460 |
+
attention_mask &
|
| 461 |
+
(torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] > 0)
|
| 462 |
+
).int().argmax(1)
|
| 463 |
+
is_instruction_token_mask = (
|
| 464 |
+
torch.arange(input_ids.shape[1], device=input_ids.device)[None, :] <= instruction_end_idx[:, None]
|
| 465 |
+
)
|
| 466 |
+
# catch edge case where there is no instruction
|
| 467 |
+
is_instruction_token_mask = is_instruction_token_mask.where(
|
| 468 |
+
(instruction_end_idx > 0)[:, None], torch.zeros_like(is_instruction_token_mask)
|
| 469 |
+
)
|
| 470 |
+
output_attention_mask = torch.cat((backbone_attention_mask, attention_mask & ~is_instruction_token_mask), dim=1)
|
| 471 |
+
else:
|
| 472 |
+
output_attention_mask = input_attention_mask
|
| 473 |
+
|
| 474 |
+
if self.pool_ignore_contextual_tokens:
|
| 475 |
+
output_vectors = output_vectors[:, n_soft_prompt_tokens:, :]
|
| 476 |
+
output_attention_mask = output_attention_mask[:, n_soft_prompt_tokens:]
|
| 477 |
+
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
| 478 |
+
# average with original vectors
|
| 479 |
+
output = self.output_projection(output_pooled) + output_pooled # (b, d) -> (b, d) / with residual connection
|
| 480 |
+
|
| 481 |
+
if output_hidden_states:
|
| 482 |
+
return {
|
| 483 |
+
"hidden_states": output_vectors,
|
| 484 |
+
"pooled": output,
|
| 485 |
+
}
|
| 486 |
+
else:
|
| 487 |
+
return output
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class DatasetPrefixBiencoder(transformers.PreTrainedModel, ContextualModelMixin):
|
| 491 |
+
def __init__(
|
| 492 |
+
self,
|
| 493 |
+
config, #: transformers.PreTrainedConfig,
|
| 494 |
+
embedder: transformers.PreTrainedModel,
|
| 495 |
+
):
|
| 496 |
+
super().__init__(config=config)
|
| 497 |
+
self.embedder = embedder
|
| 498 |
+
self.hidden_size = self.embedder.config.hidden_size
|
| 499 |
+
self.contextual_init()
|
| 500 |
+
|
| 501 |
+
def forward(
|
| 502 |
+
self,
|
| 503 |
+
input_ids: torch.Tensor,
|
| 504 |
+
attention_mask: torch.Tensor,
|
| 505 |
+
dataset_input_ids: torch.Tensor,
|
| 506 |
+
dataset_attention_mask: torch.Tensor,
|
| 507 |
+
output_hidden_states: bool = False,
|
| 508 |
+
) -> torch.Tensor:
|
| 509 |
+
R = torch.randint(low=0, high=len(dataset_input_ids), size=(len(input_ids),), device=dataset_input_ids.device)
|
| 510 |
+
|
| 511 |
+
dataset_input_ids = dataset_input_ids[R]
|
| 512 |
+
input_ids = torch.cat((dataset_input_ids, input_ids), dim=1)
|
| 513 |
+
|
| 514 |
+
dataset_attention_mask = torch.ones_like(dataset_attention_mask, device=dataset_attention_mask.device)
|
| 515 |
+
input_attention_mask = torch.cat((dataset_attention_mask, attention_mask), dim=1)
|
| 516 |
+
output_attention_mask = torch.cat(
|
| 517 |
+
(torch.zeros_like(dataset_input_ids), attention_mask), dim=1
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
output = self.embedder(
|
| 521 |
+
input_ids=input_ids,
|
| 522 |
+
attention_mask=input_attention_mask,
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
output_vectors = output.last_hidden_state
|
| 526 |
+
output_pooled = mean_pool(output_vectors, output_attention_mask)
|
| 527 |
+
output = self.output_projection(output_pooled) # (b, 2d) -> (b, d)
|
| 528 |
+
|
| 529 |
+
if output_hidden_states:
|
| 530 |
+
S_d = dataset_attention_mask.shape[1]
|
| 531 |
+
output_vectors = output_vectors[:, S_d:, :]
|
| 532 |
+
return {
|
| 533 |
+
"hidden_states": output_vectors,
|
| 534 |
+
"pooled": output,
|
| 535 |
+
}
|
| 536 |
+
else:
|
| 537 |
+
return output
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class ContextualDocumentEmbeddingTransformer(transformers.PreTrainedModel):
|
| 541 |
+
config_class = ContextualModelConfig
|
| 542 |
+
embedder: transformers.PreTrainedModel
|
| 543 |
+
dataset_backbone: transformers.PreTrainedModel
|
| 544 |
+
def __init__(
|
| 545 |
+
self,
|
| 546 |
+
config,
|
| 547 |
+
):
|
| 548 |
+
super().__init__(config=config)
|
| 549 |
+
dataset_backbone, _ = load_embedder_and_tokenizer(
|
| 550 |
+
vars(config).get("dataset_backbone") or config.embedder
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
if config.limit_layers:
|
| 554 |
+
print0(f"Limiting layers to {config.limit_layers}")
|
| 555 |
+
limit_layers(dataset_backbone, config.limit_layers)
|
| 556 |
+
|
| 557 |
+
biencoder_config = copy.deepcopy(config)
|
| 558 |
+
biencoder_config.embedding_output_dim = None
|
| 559 |
+
biencoder_config.limit_layers = vars(self.config).get("limit_layers_first_stage", None)
|
| 560 |
+
self.first_stage_model = BiEncoder(
|
| 561 |
+
config=biencoder_config,
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
if vars(config).get("autoregressive_backbone", False):
|
| 565 |
+
self.second_stage_model = DatasetConditionedAutoregressive(
|
| 566 |
+
config=config,
|
| 567 |
+
dataset_backbone=dataset_backbone,
|
| 568 |
+
first_stage_hidden_size=self.first_stage_model.hidden_size,
|
| 569 |
+
)
|
| 570 |
+
else:
|
| 571 |
+
self.second_stage_model = DatasetConditionedBiencoder(
|
| 572 |
+
config=config,
|
| 573 |
+
dataset_backbone=dataset_backbone
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
self.temp = config.logit_scale
|
| 577 |
+
if config.disable_dropout:
|
| 578 |
+
disable_dropout(self)
|
| 579 |
+
|
| 580 |
+
transductive_tie_token_embeddings = vars(self.config).get("transductive_tie_token_embeddings", False)
|
| 581 |
+
if transductive_tie_token_embeddings:
|
| 582 |
+
self.second_stage_model.backbone.embeddings.word_embeddings.weight = (
|
| 583 |
+
self.first_stage_model.embedder.embeddings.word_embeddings.weight
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
def forward(
|
| 587 |
+
self,
|
| 588 |
+
input_ids: torch.Tensor,
|
| 589 |
+
attention_mask: torch.Tensor,
|
| 590 |
+
dataset_input_ids: Optional[torch.Tensor],
|
| 591 |
+
dataset_attention_mask: Optional[torch.Tensor],
|
| 592 |
+
output_hidden_states: bool = False,
|
| 593 |
+
) -> torch.Tensor:
|
| 594 |
+
"""
|
| 595 |
+
input_ids (long torch.Tensor) – ids of input tokens
|
| 596 |
+
attention_mask (bool torch.Tensor)
|
| 597 |
+
"""
|
| 598 |
+
dataset_embeddings = self.first_stage_model(
|
| 599 |
+
input_ids=dataset_input_ids,
|
| 600 |
+
attention_mask=dataset_attention_mask
|
| 601 |
+
)
|
| 602 |
+
return self.second_stage_model(
|
| 603 |
+
input_ids=input_ids,
|
| 604 |
+
attention_mask=attention_mask,
|
| 605 |
+
dataset_embeddings=dataset_embeddings,
|
| 606 |
+
output_hidden_states=output_hidden_states,
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def get_model_class(name: str):
|
| 612 |
+
if name in 'transductive':
|
| 613 |
+
return ContextualDocumentEmbeddingTransformer
|
| 614 |
+
elif name == 'biencoder':
|
| 615 |
+
return BiEncoder
|
| 616 |
+
elif name == "biencoder_plus_plus":
|
| 617 |
+
from cde.model_extra import BiEncoderPlusPlus
|
| 618 |
+
return BiEncoderPlusPlus
|
| 619 |
+
elif name == "dataset_prefix_biencoder":
|
| 620 |
+
return DatasetPrefixBiencoder
|
| 621 |
+
else:
|
| 622 |
+
raise ValueError(f'unknown model cls {name}')
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e7cca261c510de07c012f3019366f1b6c5720761b6966b0388faea6e70398983
|
| 3 |
+
size 1124594680
|