File size: 3,774 Bytes
981b0ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch

global alp2num_character, alphabet_character
alp2num_character = None

def converter(label):
    string_label = label
    label = [i for i in label]
    alp2num = alp2num_character

    batch = len(label)
    length = torch.Tensor([len(i) for i in label]).long().cuda()
    max_length = max(length)

    text_input = torch.zeros(batch, max_length).long().cuda()
    for i in range(batch):
        for j in range(len(label[i]) - 1):
            text_input[i][j + 1] = alp2num[label[i][j]]

    sum_length = sum(length)
    text_all = torch.zeros(sum_length).long().cuda()
    start = 0
    for i in range(batch):
        for j in range(len(label[i])):
            if j == (len(label[i])-1):
                text_all[start + j] = alp2num['END']
            else:
                text_all[start + j] = alp2num[label[i][j]]
        start += len(label[i])

    
    return length, text_input, text_all, None, None, None, string_label





def converter_ocr_bbox(label, loc_str_tuple):
    string_label = label
    label = [i for i in label]
    alp2num = alp2num_character
    loc_str = [i for i in loc_str_tuple]

    batch = len(label)
    length = torch.Tensor([len(i) for i in label]).long().cuda()
    max_length = max(length)

    text_input = torch.zeros(batch, max_length).long().cuda()
    loc_gt = torch.zeros(batch, max_length).cuda()
    for i in range(batch):
        loc_tmps = [float(s) for s in loc_str[i].split('_')]
        for j in range(len(label[i]) - 1):
            text_input[i][j + 1] = alp2num[label[i][j]]
            loc_gt[i][j+1] = loc_tmps[j]

    sum_length = sum(length)
    text_all = torch.zeros(sum_length).long().cuda()
    start = 0
    for i in range(batch):
        for j in range(len(label[i])):
            if j == (len(label[i])-1):
                text_all[start + j] = alp2num['END']
            else:
                text_all[start + j] = alp2num[label[i][j]]
        start += len(label[i])
        
    return length, text_input, text_all, None, None, None, string_label, loc_gt

def converter_ocr(label):
    string_label = label
    label = [i for i in label]
    alp2num = alp2num_character

    batch = len(label)
    length = torch.Tensor([len(i) for i in label]).long().cuda()
    max_length = max(length)

    text_input = torch.zeros(batch, max_length).long().cuda()
    for i in range(batch):
        for j in range(len(label[i]) - 1):
            text_input[i][j + 1] = alp2num[label[i][j]]

    sum_length = sum(length)
    text_all = torch.zeros(sum_length).long().cuda()
    start = 0
    for i in range(batch):
        for j in range(len(label[i])):
            if j == (len(label[i])-1):
                text_all[start + j] = alp2num['END']
            else:
                text_all[start + j] = alp2num[label[i][j]]
        start += len(label[i])
        
    return length, text_input, text_all, None, None, None, string_label


def get_alphabet(alpha_path):
    global alp2num_character, alphabet_character
    if alp2num_character == None:
        alphabet_character_file = open(alpha_path)
        alphabet_character = list(alphabet_character_file.read().strip())
        alphabet_character_raw = ['START', '\xad']
        for item in alphabet_character:
            alphabet_character_raw.append(item)
        alphabet_character_raw.append('END')
        alphabet_character = alphabet_character_raw
        alp2num = {}
        for index, char in enumerate(alphabet_character):
            alp2num[char] = index
        alp2num_character = alp2num

    return alphabet_character

def tensor2str(tensor, alpha_path):
    alphabet = get_alphabet(alpha_path)
    string = ""
    for i in tensor:
        if i == (len(alphabet)-1):
            continue
        string += alphabet[i]
    return string