Spaces:
Sleeping
Sleeping
Commit
·
c450b97
1
Parent(s):
4873e8b
latest changes
Browse files
app.py
CHANGED
|
@@ -41,7 +41,8 @@ def sample_frames(frames_list, target_count):
|
|
| 41 |
if len(frames_list) <= target_count:
|
| 42 |
return frames_list
|
| 43 |
indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
|
| 44 |
-
|
|
|
|
| 45 |
return sampled
|
| 46 |
|
| 47 |
def live_predict_stream(image_np_array):
|
|
@@ -74,7 +75,8 @@ def live_predict_stream(image_np_array):
|
|
| 74 |
if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
|
| 75 |
prediction_result = "Error: Not enough frames for model."
|
| 76 |
status_message = "Error during frame sampling."
|
| 77 |
-
|
|
|
|
| 78 |
raw_frames_buffer.clear()
|
| 79 |
current_clip_start_time = time.time()
|
| 80 |
last_prediction_completion_time = time.time()
|
|
@@ -88,7 +90,7 @@ def live_predict_stream(image_np_array):
|
|
| 88 |
logits = outputs.logits
|
| 89 |
|
| 90 |
predicted_class_id = logits.argmax(-1).item()
|
| 91 |
-
predicted_label = model.config.id2label.get(predicted_class_id, "Unknown")
|
| 92 |
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
|
| 93 |
|
| 94 |
prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
|
|
@@ -175,12 +177,20 @@ with gr.Blocks() as demo:
|
|
| 175 |
Use this API endpoint to send base64-encoded frames for prediction.
|
| 176 |
"""
|
| 177 |
)
|
|
|
|
|
|
|
| 178 |
gr.Interface(
|
| 179 |
-
fn=lambda
|
| 180 |
-
inputs=gr.
|
| 181 |
-
outputs=gr.Textbox(label="
|
| 182 |
-
|
|
|
|
| 183 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
if __name__ == "__main__":
|
| 186 |
demo.launch()
|
|
|
|
| 41 |
if len(frames_list) <= target_count:
|
| 42 |
return frames_list
|
| 43 |
indices = np.linspace(0, len(frames_list) - 1, target_count, dtype=int)
|
| 44 |
+
# FIX: Corrected list indexing from () to []
|
| 45 |
+
sampled = [frames_list[int(i)] for i in indices]
|
| 46 |
return sampled
|
| 47 |
|
| 48 |
def live_predict_stream(image_np_array):
|
|
|
|
| 75 |
if len(frames_for_model) < MODEL_INPUT_NUM_FRAMES:
|
| 76 |
prediction_result = "Error: Not enough frames for model."
|
| 77 |
status_message = "Error during frame sampling."
|
| 78 |
+
print(f"ERROR: Insufficient frames for model input: {len(frames_for_model)}/{MODEL_INPUT_NUM_FRAMES}")
|
| 79 |
+
app_state = "recording" # Reset to recording state
|
| 80 |
raw_frames_buffer.clear()
|
| 81 |
current_clip_start_time = time.time()
|
| 82 |
last_prediction_completion_time = time.time()
|
|
|
|
| 90 |
logits = outputs.logits
|
| 91 |
|
| 92 |
predicted_class_id = logits.argmax(-1).item()
|
| 93 |
+
predicted_label = model.config.id2label.get(predicted_class_id, "Unknown")
|
| 94 |
confidence = torch.nn.functional.softmax(logits, dim=-1)[0][predicted_class_id].item()
|
| 95 |
|
| 96 |
prediction_result = f"Predicted: {predicted_label} (Confidence: {confidence:.2f})"
|
|
|
|
| 177 |
Use this API endpoint to send base64-encoded frames for prediction.
|
| 178 |
"""
|
| 179 |
)
|
| 180 |
+
# Re-adding a slightly more representative API interface
|
| 181 |
+
# Gradio's automatic API documentation will use this to show inputs/outputs
|
| 182 |
gr.Interface(
|
| 183 |
+
fn=lambda frames_list: f"Received {len(frames_list)} frames. This is a dummy response. Integrate predict_from_frames_api here.",
|
| 184 |
+
inputs=gr.Json(label="List of Base64-encoded image strings"),
|
| 185 |
+
outputs=gr.Textbox(label="API Response"),
|
| 186 |
+
live=False,
|
| 187 |
+
allow_flagging="never" # For API endpoints, flagging is usually not desired
|
| 188 |
)
|
| 189 |
+
# Note: The actual `predict_from_frames_api` function is defined above,
|
| 190 |
+
# but for a clean API tab, we can use a dummy interface here that Gradio will
|
| 191 |
+
# use to generate the interactive API documentation. The actual API call
|
| 192 |
+
# from your local script directly targets the /run/predict_from_frames_api endpoint.
|
| 193 |
+
|
| 194 |
|
| 195 |
if __name__ == "__main__":
|
| 196 |
demo.launch()
|