# -*- coding: utf-8 -*- import torch import argparse from models import ResnetEncoderDecoder, CaformerEncoderDecoder from utils import remove_rptch from safetensors import safe_open from torchvision import transforms as T from PIL import Image device = torch.device("cuda" if torch.cuda.is_available() else "cpu") char_dict = '_0123456789abcdefghijklmnopqrstuvwxyz' char_dict_pp = '_0123456789abcdefghijklmnopqrstuvwxyz()+-*/=' class Predictor: def __init__(self, model_path, ckpt_name, char_dict=char_dict_pp): if 'caformer' in ckpt_name: self.model = CaformerEncoderDecoder(char_dict).to(device) else: self.model = ResnetEncoderDecoder(char_dict).to(device) self.model.eval() if str(device)=='cpu': check_point = self.load_safetensor(model_path, map_location='cpu') else: check_point = self.load_safetensor(model_path) self.model.load_state_dict(check_point) self.char_dict = char_dict self.trans = T.Compose([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # >>>>> from RainbowNeko Engine >>>>> @staticmethod def fold_dict(safe_f, split_key=':'): dict_fold = {} for k in safe_f.keys(): k_list = k.split(split_key) dict_last = dict_fold for item in k_list[:-1]: if item not in dict_last: dict_last[item] = {} dict_last = dict_last[item] dict_last[k_list[-1]]=safe_f.get_tensor(k) return dict_fold def load_safetensor(self, ckpt_f, map_location='cpu'): with safe_open(ckpt_f, framework="pt", device=map_location) as f: sd_fold = self.fold_dict(f) return sd_fold # <<<<< from RainbowNeko Engine <<<<< def pred(self, input): pred = self.model(input.to(device)) B, H, W, C = pred.size() T_ = H * W pred = pred.view(B, T_, -1) pred = pred + 1e-10 pred_cls = torch.max(pred, 2)[1].data.cpu().numpy()[0] pred_cls = pred_cls.reshape((H, W)).T.reshape((H * W,)) final_str = remove_rptch(''.join(self.char_dict[x] for x in pred_cls if x)) return pred_cls, final_str, (H, W) def pred_img(self, image, show=True): if isinstance(image, str): image = Image.open(image).convert('RGB') image = self.trans(image) pred_cls, final_str, (H, W) = self.pred(image.unsqueeze(0)) if show: pred_string = ''.join(['%2s' % self.char_dict[pn] for pn in pred_cls]) pred_string_set = [pred_string[i:i + W * 2] for i in range(0, len(pred_string), W * 2)] print('Prediction: ') for pre_str in pred_string_set: print(pre_str) print('Result:', final_str) return final_str if __name__ == "__main__": parser = argparse.ArgumentParser(description='CAPTCHA Recognizer') parser.add_argument('--model_path', type=str, default='exps/captcha/ckpts/model-2000.safetensors', help='Path to the model file') parser.add_argument('--image_path', type=str, default=[ '/data1/dzy/CAPTCHA_recognize/data3/test/2.jpg', '/data1/dzy/Verification_Code_CV_v1.1/imgs/00097.png', '/data1/dzy/Verification_Code_CV_v1.1/imgs/00098.png', '/data1/dzy/Verification_Code_CV_v1.1/imgs/00099.png', ], nargs='+', help='Path to the image file') args = parser.parse_args() predictor = Predictor(args.model_path) for path in args.image_path: result = predictor.pred_img(path) print(f'Recognized CAPTCHA: {result}')