File size: 2,643 Bytes
27e248a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import numpy as np
import pandas as pd
import json
import tensorflow as tf
import mediapipe as mp
from skimage.transform import resize
import matplotlib.pyplot as plt
from mediapipe.framework.formats import landmark_pb2
from PIL import Image
# Load selected columns for inference
with open("inference_args.json", "r") as f:
SEL_COLS = json.load(f)["selected_columns"]
# Load TFLite model
interpreter = tf.lite.Interpreter(model_path="asl_model.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Drawing utilities
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_hands = mp.solutions.hands
def load_relevant_data_subset(pq_path):
return pd.read_parquet(pq_path, columns=SEL_COLS)
def draw_hand_landmarks(seq_df):
images = []
for seq_idx in range(len(seq_df)):
x_hand = seq_df.iloc[seq_idx].filter(regex="x_right_hand.*").values
y_hand = seq_df.iloc[seq_idx].filter(regex="y_right_hand.*").values
z_hand = seq_df.iloc[seq_idx].filter(regex="z_right_hand.*").values
right_hand_image = np.zeros((600, 600, 3))
right_hand_landmarks = landmark_pb2.NormalizedLandmarkList()
for x, y, z in zip(x_hand, y_hand, z_hand):
right_hand_landmarks.landmark.add(x=x, y=y, z=z)
mp_drawing.draw_landmarks(
right_hand_image,
right_hand_landmarks,
mp_hands.HAND_CONNECTIONS,
landmark_drawing_spec=mp_drawing_styles.get_default_hand_landmarks_style()
)
images.append(right_hand_image)
return images
def preprocess_image(image):
img = resize(image, (64, 64), preserve_range=True).astype(np.float32) / 255.0
return np.expand_dims(img, axis=0)
def predict_from_parquet(parquet_path):
df = load_relevant_data_subset(parquet_path)
image_seq = draw_hand_landmarks(df)
if not image_seq:
raise ValueError("No hand image generated.")
img = preprocess_image(image_seq[len(image_seq) // 2])
interpreter.set_tensor(input_details[0]['index'], img)
interpreter.invoke()
output = interpreter.get_tensor(output_details[0]['index'])
prediction = np.argmax(output)
return prediction
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python tflite_inference.py <parquet_file_path>")
else:
parquet_file = sys.argv[1]
pred = predict_from_parquet(parquet_file)
print("Predicted class index:", pred)
|