Spaces:
Build error
Build error
| import cv2 | |
| import numpy as np | |
| import tensorflow as tf | |
| from huggingface_hub import from_pretrained_keras | |
| from tensorflow.keras.optimizers import Adam | |
| from .constants import LEARNING_RATE | |
| def get_model(): | |
| """ | |
| Download the model from the Hugging Face Hub and compile it. | |
| """ | |
| model = from_pretrained_keras("keras-io/video-vision-transformer") | |
| model.compile( | |
| optimizer=Adam(learning_rate=LEARNING_RATE), | |
| loss="sparse_categorical_crossentropy", | |
| # metrics=[ | |
| # keras.metrics.SparseCategoricalAccuracy(name="accuracy"), | |
| # keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"), | |
| # ], | |
| ) | |
| return model | |
| model = get_model() | |
| labels = ['liver', 'kidney-right', 'kidney-left', 'femur-right', 'femur-left', 'bladder', 'heart', 'lung-right', 'lung-left', 'spleen', 'pancreas'] | |
| def predict_label(path): | |
| frames = load_video(path) | |
| dataloader = prepare_dataloader(frames) | |
| prediction = model.predict(dataloader)[0] | |
| label = np.argmax(prediction, axis=0) | |
| label = labels[label] | |
| return label | |
| def load_video(path): | |
| """ | |
| Load video from path and return a list of frames. | |
| The video is converted to grayscale because it is the format expected by the model. | |
| """ | |
| cap = cv2.VideoCapture(path) | |
| frames = [] | |
| try: | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| frames.append(frame) | |
| finally: | |
| cap.release() | |
| return np.array(frames) | |
| def prepare_dataloader(video): | |
| video = tf.expand_dims(video, axis=0) | |
| dataset = tf.data.Dataset.from_tensor_slices((video, np.array([0]))) | |
| dataloader = ( | |
| dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE) | |
| .batch(1) | |
| .prefetch(tf.data.AUTOTUNE) | |
| ) | |
| return dataloader | |
| def preprocess(frames: tf.Tensor, label: tf.Tensor): | |
| """Preprocess the frames tensors and parse the labels.""" | |
| # Preprocess images | |
| frames = tf.image.convert_image_dtype( | |
| frames[ | |
| ..., tf.newaxis | |
| ], # The new axis is to help for further processing with Conv3D layers | |
| tf.float32, | |
| ) | |
| # Parse label | |
| label = tf.cast(label, tf.float32) | |
| return frames, label |