Spaces:
Sleeping
Sleeping
| from pyvis.network import Network | |
| import os | |
| NODE_TYPE_COLORS = { | |
| 'Disease': '#079dbb', | |
| 'HPO': '#58d0e8', | |
| 'Drug': '#815ac0', | |
| 'Compound': '#d2b7e5', | |
| 'Domain': '#6bbf59', | |
| 'GO_term_P': '#ff8800', | |
| 'GO_term_F': '#ffaa00', | |
| 'GO_term_C': '#ffc300', | |
| 'Pathway': '#720026', | |
| 'kegg_Pathway': '#720026', | |
| 'EC_number': '#ce4257', | |
| 'Protein': '#3aa6a4' | |
| } | |
| EDGE_LABEL_TRANSLATION = { | |
| 'Orthology': 'is ortholog to', | |
| 'Pathway': 'takes part in', | |
| 'kegg_path_prot': 'takes part in', | |
| 'protein_domain': 'has', | |
| 'PPI': 'interacts with', | |
| 'HPO': 'is associated with', | |
| 'kegg_dis_prot': 'is related to', | |
| 'Disease': 'is related to', | |
| 'Drug': 'targets', | |
| 'protein_ec': 'catalyzes', | |
| 'Chembl': 'targets', | |
| ('protein_function', 'GO_term_F'): 'enables', | |
| ('protein_function', 'GO_term_P'): 'is involved in', | |
| ('protein_function', 'GO_term_C'): 'localizes to', | |
| } | |
| GO_CATEGORY_MAPPING = { | |
| 'Biological Process': 'GO_term_P', | |
| 'Molecular Function': 'GO_term_F', | |
| 'Cellular Component': 'GO_term_C' | |
| } | |
| def _gather_protein_edges(data, protein_id): | |
| protein_idx = data['Protein']['id_mapping'][protein_id] | |
| reverse_id_mapping = {} | |
| for node_type in data.node_types: | |
| reverse_id_mapping[node_type] = {v:k for k, v in data[node_type]['id_mapping'].items()} | |
| protein_edges = {} | |
| print(f'Gathering edges for {protein_id}...') | |
| for edge_type in data.edge_types: | |
| if 'rev' not in edge_type[1]: | |
| if edge_type not in protein_edges: | |
| protein_edges[edge_type] = [] | |
| if edge_type[0] == 'Protein': | |
| print(f'Gathering edges for {edge_type}...') | |
| # append the edges with protein_idx as source node | |
| edges = data[edge_type].edge_index[:, data[edge_type].edge_index[0] == protein_idx] | |
| protein_edges[edge_type].extend(edges.T.tolist()) | |
| elif edge_type[2] == 'Protein': | |
| print(f'Gathering edges for {edge_type}...') | |
| # append the edges with protein_idx as target node | |
| edges = data[edge_type].edge_index[:, data[edge_type].edge_index[1] == protein_idx] | |
| protein_edges[edge_type].extend(edges.T.tolist()) | |
| for edge_type in protein_edges.keys(): | |
| if protein_edges[edge_type]: | |
| mapped_edges = set() | |
| for edge in protein_edges[edge_type]: | |
| # Get source and target node types from edge_type | |
| source_type, _, target_type = edge_type | |
| # Map indices back to original IDs | |
| source_id = reverse_id_mapping[source_type][edge[0]] | |
| target_id = reverse_id_mapping[target_type][edge[1]] | |
| mapped_edges.add((source_id, target_id)) | |
| protein_edges[edge_type] = mapped_edges | |
| return protein_edges | |
| def _filter_edges(protein_id, protein_edges, prediction_df, limit=10): | |
| filtered_edges = {} | |
| prediction_categories = prediction_df['GO_category'].unique() | |
| prediction_categories = [GO_CATEGORY_MAPPING[category] for category in prediction_categories] | |
| go_category_reverse_mapping = {v:k for k, v in GO_CATEGORY_MAPPING.items()} | |
| for edge_type, edges in protein_edges.items(): | |
| # Skip if edges is empty | |
| if edges is None or len(edges) == 0: | |
| continue | |
| if edge_type[2].startswith('GO_term'): # Check if it's any GO term edge | |
| if edge_type[2] in prediction_categories: | |
| # Handle edges for GO terms that are in prediction_df | |
| category_mask = (prediction_df['GO_category'] == go_category_reverse_mapping[edge_type[2]]) & (prediction_df['UniProt_ID'] == protein_id) | |
| category_predictions = prediction_df[category_mask] | |
| if len(category_predictions) > 0: | |
| category_predictions = category_predictions.sort_values(by='Probability', ascending=False) | |
| edges_set = set(edges) # Convert to set for O(1) lookup | |
| valid_edges = [] | |
| for _, row in category_predictions.iterrows(): | |
| term = row['GO_ID'] | |
| prob = row['Probability'] | |
| edge = (protein_id, term) | |
| is_ground_truth = edge in edges_set | |
| valid_edges.append((edge, prob, is_ground_truth)) | |
| if len(valid_edges) >= limit: | |
| break | |
| filtered_edges[edge_type] = valid_edges | |
| else: | |
| # If no predictions but it's a GO category in prediction_df | |
| filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:limit]] | |
| else: | |
| # For GO terms not in prediction_df, mark them as ground truth with blue color | |
| filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:limit]] | |
| else: | |
| # For non-GO edges, include all edges up to limit | |
| filtered_edges[edge_type] = [(edge, None, True) for edge in list(edges)[:limit]] | |
| return filtered_edges | |
| def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10): | |
| protein_edges = _gather_protein_edges(data, protein_id) | |
| visualized_edges = _filter_edges(protein_id, protein_edges, prediction_df, limit) | |
| print(f'Edges to be visualized: {visualized_edges}') | |
| net = Network(height="600px", width="100%", directed=True, notebook=False) | |
| # Create groups configuration from NODE_TYPE_COLORS | |
| groups_config = {} | |
| for node_type, color in NODE_TYPE_COLORS.items(): | |
| groups_config[node_type] = { | |
| "color": {"background": color, "border": color} | |
| } | |
| # Convert groups_config to a JSON-compatible string | |
| import json | |
| groups_json = json.dumps(groups_config) | |
| # Configure physics options with settings for better clustering | |
| net.set_options("""{ | |
| "physics": { | |
| "enabled": true, | |
| "barnesHut": { | |
| "gravitationalConstant": -1000, | |
| "springLength": 250, | |
| "springConstant": 0.001, | |
| "damping": 0.09, | |
| "avoidOverlap": 0 | |
| }, | |
| "forceAtlas2Based": { | |
| "gravitationalConstant": -50, | |
| "centralGravity": 0.01, | |
| "springLength": 100, | |
| "springConstant": 0.08, | |
| "damping": 0.4, | |
| "avoidOverlap": 0 | |
| }, | |
| "solver": "barnesHut", | |
| "stabilization": { | |
| "enabled": true, | |
| "iterations": 1000, | |
| "updateInterval": 25 | |
| } | |
| }, | |
| "layout": { | |
| "improvedLayout": true, | |
| "hierarchical": { | |
| "enabled": false | |
| } | |
| }, | |
| "interaction": { | |
| "hover": true, | |
| "navigationButtons": true, | |
| "multiselect": true | |
| }, | |
| "configure": { | |
| "enabled": true, | |
| "filter": ["physics", "layout", "manipulation"], | |
| "showButton": true | |
| }, | |
| "groups": """ + groups_json + "}") | |
| # Add the main protein node | |
| net.add_node(protein_id, | |
| label=f"Protein: {protein_id}", | |
| color={'background': 'white', 'border': '#c1121f'}, | |
| borderWidth=4, | |
| shape="dot", | |
| font={'color': '#000000', 'size': 15}, | |
| group='Protein', | |
| size=30, | |
| mass=2.5) | |
| # Track added nodes to avoid duplication | |
| added_nodes = {protein_id} | |
| # Add edges and target nodes | |
| for edge_type, edges in visualized_edges.items(): | |
| source_type, relation_type, target_type = edge_type | |
| if relation_type == 'protein_function': | |
| relation_type = EDGE_LABEL_TRANSLATION[(relation_type, target_type)] | |
| else: | |
| relation_type = EDGE_LABEL_TRANSLATION[relation_type] | |
| for edge_info in edges: | |
| edge, probability, is_ground_truth = edge_info | |
| source, target = edge[0], edge[1] | |
| source_str = str(source) | |
| target_str = str(target) | |
| # Add source node if not present | |
| if source_str not in added_nodes: | |
| net.add_node(source_str, | |
| label=f"{source_str}", | |
| shape="dot", | |
| font={'color': '#000000', 'size': 12}, | |
| title=f"{source_type}: {source_str}", | |
| group=source_type, | |
| size=15, | |
| mass=1.5) | |
| added_nodes.add(source_str) | |
| # Add target node if not present | |
| if target_str not in added_nodes: | |
| net.add_node(target_str, | |
| label=f"{target_str}", | |
| shape="dot", | |
| font={'color': '#000000', 'size': 12}, | |
| title=f"{target_type}: {target_str}", | |
| group=target_type, | |
| size=15, | |
| mass=1.5) | |
| added_nodes.add(target_str) | |
| # Add edge with relationship type and probability as label | |
| edge_label = f"{relation_type}" | |
| if probability is not None: | |
| if probability == 'no_pred': | |
| edge_color = '#219ebc' | |
| edge_label += ' (P=Not generated)' | |
| else: | |
| edge_label += f" (P={probability:.2f})" | |
| edge_color = '#8338ec' if is_ground_truth else '#c1121f' | |
| # if validated prediction purple, if non-validated prediction red, if no prediction (directly from database) blue | |
| net.add_edge(source_str, target_str, | |
| label=edge_label, | |
| font={'size': 0}, | |
| color=edge_color, | |
| title=edge_label, | |
| length=200, | |
| smooth={'type': 'curvedCW', 'roundness': 0.1}) | |
| else: | |
| net.add_edge(source_str, target_str, | |
| label=edge_label, | |
| font={'size': 0}, | |
| color='#666666', # Keep default gray for non-GO edges | |
| title=edge_label, | |
| length=200, | |
| smooth={'type': 'curvedCW', 'roundness': 0.1}) | |
| # Save graph to a protein-specific file in a temporary directory | |
| os.makedirs('temp_viz', exist_ok=True) | |
| file_path = os.path.join('temp_viz', f'{protein_id}_graph.html') | |
| net.save_graph(file_path) | |
| return file_path, visualized_edges |