Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| import torch | |
| import argparse | |
| from models import ResnetEncoderDecoder, CaformerEncoderDecoder | |
| from utils import remove_rptch | |
| from safetensors import safe_open | |
| from torchvision import transforms as T | |
| from PIL import Image | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| char_dict = '_0123456789abcdefghijklmnopqrstuvwxyz' | |
| char_dict_pp = '_0123456789abcdefghijklmnopqrstuvwxyz()+-*/=' | |
| class Predictor: | |
| def __init__(self, model_path, ckpt_name, char_dict=char_dict_pp): | |
| if 'caformer' in ckpt_name: | |
| self.model = CaformerEncoderDecoder(char_dict).to(device) | |
| else: | |
| self.model = ResnetEncoderDecoder(char_dict).to(device) | |
| self.model.eval() | |
| if str(device)=='cpu': | |
| check_point = self.load_safetensor(model_path, map_location='cpu') | |
| else: | |
| check_point = self.load_safetensor(model_path) | |
| self.model.load_state_dict(check_point) | |
| self.char_dict = char_dict | |
| self.trans = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # >>>>> from RainbowNeko Engine >>>>> | |
| def fold_dict(safe_f, split_key=':'): | |
| dict_fold = {} | |
| for k in safe_f.keys(): | |
| k_list = k.split(split_key) | |
| dict_last = dict_fold | |
| for item in k_list[:-1]: | |
| if item not in dict_last: | |
| dict_last[item] = {} | |
| dict_last = dict_last[item] | |
| dict_last[k_list[-1]]=safe_f.get_tensor(k) | |
| return dict_fold | |
| def load_safetensor(self, ckpt_f, map_location='cpu'): | |
| with safe_open(ckpt_f, framework="pt", device=map_location) as f: | |
| sd_fold = self.fold_dict(f) | |
| return sd_fold | |
| # <<<<< from RainbowNeko Engine <<<<< | |
| def pred(self, input): | |
| pred = self.model(input.to(device)) | |
| B, H, W, C = pred.size() | |
| T_ = H * W | |
| pred = pred.view(B, T_, -1) | |
| pred = pred + 1e-10 | |
| pred_cls = torch.max(pred, 2)[1].data.cpu().numpy()[0] | |
| pred_cls = pred_cls.reshape((H, W)).T.reshape((H * W,)) | |
| final_str = remove_rptch(''.join(self.char_dict[x] for x in pred_cls if x)) | |
| return pred_cls, final_str, (H, W) | |
| def pred_img(self, image, show=True): | |
| if isinstance(image, str): | |
| image = Image.open(image).convert('RGB') | |
| image = self.trans(image) | |
| pred_cls, final_str, (H, W) = self.pred(image.unsqueeze(0)) | |
| if show: | |
| pred_string = ''.join(['%2s' % self.char_dict[pn] for pn in pred_cls]) | |
| pred_string_set = [pred_string[i:i + W * 2] for i in range(0, len(pred_string), W * 2)] | |
| print('Prediction: ') | |
| for pre_str in pred_string_set: | |
| print(pre_str) | |
| print('Result:', final_str) | |
| return final_str | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='CAPTCHA Recognizer') | |
| parser.add_argument('--model_path', type=str, default='exps/captcha/ckpts/model-2000.safetensors', help='Path to the model file') | |
| parser.add_argument('--image_path', type=str, default=[ | |
| '/data1/dzy/CAPTCHA_recognize/data3/test/2.jpg', | |
| '/data1/dzy/Verification_Code_CV_v1.1/imgs/00097.png', | |
| '/data1/dzy/Verification_Code_CV_v1.1/imgs/00098.png', | |
| '/data1/dzy/Verification_Code_CV_v1.1/imgs/00099.png', | |
| ], nargs='+', help='Path to the image file') | |
| args = parser.parse_args() | |
| predictor = Predictor(args.model_path) | |
| for path in args.image_path: | |
| result = predictor.pred_img(path) | |
| print(f'Recognized CAPTCHA: {result}') |