Spaces:
Runtime error
Runtime error
| import plotly.graph_objects as go | |
| import networkx as nx | |
| import networkx as nx | |
| from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges, | |
| Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource) | |
| from bokeh.palettes import Spectral4 | |
| from bokeh.plotting import from_networkx | |
| def create_bokeh_plot(entities, relationships): | |
| # Create a NetworkX graph | |
| G = nx.Graph() | |
| for entity_id, entity_data in entities.items(): | |
| G.add_node(entity_id, label=f"{entity_data['value']} ({entity_data['type']})") | |
| for source, relation, target in relationships: | |
| G.add_edge(source, target, label=relation) | |
| plot = Plot(width=600, height=600, # Increased size for better visibility | |
| x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2)) | |
| plot.title.text = "Knowledge Graph Interaction" | |
| # Use tooltips to show node and edge labels on hover | |
| node_hover = HoverTool(tooltips=[("Entity", "@label")]) | |
| edge_hover = HoverTool(tooltips=[("Relation", "@label")]) | |
| plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool()) | |
| graph_renderer = from_networkx(G, nx.spring_layout, scale=1,k=0.5, iterations=50, center=(0, 0)) | |
| graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0]) | |
| graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2]) | |
| graph_renderer.node_renderer.hover_glyph = Scatter(size=15, fill_color=Spectral4[1]) | |
| graph_renderer.edge_renderer.glyph = MultiLine(line_color="#000", line_alpha=0.9, line_width=3) | |
| graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=4) | |
| graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3) | |
| graph_renderer.selection_policy = NodesAndLinkedEdges() | |
| graph_renderer.inspection_policy = NodesAndLinkedEdges() | |
| plot.renderers.append(graph_renderer) | |
| # Add node labels | |
| x, y = zip(*graph_renderer.layout_provider.graph_layout.values()) | |
| node_labels = nx.get_node_attributes(G, 'label') | |
| source = ColumnDataSource({'x': x, 'y': y, 'label': [node_labels[node] for node in G.nodes()]}) | |
| labels = LabelSet(x='x', y='y', text='label', source=source, background_fill_color='white', | |
| text_font_size='8pt', background_fill_alpha=0.7) | |
| plot.renderers.append(labels) | |
| # Add edge labels | |
| edge_x = [] | |
| edge_y = [] | |
| edge_labels = [] | |
| for (start_node, end_node, label) in G.edges(data='label'): | |
| start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node] | |
| end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node] | |
| edge_x.append((start_x + end_x) / 2) | |
| edge_y.append((start_y + end_y) / 2) | |
| edge_labels.append(label) | |
| edge_label_source = ColumnDataSource({'x': edge_x, 'y': edge_y, 'label': edge_labels}) | |
| edge_labels = LabelSet(x='x', y='y', text='label', source=edge_label_source, | |
| background_fill_color='white', text_font_size='8pt', | |
| background_fill_alpha=0.7) | |
| plot.renderers.append(edge_labels) | |
| return plot | |
| # def create_bokeh_plot(entities, relationships): | |
| # # Create a NetworkX graph | |
| # G = nx.Graph() | |
| # for entity_id, entity_data in entities.items(): | |
| # G.add_node(entity_id, **entity_data) | |
| # for source, relation, target in relationships: | |
| # G.add_edge(source, target) | |
| # # Create a Bokeh plot | |
| # plot = figure(title="Knowledge Graph", x_range=(-1.1,1.1), y_range=(-1.1,1.1), | |
| # width=400, height=400, tools="pan,wheel_zoom,box_zoom,reset") | |
| # # Create graph renderer | |
| # graph_renderer = from_networkx(G, nx.spring_layout, scale=1, center=(0,0)) | |
| # # Add graph renderer to plot | |
| # plot.renderers.append(graph_renderer) | |
| # return plot | |
| def create_plotly_plot(entities, relationships): | |
| G = nx.DiGraph() # Use DiGraph for directed edges | |
| for entity_id, entity_data in entities.items(): | |
| G.add_node(entity_id, **entity_data) | |
| for source, relation, target in relationships: | |
| G.add_edge(source, target, relation=relation) | |
| pos = nx.spring_layout(G, k=0.5, iterations=50) # Adjust layout parameters | |
| edge_trace = go.Scatter( | |
| x=[], | |
| y=[], | |
| line=dict(width=1, color="#888"), | |
| hoverinfo="text", | |
| mode="lines", | |
| text=[], | |
| ) | |
| node_trace = go.Scatter( | |
| x=[], | |
| y=[], | |
| mode="markers+text", | |
| hoverinfo="text", | |
| marker=dict( | |
| showscale=True, | |
| colorscale="Viridis", | |
| reversescale=True, | |
| color=[], | |
| size=15, | |
| colorbar=dict( | |
| thickness=15, | |
| title="Node Connections", | |
| xanchor="left", | |
| titleside="right", | |
| ), | |
| line_width=2, | |
| ), | |
| text=[], | |
| textposition="top center", | |
| ) | |
| edge_labels = [] | |
| for edge in G.edges(): | |
| x0, y0 = pos[edge[0]] | |
| x1, y1 = pos[edge[1]] | |
| edge_trace["x"] += (x0, x1, None) | |
| edge_trace["y"] += (y0, y1, None) | |
| # Calculate midpoint for edge label | |
| mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2 | |
| edge_labels.append( | |
| go.Scatter( | |
| x=[mid_x], | |
| y=[mid_y], | |
| mode="text", | |
| text=[G.edges[edge]["relation"]], | |
| textposition="middle center", | |
| hoverinfo="none", | |
| showlegend=False, | |
| textfont=dict(size=8), | |
| ) | |
| ) | |
| for node in G.nodes(): | |
| x, y = pos[node] | |
| node_trace["x"] += (x,) | |
| node_trace["y"] += (y,) | |
| node_info = f"{entities[node]['value']} ({entities[node]['type']})" | |
| node_trace["text"] += (node_info,) | |
| node_trace["marker"]["color"] += (len(list(G.neighbors(node))),) | |
| fig = go.Figure( | |
| data=[edge_trace, node_trace] + edge_labels, | |
| layout=go.Layout( | |
| title="Knowledge Graph", | |
| titlefont_size=16, | |
| showlegend=False, | |
| hovermode="closest", | |
| margin=dict(b=20, l=5, r=5, t=40), | |
| annotations=[], | |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), | |
| width=800, | |
| height=600, | |
| ), | |
| ) | |
| # Enable dragging of nodes | |
| fig.update_layout( | |
| newshape=dict(line_color="#009900"), | |
| # Enable zoom | |
| xaxis=dict( | |
| scaleanchor="y", | |
| scaleratio=1, | |
| ), | |
| yaxis=dict( | |
| scaleanchor="x", | |
| scaleratio=1, | |
| ), | |
| ) | |
| return fig |