File size: 2,522 Bytes
2b67076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0

import torch
import torchvision.transforms as T

import numpy as np

from insightface.utils import face_align
from insightface.app import FaceAnalysis
from facexlib.recognition import init_recognition_model


__all__ = [
    "FaceEncoderArcFace",
    "get_landmarks_from_image",
]


detector = None

def get_landmarks_from_image(image):
    """
    Detect landmarks with insightface.

    Args:
        image (np.ndarray or PIL.Image):
            The input image in RGB format.

    Returns:
        5 2D keypoints, only one face will be returned.
    """
    global detector
    if detector is None:
        detector = FaceAnalysis()
        detector.prepare(ctx_id=0, det_size=(640, 640))

    in_image = np.array(image).copy()
    
    faces = detector.get(in_image)
    if len(faces) == 0:
        raise ValueError("No face detected in the image")
    
    # Get the largest face
    face = max(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))
    
    # Return the 5 keypoints directly
    keypoints = face.kps  # 5 x 2

    return keypoints


class FaceEncoderArcFace():
    """ Official ArcFace, no_grad-only """

    def __repr__(self):
        return "ArcFace"


    def init_encoder_model(self, device, eval_mode=True):
        self.device = device
        self.encoder_model = init_recognition_model('arcface', device=device)

        if eval_mode:
            self.encoder_model.eval()


    @torch.no_grad()
    def input_preprocessing(self, in_image, landmarks, image_size=112):
        assert landmarks is not None, "landmarks are not provided!"

        in_image = np.array(in_image)
        landmark = np.array(landmarks)

        face_aligned = face_align.norm_crop(in_image, landmark=landmark, image_size=image_size)

        image_transform = T.Compose([
            T.ToTensor(),
            T.Normalize([0.5], [0.5]),
        ])
        face_aligned = image_transform(face_aligned).unsqueeze(0).to(self.device)

        return face_aligned


    @torch.no_grad()
    def __call__(self, in_image, need_proc=False, landmarks=None, image_size=112):

        if need_proc:
            in_image = self.input_preprocessing(in_image, landmarks, image_size)
        else:
            assert isinstance(in_image, torch.Tensor)

        in_image = in_image[:, [2, 1, 0], :, :].contiguous()
        image_embeds = self.encoder_model(in_image) # [B, 512], normalized

        return image_embeds