Upload sd_token_similarity_calculator.ipynb
Browse files- sd_token_similarity_calculator.ipynb +172 -69
sd_token_similarity_calculator.ipynb
CHANGED
|
@@ -116,10 +116,28 @@
|
|
| 116 |
"metadata": {
|
| 117 |
"id": "Ch9puvwKH1s3",
|
| 118 |
"collapsed": true,
|
| 119 |
-
"cellView": "form"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
},
|
| 121 |
-
"execution_count":
|
| 122 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
},
|
| 124 |
{
|
| 125 |
"cell_type": "code",
|
|
@@ -272,56 +290,23 @@
|
|
| 272 |
"outputs": []
|
| 273 |
},
|
| 274 |
{
|
| 275 |
-
"cell_type": "
|
| 276 |
"source": [
|
| 277 |
-
"
|
| 278 |
-
"\n",
|
| 279 |
-
"prompt_A = \"banana\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
|
| 280 |
-
"prompt_B = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
|
| 281 |
-
"use_token_padding = True # @param {type:\"boolean\"}\n",
|
| 282 |
"\n",
|
| 283 |
-
"
|
| 284 |
-
"\n",
|
| 285 |
-
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\" , clean_up_tokenization_spaces = True)\n",
|
| 286 |
"\n",
|
| 287 |
-
"
|
| 288 |
-
"\n"
|
| 289 |
-
"ids_A = processor.tokenizer(text=prompt_A, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 290 |
-
"text_encoding_A = model.get_text_features(**ids_A)\n",
|
| 291 |
-
"\n",
|
| 292 |
-
"\n",
|
| 293 |
-
"ids_B = processor.tokenizer(text=prompt_B, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 294 |
-
"text_encoding_B = model.get_text_features(**ids_B)\n",
|
| 295 |
-
"\n",
|
| 296 |
-
"similarity_str = 'The similarity between the text_encoding for A:\"' + prompt_A + '\" and B: \"' + prompt_B +'\" is ' + token_similarity(text_encoding_A[0] , text_encoding_B[0])\n",
|
| 297 |
-
"\n",
|
| 298 |
-
"\n",
|
| 299 |
-
"print(similarity_str)\n",
|
| 300 |
-
"#outputs = model(**inputs)\n",
|
| 301 |
-
"#logits_per_image = outputs.logits_per_image # this is the image-text similarity score\n",
|
| 302 |
-
"#probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities"
|
| 303 |
-
],
|
| 304 |
-
"metadata": {
|
| 305 |
-
"id": "QQOjh5BvnG8M",
|
| 306 |
-
"collapsed": true,
|
| 307 |
-
"cellView": "form"
|
| 308 |
-
},
|
| 309 |
-
"execution_count": null,
|
| 310 |
-
"outputs": []
|
| 311 |
-
},
|
| 312 |
-
{
|
| 313 |
-
"cell_type": "markdown",
|
| 314 |
-
"source": [
|
| 315 |
-
"You can write an url or upload a file locally from your device to use as reference. The image will by saved in the 'sd_tokens' folder. Note that the 'sd_tokens' folder will be deleted upon exiting this runtime."
|
| 316 |
],
|
| 317 |
"metadata": {
|
| 318 |
-
"id": "
|
| 319 |
}
|
| 320 |
},
|
| 321 |
{
|
| 322 |
"cell_type": "code",
|
| 323 |
"source": [
|
| 324 |
-
"# @title 🪐🖼️ -> 📝 Image to prompt :
|
| 325 |
"from google.colab import files\n",
|
| 326 |
"def upload_files():\n",
|
| 327 |
" from google.colab import files\n",
|
|
@@ -331,7 +316,7 @@
|
|
| 331 |
" return list(uploaded.keys())\n",
|
| 332 |
"#Get image\n",
|
| 333 |
"# You can use \"http://images.cocodataset.org/val2017/000000039769.jpg\" for testing\n",
|
| 334 |
-
"url = \"\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for local upload (scroll down to see it)\"}\n",
|
| 335 |
"\n",
|
| 336 |
"colab_image_path = \"\" # @param {\"type\":\"string\",\"placeholder\":\"(optional) Write colab image path to load from\"}\n",
|
| 337 |
"from PIL import Image\n",
|
|
@@ -369,19 +354,19 @@
|
|
| 369 |
"\n",
|
| 370 |
"# @markdown Set conditions for the output\n",
|
| 371 |
"must_start_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 372 |
-
"must_contain = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 373 |
"must_end_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 374 |
"token_B = must_contain\n",
|
| 375 |
"\n",
|
| 376 |
"# @markdown Limit the search\n",
|
| 377 |
"use_token_padding = True # @param {type:\"boolean\"}\n",
|
| 378 |
-
"start_search_at_ID =
|
| 379 |
-
"search_range =
|
| 380 |
-
"restrictions = '
|
| 381 |
"\n",
|
| 382 |
"# @markdown Limit char size of included token\n",
|
| 383 |
-
"min_char_size = 3 # @param {type:\"slider\", min:0, max:
|
| 384 |
-
"char_range =
|
| 385 |
"\n",
|
| 386 |
"#Tokenize input B\n",
|
| 387 |
"from transformers import AutoTokenizer\n",
|
|
@@ -397,14 +382,26 @@
|
|
| 397 |
"\n",
|
| 398 |
"dots = torch.zeros(RANGE)\n",
|
| 399 |
"is_BC = torch.zeros(RANGE)\n",
|
|
|
|
|
|
|
|
|
|
| 400 |
"for index in range(RANGE):\n",
|
| 401 |
" id_C = START + index\n",
|
| 402 |
" C = token[id_C]\n",
|
| 403 |
" _C = LA.vector_norm(C, ord=2)\n",
|
| 404 |
" name_C = vocab[id_C]\n",
|
| 405 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
" # Decide if we should process prefix/suffix tokens\n",
|
| 407 |
" if name_C.find('</w>')<=-1:\n",
|
|
|
|
| 408 |
" if restrictions != \"Prefix only\":\n",
|
| 409 |
" continue\n",
|
| 410 |
" else:\n",
|
|
@@ -420,8 +417,8 @@
|
|
| 420 |
" #-----#\n",
|
| 421 |
"\n",
|
| 422 |
" name_CB = must_start_with + name_C + name_B + must_end_with\n",
|
| 423 |
-
" if
|
| 424 |
-
" name_CB = must_start_with +
|
| 425 |
" #-----#\n",
|
| 426 |
" ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 427 |
" text_encoding_CB = model.get_text_features(**ids_CB)\n",
|
|
@@ -469,37 +466,143 @@
|
|
| 469 |
"print('')\n",
|
| 470 |
"print(f'These token pairings within the range ID = {START} to ID = {START + RANGE} most closely match the text_encoding for {prompt_A} : ')\n",
|
| 471 |
"print('')\n",
|
| 472 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
"for index in range(min(list_size,RANGE)):\n",
|
| 474 |
" id = START + indices[index].item()\n",
|
| 475 |
-
"
|
| 476 |
-
"
|
| 477 |
-
"
|
| 478 |
-
"
|
| 479 |
-
"
|
| 480 |
-
"
|
| 481 |
-
"
|
| 482 |
-
"
|
| 483 |
-
"
|
| 484 |
-
"
|
| 485 |
-
"
|
| 486 |
-
"
|
| 487 |
-
" if (print_Divider):\n",
|
| 488 |
-
" print('--------')\n",
|
| 489 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
"\n",
|
|
|
|
| 492 |
"\n",
|
| 493 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
],
|
| 495 |
"metadata": {
|
| 496 |
"collapsed": true,
|
| 497 |
-
"cellView": "form",
|
| 498 |
"id": "fi0jRruI0-tu"
|
| 499 |
},
|
| 500 |
"execution_count": null,
|
| 501 |
"outputs": []
|
| 502 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
{
|
| 504 |
"cell_type": "code",
|
| 505 |
"source": [
|
|
|
|
| 116 |
"metadata": {
|
| 117 |
"id": "Ch9puvwKH1s3",
|
| 118 |
"collapsed": true,
|
| 119 |
+
"cellView": "form",
|
| 120 |
+
"outputId": "aa58503f-8e68-43bf-d73b-3eb877ae10e4",
|
| 121 |
+
"colab": {
|
| 122 |
+
"base_uri": "https://localhost:8080/"
|
| 123 |
+
}
|
| 124 |
},
|
| 125 |
+
"execution_count": 1,
|
| 126 |
+
"outputs": [
|
| 127 |
+
{
|
| 128 |
+
"output_type": "stream",
|
| 129 |
+
"name": "stdout",
|
| 130 |
+
"text": [
|
| 131 |
+
"Cloning into 'sd_tokens'...\n",
|
| 132 |
+
"remote: Enumerating objects: 10, done.\u001b[K\n",
|
| 133 |
+
"remote: Counting objects: 100% (7/7), done.\u001b[K\n",
|
| 134 |
+
"remote: Compressing objects: 100% (7/7), done.\u001b[K\n",
|
| 135 |
+
"remote: Total 10 (delta 1), reused 0 (delta 0), pack-reused 3 (from 1)\u001b[K\n",
|
| 136 |
+
"Unpacking objects: 100% (10/10), 306.93 KiB | 5.48 MiB/s, done.\n",
|
| 137 |
+
"/content/sd_tokens\n"
|
| 138 |
+
]
|
| 139 |
+
}
|
| 140 |
+
]
|
| 141 |
},
|
| 142 |
{
|
| 143 |
"cell_type": "code",
|
|
|
|
| 290 |
"outputs": []
|
| 291 |
},
|
| 292 |
{
|
| 293 |
+
"cell_type": "markdown",
|
| 294 |
"source": [
|
| 295 |
+
"Below image interrogator appends CLIP tokens to either end of the 'must_contain' text , and seeks to maximize similarity with the image encoding.\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
"\n",
|
| 297 |
+
"It takes a long while to check all the tokens (too long!) so this cell only samples a range of the 49K available tokens.\n",
|
|
|
|
|
|
|
| 298 |
"\n",
|
| 299 |
+
"You can run this cell, then paste the result into the 'must_contain' box , and then run the cell again.\n",
|
| 300 |
+
"\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
],
|
| 302 |
"metadata": {
|
| 303 |
+
"id": "IUCuV9RtQpBn"
|
| 304 |
}
|
| 305 |
},
|
| 306 |
{
|
| 307 |
"cell_type": "code",
|
| 308 |
"source": [
|
| 309 |
+
"# @title 🪐🖼️ -> 📝 Image to prompt : Create suggestions of things to add to prompt to match image\n",
|
| 310 |
"from google.colab import files\n",
|
| 311 |
"def upload_files():\n",
|
| 312 |
" from google.colab import files\n",
|
|
|
|
| 316 |
" return list(uploaded.keys())\n",
|
| 317 |
"#Get image\n",
|
| 318 |
"# You can use \"http://images.cocodataset.org/val2017/000000039769.jpg\" for testing\n",
|
| 319 |
+
"url = \"http://images.cocodataset.org/val2017/000000039769.jpg\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for local upload (scroll down to see it)\"}\n",
|
| 320 |
"\n",
|
| 321 |
"colab_image_path = \"\" # @param {\"type\":\"string\",\"placeholder\":\"(optional) Write colab image path to load from\"}\n",
|
| 322 |
"from PIL import Image\n",
|
|
|
|
| 354 |
"\n",
|
| 355 |
"# @markdown Set conditions for the output\n",
|
| 356 |
"must_start_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 357 |
+
"must_contain = \"banana \" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 358 |
"must_end_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 359 |
"token_B = must_contain\n",
|
| 360 |
"\n",
|
| 361 |
"# @markdown Limit the search\n",
|
| 362 |
"use_token_padding = True # @param {type:\"boolean\"}\n",
|
| 363 |
+
"start_search_at_ID = 27700 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
|
| 364 |
+
"search_range = 288 # @param {type:\"slider\", min:100, max: 2000, step:0}\n",
|
| 365 |
+
"restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
|
| 366 |
"\n",
|
| 367 |
"# @markdown Limit char size of included token\n",
|
| 368 |
+
"min_char_size = 3 # @param {type:\"slider\", min:0, max: 20, step:1}\n",
|
| 369 |
+
"char_range = 14 # @param {type:\"slider\", min:0, max: 20, step:1}\n",
|
| 370 |
"\n",
|
| 371 |
"#Tokenize input B\n",
|
| 372 |
"from transformers import AutoTokenizer\n",
|
|
|
|
| 382 |
"\n",
|
| 383 |
"dots = torch.zeros(RANGE)\n",
|
| 384 |
"is_BC = torch.zeros(RANGE)\n",
|
| 385 |
+
"\n",
|
| 386 |
+
"import re\n",
|
| 387 |
+
"\n",
|
| 388 |
"for index in range(RANGE):\n",
|
| 389 |
" id_C = START + index\n",
|
| 390 |
" C = token[id_C]\n",
|
| 391 |
" _C = LA.vector_norm(C, ord=2)\n",
|
| 392 |
" name_C = vocab[id_C]\n",
|
| 393 |
"\n",
|
| 394 |
+
" is_Prefix = 0\n",
|
| 395 |
+
"\n",
|
| 396 |
+
"\n",
|
| 397 |
+
" #Skip if non-AZ characters are found\n",
|
| 398 |
+
" if re.search(\"\\W/g\" , name_C.replace('</w>', '')):\n",
|
| 399 |
+
" continue\n",
|
| 400 |
+
"\n",
|
| 401 |
+
"\n",
|
| 402 |
" # Decide if we should process prefix/suffix tokens\n",
|
| 403 |
" if name_C.find('</w>')<=-1:\n",
|
| 404 |
+
" is_Prefix = 1\n",
|
| 405 |
" if restrictions != \"Prefix only\":\n",
|
| 406 |
" continue\n",
|
| 407 |
" else:\n",
|
|
|
|
| 417 |
" #-----#\n",
|
| 418 |
"\n",
|
| 419 |
" name_CB = must_start_with + name_C + name_B + must_end_with\n",
|
| 420 |
+
" if is_Prefix>0:\n",
|
| 421 |
+
" name_CB = must_start_with + ' ' + name_C.strip() + '-' + name_B.strip() + ' ' + must_end_with\n",
|
| 422 |
" #-----#\n",
|
| 423 |
" ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 424 |
" text_encoding_CB = model.get_text_features(**ids_CB)\n",
|
|
|
|
| 466 |
"print('')\n",
|
| 467 |
"print(f'These token pairings within the range ID = {START} to ID = {START + RANGE} most closely match the text_encoding for {prompt_A} : ')\n",
|
| 468 |
"print('')\n",
|
| 469 |
+
"#----#\n",
|
| 470 |
+
"aheads = \"{\"\n",
|
| 471 |
+
"trails = \"{\"\n",
|
| 472 |
+
"tmp = \"\"\n",
|
| 473 |
+
"#----#\n",
|
| 474 |
+
"max_sim_ahead = 0\n",
|
| 475 |
+
"max_sim_trail = 0\n",
|
| 476 |
+
"sim = 0\n",
|
| 477 |
+
"max_name_ahead = ''\n",
|
| 478 |
+
"max_name_trail = ''\n",
|
| 479 |
+
"#----#\n",
|
| 480 |
"for index in range(min(list_size,RANGE)):\n",
|
| 481 |
" id = START + indices[index].item()\n",
|
| 482 |
+
" name = vocab[id]\n",
|
| 483 |
+
" #-----#\n",
|
| 484 |
+
" if (name.find('</w>')<=-1):\n",
|
| 485 |
+
" name = name + '-'\n",
|
| 486 |
+
" else:\n",
|
| 487 |
+
" name = name.replace('</w>', ' ')\n",
|
| 488 |
+
" if(is_BC[index]>0):\n",
|
| 489 |
+
" trails = trails + name + \"|\"\n",
|
| 490 |
+
" else:\n",
|
| 491 |
+
" aheads = aheads + name + \"|\"\n",
|
| 492 |
+
" #----#\n",
|
| 493 |
+
" sim = sorted[index].item()\n",
|
|
|
|
|
|
|
| 494 |
"\n",
|
| 495 |
+
" if(is_BC[index]>0):\n",
|
| 496 |
+
" if sim>max_sim_ahead:\n",
|
| 497 |
+
" max_sim_ahead = sim\n",
|
| 498 |
+
" max_name_ahead = name\n",
|
| 499 |
+
" else:\n",
|
| 500 |
+
" if sim>max_sim_trail:\n",
|
| 501 |
+
" max_sim_trail = sim\n",
|
| 502 |
+
" max_name_trail = name\n",
|
| 503 |
"\n",
|
| 504 |
+
"#------#\n",
|
| 505 |
+
"trails = (trails + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
| 506 |
+
"aheads = (aheads + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
|
| 507 |
+
"max_sim_ahead=max_sim_ahead*100\n",
|
| 508 |
+
"max_sim_ahead=max_sim_trail*100\n",
|
| 509 |
+
"#-----#\n",
|
| 510 |
+
"print(f\"place these items ahead of prompt : {aheads}\")\n",
|
| 511 |
+
"print(\"\")\n",
|
| 512 |
+
"print(f\"place these items behind the prompt : {trails}\")\n",
|
| 513 |
+
"print(\"\")\n",
|
| 514 |
+
"print(f\"max_similarity = {max_sim_ahead} % when using '{max_name_ahead + must_contain}' \")\n",
|
| 515 |
+
"print(\"\")\n",
|
| 516 |
+
"print(f\"max_similarity = {max_sim_trail} % when using '{must_contain + max_name_trail}' \")\n",
|
| 517 |
+
"#-----#\n",
|
| 518 |
+
"#STEP 2\n",
|
| 519 |
+
"import random\n",
|
| 520 |
+
"\n",
|
| 521 |
+
"names = {}\n",
|
| 522 |
+
"\n",
|
| 523 |
+
"NUM_PERMUTATIONS = 4 # 0 1 2 3\n",
|
| 524 |
+
"dots = torch.zeros(NUM_PERMUTATIONS)\n",
|
| 525 |
+
"for index in range(NUM_PERMUTATIONS):\n",
|
| 526 |
+
" name = must_start_with\n",
|
| 527 |
+
" if index == 0 : name = name + must_contain\n",
|
| 528 |
+
" if index == 1 : name = name + max_name_ahead + must_contain\n",
|
| 529 |
+
" if index == 2 : name = name + must_contain + max_name_trail\n",
|
| 530 |
+
" if index == 3 : name = name + max_name_ahead + must_contain + max_name_trail\n",
|
| 531 |
+
" name = name + must_end_with\n",
|
| 532 |
+
" #----#\n",
|
| 533 |
+
" ids_B = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 534 |
+
" text_encoding_B = model.get_text_features(**ids_B)\n",
|
| 535 |
+
" B = text_encoding_B[0]\n",
|
| 536 |
+
" _B = LA.vector_norm(B, ord=2)\n",
|
| 537 |
+
" dots[index] = torch.dot(A,B)/(_A*_B)\n",
|
| 538 |
+
" names[index] = name\n",
|
| 539 |
+
"#------#\n",
|
| 540 |
"\n",
|
| 541 |
+
"sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
| 542 |
"\n",
|
| 543 |
+
"for index in range(NUM_PERMUTATIONS):\n",
|
| 544 |
+
" print(names[indices[index].item()])\n",
|
| 545 |
+
" print(f'similiarity = {round(sorted[index].item()*100,2)} %')\n",
|
| 546 |
+
" print('------')\n",
|
| 547 |
+
"\n",
|
| 548 |
+
"\n",
|
| 549 |
+
"\n",
|
| 550 |
+
""
|
| 551 |
],
|
| 552 |
"metadata": {
|
| 553 |
"collapsed": true,
|
|
|
|
| 554 |
"id": "fi0jRruI0-tu"
|
| 555 |
},
|
| 556 |
"execution_count": null,
|
| 557 |
"outputs": []
|
| 558 |
},
|
| 559 |
+
{
|
| 560 |
+
"cell_type": "code",
|
| 561 |
+
"source": [
|
| 562 |
+
"# @title 💫 Compare Text encodings\n",
|
| 563 |
+
"\n",
|
| 564 |
+
"prompt_A = \"banana\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
|
| 565 |
+
"prompt_B = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
|
| 566 |
+
"use_token_padding = True # @param {type:\"boolean\"}\n",
|
| 567 |
+
"\n",
|
| 568 |
+
"from transformers import CLIPProcessor, CLIPModel\n",
|
| 569 |
+
"\n",
|
| 570 |
+
"processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\" , clean_up_tokenization_spaces = True)\n",
|
| 571 |
+
"\n",
|
| 572 |
+
"model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
|
| 573 |
+
"\n",
|
| 574 |
+
"ids_A = processor.tokenizer(text=prompt_A, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 575 |
+
"text_encoding_A = model.get_text_features(**ids_A)\n",
|
| 576 |
+
"\n",
|
| 577 |
+
"\n",
|
| 578 |
+
"ids_B = processor.tokenizer(text=prompt_B, padding=use_token_padding, return_tensors=\"pt\")\n",
|
| 579 |
+
"text_encoding_B = model.get_text_features(**ids_B)\n",
|
| 580 |
+
"\n",
|
| 581 |
+
"similarity_str = 'The similarity between the text_encoding for A:\"' + prompt_A + '\" and B: \"' + prompt_B +'\" is ' + token_similarity(text_encoding_A[0] , text_encoding_B[0])\n",
|
| 582 |
+
"\n",
|
| 583 |
+
"\n",
|
| 584 |
+
"print(similarity_str)\n",
|
| 585 |
+
"#outputs = model(**inputs)\n",
|
| 586 |
+
"#logits_per_image = outputs.logits_per_image # this is the image-text similarity score\n",
|
| 587 |
+
"#probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities"
|
| 588 |
+
],
|
| 589 |
+
"metadata": {
|
| 590 |
+
"id": "QQOjh5BvnG8M",
|
| 591 |
+
"collapsed": true,
|
| 592 |
+
"cellView": "form"
|
| 593 |
+
},
|
| 594 |
+
"execution_count": null,
|
| 595 |
+
"outputs": []
|
| 596 |
+
},
|
| 597 |
+
{
|
| 598 |
+
"cell_type": "markdown",
|
| 599 |
+
"source": [
|
| 600 |
+
"You can write an url or upload a file locally from your device to use as reference. The image will by saved in the 'sd_tokens' folder. Note that the 'sd_tokens' folder will be deleted upon exiting this runtime."
|
| 601 |
+
],
|
| 602 |
+
"metadata": {
|
| 603 |
+
"id": "hyK423TQCRup"
|
| 604 |
+
}
|
| 605 |
+
},
|
| 606 |
{
|
| 607 |
"cell_type": "code",
|
| 608 |
"source": [
|