Spaces:
Sleeping
Sleeping
| from pyvis.network import Network | |
| import os | |
| import json | |
| import gzip | |
| NODE_TYPE_COLORS = { | |
| 'Disease': '#079dbb', | |
| 'HPO': '#58d0e8', | |
| 'Drug': '#815ac0', | |
| 'Compound': '#d2b7e5', | |
| 'Domain': '#6bbf59', | |
| 'GO_term_P': '#ff8800', | |
| 'GO_term_F': '#ffaa00', | |
| 'GO_term_C': '#ffc300', | |
| 'Pathway': '#720026', | |
| 'kegg_Pathway': '#720026', | |
| 'EC_number': '#ce4257', | |
| 'Protein': '#3aa6a4' | |
| } | |
| EDGE_LABEL_TRANSLATION = { | |
| 'Orthology': 'is ortholog to', | |
| 'Pathway': 'takes part in', | |
| 'kegg_path_prot': 'takes part in', | |
| 'protein_domain': 'has', | |
| 'PPI': 'interacts with', | |
| 'HPO': 'is associated with', | |
| 'kegg_dis_prot': 'is related to', | |
| 'Disease': 'is related to', | |
| 'Drug': 'targets', | |
| 'protein_ec': 'catalyzes', | |
| 'Chembl': 'targets', | |
| ('protein_function', 'GO_term_F'): 'enables', | |
| ('protein_function', 'GO_term_P'): 'is involved in', | |
| ('protein_function', 'GO_term_C'): 'localizes to', | |
| } | |
| NODE_LABEL_TRANSLATION = { | |
| 'HPO': 'Phenotype', | |
| 'GO_term_P': 'Biological Process', | |
| 'GO_term_F': 'Molecular Function', | |
| 'GO_term_C': 'Cellular Component', | |
| 'kegg_Pathway': 'Pathway', | |
| 'EC_number': 'EC Number', | |
| } | |
| GO_CATEGORY_MAPPING = { | |
| 'Biological Process': 'GO_term_P', | |
| 'Molecular Function': 'GO_term_F', | |
| 'Cellular Component': 'GO_term_C' | |
| } | |
| def get_node_url(node_type, node_id): | |
| """Get the URL for a node based on its type and ID""" | |
| if node_type.startswith('GO_term'): | |
| return f"https://www.ebi.ac.uk/QuickGO/term/{node_id}" | |
| elif node_type == 'Protein': | |
| return f"https://www.uniprot.org/uniprotkb/{node_id}/entry" | |
| elif node_type == 'Disease': | |
| if ':' in node_id: | |
| ontology = node_id.split(':')[0] | |
| if ontology == 'EFO': | |
| return f"http://www.ebi.ac.uk/efo/EFO_{node_id.split(':')[1]}" | |
| elif ontology == 'MONDO': | |
| return f'http://purl.obolibrary.org/obo/MONDO_{node_id.split(":")[1]}' | |
| elif ontology == 'Orphanet': | |
| return f"http://www.orpha.net/ORDO/Orphanet_{node_id.split(':')[1]}" | |
| else: | |
| return f"https://www.genome.jp/entry/{node_id}" | |
| elif node_type == 'HPO': | |
| return f"https://hpo.jax.org/browse/term/{node_id}" | |
| elif node_type == 'Drug': | |
| return f"https://go.drugbank.com/drugs/{node_id}" | |
| elif node_type == 'Compound': | |
| return f"https://www.ebi.ac.uk/chembl/explore/compound/{node_id}" | |
| elif node_type == 'Domain': | |
| return f"https://www.ebi.ac.uk/interpro/entry/InterPro/{node_id}" | |
| elif node_type == 'Pathway': | |
| return f"https://reactome.org/content/detail/{node_id}" | |
| elif node_type == 'kegg_Pathway': | |
| return f"https://www.genome.jp/pathway/{node_id}" | |
| elif node_type == 'EC_number': | |
| return f"https://enzyme.expasy.org/EC/{node_id}" | |
| else: | |
| return None | |
| def _gather_protein_edges(data, protein_id): | |
| protein_idx = data['Protein']['id_mapping'][protein_id] | |
| reverse_id_mapping = {} | |
| for node_type in data.node_types: | |
| reverse_id_mapping[node_type] = {v:k for k, v in data[node_type]['id_mapping'].items()} | |
| protein_edges = {} | |
| print(f'Gathering edges for {protein_id}...') | |
| for edge_type in data.edge_types: | |
| if 'rev' not in edge_type[1]: | |
| if edge_type not in protein_edges: | |
| protein_edges[edge_type] = [] | |
| if edge_type[0] == 'Protein': | |
| print(f'Gathering edges for {edge_type}...') | |
| # append the edges with protein_idx as source node | |
| edges = data[edge_type].edge_index[:, data[edge_type].edge_index[0] == protein_idx] | |
| protein_edges[edge_type].extend(edges.T.tolist()) | |
| elif edge_type[2] == 'Protein': | |
| print(f'Gathering edges for {edge_type}...') | |
| # append the edges with protein_idx as target node | |
| edges = data[edge_type].edge_index[:, data[edge_type].edge_index[1] == protein_idx] | |
| protein_edges[edge_type].extend(edges.T.tolist()) | |
| for edge_type in protein_edges.keys(): | |
| if protein_edges[edge_type]: | |
| mapped_edges = set() | |
| for edge in protein_edges[edge_type]: | |
| # Get source and target node types from edge_type | |
| source_type, _, target_type = edge_type | |
| # Map indices back to original IDs | |
| source_id = reverse_id_mapping[source_type][edge[0]] | |
| target_id = reverse_id_mapping[target_type][edge[1]] | |
| mapped_edges.add((source_id, target_id)) | |
| protein_edges[edge_type] = mapped_edges | |
| return protein_edges | |
| def _filter_edges(protein_id, protein_edges, prediction_df, limit=10): | |
| filtered_edges = {} | |
| prediction_categories = prediction_df['GO_category'].unique() | |
| prediction_categories = [GO_CATEGORY_MAPPING[category] for category in prediction_categories] | |
| go_category_reverse_mapping = {v:k for k, v in GO_CATEGORY_MAPPING.items()} | |
| for edge_type, edges in protein_edges.items(): | |
| # Skip if edges is empty | |
| if edges is None or len(edges) == 0: | |
| continue | |
| if edge_type[2].startswith('GO_term'): # Check if it's any GO term edge | |
| if edge_type[2] in prediction_categories: | |
| # Handle edges for GO terms that are in prediction_df | |
| category_mask = (prediction_df['GO_category'] == go_category_reverse_mapping[edge_type[2]]) & (prediction_df['UniProt_ID'] == protein_id) | |
| category_predictions = prediction_df[category_mask] | |
| if len(category_predictions) > 0: | |
| category_predictions = category_predictions.sort_values(by='Probability', ascending=False) | |
| edges_set = set(edges) # Convert to set for O(1) lookup | |
| valid_edges = [] | |
| for _, row in category_predictions.iterrows(): | |
| term = row['GO_ID'] | |
| prob = row['Probability'] | |
| edge = (protein_id, term) | |
| is_ground_truth = edge in edges_set | |
| valid_edges.append((edge, prob, is_ground_truth)) | |
| if len(valid_edges) >= limit: | |
| break | |
| filtered_edges[edge_type] = valid_edges | |
| else: | |
| # If no predictions but it's a GO category in prediction_df | |
| filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:limit]] | |
| else: | |
| # For GO terms not in prediction_df, mark them as ground truth with blue color | |
| filtered_edges[edge_type] = [(edge, 'no_pred', True) for edge in list(edges)[:limit]] | |
| else: | |
| # For non-GO edges, include all edges up to limit | |
| filtered_edges[edge_type] = [(edge, None, True) for edge in list(edges)[:limit]] | |
| return filtered_edges | |
| def visualize_protein_subgraph(data, protein_id, prediction_df, limit=10): | |
| with gzip.open('data/name_info.json.gz', 'rt', encoding='utf-8') as file: | |
| name_info = json.load(file) | |
| protein_edges = _gather_protein_edges(data, protein_id) | |
| visualized_edges = _filter_edges(protein_id, protein_edges, prediction_df, limit) | |
| print(f'Edges to be visualized: {visualized_edges}') | |
| net = Network(height="600px", width="100%", directed=True, notebook=False) | |
| # Create groups configuration from NODE_TYPE_COLORS | |
| groups_config = {} | |
| for node_type, color in NODE_TYPE_COLORS.items(): | |
| groups_config[node_type] = { | |
| "color": {"background": color, "border": color} | |
| } | |
| # Convert groups_config to a JSON-compatible string | |
| groups_json = json.dumps(groups_config) | |
| # Configure physics options with settings for better clustering | |
| net.set_options("""{ | |
| "physics": { | |
| "enabled": true, | |
| "barnesHut": { | |
| "gravitationalConstant": -1000, | |
| "springLength": 250, | |
| "springConstant": 0.001, | |
| "damping": 0.09, | |
| "avoidOverlap": 0 | |
| }, | |
| "forceAtlas2Based": { | |
| "gravitationalConstant": -50, | |
| "centralGravity": 0.01, | |
| "springLength": 100, | |
| "springConstant": 0.08, | |
| "damping": 0.4, | |
| "avoidOverlap": 0 | |
| }, | |
| "solver": "barnesHut", | |
| "stabilization": { | |
| "enabled": true, | |
| "iterations": 1000, | |
| "updateInterval": 25 | |
| } | |
| }, | |
| "layout": { | |
| "improvedLayout": true, | |
| "hierarchical": { | |
| "enabled": false | |
| } | |
| }, | |
| "interaction": { | |
| "hover": true, | |
| "navigationButtons": true, | |
| "multiselect": true | |
| }, | |
| "configure": { | |
| "enabled": false, | |
| "filter": ["physics", "layout", "manipulation"], | |
| "showButton": true | |
| }, | |
| "groups": """ + groups_json + "}") | |
| # Add the main protein node | |
| query_node_url = get_node_url('Protein', protein_id) | |
| node_name = name_info['Protein'][protein_id] | |
| query_node_title = f"{node_name} (Query Protein)" | |
| if query_node_url: | |
| query_node_title = f'<a href="{query_node_url}" target="_blank">{query_node_title}</a>' | |
| net.add_node(protein_id, | |
| label=protein_id, | |
| title=query_node_title, | |
| color={'background': 'white', 'border': '#c1121f'}, | |
| borderWidth=4, | |
| shape="dot", | |
| font={'color': '#000000', 'size': 15}, | |
| group='Protein', | |
| size=30, | |
| mass=2.5) | |
| # Track added nodes to avoid duplication | |
| added_nodes = {protein_id} | |
| # Add edges and target nodes | |
| for edge_type, edges in visualized_edges.items(): | |
| source_type, relation_type, target_type = edge_type | |
| if relation_type == 'protein_function': | |
| relation_type = EDGE_LABEL_TRANSLATION[(relation_type, target_type)] | |
| else: | |
| relation_type = EDGE_LABEL_TRANSLATION[relation_type] | |
| for edge_info in edges: | |
| edge, probability, is_ground_truth = edge_info | |
| source, target = edge[0], edge[1] | |
| source_str = str(source) | |
| target_str = str(target) | |
| # Add source node if not present | |
| if source_str not in added_nodes: | |
| if not source_type.startswith('GO_term'): | |
| node_name = name_info[source_type][source_str] | |
| else: | |
| node_name = name_info['GO_term'][source_str] | |
| url = get_node_url(source_type, source_str) | |
| title = f"{node_name} ({NODE_LABEL_TRANSLATION[source_type] if source_type in NODE_LABEL_TRANSLATION else source_type})" | |
| if url: | |
| title = f'<a href="{url}" target="_blank">{title}</a>' | |
| net.add_node(source_str, | |
| label=source_str, | |
| shape="dot", | |
| font={'color': '#000000', 'size': 12}, | |
| title=title, | |
| group=source_type, | |
| size=15, | |
| mass=1.5) | |
| added_nodes.add(source_str) | |
| # Add target node if not present | |
| if target_str not in added_nodes: | |
| if not target_type.startswith('GO_term'): | |
| node_name = name_info[target_type][target_str] | |
| else: | |
| node_name = name_info['GO_term'][target_str] | |
| url = get_node_url(target_type, target_str) | |
| title = f"{node_name} ({NODE_LABEL_TRANSLATION[target_type] if target_type in NODE_LABEL_TRANSLATION else target_type})" | |
| if url: | |
| title = f'<a href="{url}" target="_blank">{title}</a>' | |
| net.add_node(target_str, | |
| label=target_str, | |
| shape="dot", | |
| font={'color': '#000000', 'size': 12}, | |
| title=title, | |
| group=target_type, | |
| size=15, | |
| mass=1.5) | |
| added_nodes.add(target_str) | |
| # Add edge with relationship type and probability as label | |
| edge_label = f"{relation_type}" | |
| if probability is not None: | |
| if probability == 'no_pred': | |
| edge_color = '#219ebc' | |
| edge_label += ' (P=Not generated)' | |
| else: | |
| edge_label += f" (P={probability:.2f})" | |
| edge_color = '#8338ec' if is_ground_truth else '#c1121f' | |
| # if validated prediction purple, if non-validated prediction red, if no prediction (directly from database) blue | |
| net.add_edge(source_str, target_str, | |
| label=edge_label, | |
| font={'size': 0}, | |
| color=edge_color, | |
| title=edge_label, | |
| length=200, | |
| smooth={'type': 'curvedCW', 'roundness': 0.1}) | |
| else: | |
| net.add_edge(source_str, target_str, | |
| label=edge_label, | |
| font={'size': 0}, | |
| color='#666666', # Keep default gray for non-GO edges | |
| title=edge_label, | |
| length=200, | |
| smooth={'type': 'curvedCW', 'roundness': 0.1}) | |
| # LEGEND | |
| legend_html = """ | |
| <style> | |
| .kg-legend { | |
| margin-top: 20px; | |
| padding: 20px; | |
| border: 1px solid #ddd; | |
| border-radius: 5px; | |
| font-family: Arial, sans-serif; | |
| display: flex; | |
| gap: 20px; | |
| } | |
| .legend-section-nodes { | |
| flex: 2; /* Takes up 2/3 of the space */ | |
| } | |
| .legend-section-edges { | |
| flex: 1; /* Takes up 1/3 of the space */ | |
| } | |
| .legend-title { | |
| margin-bottom: 15px; | |
| color: #333; | |
| font-size: 16px; | |
| font-weight: bold; | |
| } | |
| .nodes-grid { | |
| display: grid; | |
| grid-template-columns: repeat(2, 1fr); | |
| gap: 12px; | |
| } | |
| .edges-grid { | |
| display: grid; | |
| grid-template-columns: 1fr; | |
| gap: 12px; | |
| } | |
| .legend-item { | |
| display: flex; | |
| align-items: center; | |
| padding: 4px; | |
| } | |
| .node-indicator { | |
| width: 15px; | |
| height: 15px; | |
| border-radius: 50%; | |
| margin-right: 10px; | |
| flex-shrink: 0; | |
| } | |
| .edge-indicator { | |
| width: 40px; | |
| height: 3px; | |
| margin-right: 10px; | |
| flex-shrink: 0; | |
| } | |
| .legend-label { | |
| font-size: 14px; | |
| } | |
| </style> | |
| <div class="kg-legend"> | |
| <div class="legend-section-nodes"> | |
| <div class="legend-title">Node Types</div> | |
| <div class="nodes-grid">""" | |
| # Node types in 2 columns | |
| for node_type, color in NODE_TYPE_COLORS.items(): | |
| if node_type == 'kegg_Pathway': | |
| continue | |
| if node_type in NODE_LABEL_TRANSLATION: | |
| node_label = NODE_LABEL_TRANSLATION[node_type] | |
| else: | |
| node_label = node_type | |
| legend_html += f""" | |
| <div class="legend-item"> | |
| <div class="node-indicator" style="background-color: {color};"></div> | |
| <span class="legend-label">{node_label}</span> | |
| </div>""" | |
| # Edge types in 1 column | |
| legend_html += """ | |
| </div> | |
| </div> | |
| <div class="legend-section-edges"> | |
| <div class="legend-title">Edge Colors</div> | |
| <div class="edges-grid"> | |
| <div class="legend-item"> | |
| <div class="edge-indicator" style="background-color: #8338ec;"></div> | |
| <span class="legend-label">Confirmed Prediction (Found in Ground Truth)</span> | |
| </div> | |
| <div class="legend-item"> | |
| <div class="edge-indicator" style="background-color: #c1121f;"></div> | |
| <span class="legend-label">Novel Prediction (Not in Ground Truth)</span> | |
| </div> | |
| <div class="legend-item"> | |
| <div class="edge-indicator" style="background-color: #219ebc;"></div> | |
| <span class="legend-label">Existing GO Term Annotation</span> | |
| </div> | |
| <div class="legend-item"> | |
| <div class="edge-indicator" style="background-color: #666666;"></div> | |
| <span class="legend-label">Other Relationships</span> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| # Save graph to a protein-specific file in a temporary directory | |
| os.makedirs('temp_viz', exist_ok=True) | |
| file_path = os.path.join('temp_viz', f'{protein_id}_graph.html') | |
| net.save_graph(file_path) | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| # Add the custom popup JavaScript code before the return network statement | |
| custom_popup_code = """ | |
| // make a custom popup | |
| var popup = document.createElement("div"); | |
| popup.className = 'popup'; | |
| popupTimeout = null; | |
| popup.addEventListener('mouseover', function () { | |
| if (popupTimeout !== null) { | |
| clearTimeout(popupTimeout); | |
| popupTimeout = null; | |
| } | |
| }); | |
| popup.addEventListener('mouseout', function () { | |
| if (popupTimeout === null) { | |
| hidePopup(); | |
| } | |
| }); | |
| container.appendChild(popup); | |
| // use the popup event to show | |
| network.on("showPopup", function (params) { | |
| showPopup(params); | |
| }); | |
| // use the hide event to hide it | |
| network.on("hidePopup", function (params) { | |
| hidePopup(); | |
| }); | |
| // hiding the popup through css | |
| function hidePopup() { | |
| popupTimeout = setTimeout(function () { popup.style.display = 'none'; }, 500); | |
| } | |
| // showing the popup | |
| function showPopup(nodeId) { | |
| // get the data from the vis.DataSet | |
| var nodeData = nodes.get(nodeId); | |
| // get the position of the node | |
| var posCanvas = network.getPositions([nodeId])[nodeId]; | |
| if (!nodeData) { | |
| var edgeData = edges.get(nodeId); | |
| var poses = network.getPositions([edgeData.from, edgeData.to]); | |
| var middle_x = (poses[edgeData.to].x - poses[edgeData.from].x) * 0.5; | |
| var middle_y = (poses[edgeData.to].y - poses[edgeData.from].y) * 0.5; | |
| posCanvas = poses[edgeData.from]; | |
| posCanvas.x = posCanvas.x + middle_x; | |
| posCanvas.y = posCanvas.y + middle_y; | |
| popup.innerHTML = edgeData.title; | |
| } else { | |
| popup.innerHTML = nodeData.title; | |
| // get the bounding box of the node | |
| var boundingBox = network.getBoundingBox(nodeId); | |
| posCanvas.x = posCanvas.x + 0.5 * (boundingBox.right - boundingBox.left); | |
| posCanvas.y = posCanvas.y + 0.5 * (boundingBox.top - boundingBox.bottom); | |
| }; | |
| //position tooltip: | |
| // convert coordinates to the DOM space | |
| var posDOM = network.canvasToDOM(posCanvas); | |
| // Give it an offset | |
| posDOM.x += 10; | |
| posDOM.y -= 20; | |
| // show and place the tooltip. | |
| popup.style.display = 'block'; | |
| popup.style.top = posDOM.y + 'px'; | |
| popup.style.left = posDOM.x + 'px'; | |
| } | |
| """ | |
| # Add the custom popup CSS | |
| custom_popup_css = """ | |
| /* position absolute is important and the container has to be relative or absolute as well. */ | |
| div.popup { | |
| position: absolute; | |
| top: 0px; | |
| left: 0px; | |
| display: none; | |
| background-color: white; | |
| border-radius: 3px; | |
| border: 1px solid #ddd; | |
| box-shadow: 3px 3px 10px rgba(0, 0, 0, 0.2); | |
| padding: 5px; | |
| z-index: 1000; | |
| } | |
| """ | |
| # Insert the custom CSS in the head | |
| content = content.replace('</style>', f'{custom_popup_css}</style>') | |
| # Insert the custom popup code before the "return network;" statement | |
| content = content.replace('return network;', f'{custom_popup_code}\nreturn network;') | |
| # Remove the original tooltip-hiding CSS if it exists | |
| content = content.replace(""" | |
| /* hide the original tooltip */ | |
| .vis-network-tooltip { | |
| display:none; | |
| }""", "") | |
| # Insert the legend before the closing body tag | |
| content = content.replace('</body>', f'{legend_html}</body>') | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| f.write(content) | |
| return file_path, visualized_edges |