File size: 2,643 Bytes
27e248a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
import pandas as pd
import json
import tensorflow as tf
import mediapipe as mp
from skimage.transform import resize
import matplotlib.pyplot as plt
from mediapipe.framework.formats import landmark_pb2
from PIL import Image

# Load selected columns for inference
with open("inference_args.json", "r") as f:
    SEL_COLS = json.load(f)["selected_columns"]

# Load TFLite model
interpreter = tf.lite.Interpreter(model_path="asl_model.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Drawing utilities
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_hands = mp.solutions.hands

def load_relevant_data_subset(pq_path):
    return pd.read_parquet(pq_path, columns=SEL_COLS)

def draw_hand_landmarks(seq_df):
    images = []
    for seq_idx in range(len(seq_df)):
        x_hand = seq_df.iloc[seq_idx].filter(regex="x_right_hand.*").values
        y_hand = seq_df.iloc[seq_idx].filter(regex="y_right_hand.*").values
        z_hand = seq_df.iloc[seq_idx].filter(regex="z_right_hand.*").values

        right_hand_image = np.zeros((600, 600, 3))
        right_hand_landmarks = landmark_pb2.NormalizedLandmarkList()

        for x, y, z in zip(x_hand, y_hand, z_hand):
            right_hand_landmarks.landmark.add(x=x, y=y, z=z)

        mp_drawing.draw_landmarks(
            right_hand_image,
            right_hand_landmarks,
            mp_hands.HAND_CONNECTIONS,
            landmark_drawing_spec=mp_drawing_styles.get_default_hand_landmarks_style()
        )
        images.append(right_hand_image)
    return images

def preprocess_image(image):
    img = resize(image, (64, 64), preserve_range=True).astype(np.float32) / 255.0
    return np.expand_dims(img, axis=0)

def predict_from_parquet(parquet_path):
    df = load_relevant_data_subset(parquet_path)
    image_seq = draw_hand_landmarks(df)
    if not image_seq:
        raise ValueError("No hand image generated.")
    img = preprocess_image(image_seq[len(image_seq) // 2])
    interpreter.set_tensor(input_details[0]['index'], img)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]['index'])
    prediction = np.argmax(output)
    return prediction

if __name__ == "__main__":
    import sys
    if len(sys.argv) < 2:
        print("Usage: python tflite_inference.py <parquet_file_path>")
    else:
        parquet_file = sys.argv[1]
        pred = predict_from_parquet(parquet_file)
        print("Predicted class index:", pred)