File size: 4,433 Bytes
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from .iresnet import iresnet100
from .iresnet import iresnet18
from .iresnet import iresnet200
from .iresnet import iresnet34
from .iresnet import iresnet50
from .mobilefacenet import get_mbf


def get_model(name, **kwargs):
    # resnet
    if name == "r18":
        return iresnet18(False, **kwargs)
    elif name == "r34":
        return iresnet34(False, **kwargs)
    elif name == "r50":
        return iresnet50(False, **kwargs)
    elif name == "r100":
        return iresnet100(False, **kwargs)
    elif name == "r200":
        return iresnet200(False, **kwargs)
    elif name == "r2060":
        from .iresnet2060 import iresnet2060

        return iresnet2060(False, **kwargs)

    elif name == "mbf":
        fp16 = kwargs.get("fp16", False)
        num_features = kwargs.get("num_features", 512)
        return get_mbf(fp16=fp16, num_features=num_features)

    elif name == "mbf_large":
        from .mobilefacenet import get_mbf_large

        fp16 = kwargs.get("fp16", False)
        num_features = kwargs.get("num_features", 512)
        return get_mbf_large(fp16=fp16, num_features=num_features)

    elif name == "vit_t":
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer

        return VisionTransformer(
            img_size=112,
            patch_size=9,
            num_classes=num_features,
            embed_dim=256,
            depth=12,
            num_heads=8,
            drop_path_rate=0.1,
            norm_layer="ln",
            mask_ratio=0.1,
        )

    elif name == "vit_t_dp005_mask0":  # For WebFace42M
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer

        return VisionTransformer(
            img_size=112,
            patch_size=9,
            num_classes=num_features,
            embed_dim=256,
            depth=12,
            num_heads=8,
            drop_path_rate=0.05,
            norm_layer="ln",
            mask_ratio=0.0,
        )

    elif name == "vit_s":
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer

        return VisionTransformer(
            img_size=112,
            patch_size=9,
            num_classes=num_features,
            embed_dim=512,
            depth=12,
            num_heads=8,
            drop_path_rate=0.1,
            norm_layer="ln",
            mask_ratio=0.1,
        )

    elif name == "vit_s_dp005_mask_0":  # For WebFace42M
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer

        return VisionTransformer(
            img_size=112,
            patch_size=9,
            num_classes=num_features,
            embed_dim=512,
            depth=12,
            num_heads=8,
            drop_path_rate=0.05,
            norm_layer="ln",
            mask_ratio=0.0,
        )

    elif name == "vit_b":
        # this is a feature
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer

        return VisionTransformer(
            img_size=112,
            patch_size=9,
            num_classes=num_features,
            embed_dim=512,
            depth=24,
            num_heads=8,
            drop_path_rate=0.1,
            norm_layer="ln",
            mask_ratio=0.1,
            using_checkpoint=True,
        )

    elif name == "vit_b_dp005_mask_005":  # For WebFace42M
        # this is a feature
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer

        return VisionTransformer(
            img_size=112,
            patch_size=9,
            num_classes=num_features,
            embed_dim=512,
            depth=24,
            num_heads=8,
            drop_path_rate=0.05,
            norm_layer="ln",
            mask_ratio=0.05,
            using_checkpoint=True,
        )

    elif name == "vit_l_dp005_mask_005":  # For WebFace42M
        # this is a feature
        num_features = kwargs.get("num_features", 512)
        from .vit import VisionTransformer

        return VisionTransformer(
            img_size=112,
            patch_size=9,
            num_classes=num_features,
            embed_dim=768,
            depth=24,
            num_heads=8,
            drop_path_rate=0.05,
            norm_layer="ln",
            mask_ratio=0.05,
            using_checkpoint=True,
        )

    else:
        raise ValueError()