Upload evaluate_models.ipynb
Browse files- evaluate_models.ipynb +172 -70
    	
        evaluate_models.ipynb
    CHANGED
    
    | @@ -2,92 +2,122 @@ | |
| 2 | 
             
             "cells": [
         | 
| 3 | 
             
              {
         | 
| 4 | 
             
               "cell_type": "code",
         | 
|  | |
| 5 | 
             
               "id": "initial_id",
         | 
| 6 | 
             
               "metadata": {
         | 
| 7 | 
             
                "collapsed": true
         | 
| 8 | 
             
               },
         | 
|  | |
| 9 | 
             
               "source": [
         | 
| 10 | 
             
                "import os\n",
         | 
|  | |
| 11 | 
             
                "\n",
         | 
| 12 | 
            -
                "IS_COLAB = True if  | 
| 13 | 
             
                "if IS_COLAB:\n",
         | 
| 14 | 
             
                "    # this needs to run before all other imports\n",
         | 
| 15 | 
            -
                "    os.environ[ | 
| 16 | 
             
                "\n",
         | 
| 17 | 
             
                "import mteb\n",
         | 
|  | |
|  | |
|  | |
| 18 | 
             
                "from sentence_transformers import SentenceTransformer"
         | 
| 19 | 
            -
               ] | 
| 20 | 
            -
               "outputs": [],
         | 
| 21 | 
            -
               "execution_count": null
         | 
| 22 | 
             
              },
         | 
| 23 | 
             
              {
         | 
|  | |
|  | |
| 24 | 
             
               "metadata": {},
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 25 | 
             
               "cell_type": "code",
         | 
|  | |
|  | |
|  | |
|  | |
| 26 | 
             
               "source": [
         | 
| 27 | 
             
                "MODELS = {\n",
         | 
| 28 | 
            -
                "     | 
| 29 | 
            -
                "         | 
| 30 | 
            -
                "         | 
|  | |
| 31 | 
             
                "    },\n",
         | 
| 32 | 
            -
                "     | 
| 33 | 
            -
                "         | 
| 34 | 
            -
                "         | 
|  | |
| 35 | 
             
                "    },\n",
         | 
| 36 | 
            -
                "     | 
| 37 | 
            -
                "         | 
| 38 | 
            -
                "         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 39 | 
             
                "    },\n",
         | 
| 40 | 
            -
                "    'mt-paper': {\n",
         | 
| 41 | 
            -
                "        'name': 'MongoDB/mdbr-leaf-mt',\n",
         | 
| 42 | 
            -
                "        'revision': 'c342f945a6855346bd5f48d5ee8b7e39120b0ce9',\n",
         | 
| 43 | 
            -
                "    }\n",
         | 
| 44 | 
             
                "}"
         | 
| 45 | 
            -
               ] | 
| 46 | 
            -
               "id": "f0189ff1e7814a5a",
         | 
| 47 | 
            -
               "outputs": [],
         | 
| 48 | 
            -
               "execution_count": null
         | 
| 49 | 
             
              },
         | 
| 50 | 
             
              {
         | 
| 51 | 
            -
               "metadata": {},
         | 
| 52 | 
             
               "cell_type": "markdown",
         | 
|  | |
|  | |
| 53 | 
             
               "source": [
         | 
| 54 | 
            -
                " | 
| 55 | 
             
                "* set the output folder and\n",
         | 
| 56 | 
             
                "* select one of the models defined above\n",
         | 
| 57 | 
             
                "* desired benchmark"
         | 
| 58 | 
            -
               ] | 
| 59 | 
            -
               "id": "371c6122efdf476a"
         | 
| 60 | 
             
              },
         | 
| 61 | 
             
              {
         | 
| 62 | 
            -
               "metadata": {},
         | 
| 63 | 
             
               "cell_type": "code",
         | 
|  | |
|  | |
|  | |
|  | |
| 64 | 
             
               "source": [
         | 
| 65 | 
            -
                "output_folder = f\"../../data/results/publish/\"\n",
         | 
|  | |
| 66 | 
             
                "\n",
         | 
| 67 | 
            -
                "model_selection = MODELS[ | 
| 68 | 
             
                "benchmark_name = \"BEIR\"\n",
         | 
| 69 | 
             
                "\n",
         | 
| 70 | 
             
                "# model_selection = MODELS['mt-prod']\n",
         | 
| 71 | 
             
                "# benchmark_name = \"MTEB(eng, v2)\""
         | 
| 72 | 
            -
               ] | 
| 73 | 
            -
               "id": "58d52a330febb9ac",
         | 
| 74 | 
            -
               "outputs": [],
         | 
| 75 | 
            -
               "execution_count": null
         | 
| 76 | 
             
              },
         | 
| 77 | 
             
              {
         | 
| 78 | 
            -
               "metadata": {},
         | 
| 79 | 
             
               "cell_type": "markdown",
         | 
| 80 | 
            -
               " | 
| 81 | 
            -
               " | 
|  | |
|  | |
|  | |
| 82 | 
             
              },
         | 
| 83 | 
             
              {
         | 
|  | |
|  | |
|  | |
| 84 | 
             
               "metadata": {},
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 85 | 
             
               "cell_type": "code",
         | 
|  | |
|  | |
|  | |
|  | |
| 86 | 
             
               "source": [
         | 
| 87 | 
            -
                "model = SentenceTransformer(\n",
         | 
| 88 | 
            -
                "    model_selection['name'],\n",
         | 
| 89 | 
            -
                "    revision=model_selection['revision']\n",
         | 
| 90 | 
            -
                ")\n",
         | 
| 91 | 
             
                "\n",
         | 
| 92 | 
             
                "# alternative:\n",
         | 
| 93 | 
             
                "# meta = mteb.get_model_meta(\n",
         | 
| @@ -95,25 +125,14 @@ | |
| 95 | 
             
                "#     revision=model_selection['revision']\n",
         | 
| 96 | 
             
                "# )\n",
         | 
| 97 | 
             
                "# model = meta.load_model()"
         | 
| 98 | 
            -
               ] | 
| 99 | 
            -
               "id": "d6f13945a94f7a85",
         | 
| 100 | 
            -
               "outputs": [],
         | 
| 101 | 
            -
               "execution_count": null
         | 
| 102 | 
             
              },
         | 
| 103 | 
             
              {
         | 
| 104 | 
            -
               "metadata": {},
         | 
| 105 | 
             
               "cell_type": "code",
         | 
| 106 | 
            -
               " | 
| 107 | 
            -
             | 
| 108 | 
            -
                "evaluation = mteb.MTEB(tasks=benchmark)"
         | 
| 109 | 
            -
               ],
         | 
| 110 | 
            -
               "id": "c716c6344f9cd939",
         | 
| 111 | 
            -
               "outputs": [],
         | 
| 112 | 
            -
               "execution_count": null
         | 
| 113 | 
            -
              },
         | 
