Spaces:
Running
Running
| import streamlit as st | |
| import io | |
| import csv | |
| from segments import SegmentsClient | |
| from get_labels_from_samples import ( | |
| get_samples as get_samples_objects, | |
| export_frames_and_annotations, | |
| export_sensor_frames_and_annotations, | |
| export_all_sensor_frames_and_annotations | |
| ) | |
| def init_session_state(): | |
| if 'csv_content' not in st.session_state: | |
| st.session_state.csv_content = None | |
| if 'error' not in st.session_state: | |
| st.session_state.error = None | |
| def init_client(api_key: str) -> SegmentsClient: | |
| """Initialize the Segments.ai API client using the provided API key.""" | |
| return SegmentsClient(api_key) | |
| def parse_classes(input_str: str) -> list: | |
| """ | |
| Parse user input for classes (ranges and comma-separated lists). Returns unique sorted list of ints. | |
| """ | |
| classes = [] | |
| tokens = input_str.split(',') | |
| for token in tokens: | |
| token = token.strip() | |
| if '-' in token: | |
| try: | |
| start, end = map(int, token.split('-')) | |
| classes.extend(range(start, end + 1)) | |
| except ValueError: | |
| continue | |
| else: | |
| try: | |
| classes.append(int(token)) | |
| except ValueError: | |
| continue | |
| return sorted(set(classes)) | |
| def generate_csv(metrics: list, dataset_identifier: str) -> str: | |
| """ | |
| Generate CSV content from list of per-sample metrics. | |
| Columns: name, sample_url, sensor, num_frames, total_annotations, | |
| matching_annotations, labeled_by, reviewed_by | |
| """ | |
| output = io.StringIO() | |
| writer = csv.writer(output) | |
| writer.writerow([ | |
| 'name', 'sample_url', 'sensor', 'num_frames', | |
| 'total_annotations', 'matching_annotations', | |
| 'labeled_by', 'reviewed_by' | |
| ]) | |
| for m in metrics: | |
| url = f"https://app.segments.ai/{dataset_identifier}/samples/{m['uuid']}/{m['labelset']}" | |
| writer.writerow([ | |
| m['name'], url, m['sensor'], | |
| m['num_frames'], m['total_annotations'], | |
| m['matching_annotations'], m['labeled_by'], | |
| m['reviewed_by'] | |
| ]) | |
| content = output.getvalue() | |
| output.close() | |
| return content | |
| # ---------------------- | |
| # Streamlit UI | |
| # ---------------------- | |
| init_session_state() | |
| st.title("Per-Sample Annotation Counts by Class") | |
| api_key = st.text_input("API Key", type="password", key="api_key_input") | |
| dataset_identifier = st.text_input("Dataset Identifier (e.g., username/dataset)", key="dataset_identifier_input") | |
| classes_input = st.text_input("Classes (e.g., 1,2,5 or 1-3)", key="classes_input") | |
| run_button = st.button("Generate CSV", key="run_button") | |
| sensor_names = [] | |
| is_multisensor = False | |
| sensor_select = None | |
| samples_objects = [] | |
| if api_key and dataset_identifier: | |
| try: | |
| client = init_client(api_key) | |
| samples_objects = get_samples_objects(client, dataset_identifier) | |
| if samples_objects: | |
| label = client.get_label(samples_objects[0].uuid) | |
| sensors = getattr(getattr(label, 'attributes', None), 'sensors', None) | |
| if sensors is not None: | |
| is_multisensor = True | |
| sensor_names = [getattr(sensor, 'name', 'Unknown') for sensor in sensors] | |
| except Exception as e: | |
| st.warning(f"Could not inspect dataset sensors: {e}") | |
| if is_multisensor: | |
| sensor_select = st.selectbox("Choose sensor (optional)", options=['All sensors'] + sensor_names) | |
| if run_button: | |
| st.session_state.csv_content = None | |
| st.session_state.error = None | |
| if not api_key: | |
| st.session_state.error = "API Key is required." | |
| elif not dataset_identifier: | |
| st.session_state.error = "Dataset identifier is required." | |
| elif not classes_input: | |
| st.session_state.error = "Please specify at least one class." | |
| elif is_multisensor and not sensor_select: | |
| st.session_state.error = "Please select a sensor or 'All sensors' before generating CSV." | |
| else: | |
| with st.spinner("Processing samples..."): | |
| try: | |
| target_classes = parse_classes(classes_input) | |
| client = init_client(api_key) | |
| metrics = [] | |
| for sample in samples_objects: | |
| try: | |
| label = client.get_label(sample.uuid) | |
| labelset = getattr(label, 'labelset', '') or '' | |
| labeled_by = getattr(label, 'created_by', '') or '' | |
| reviewed_by = getattr(label, 'reviewed_by', '') or '' | |
| if is_multisensor and sensor_select and sensor_select != 'All sensors': | |
| frames_list = export_sensor_frames_and_annotations(label, sensor_select) | |
| sensor_val = sensor_select | |
| num_frames = len(frames_list) | |
| total_annotations = sum(len(f['annotations']) for f in frames_list) | |
| matching_annotations = sum( | |
| 1 | |
| for f in frames_list | |
| for ann in f['annotations'] | |
| if getattr(ann, 'category_id', None) in target_classes | |
| ) | |
| elif is_multisensor and (not sensor_select or sensor_select == 'All sensors'): | |
| all_sensor_frames = export_all_sensor_frames_and_annotations(label) | |
| for sensor_name, frames_list in all_sensor_frames.items(): | |
| num_frames = len(frames_list) | |
| total_annotations = sum(len(f['annotations']) for f in frames_list) | |
| matching_annotations = sum( | |
| 1 | |
| for f in frames_list | |
| for ann in f['annotations'] | |
| if getattr(ann, 'category_id', None) in target_classes | |
| ) | |
| metrics.append({ | |
| 'name': getattr(sample, 'name', sample.uuid), | |
| 'uuid': sample.uuid, | |
| 'labelset': labelset, | |
| 'sensor': sensor_name, | |
| 'num_frames': num_frames, | |
| 'total_annotations': total_annotations, | |
| 'matching_annotations': matching_annotations, | |
| 'labeled_by': labeled_by, | |
| 'reviewed_by': reviewed_by | |
| }) | |
| continue | |
| else: | |
| frames_list = export_frames_and_annotations(label) | |
| sensor_val = '' | |
| num_frames = len(frames_list) | |
| total_annotations = sum(len(f['annotations']) for f in frames_list) | |
| matching_annotations = sum( | |
| 1 | |
| for f in frames_list | |
| for ann in f['annotations'] | |
| if getattr(ann, 'category_id', None) in target_classes | |
| ) | |
| metrics.append({ | |
| 'name': getattr(sample, 'name', sample.uuid), | |
| 'uuid': sample.uuid, | |
| 'labelset': labelset, | |
| 'sensor': sensor_val if is_multisensor else '', | |
| 'num_frames': num_frames, | |
| 'total_annotations': total_annotations, | |
| 'matching_annotations': matching_annotations, | |
| 'labeled_by': labeled_by, | |
| 'reviewed_by': reviewed_by | |
| }) | |
| except Exception as e: | |
| continue | |
| if not metrics: | |
| st.session_state.error = "No metrics could be generated for the dataset." | |
| else: | |
| st.session_state.csv_content = generate_csv(metrics, dataset_identifier) | |
| except Exception as e: | |
| st.session_state.error = f"An error occurred: {e}" | |
| if st.session_state.error: | |
| st.error(st.session_state.error) | |
| if st.session_state.csv_content: | |
| st.download_button( | |
| label="Download Metrics CSV", | |
| data=st.session_state.csv_content, | |
| file_name="sample_metrics.csv", | |
| mime="text/csv" | |
| ) |