Spaces:
Sleeping
Sleeping
devjas1
(FEAT)[Enhanced Results Widget]: Integrate advanced probability breakdown, QC, and provenance export
fe030dd
| """ | |
| Diagnostic script to inspect the weights within a PyTorch .pth file. | |
| This utility loads a model's state dictionary and prints summary statistics | |
| (mean, std, min, max) for each parameter tensor. It helps diagnose issues | |
| like corrupted weights from failed or interrupted training runs, which might | |
| result in a model producing constant, incorrect outputs. | |
| Usage: | |
| python scripts/inspect_weights.py path/to/your/model_weights.pth | |
| """ | |
| import torch | |
| import argparse | |
| import os | |
| from pathlib import Path | |
| import sys | |
| # Add project root to path to allow imports from other modules | |
| sys.path.append(str(Path(__file__).resolve().parent.parent)) | |
| def inspect_weights(file_path: str): | |
| """ | |
| Loads a model state_dict from a .pth file and prints statistics | |
| for each parameter tensor to help diagnose corrupted weights. | |
| """ | |
| if not os.path.exists(file_path): | |
| print(f"β Error: File not found at {file_path}") | |
| return | |
| print(f"π Inspecting weights for: {file_path}\n") | |
| try: | |
| # Load the state dictionary | |
| # Use weights_only=True for security and to supress the warning | |
| try: | |
| state_dict = torch.load( | |
| file_path, map_location=torch.device("cpu"), weights_only=True | |
| ) | |
| except TypeError: # Fallback for older torch versions | |
| state_dict = torch.load(file_path, map_location=torch.device("cpu")) | |
| # Handle checkpoints that save the model in a sub-dictionary | |
| if "model_state_dict" in state_dict: | |
| state_dict = state_dict["model_state_dict"] | |
| elif "model" in state_dict: | |
| state_dict = state_dict["model"] | |
| if not state_dict: | |
| print("β οΈ State dictionary is empty.") | |
| return | |
| print( | |
| f"{'Parameter Name':<40} {'Shape':<20} {'Mean':<15} {'Std Dev':<15} {'Min':<15} {'Max':<15}" | |
| ) | |
| print("-" * 120) | |
| all_stds = [] | |
| for name, param in state_dict.items(): | |
| if isinstance(param, torch.Tensor): | |
| # Ensure tensor is float for stats, but don't fail if not | |
| try: | |
| param_float = param.float() | |
| mean_val = f"{param_float.mean().item():.4e}" | |
| std_val_float = param_float.std().item() | |
| std_val = f"{std_val_float:.4e}" | |
| min_val = f"{param_float.min().item():.4e}" | |
| max_val = f"{param_float.max().item():.4e}" | |
| all_stds.append(std_val_float) | |
| except (RuntimeError, TypeError): | |
| mean_val, std_val, min_val, max_val = "N/A", "N/A", "N/A", "N/A" | |
| shape_str = str(list(param.shape)) | |
| print( | |
| f"{name:<40} {shape_str:<20} {mean_val:<15} {std_val:<15} {min_val:<15} {max_val:<15}" | |
| ) | |
| else: | |
| print(f"{name:<40} {'Non-Tensor':<20} {str(param):<60}") | |
| print("\n" + "-" * 120) | |
| print("β Inspection complete.") | |
| print("\nDiagnosis:") | |
| print( | |
| "- If you see all zeros, NaNs, or very small (e.g., e-38) uniform values, the weights file is likely corrupted." | |
| ) | |
| if all(s < 1e-6 for s in all_stds if s is not None): | |
| print( | |
| "- WARNING: All parameter standard deviations are extremely low. The model may be 'dead' and insensitive to input." | |
| ) | |
| else: | |
| print( | |
| "- The weight statistics appear varied, suggesting the file is not corrupted with zeros/NaNs." | |
| ) | |
| print( | |
| "- If the model still produces constant output, it is likely poorly trained." | |
| ) | |
| print("\nRecommendation: Retraining the model is the correct solution.") | |
| except Exception as e: | |
| print(f"β An error occurred while inspecting the weights file: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="Inspect PyTorch model weights in a .pth file." | |
| ) | |
| parser.add_argument( | |
| "file_path", type=str, help="Path to the .pth model weights file." | |
| ) | |
| args = parser.parse_args() | |
| inspect_weights(args.file_path) | |