| 114 | 
            -
              {
         | 
| 115 | 
             
               "metadata": {},
         | 
| 116 | 
            -
               " | 
| 117 | 
             
               "source": [
         | 
| 118 | 
             
                "%%time\n",
         | 
| 119 | 
             
                "results = evaluation.run(\n",
         | 
| @@ -122,28 +141,32 @@ | |
| 122 | 
             
                "    output_folder=output_folder,\n",
         | 
| 123 | 
             
                "    overwrite_results=True,\n",
         | 
| 124 | 
             
                ")"
         | 
| 125 | 
            -
               ] | 
| 126 | 
            -
               "id": "9bd44e88fc360663",
         | 
| 127 | 
            -
               "outputs": [],
         | 
| 128 | 
            -
               "execution_count": null
         | 
| 129 | 
             
              },
         | 
| 130 | 
             
              {
         | 
| 131 | 
            -
               "metadata": {},
         | 
| 132 | 
             
               "cell_type": "markdown",
         | 
| 133 | 
            -
               " | 
| 134 | 
            -
               " | 
|  | |
|  | |
|  | |
| 135 | 
             
              },
         | 
| 136 | 
             
              {
         | 
| 137 | 
            -
               "metadata": {},
         | 
| 138 | 
             
               "cell_type": "code",
         | 
|  | |
|  | |
|  | |
|  | |
| 139 | 
             
               "source": [
         | 
| 140 | 
            -
                "if model_selection[ | 
| 141 | 
             
                "    # quora is closer to a sentence similarity task than a retrieval one, as queries aren't proper user queries\n",
         | 
| 142 | 
             
                "    # we thus embed them without the typical query prompt\n",
         | 
| 143 | 
             
                "    model.prompts = {}\n",
         | 
| 144 | 
            -
                "    tasks = mteb.get_tasks( | 
| 145 | 
            -
                "        \ | 
| 146 | 
            -
                " | 
|  | |
|  | |
| 147 | 
             
                "\n",
         | 
| 148 | 
             
                "    evaluation = mteb.MTEB(tasks=tasks)\n",
         | 
| 149 | 
             
                "    results = evaluation.run(\n",
         | 
| @@ -152,10 +175,89 @@ | |
| 152 | 
             
                "        output_folder=output_folder,\n",
         | 
| 153 | 
             
                "        overwrite_results=True,\n",
         | 
| 154 | 
             
                "    )"
         | 
| 155 | 
            -
               ] | 
| 156 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 157 | 
             
               "outputs": [],
         | 
| 158 | 
            -
               " | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 159 | 
             
              }
         | 
