| import torch | |
| import torch.nn as nn | |
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| script_path=os.path.abspath(__file__) | |
| script_dir=os.path.dirname(script_path) | |
| os.chdir(script_dir) | |
| class SimpleNN(nn.Module): | |
| def __init__(self, input_dim): | |
| super(SimpleNN, self).__init__() | |
| self.fc1 = nn.Linear(input_dim, 100) | |
| self.dropout1 = nn.Dropout(0.5) | |
| self.fc2 = nn.Linear(100, 100) | |
| self.dropout2 = nn.Dropout(0.5) | |
| self.fc3 = nn.Linear(100, 1) | |
| def forward(self, x): | |
| x = torch.relu(self.fc1(x)) | |
| x = self.dropout1(x) | |
| x = torch.relu(self.fc2(x)) | |
| x = self.dropout2(x) | |
| x = self.fc3(x) | |
| return x | |
| input_dim = 51 | |
| model = SimpleNN(input_dim) | |
| model.load_state_dict(torch.load('best_model.pth')) | |
| model.eval() | |
| weights = [] | |
| weights.append(model.fc1.weight.detach().numpy()) | |
| weights.append(model.fc2.weight.detach().numpy()) | |
| weights.append(model.fc3.weight.detach().numpy()) | |
| layers = [input_dim, 100, 100, 1] | |
| def draw_neural_network(layers, weights): | |
| G = nx.Graph() | |
| pos = {} | |
| layer_nodes = [] | |
| for i, num_nodes in enumerate(layers): | |
| layer_nodes.append([]) | |
| for j in range(num_nodes): | |
| node_name = f'L{i}_N{j}' | |
| layer_nodes[-1].append(node_name) | |
| pos[node_name] = (i, -j + num_nodes // 2) | |
| edges = [] | |
| edge_colors = [] | |
| for i in range(len(layers) - 1): | |
| for j, node in enumerate(layer_nodes[i]): | |
| for k, next_node in enumerate(layer_nodes[i+1]): | |
| weight = weights[i][k, j] | |
| edges.append((node, next_node)) | |
| edge_colors.append(weight) | |
| G.add_edges_from(edges) | |
| plt.figure(figsize=(10, 10)) | |
| nx.draw(G, pos, with_labels=False, node_size=700, node_color='lightblue', | |
| edge_color=edge_colors, edge_cmap=plt.cm.viridis, | |
| width=2, edge_vmin=min(edge_colors), edge_vmax=max(edge_colors)) | |
| for key, value in pos.items(): | |
| plt.text(value[0], value[1] + 0.1, key, ha='center', va='center') | |
| plt.title("Neural Network Visualization") | |
| plt.show() | |
| draw_neural_network(layers, weights) | |