Spaces:
Runtime error
Runtime error
| import os | |
| import xml.etree.ElementTree as ET | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import List, Dict, Any | |
| from collections import defaultdict | |
| from accelerate import Accelerator | |
| class DynamicModel(nn.Module): | |
| def __init__(self, sections: Dict[str, List[Dict[str, Any]]]): | |
| super(DynamicModel, self).__init__() | |
| self.sections = nn.ModuleDict() | |
| for section_name, layers in sections.items(): | |
| self.sections[section_name] = nn.ModuleList() | |
| for layer_params in layers: | |
| self.sections[section_name].append(self.create_layer(layer_params)) | |
| def create_layer(self, layer_params: Dict[str, Any]) -> nn.Module: | |
| layer = nn.Linear(layer_params['input_size'], layer_params['output_size']) | |
| activation = layer_params.get('activation', 'relu') | |
| if activation == 'relu': | |
| return nn.Sequential(layer, nn.ReLU()) | |
| elif activation == 'tanh': | |
| return nn.Sequential(layer, nn.Tanh()) | |
| elif activation == 'sigmoid': | |
| return nn.Sequential(layer, nn.Sigmoid()) | |
| else: | |
| return layer | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| for section_name, layers in self.sections.items(): | |
| for layer in layers: | |
| x = layer(x) | |
| return x | |
| def parse_xml_file(file_path: str) -> List[Dict[str, Any]]: | |
| tree = ET.parse(file_path) | |
| root = tree.getroot() | |
| layers = [] | |
| for prov in root.findall('.//prov'): | |
| layer_params = { | |
| 'input_size': 128, # Example: fixed input size | |
| 'output_size': 256, # Example: fixed output size | |
| 'activation': 'relu' # Default activation | |
| } | |
| layers.append(layer_params) | |
| return layers | |
| def create_model_from_folder(folder_path: str) -> DynamicModel: | |
| sections = defaultdict(list) | |
| for root, dirs, files in os.walk(folder_path): | |
| for file in files: | |
| if file.endswith('.xml'): | |
| file_path = os.path.join(root, file) | |
| try: | |
| layers = parse_xml_file(file_path) | |
| section_name = os.path.basename(root) | |
| sections[section_name].extend(layers) | |
| except Exception as e: | |
| print(f"Error processing {file_path}: {str(e)}") | |
| return DynamicModel(sections) | |
| def main(): | |
| folder_path = 'Xml_Data' | |
| model = create_model_from_folder(folder_path) | |
| print(f"Created dynamic PyTorch model with sections: {list(model.sections.keys())}") | |
| # Get first section's first layer's input size dynamically | |
| first_section = list(model.sections.keys())[0] | |
| first_layer = model.sections[first_section][0] | |
| input_features = first_layer[0].in_features | |
| # Create sample input tensor matching the model's expected input size | |
| sample_input = torch.randn(1, input_features) | |
| output = model(sample_input) | |
| print(f"Sample output shape: {output.shape}") | |
| # Initialize accelerator for distributed training | |
| accelerator = Accelerator() | |
| # Setup optimization components | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
| criterion = nn.CrossEntropyLoss() | |
| num_epochs = 10 | |
| # Create synthetic dataset for demonstration | |
| dataset = torch.utils.data.TensorDataset( | |
| torch.randn(100, input_features), | |
| torch.randint(0, 2, (100,)) | |
| ) | |
| train_dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=16, | |
| shuffle=True | |
| ) | |
| # Prepare for distributed training | |
| model, optimizer, train_dataloader = accelerator.prepare( | |
| model, | |
| optimizer, | |
| train_dataloader | |
| ) | |
| # Training loop | |
| for epoch in range(num_epochs): | |
| model.train() | |
| total_loss = 0 | |
| for batch_idx, (inputs, labels) in enumerate(train_dataloader): | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = criterion(outputs, labels) | |
| accelerator.backward(loss) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| avg_loss = total_loss / len(train_dataloader) | |
| print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}") | |
| if __name__ == "__main__": | |
| main() |