| 160 | 
             
             ],
         | 
| 161 | 
             
             "metadata": {
         | 
|  | |
| 2 | 
             
             "cells": [
         | 
| 3 | 
             
              {
         | 
| 4 | 
             
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": null,
         | 
| 6 | 
             
               "id": "initial_id",
         | 
| 7 | 
             
               "metadata": {
         | 
| 8 | 
             
                "collapsed": true
         | 
| 9 | 
             
               },
         | 
| 10 | 
            +
               "outputs": [],
         | 
| 11 | 
             
               "source": [
         | 
| 12 | 
             
                "import os\n",
         | 
| 13 | 
            +
                "from typing import Dict, List\n",
         | 
| 14 | 
             
                "\n",
         | 
| 15 | 
            +
                "IS_COLAB = True if \"GOOGLE_CLOUD_PROJECT\" in os.environ else False\n",
         | 
| 16 | 
             
                "if IS_COLAB:\n",
         | 
| 17 | 
             
                "    # this needs to run before all other imports\n",
         | 
| 18 | 
            +
                "    os.environ[\"HF_HOME\"] = \"/content/cache/\"  # to avoid running out of disk space\n",
         | 
| 19 | 
             
                "\n",
         | 
| 20 | 
             
                "import mteb\n",
         | 
| 21 | 
            +
                "import numpy as np\n",
         | 
| 22 | 
            +
                "import torch\n",
         | 
| 23 | 
            +
                "from mteb.encoder_interface import PromptType\n",
         | 
| 24 | 
             
                "from sentence_transformers import SentenceTransformer"
         | 
| 25 | 
            +
               ]
         | 
|  | |
|  | |
| 26 | 
             
              },
         | 
| 27 | 
             
              {
         | 
| 28 | 
            +
               "cell_type": "markdown",
         | 
| 29 | 
            +
               "id": "5325acfb",
         | 
| 30 | 
             
               "metadata": {},
         | 
| 31 | 
            +
               "source": [
         | 
| 32 | 
            +
                "### Notebook Configuration"
         | 
| 33 | 
            +
               ]
         | 
| 34 | 
            +
              },
         | 
| 35 | 
            +
              {
         | 
| 36 | 
             
               "cell_type": "code",
         | 
| 37 | 
            +
               "execution_count": null,
         | 
| 38 | 
            +
               "id": "f0189ff1e7814a5a",
         | 
| 39 | 
            +
               "metadata": {},
         | 
| 40 | 
            +
               "outputs": [],
         | 
| 41 | 
             
               "source": [
         | 
| 42 | 
             
                "MODELS = {\n",
         | 
| 43 | 
            +
                "    \"ir-prod\": {\n",
         | 
| 44 | 
            +
                "        \"name\": \"MongoDB/mdbr-leaf-ir\",\n",
         | 
| 45 | 
            +
                "        \"revision\": \"2e46f5aac796e621d51f678c306a66ede4712ecb\",\n",
         | 
| 46 | 
            +
                "        \"teacher\": \"Snowflake/snowflake-arctic-embed-m-v1.5\",\n",
         | 
| 47 | 
             
                "    },\n",
         | 
| 48 | 
            +
                "    \"ir-paper\": {\n",
         | 
| 49 | 
            +
                "        \"name\": \"MongoDB/mdbr-leaf-ir\",\n",
         | 
| 50 | 
            +
                "        \"revision\": \"ea98995e96beac21b820aa8ad9afaa6fd29b243d\",\n",
         | 
| 51 | 
            +
                "        \"teacher\": \"Snowflake/snowflake-arctic-embed-m-v1.5\",\n",
         | 
| 52 | 
             
                "    },\n",
         | 
| 53 | 
            +
                "    \"mt-prod\": {\n",
         | 
| 54 | 
            +
                "        \"name\": \"MongoDB/mdbr-leaf-mt\",\n",
         | 
| 55 | 
            +
                "        \"revision\": \"66c47ba6d753efc208d54412b5af6c744a39a4df\",\n",
         | 
| 56 | 
            +
                "        \"teacher\": \"mixedbread-ai/mxbai-embed-large-v1\",\n",
         | 
| 57 | 
            +
                "    },\n",
         | 
| 58 | 
            +
                "    \"mt-paper\": {\n",
         | 
| 59 | 
            +
                "        \"name\": \"MongoDB/mdbr-leaf-mt\",\n",
         | 
| 60 | 
            +
                "        \"revision\": \"c342f945a6855346bd5f48d5ee8b7e39120b0ce9\",\n",
         | 
| 61 | 
            +
                "        \"teacher\": \"mixedbread-ai/mxbai-embed-large-v1\",\n",
         | 
| 62 | 
             
                "    },\n",
         | 
|  | |
|  | |
|  | |
|  | |
| 63 | 
             
                "}"
         | 
| 64 | 
            +
               ]
         | 
|  | |
|  | |
|  | |
| 65 | 
             
              },
         | 
| 66 | 
             
              {
         | 
|  | |
| 67 | 
             
               "cell_type": "markdown",
         | 
| 68 | 
            +
               "id": "371c6122efdf476a",
         | 
| 69 | 
            +
               "metadata": {},
         | 
| 70 | 
             
               "source": [
         | 
| 71 | 
            +
                "In the cell below:\n",
         | 
| 72 | 
             
                "* set the output folder and\n",
         | 
| 73 | 
             
                "* select one of the models defined above\n",
         | 
| 74 | 
             
                "* desired benchmark"
         | 
| 75 | 
            +
               ]
         | 
|  | |
| 76 | 
             
              },
         | 
| 77 | 
             
              {
         | 
|  | |
| 78 | 
             
               "cell_type": "code",
         | 
| 79 | 
            +
               "execution_count": null,
         | 
| 80 | 
            +
               "id": "58d52a330febb9ac",
         | 
| 81 | 
            +
               "metadata": {},
         | 
| 82 | 
            +
               "outputs": [],
         | 
| 83 | 
             
               "source": [
         | 
| 84 | 
            +
                "# output_folder = f\"../../data/results/publish/\"\n",
         | 
| 85 | 
            +
                "output_folder = f\"/content/data/results/publish/\"\n",
         | 
| 86 | 
             
                "\n",
         | 
| 87 | 
            +
                "model_selection = MODELS[\"ir-prod\"]\n",
         | 
| 88 | 
             
                "benchmark_name = \"BEIR\"\n",
         | 
| 89 | 
             
                "\n",
         | 
| 90 | 
             
                "# model_selection = MODELS['mt-prod']\n",
         | 
| 91 | 
             
                "# benchmark_name = \"MTEB(eng, v2)\""
         | 
| 92 | 
            +
               ]
         | 
|  | |
|  | |
|  | |
| 93 | 
             
              },
         | 
| 94 | 
             
              {
         | 
|  | |
| 95 | 
             
               "cell_type": "markdown",
         | 
| 96 | 
            +
               "id": "1b4367afc1278e",
         | 
| 97 | 
            +
               "metadata": {},
         | 
| 98 | 
            +
               "source": [
         | 
| 99 | 
            +
                "### Run Evals"
         | 
| 100 | 
            +
               ]
         | 
| 101 | 
             
              },
         | 
| 102 | 
             
              {
         | 
| 103 | 
            +
               "cell_type": "code",
         | 
| 104 | 
            +
               "execution_count": null,
         | 
| 105 | 
            +
               "id": "c716c6344f9cd939",
         | 
| 106 | 
             
               "metadata": {},
         | 
| 107 | 
            +
               "outputs": [],
         | 
| 108 | 
            +
               "source": [
         | 
| 109 | 
            +
                "benchmark = mteb.get_benchmark(benchmark_name)\n",
         | 
| 110 | 
            +
                "evaluation = mteb.MTEB(tasks=benchmark)"
         | 
| 111 | 
            +
               ]
         | 
| 112 | 
            +
              },
         | 
| 113 | 
            +
              {
         | 
| 114 | 
             
               "cell_type": "code",
         | 
| 115 | 
            +
               "execution_count": null,
         | 
| 116 | 
            +
               "id": "d6f13945a94f7a85",
         | 
| 117 | 
            +
               "metadata": {},
         | 
| 118 | 
            +
               "outputs": [],
         | 
| 119 | 
             
               "source": [
         | 
| 120 | 
            +
                "model = SentenceTransformer(model_selection[\"name\"], revision=model_selection[\"revision\"])\n",
         | 
|  | |
|  | |
|  | |
| 121 | 
             
                "\n",
         | 
| 122 | 
             
                "# alternative:\n",
         | 
| 123 | 
             
                "# meta = mteb.get_model_meta(\n",
         | 
|  | |
| 125 | 
             
                "#     revision=model_selection['revision']\n",
         | 
| 126 | 
             
                "# )\n",
         | 
| 127 | 
             
                "# model = meta.load_model()"
         | 
| 128 | 
            +
               ]
         | 
|  | |
|  | |
|  | |
| 129 | 
             
              },
         | 
| 130 | 
             
              {
         | 
|  | |
| 131 | 
             
               "cell_type": "code",
         | 
| 132 | 
            +
               "execution_count": null,
         | 
| 133 | 
            +
               "id": "9bd44e88fc360663",
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 134 | 
             
               "metadata": {},
         | 
| 135 | 
            +
               "outputs": [],
         | 
| 136 | 
             
               "source": [
         | 
| 137 | 
             
                "%%time\n",
         | 
| 138 | 
             
                "results = evaluation.run(\n",
         | 
|  | |
| 141 | 
             
                "    output_folder=output_folder,\n",
         | 
| 142 | 
             
                "    overwrite_results=True,\n",
         | 
| 143 | 
             
                ")"
         | 
| 144 | 
            +
               ]
         | 
|  | |
|  | |
|  | |
| 145 | 
             
              },
         | 
| 146 | 
             
              {
         | 
|  | |
| 147 | 
             
               "cell_type": "markdown",
         | 
| 148 | 
            +
               "id": "733e52ca41cf92a7",
         | 
| 149 | 
            +
               "metadata": {},
         | 
| 150 | 
            +
               "source": [
         | 
| 151 | 
            +
                "Evaluate Quora"
         | 
| 152 | 
            +
               ]
         | 
| 153 | 
             
              },
         | 
| 154 | 
             
              {
         | 
|  | |
| 155 | 
             
               "cell_type": "code",
         | 
| 156 | 
            +
               "execution_count": null,
         | 
| 157 | 
            +
               "id": "61aea9a04468202f",
         | 
| 158 | 
            +
               "metadata": {},
         | 
| 159 | 
            +
               "outputs": [],
         | 
| 160 | 
             
               "source": [
         | 
| 161 | 
            +
                "if model_selection[\"name\"].endswith(\"ir\"):\n",
         | 
| 162 | 
             
                "    # quora is closer to a sentence similarity task than a retrieval one, as queries aren't proper user queries\n",
         | 
| 163 | 
             
                "    # we thus embed them without the typical query prompt\n",
         | 
| 164 | 
             
                "    model.prompts = {}\n",
         | 
| 165 | 
            +
                "    tasks = mteb.get_tasks(\n",
         | 
| 166 | 
            +
                "        tasks=[\n",
         | 
| 167 | 
            +
                "            \"QuoraRetrieval\",\n",
         | 
| 168 | 
            +
                "        ]\n",
         | 
| 169 | 
            +
                "    )\n",
         | 
| 170 | 
             
                "\n",
         | 
| 171 | 
             
                "    evaluation = mteb.MTEB(tasks=tasks)\n",
         | 
| 172 | 
             
                "    results = evaluation.run(\n",
         | 
|  | |
| 175 | 
             
                "        output_folder=output_folder,\n",
         | 
| 176 | 
             
                "        overwrite_results=True,\n",
         | 
| 177 | 
             
                "    )"
         | 
| 178 | 
            +
               ]
         | 
| 179 | 
            +
              },
         | 
| 180 | 
            +
              {
         | 
| 181 | 
            +
               "cell_type": "markdown",
         | 
| 182 | 
            +
               "id": "6a6c164e",
         | 
| 183 | 
            +
               "metadata": {},
         | 
| 184 | 
            +
               "source": [
         | 
| 185 | 
            +
                "### Asymmetric Mode\n",
         | 
| 186 | 
            +
                "\n",
         | 
| 187 | 
            +
                "Compute asymmetric mode scores: queries encoded by `leaf`, documents by the original teacher model."
         | 
| 188 | 
            +
               ]
         | 
| 189 | 
            +
              },
         | 
| 190 | 
            +
              {
         | 
| 191 | 
            +
               "cell_type": "code",
         | 
| 192 | 
            +
               "execution_count": null,
         | 
| 193 | 
            +
               "id": "487ba349",
         | 
| 194 | 
            +
               "metadata": {},
         | 
| 195 | 
             
               "outputs": [],
         | 
| 196 | 
            +
               "source": [
         | 
| 197 | 
            +
                "class AsymmetricModel:\n",
         | 
| 198 | 
            +
                "    def __init__(\n",
         | 
| 199 | 
            +
                "        self,\n",
         | 
| 200 | 
            +
                "        doc_model: SentenceTransformer,\n",
         | 
| 201 | 
            +
                "        query_model: SentenceTransformer,\n",
         | 
| 202 | 
            +
                "    ) -> None:\n",
         | 
| 203 | 
            +
                "        self.doc_model = doc_model\n",
         | 
| 204 | 
            +
                "        self.query_model = query_model\n",
         | 
| 205 | 
            +
                "\n",
         | 
| 206 | 
            +
                "    def encode(self, sentences: List[str], **kwargs) -> np.ndarray | torch.Tensor:\n",
         | 
| 207 | 
            +
                "        if \"prompt_type\" not in kwargs:\n",
         | 
| 208 | 
            +
                "            kwargs[\"prompt_type\"] = None\n",
         | 
| 209 | 
            +
                "\n",
         | 
| 210 | 
            +
                "        match kwargs[\"prompt_type\"]:\n",
         | 
| 211 | 
            +
                "            case PromptType.query:\n",
         | 
| 212 | 
            +
                "                out = self.query_model.encode(sentences, prompt_name=\"query\", **kwargs)\n",
         | 
| 213 | 
            +
                "\n",
         | 
| 214 | 
            +
                "            case PromptType.document:\n",
         | 
| 215 | 
            +
                "                out = self.doc_model.encode(sentences, **kwargs)\n",
         | 
| 216 | 
            +
                "\n",
         | 
| 217 | 
            +
                "            case None:\n",
         | 
| 218 | 
            +
                "                print(\"No prompt type: using query (leaf) model for encoding\")\n",
         | 
| 219 | 
            +
                "                out = self.query_model.encode(sentences, **kwargs)\n",
         | 
| 220 | 
            +
                "            case _:\n",
         | 
| 221 | 
            +
                "                raise ValueError(f\"Encoding unknown type: {kwargs['prompt_type']}\")\n",
         | 
| 222 | 
            +
                "\n",
         | 
| 223 | 
            +
                "        if not isinstance(out, torch.Tensor):\n",
         | 
| 224 | 
            +
                "            out = torch.from_numpy(out)\n",
         | 
| 225 | 
            +
                "\n",
         | 
| 226 | 
            +
                "        out = out.to(\"cpu\")\n",
         | 
| 227 | 
            +
                "        return out"
         | 
| 228 | 
            +
               ]
         | 
| 229 | 
            +
              },
         | 
| 230 | 
            +
              {
         | 
| 231 | 
            +
               "cell_type": "code",
         | 
| 232 | 
            +
               "execution_count": null,
         | 
| 233 | 
            +
               "id": "4162af7f",
         | 
| 234 | 
            +
               "metadata": {},
         | 
| 235 | 
            +
               "outputs": [],
         | 
| 236 | 
            +
               "source": [
         | 
| 237 | 
            +
                "leaf = SentenceTransformer(model_selection[\"name\"], revision=model_selection[\"revision\"])\n",
         | 
| 238 | 
            +
                "teacher = SentenceTransformer(model_selection[\"teacher\"])\n",
         | 
| 239 | 
            +
                "\n",
         | 
| 240 | 
            +
                "asymm_model = AsymmetricModel(\n",
         | 
| 241 | 
            +
                "    query_model=leaf,\n",
         | 
| 242 | 
            +
                "    doc_model=teacher,\n",
         | 
| 243 | 
            +
                ")"
         | 
| 244 | 
            +
               ]
         | 
| 245 | 
            +
              },
         | 
| 246 | 
            +
              {
         | 
| 247 | 
            +
               "cell_type": "code",
         | 
| 248 | 
            +
               "execution_count": null,
         | 
| 249 | 
            +
               "id": "848d8a5f",
         | 
| 250 | 
            +
               "metadata": {},
         | 
| 251 | 
            +
               "outputs": [],
         | 
| 252 | 
            +
               "source": [
         | 
| 253 | 
            +
                "%%time\n",
         | 
| 254 | 
            +
                "results = evaluation.run(\n",
         | 
| 255 | 
            +
                "    model=asymm_model,\n",
         | 
| 256 | 
            +
                "    verbosity=1,\n",
         | 
| 257 | 
            +
                "    output_folder=output_folder,\n",
         | 
| 258 | 
            +
                "    overwrite_results=True,\n",
         | 
| 259 | 
            +
                ")"
         | 
| 260 | 
            +
               ]
         | 
| 261 | 
             
              }
         | 
| 262 | 
             
             ],
         | 
| 263 | 
             
             "metadata": {
         | 

