Upload sd_token_similarity_calculator.ipynb
Browse files
sd_token_similarity_calculator.ipynb
CHANGED
|
@@ -118,10 +118,29 @@
|
|
| 118 |
],
|
| 119 |
"metadata": {
|
| 120 |
"id": "Ch9puvwKH1s3",
|
| 121 |
-
"collapsed": true
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
},
|
| 123 |
-
"execution_count":
|
| 124 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
},
|
| 126 |
{
|
| 127 |
"cell_type": "code",
|
|
@@ -132,7 +151,7 @@
|
|
| 132 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
| 133 |
"\n",
|
| 134 |
"# @markdown Write name of token to match against\n",
|
| 135 |
-
"token_name = \"
|
| 136 |
"\n",
|
| 137 |
"prompt = token_name\n",
|
| 138 |
"# @markdown (optional) Mix the token with something else\n",
|
|
@@ -368,7 +387,10 @@
|
|
| 368 |
"start_search_at_index = 0 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
|
| 369 |
"# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\". If the cell was not run, then it will use tokens ordered by similarity to the \"girl\\</w>\" token\n",
|
| 370 |
"start_search_at_ID = start_search_at_index\n",
|
| 371 |
-
"search_range =
|
|
|
|
|
|
|
|
|
|
| 372 |
"iterations = 5 # @param {type:\"slider\", min:1, max: 20, step:0}\n",
|
| 373 |
"restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
|
| 374 |
"#markdown Limit char size of included token <----- Disabled\n",
|
|
@@ -384,15 +406,11 @@
|
|
| 384 |
"RANGE = min(search_range , max(1,NUM_TOKENS - start_search_at_ID))\n",
|
| 385 |
"#-----#\n",
|
| 386 |
"import math, random\n",
|
| 387 |
-
"
|
| 388 |
-
"\n",
|
| 389 |
"ITERS = iterations\n",
|
| 390 |
"#-----#\n",
|
| 391 |
"#LOOP START\n",
|
| 392 |
"#-----#\n",
|
| 393 |
-
"\n",
|
| 394 |
-
"\n",
|
| 395 |
-
"\n",
|
| 396 |
"# Check if original solution is best\n",
|
| 397 |
"best_sim = 0\n",
|
| 398 |
"name = must_start_with + must_contain + must_end_with\n",
|
|
@@ -400,6 +418,7 @@
|
|
| 400 |
"text_features = model.get_text_features(**ids)\n",
|
| 401 |
"text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 402 |
"#------#\n",
|
|
|
|
| 403 |
"if(use == '🖼️image_encoding from image'):\n",
|
| 404 |
" logit_scale = model.logit_scale.exp()\n",
|
| 405 |
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
|
@@ -411,7 +430,8 @@
|
|
| 411 |
"best_sim = sim\n",
|
| 412 |
"best_name = name\n",
|
| 413 |
"name_B = must_contain\n",
|
| 414 |
-
"
|
|
|
|
| 415 |
"results_name_B = {}\n",
|
| 416 |
"results_name = {}\n",
|
| 417 |
"#-----#\n",
|
|
@@ -420,17 +440,10 @@
|
|
| 420 |
" is_trail = torch.zeros(RANGE)\n",
|
| 421 |
" import re\n",
|
| 422 |
" #-----#\n",
|
|
|
|
| 423 |
"\n",
|
| 424 |
-
"
|
| 425 |
-
"
|
| 426 |
-
" results_sim[iter] = best_sim\n",
|
| 427 |
-
" results_name_B[iter] = name_B\n",
|
| 428 |
-
" #-----#\n",
|
| 429 |
-
" sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
|
| 430 |
-
" name_B = results_name_B[indices[0].item()].replace('</w>', ' ') #Update name_B with best value\n",
|
| 431 |
-
"\n",
|
| 432 |
-
" for index in range(RANGE):\n",
|
| 433 |
-
" id_C = min(_start + index, NUM_TOKENS)\n",
|
| 434 |
" name_C = db_vocab[f'{id_C}']\n",
|
| 435 |
" is_Prefix = 0\n",
|
| 436 |
" #Skip if non-AZ characters are found\n",
|
|
@@ -573,17 +586,15 @@
|
|
| 573 |
" #-----#\n",
|
| 574 |
" #STEP 2\n",
|
| 575 |
" import random\n",
|
| 576 |
-
" names = {}\n",
|
| 577 |
-
" name_inners = {}\n",
|
| 578 |
-
" NUM_PERMUTATIONS = 4\n",
|
| 579 |
" #-----#\n",
|
| 580 |
-
" dots = torch.zeros(NUM_PERMUTATIONS)\n",
|
| 581 |
" for index in range(NUM_PERMUTATIONS):\n",
|
| 582 |
" name_inner = ''\n",
|
| 583 |
" if index == 0 : name_inner = name_B\n",
|
| 584 |
-
" if index == 1
|
| 585 |
-
" if index == 2
|
| 586 |
-
" if index == 3
|
|
|
|
|
|
|
| 587 |
" name = must_start_with + name_inner + must_end_with\n",
|
| 588 |
" #----#\n",
|
| 589 |
" ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
|
|
@@ -601,25 +612,17 @@
|
|
| 601 |
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 602 |
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
| 603 |
" #-----#\n",
|
| 604 |
-
"
|
| 605 |
-
"
|
| 606 |
-
"
|
| 607 |
" #------#\n",
|
| 608 |
-
"
|
| 609 |
-
" #------#\n",
|
| 610 |
-
" best_sim = dots[indices[0].item()]\n",
|
| 611 |
-
" best_name = names[indices[0].item()]\n",
|
| 612 |
-
" name_B = name_inners[indices[0].item()].replace('</w>', ' ') #Update name_B with best value\n",
|
| 613 |
"#--------#\n",
|
| 614 |
-
"#store the final value\n",
|
| 615 |
-
"results_name[iter+1] = best_name\n",
|
| 616 |
-
"results_sim[iter+1] = best_sim\n",
|
| 617 |
-
"results_name_B[iter+1] = name_B\n",
|
| 618 |
"\n",
|
|
|
|
| 619 |
"sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
|
| 620 |
"\n",
|
| 621 |
-
"
|
| 622 |
-
"for index in range(ITERS+1):\n",
|
| 623 |
" name_inner = results_name[indices[index].item()]\n",
|
| 624 |
" print(must_start_with + name_inner + must_end_with)\n",
|
| 625 |
" print(f'similiarity = {round(sorted[index].item(),2)} %')\n",
|
|
|
|
| 118 |
],
|
| 119 |
"metadata": {
|
| 120 |
"id": "Ch9puvwKH1s3",
|
| 121 |
+
"collapsed": true,
|
| 122 |
+
"outputId": "033c251a-2043-40e7-9500-4da870ffa7fd",
|
| 123 |
+
"colab": {
|
| 124 |
+
"base_uri": "https://localhost:8080/"
|
| 125 |
+
}
|
| 126 |
},
|
| 127 |
+
"execution_count": 1,
|
| 128 |
+
"outputs": [
|
| 129 |
+
{
|
| 130 |
+
"output_type": "stream",
|
| 131 |
+
"name": "stdout",
|
| 132 |
+
"text": [
|
| 133 |
+
"Cloning into 'sd_tokens'...\n",
|
| 134 |
+
"remote: Enumerating objects: 20, done.\u001b[K\n",
|
| 135 |
+
"remote: Counting objects: 100% (17/17), done.\u001b[K\n",
|
| 136 |
+
"remote: Compressing objects: 100% (17/17), done.\u001b[K\n",
|
| 137 |
+
"remote: Total 20 (delta 4), reused 0 (delta 0), pack-reused 3 (from 1)\u001b[K\n",
|
| 138 |
+
"Unpacking objects: 100% (20/20), 310.37 KiB | 2.10 MiB/s, done.\n",
|
| 139 |
+
"Filtering content: 100% (3/3), 160.82 MiB | 26.64 MiB/s, done.\n",
|
| 140 |
+
"/content/sd_tokens\n"
|
| 141 |
+
]
|
| 142 |
+
}
|
| 143 |
+
]
|
| 144 |
},
|
| 145 |
{
|
| 146 |
"cell_type": "code",
|
|
|
|
| 151 |
"tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n",
|
| 152 |
"\n",
|
| 153 |
"# @markdown Write name of token to match against\n",
|
| 154 |
+
"token_name = \" blanket \" # @param {type:'string',\"placeholder\":\"leave empty for random value token\"}\n",
|
| 155 |
"\n",
|
| 156 |
"prompt = token_name\n",
|
| 157 |
"# @markdown (optional) Mix the token with something else\n",
|
|
|
|
| 387 |
"start_search_at_index = 0 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
|
| 388 |
"# @markdown The lower the start_index, the more similiar the sampled tokens will be to the target token assigned in the '⚡ Get similiar tokens' cell\". If the cell was not run, then it will use tokens ordered by similarity to the \"girl\\</w>\" token\n",
|
| 389 |
"start_search_at_ID = start_search_at_index\n",
|
| 390 |
+
"search_range = 1000 # @param {type:\"slider\", min:10, max: 1000, step:10}\n",
|
| 391 |
+
"\n",
|
| 392 |
+
"samples_per_iter = 10 # @param {type:\"slider\", min:10, max: 100, step:10}\n",
|
| 393 |
+
"\n",
|
| 394 |
"iterations = 5 # @param {type:\"slider\", min:1, max: 20, step:0}\n",
|
| 395 |
"restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
|
| 396 |
"#markdown Limit char size of included token <----- Disabled\n",
|
|
|
|
| 406 |
"RANGE = min(search_range , max(1,NUM_TOKENS - start_search_at_ID))\n",
|
| 407 |
"#-----#\n",
|
| 408 |
"import math, random\n",
|
| 409 |
+
"NUM_PERMUTATIONS = 4\n",
|
|
|
|
| 410 |
"ITERS = iterations\n",
|
| 411 |
"#-----#\n",
|
| 412 |
"#LOOP START\n",
|
| 413 |
"#-----#\n",
|
|
|
|
|
|
|
|
|
|
| 414 |
"# Check if original solution is best\n",
|
| 415 |
"best_sim = 0\n",
|
| 416 |
"name = must_start_with + must_contain + must_end_with\n",
|
|
|
|
| 418 |
"text_features = model.get_text_features(**ids)\n",
|
| 419 |
"text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 420 |
"#------#\n",
|
| 421 |
+
"sim = 0\n",
|
| 422 |
"if(use == '🖼️image_encoding from image'):\n",
|
| 423 |
" logit_scale = model.logit_scale.exp()\n",
|
| 424 |
" torch.matmul(text_features, image_features.t()) * logit_scale\n",
|
|
|
|
| 430 |
"best_sim = sim\n",
|
| 431 |
"best_name = name\n",
|
| 432 |
"name_B = must_contain\n",
|
| 433 |
+
"#------#\n",
|
| 434 |
+
"results_sim = torch.zeros(ITERS*NUM_PERMUTATIONS)\n",
|
| 435 |
"results_name_B = {}\n",
|
| 436 |
"results_name = {}\n",
|
| 437 |
"#-----#\n",
|
|
|
|
| 440 |
" is_trail = torch.zeros(RANGE)\n",
|
| 441 |
" import re\n",
|
| 442 |
" #-----#\n",
|
| 443 |
+
" _start = START + iter*RANGE\n",
|
| 444 |
"\n",
|
| 445 |
+
" for index in range(samples_per_iter):\n",
|
| 446 |
+
" id_C = min(_start + index, NUM_TOKENS) + random.randint(0,RANGE)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
" name_C = db_vocab[f'{id_C}']\n",
|
| 448 |
" is_Prefix = 0\n",
|
| 449 |
" #Skip if non-AZ characters are found\n",
|
|
|
|
| 586 |
" #-----#\n",
|
| 587 |
" #STEP 2\n",
|
| 588 |
" import random\n",
|
|
|
|
|
|
|
|
|
|
| 589 |
" #-----#\n",
|
|
|
|
| 590 |
" for index in range(NUM_PERMUTATIONS):\n",
|
| 591 |
" name_inner = ''\n",
|
| 592 |
" if index == 0 : name_inner = name_B\n",
|
| 593 |
+
" if index == 1: name_inner = max_name_ahead\n",
|
| 594 |
+
" if index == 2: name_inner = name_B + max_name_trail\n",
|
| 595 |
+
" if index == 3: name_inner = max_name_ahead + name_B + max_name_trail\n",
|
| 596 |
+
" if name_inner == '': name_inner = max_name_ahead + name_B + max_name_trail\n",
|
| 597 |
+
"\n",
|
| 598 |
" name = must_start_with + name_inner + must_end_with\n",
|
| 599 |
" #----#\n",
|
| 600 |
" ids = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
|
|
|
|
| 612 |
" text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)\n",
|
| 613 |
" sim = torch.nn.functional.cosine_similarity(text_features, text_features_A)\n",
|
| 614 |
" #-----#\n",
|
| 615 |
+
" results_name[iter*NUM_PERMUTATIONS + index] = name\n",
|
| 616 |
+
" results_sim[iter*NUM_PERMUTATIONS + index] = sim\n",
|
| 617 |
+
" results_name_B[iter*NUM_PERMUTATIONS + index] = name_inner.replace('</w>',' ')\n",
|
| 618 |
" #------#\n",
|
| 619 |
+
" name_B = results_name_B[iter*NUM_PERMUTATIONS + random.randint(0,3)]\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 620 |
"#--------#\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
"\n",
|
| 622 |
+
"print('')\n",
|
| 623 |
"sorted, indices = torch.sort(results_sim,dim=0 , descending=True)\n",
|
| 624 |
"\n",
|
| 625 |
+
"for index in range(ITERS*NUM_PERMUTATIONS):\n",
|
|
|
|
| 626 |
" name_inner = results_name[indices[index].item()]\n",
|
| 627 |
" print(must_start_with + name_inner + must_end_with)\n",
|
| 628 |
" print(f'similiarity = {round(sorted[index].item(),2)} %')\n",
|