Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| import streamlit as st | |
| import io | |
| import csv | |
| import concurrent.futures | |
| from segments import SegmentsClient | |
| from datetime import datetime | |
| import sys | |
| import os | |
| from get_labels_from_samples import ( | |
| get_samples as get_samples_objects, | |
| export_frames_and_annotations, | |
| export_sensor_frames_and_annotations, | |
| export_all_sensor_frames_and_annotations | |
| ) | |
| def init_session_state(): | |
| if 'csv_content' not in st.session_state: | |
| st.session_state.csv_content = None | |
| if 'error' not in st.session_state: | |
| st.session_state.error = None | |
| def init_client(api_key: str) -> SegmentsClient: | |
| """Initialize the Segments.ai API client using the provided API key.""" | |
| return SegmentsClient(api_key) | |
| def parse_classes(input_str: str) -> list: | |
| """ | |
| Parse user input for classes (ranges and comma-separated lists). Returns unique sorted list of ints. | |
| """ | |
| classes = [] | |
| tokens = input_str.split(',') | |
| for token in tokens: | |
| token = token.strip() | |
| if '-' in token: | |
| try: | |
| start, end = map(int, token.split('-')) | |
| classes.extend(range(start, end + 1)) | |
| except ValueError: | |
| continue | |
| else: | |
| try: | |
| classes.append(int(token)) | |
| except ValueError: | |
| continue | |
| return sorted(set(classes)) | |
| def _count_from_frames(frames, target_set): | |
| """Helper to count frames, total annotations, and matching annotations directly.""" | |
| if not frames: | |
| return 0, 0, 0 | |
| num_frames = len(frames) | |
| total_annotations = 0 | |
| matching_annotations = 0 | |
| for f in frames: | |
| anns = getattr(f, 'annotations', []) | |
| total_annotations += len(anns) | |
| if target_set: | |
| for ann in anns: | |
| if getattr(ann, 'category_id', None) in target_set: | |
| matching_annotations += 1 | |
| return num_frames, total_annotations, matching_annotations | |
| def compute_metrics_for_sample(sample, api_key, target_set, is_multisensor, sensor_select): | |
| """ | |
| Fetch label for a single sample and compute metrics. | |
| Returns a list of metric dicts (one per sensor if 'All sensors', otherwise one). | |
| """ | |
| try: | |
| client = init_client(api_key) | |
| label = client.get_label(sample.uuid) | |
| labelset = getattr(label, 'labelset', '') or '' | |
| labeled_by = getattr(label, 'created_by', '') or '' | |
| reviewed_by = getattr(label, 'reviewed_by', '') or '' | |
| metrics_rows = [] | |
| if is_multisensor: | |
| sensors = getattr(getattr(label, 'attributes', None), 'sensors', None) or [] | |
| if sensor_select and sensor_select != 'All sensors': | |
| # single sensor | |
| for sensor in sensors: | |
| if getattr(sensor, 'name', None) == sensor_select: | |
| frames = getattr(getattr(sensor, 'attributes', None), 'frames', []) | |
| num_frames, total_annotations, matching_annotations = _count_from_frames(frames, target_set) | |
| metrics_rows.append({ | |
| 'name': getattr(sample, 'name', sample.uuid), | |
| 'uuid': sample.uuid, | |
| 'labelset': labelset, | |
| 'sensor': sensor_select, | |
| 'num_frames': num_frames, | |
| 'total_annotations': total_annotations, | |
| 'matching_annotations': matching_annotations, | |
| 'labeled_by': labeled_by, | |
| 'reviewed_by': reviewed_by | |
| }) | |
| break | |
| else: | |
| # all sensors | |
| for sensor in sensors: | |
| sensor_name = getattr(sensor, 'name', 'Unknown') | |
| frames = getattr(getattr(sensor, 'attributes', None), 'frames', []) | |
| num_frames, total_annotations, matching_annotations = _count_from_frames(frames, target_set) | |
| metrics_rows.append({ | |
| 'name': getattr(sample, 'name', sample.uuid), | |
| 'uuid': sample.uuid, | |
| 'labelset': labelset, | |
| 'sensor': sensor_name, | |
| 'num_frames': num_frames, | |
| 'total_annotations': total_annotations, | |
| 'matching_annotations': matching_annotations, | |
| 'labeled_by': labeled_by, | |
| 'reviewed_by': reviewed_by | |
| }) | |
| else: | |
| # single-sensor dataset | |
| frames = getattr(getattr(label, 'attributes', None), 'frames', []) | |
| num_frames, total_annotations, matching_annotations = _count_from_frames(frames, target_set) | |
| metrics_rows.append({ | |
| 'name': getattr(sample, 'name', sample.uuid), | |
| 'uuid': sample.uuid, | |
| 'labelset': labelset, | |
| 'sensor': '', | |
| 'num_frames': num_frames, | |
| 'total_annotations': total_annotations, | |
| 'matching_annotations': matching_annotations, | |
| 'labeled_by': labeled_by, | |
| 'reviewed_by': reviewed_by | |
| }) | |
| return metrics_rows | |
| except Exception: | |
| return [] | |
| def generate_csv(metrics: list, dataset_identifier: str) -> str: | |
| """ | |
| Generate CSV content from list of per-sample metrics. | |
| Columns: name, sample_url, sensor, num_frames, total_annotations, | |
| matching_annotations, labeled_by, reviewed_by | |
| """ | |
| output = io.StringIO() | |
| writer = csv.writer(output) | |
| writer.writerow([ | |
| 'name', 'sample_url', 'sensor', 'num_frames', | |
| 'total_annotations', 'matching_annotations', | |
| 'labeled_by', 'reviewed_by' | |
| ]) | |
| for m in metrics: | |
| url = f"https://app.segments.ai/{dataset_identifier}/samples/{m['uuid']}/{m['labelset']}" | |
| writer.writerow([ | |
| m['name'], url, m['sensor'], | |
| m['num_frames'], m['total_annotations'], | |
| m['matching_annotations'], m['labeled_by'], | |
| m['reviewed_by'] | |
| ]) | |
| content = output.getvalue() | |
| output.close() | |
| return content | |
| # ---------------------- | |
| # Streamlit UI | |
| # ---------------------- | |
| init_session_state() | |
| st.title("Per-Sample Annotation Counts by Class") | |
| api_key = st.text_input("API Key", type="password", key="api_key_input") | |
| dataset_identifier = st.text_input("Dataset Identifier (e.g., username/dataset)", key="dataset_identifier_input") | |
| classes_input = st.text_input("Classes (e.g., 1,2,5 or 1-3)", key="classes_input") | |
| run_button = st.button("Generate CSV", key="run_button") | |
| sensor_names = [] | |
| is_multisensor = False | |
| sensor_select = None | |
| samples_objects = [] | |
| if api_key and dataset_identifier: | |
| try: | |
| client = init_client(api_key) | |
| samples_objects = get_samples_objects(client, dataset_identifier) | |
| if samples_objects: | |
| label = client.get_label(samples_objects[0].uuid) | |
| sensors = getattr(getattr(label, 'attributes', None), 'sensors', None) | |
| if sensors is not None: | |
| is_multisensor = True | |
| sensor_names = [getattr(sensor, 'name', 'Unknown') for sensor in sensors] | |
| except Exception as e: | |
| st.warning(f"Could not inspect dataset sensors: {e}") | |
| if is_multisensor: | |
| sensor_select = st.selectbox("Choose sensor (optional)", options=['All sensors'] + sensor_names) | |
| # Concurrency control | |
| parallel_workers = st.slider("Parallel requests", min_value=1, max_value=32, value=8, help="Increase to speed up processing; lower if you hit API limits.") | |
| if run_button: | |
| st.session_state.csv_content = None | |
| st.session_state.error = None | |
| if not api_key: | |
| st.session_state.error = "API Key is required." | |
| elif not dataset_identifier: | |
| st.session_state.error = "Dataset identifier is required." | |
| elif not classes_input: | |
| st.session_state.error = "Please specify at least one class." | |
| elif is_multisensor and not sensor_select: | |
| st.session_state.error = "Please select a sensor or 'All sensors' before generating CSV." | |
| else: | |
| # Show loader/status message while checking dataset type and generating CSV | |
| status_ctx = None | |
| try: | |
| status_ctx = st.status("Checking dataset type...", expanded=True) | |
| except AttributeError: | |
| st.info("Checking dataset type...") | |
| try: | |
| target_classes = parse_classes(classes_input) | |
| target_set = set(target_classes) | |
| metrics = [] | |
| # Update loader after dataset type check | |
| if status_ctx is not None: | |
| status_ctx.update(label="Dataset type checked. Processing samples...", state="running") | |
| progress = st.progress(0) | |
| total = len(samples_objects) | |
| done = 0 | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=parallel_workers) as executor: | |
| futures = [ | |
| executor.submit( | |
| compute_metrics_for_sample, | |
| sample, | |
| api_key, | |
| target_set, | |
| is_multisensor, | |
| sensor_select, | |
| ) | |
| for sample in samples_objects | |
| ] | |
| for future in concurrent.futures.as_completed(futures): | |
| rows = future.result() | |
| if rows: | |
| metrics.extend(rows) | |
| done += 1 | |
| if total: | |
| progress.progress(min(done / total, 1.0)) | |
| if not metrics: | |
| st.session_state.error = "No metrics could be generated for the dataset." | |
| else: | |
| st.session_state.csv_content = generate_csv(metrics, dataset_identifier) | |
| if status_ctx is not None: | |
| status_ctx.update(label="CSV generated!", state="complete") | |
| except Exception as e: | |
| st.session_state.error = f"An error occurred: {e}" | |
| if status_ctx is not None: | |
| status_ctx.update(label="Error occurred.", state="error") | |
| if st.session_state.error: | |
| st.error(st.session_state.error) | |
| if st.session_state.csv_content: | |
| today_str = datetime.now().strftime("%Y%m%d") | |
| filename = f"{today_str}_{dataset_identifier}_count-by-class.csv" | |
| st.download_button( | |
| "Download CSV", | |
| data=st.session_state.csv_content, | |
| file_name=filename, | |
| mime="text/csv" | |
| ) | |