Upload sd_token_similarity_calculator.ipynb
Browse files
sd_token_similarity_calculator.ipynb
CHANGED
|
@@ -132,7 +132,7 @@
|
|
| 132 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
| 133 |
"\n",
|
| 134 |
"# @markdown Write name of token to match against\n",
|
| 135 |
-
"token_name = \"
|
| 136 |
"\n",
|
| 137 |
"prompt = token_name\n",
|
| 138 |
"# @markdown (optional) Mix the token with something else\n",
|
|
@@ -361,14 +361,15 @@
|
|
| 361 |
"#-----#\n",
|
| 362 |
"# @markdown # The output...\n",
|
| 363 |
"must_start_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 364 |
-
"must_contain = \"
|
| 365 |
"must_end_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 366 |
"# @markdown -----\n",
|
| 367 |
"# @markdown # Use a range of tokens from the vocab.json (slow method)\n",
|
| 368 |
-
"start_search_at_index =
|
| 369 |
"# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\". If the cell was not run, then it will use tokens ordered by similarity to the \"girl\\</w>\" token\n",
|
| 370 |
"start_search_at_ID = start_search_at_index\n",
|
| 371 |
-
"search_range = 100 # @param {type:\"slider\", min:
|
|
|
|
| 372 |
"restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
|
| 373 |
"#markdown Limit char size of included token <----- Disabled\n",
|
| 374 |
"min_char_size = 0 #param {type:\"slider\", min:0, max: 20, step:1}\n",
|
|
@@ -383,15 +384,14 @@
|
|
| 383 |
"RANGE = min(search_range , max(1,NUM_TOKENS - start_search_at_ID))\n",
|
| 384 |
"#-----#\n",
|
| 385 |
"import math, random\n",
|
| 386 |
-
"CHUNK = math.floor(NUM_TOKENS/
|
| 387 |
"\n",
|
| 388 |
-
"ITERS =
|
| 389 |
"#-----#\n",
|
| 390 |
"#LOOP START\n",
|
| 391 |
"#-----#\n",
|
| 392 |
"\n",
|
| 393 |
-
"
|
| 394 |
-
"results_name = {}\n",
|
| 395 |
"\n",
|
| 396 |
"# Check if original solution is best\n",
|
| 397 |
"best_sim = 0\n",
|
|
@@ -409,7 +409,11 @@
|
|
| 409 |
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
| 410 |
"#-----#\n",
|
| 411 |
"best_sim = sim\n",
|
|
|
|
| 412 |
"name_B = must_contain\n",
|
|
|
|
|
|
|
|
|
|
| 413 |
"#-----#\n",
|
| 414 |
"for iter in range(ITERS):\n",
|
| 415 |
" dots = torch.zeros(RANGE)\n",
|
|
@@ -418,8 +422,12 @@
|
|
| 418 |
" #-----#\n",
|
| 419 |
"\n",
|
| 420 |
" _start = START + iter*CHUNK + iter*random.randint(1,CHUNK)\n",
|
| 421 |
-
" results_name[iter] =
|
| 422 |
" results_sim[iter] = best_sim\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
"\n",
|
| 424 |
" for index in range(RANGE):\n",
|
| 425 |
" id_C = min(_start + index, NUM_TOKENS)\n",
|
|
@@ -510,7 +518,7 @@
|
|
| 510 |
" used_reference = f'the text_encoding for {prompt_A}'\n",
|
| 511 |
" if(use == '🖼️image_encoding from image'):\n",
|
| 512 |
" used_reference = 'the image input'\n",
|
| 513 |
-
" print(f'These token pairings within the range ID = {
|
| 514 |
" print('')\n",
|
| 515 |
" #----#\n",
|
| 516 |
" aheads = \"{\"\n",
|
|
@@ -556,16 +564,17 @@
|
|
| 556 |
" print(\"\")\n",
|
| 557 |
"\n",
|
| 558 |
" tmp = must_start_with + ' ' + max_name_ahead + name_B + ' ' + must_end_with\n",
|
| 559 |
-
" tmp = tmp.strip()\n",
|
| 560 |
" print(f\"max_similarity_ahead = {round(max_sim_ahead,2)} % when using '{tmp}' \")\n",
|
| 561 |
" print(\"\")\n",
|
| 562 |
" tmp = must_start_with + ' ' + name_B + max_name_trail + ' ' + must_end_with\n",
|
| 563 |
-
" tmp = tmp.strip()\n",
|
| 564 |
" print(f\"max_similarity_trail = {round(max_sim_trail,2)} % when using '{tmp}' \")\n",
|
| 565 |
" #-----#\n",
|
| 566 |
" #STEP 2\n",
|
| 567 |
" import random\n",
|
| 568 |
" names = {}\n",
|
|
|
|
| 569 |
" NUM_PERMUTATIONS = 4\n",
|
| 570 |
" #-----#\n",
|
| 571 |
" dots = torch.zeros(NUM_PERMUTATIONS)\n",
|
|
@@ -593,16 +602,19 @@
|
|
| 593 |
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
| 594 |
" #-----#\n",
|
| 595 |
" dots[index] = sim\n",
|
| 596 |
-
" names[index] =
|
|
|
|
| 597 |
" #------#\n",
|
| 598 |
" sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
| 599 |
" #------#\n",
|
| 600 |
" best_sim = dots[indices[0].item()]\n",
|
| 601 |
-
"
|
|
|
|
| 602 |
"#--------#\n",
|
| 603 |
"#store the final value\n",
|
| 604 |
-
"results_name[iter] =
|
| 605 |
-
"results_sim[iter] = best_sim\n",
|
|
|
|
| 606 |
"\n",
|
| 607 |
"sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
|
| 608 |
"\n",
|
|
|
|
| 132 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
| 133 |
"\n",
|
| 134 |
"# @markdown Write name of token to match against\n",
|
| 135 |
+
"token_name = \"banana \" # @param {type:'string',\"placeholder\":\"leave empty for random value token\"}\n",
|
| 136 |
"\n",
|
| 137 |
"prompt = token_name\n",
|
| 138 |
"# @markdown (optional) Mix the token with something else\n",
|
|
|
|
| 361 |
"#-----#\n",
|
| 362 |
"# @markdown # The output...\n",
|
| 363 |
"must_start_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 364 |
+
"must_contain = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 365 |
"must_end_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
|
| 366 |
"# @markdown -----\n",
|
| 367 |
"# @markdown # Use a range of tokens from the vocab.json (slow method)\n",
|
| 368 |
+
"start_search_at_index = 0 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
|
| 369 |
"# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\". If the cell was not run, then it will use tokens ordered by similarity to the \"girl\\</w>\" token\n",
|
| 370 |
"start_search_at_ID = start_search_at_index\n",
|
| 371 |
+
"search_range = 100 # @param {type:\"slider\", min:10, max: 200, step:0}\n",
|
| 372 |
+
"iterations = 5 # @param {type:\"slider\", min:1, max: 20, step:0}\n",
|
| 373 |
"restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
|
| 374 |
"#markdown Limit char size of included token <----- Disabled\n",
|
| 375 |
"min_char_size = 0 #param {type:\"slider\", min:0, max: 20, step:1}\n",
|
|
|
|
| 384 |
"RANGE = min(search_range , max(1,NUM_TOKENS - start_search_at_ID))\n",
|
| 385 |
"#-----#\n",
|
| 386 |
"import math, random\n",
|
| 387 |
+
"CHUNK = math.floor(NUM_TOKENS/RANGE)\n",
|
| 388 |
"\n",
|
| 389 |
+
"ITERS = iterations\n",
|
| 390 |
"#-----#\n",
|
| 391 |
"#LOOP START\n",
|
| 392 |
"#-----#\n",
|
| 393 |
"\n",
|
| 394 |
+
"\n",
|
|
|
|
| 395 |
"\n",
|
| 396 |
"# Check if original solution is best\n",
|
| 397 |
"best_sim = 0\n",
|
|
|
|
| 409 |
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
| 410 |
"#-----#\n",
|
| 411 |
"best_sim = sim\n",
|
| 412 |
+
"best_name = name\n",
|
| 413 |
"name_B = must_contain\n",
|
| 414 |
+
"results_sim = torch.zeros(ITERS+1)\n",
|
| 415 |
+
"results_name_B = {}\n",
|
| 416 |
+
"results_name = {}\n",
|
| 417 |
"#-----#\n",
|
| 418 |
"for iter in range(ITERS):\n",
|
| 419 |
" dots = torch.zeros(RANGE)\n",
|
|
|
|
| 422 |
" #-----#\n",
|
| 423 |
"\n",
|
| 424 |
" _start = START + iter*CHUNK + iter*random.randint(1,CHUNK)\n",
|
| 425 |
+
" results_name[iter] = best_name\n",
|
| 426 |
" results_sim[iter] = best_sim\n",
|
| 427 |
+
" results_name_B[iter] = name_B\n",
|
| 428 |
+
" #-----#\n",
|
| 429 |
+
" sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
|
| 430 |
+
" name_B = results_name_B[indices[0].item()].replace('</w>', ' ') #Update name_B with best value\n",
|
| 431 |
"\n",
|
| 432 |
" for index in range(RANGE):\n",
|
| 433 |
" id_C = min(_start + index, NUM_TOKENS)\n",
|
|
|
|
| 518 |
" used_reference = f'the text_encoding for {prompt_A}'\n",
|
| 519 |
" if(use == '🖼️image_encoding from image'):\n",
|
| 520 |
" used_reference = 'the image input'\n",
|
| 521 |
+
" print(f'These token pairings within the range ID = {_start} to ID = {_start + RANGE} most closely match {used_reference}: ')\n",
|
| 522 |
" print('')\n",
|
| 523 |
" #----#\n",
|
| 524 |
" aheads = \"{\"\n",
|
|
|
|
| 564 |
" print(\"\")\n",
|
| 565 |
"\n",
|
| 566 |
" tmp = must_start_with + ' ' + max_name_ahead + name_B + ' ' + must_end_with\n",
|
| 567 |
+
" tmp = tmp.strip().replace('</w>', ' ')\n",
|
| 568 |
" print(f\"max_similarity_ahead = {round(max_sim_ahead,2)} % when using '{tmp}' \")\n",
|
| 569 |
" print(\"\")\n",
|
| 570 |
" tmp = must_start_with + ' ' + name_B + max_name_trail + ' ' + must_end_with\n",
|
| 571 |
+
" tmp = tmp.strip().replace('</w>', ' ')\n",
|
| 572 |
" print(f\"max_similarity_trail = {round(max_sim_trail,2)} % when using '{tmp}' \")\n",
|
| 573 |
" #-----#\n",
|
| 574 |
" #STEP 2\n",
|
| 575 |
" import random\n",
|
| 576 |
" names = {}\n",
|
| 577 |
+
" name_inners = {}\n",
|
| 578 |
" NUM_PERMUTATIONS = 4\n",
|
| 579 |
" #-----#\n",
|
| 580 |
" dots = torch.zeros(NUM_PERMUTATIONS)\n",
|
|
|
|
| 602 |
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
| 603 |
" #-----#\n",
|
| 604 |
" dots[index] = sim\n",
|
| 605 |
+
" names[index] = name\n",
|
| 606 |
+
" name_inners[index] = name_inner\n",
|
| 607 |
" #------#\n",
|
| 608 |
" sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
|
| 609 |
" #------#\n",
|
| 610 |
" best_sim = dots[indices[0].item()]\n",
|
| 611 |
+
" best_name = names[indices[0].item()]\n",
|
| 612 |
+
" name_B = name_inners[indices[0].item()].replace('</w>', ' ') #Update name_B with best value\n",
|
| 613 |
"#--------#\n",
|
| 614 |
"#store the final value\n",
|
| 615 |
+
"results_name[iter+1] = best_name\n",
|
| 616 |
+
"results_sim[iter+1] = best_sim\n",
|
| 617 |
+
"results_name_B[iter+1] = name_B\n",
|
| 618 |
"\n",
|
| 619 |
"sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
|
| 620 |
"\n",
|