Spaces:
Sleeping
Sleeping
| import unittest | |
| import os | |
| import sys | |
| import json | |
| from unittest.mock import patch, MagicMock | |
| # Add parent directory to path for imports | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| # Import modules to test | |
| from model_config import ModelConfigManager | |
| class MockDB: | |
| """Mock database for testing.""" | |
| def __init__(self): | |
| self.configs = {} | |
| def get_config(self, config_id): | |
| return self.configs.get(config_id) | |
| def add_config(self, config_id, config): | |
| self.configs[config_id] = config | |
| return config_id | |
| class TestModelConfigManager(unittest.TestCase): | |
| """Test the ModelConfigManager class.""" | |
| def setUp(self): | |
| """Set up test environment.""" | |
| self.db = MockDB() | |
| self.config_dir = "test_model_configs" | |
| # Create test directory | |
| os.makedirs(self.config_dir, exist_ok=True) | |
| # Create a test instance with the test directory | |
| self.manager = ModelConfigManager(self.db) | |
| self.manager.config_dir = self.config_dir | |
| def tearDown(self): | |
| """Clean up after tests.""" | |
| # Remove test files | |
| for filename in os.listdir(self.config_dir): | |
| file_path = os.path.join(self.config_dir, filename) | |
| if os.path.isfile(file_path): | |
| os.unlink(file_path) | |
| # Remove test directory | |
| os.rmdir(self.config_dir) | |
| def test_initialize_default_configs(self): | |
| """Test _initialize_default_configs method.""" | |
| # Check that default configs are created | |
| self.manager._initialize_default_configs() | |
| # Verify files exist | |
| for model_type in self.manager.default_configs: | |
| config_path = os.path.join(self.config_dir, f"{model_type}.json") | |
| self.assertTrue(os.path.exists(config_path)) | |
| # Verify content | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| self.assertEqual(config["name"], self.manager.default_configs[model_type]["name"]) | |
| self.assertEqual(config["description"], self.manager.default_configs[model_type]["description"]) | |
| self.assertEqual(config["parameters"], self.manager.default_configs[model_type]["parameters"]) | |
| def test_get_available_configs(self): | |
| """Test get_available_configs method.""" | |
| # Create test configs | |
| test_configs = { | |
| "test1": { | |
| "name": "Test Config 1", | |
| "description": "Test description 1", | |
| "parameters": {"temperature": 0.7} | |
| }, | |
| "test2": { | |
| "name": "Test Config 2", | |
| "description": "Test description 2", | |
| "parameters": {"temperature": 0.8} | |
| } | |
| } | |
| for config_id, config in test_configs.items(): | |
| config_path = os.path.join(self.config_dir, f"{config_id}.json") | |
| with open(config_path, "w") as f: | |
| json.dump(config, f) | |
| # Get configs | |
| configs = self.manager.get_available_configs() | |
| # Verify results | |
| self.assertEqual(len(configs), 2) | |
| # Check that IDs are added | |
| config_ids = [config["id"] for config in configs] | |
| self.assertIn("test1", config_ids) | |
| self.assertIn("test2", config_ids) | |
| # Check content | |
| for config in configs: | |
| test_config = test_configs[config["id"]] | |
| self.assertEqual(config["name"], test_config["name"]) | |
| self.assertEqual(config["description"], test_config["description"]) | |
| self.assertEqual(config["parameters"], test_config["parameters"]) | |
| def test_get_config(self): | |
| """Test get_config method.""" | |
| # Create test config | |
| config = { | |
| "name": "Test Config", | |
| "description": "Test description", | |
| "parameters": {"temperature": 0.7} | |
| } | |
| config_path = os.path.join(self.config_dir, "test.json") | |
| with open(config_path, "w") as f: | |
| json.dump(config, f) | |
| # Get config | |
| result = self.manager.get_config("test") | |
| # Verify result | |
| self.assertIsNotNone(result) | |
| self.assertEqual(result["id"], "test") | |
| self.assertEqual(result["name"], config["name"]) | |
| self.assertEqual(result["description"], config["description"]) | |
| self.assertEqual(result["parameters"], config["parameters"]) | |
| # Test non-existent config | |
| result = self.manager.get_config("nonexistent") | |
| self.assertIsNone(result) | |
| def test_add_config(self): | |
| """Test add_config method.""" | |
| # Add config | |
| name = "Test Config" | |
| description = "Test description" | |
| parameters = {"temperature": 0.7, "top_k": 50} | |
| config_id = self.manager.add_config(name, description, parameters) | |
| # Verify result | |
| self.assertIsNotNone(config_id) | |
| self.assertEqual(config_id, "test_config") | |
| # Verify file exists | |
| config_path = os.path.join(self.config_dir, f"{config_id}.json") | |
| self.assertTrue(os.path.exists(config_path)) | |
| # Verify content | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| self.assertEqual(config["name"], name) | |
| self.assertEqual(config["description"], description) | |
| self.assertEqual(config["parameters"], parameters) | |
| def test_update_config(self): | |
| """Test update_config method.""" | |
| # Create test config | |
| config = { | |
| "name": "Test Config", | |
| "description": "Test description", | |
| "parameters": {"temperature": 0.7} | |
| } | |
| config_path = os.path.join(self.config_dir, "test.json") | |
| with open(config_path, "w") as f: | |
| json.dump(config, f) | |
| # Update config | |
| new_name = "Updated Config" | |
| new_description = "Updated description" | |
| new_parameters = {"temperature": 0.8, "top_k": 60} | |
| success = self.manager.update_config("test", new_name, new_description, new_parameters) | |
| # Verify result | |
| self.assertTrue(success) | |
| # Verify content | |
| with open(config_path, "r") as f: | |
| updated_config = json.load(f) | |
| self.assertEqual(updated_config["name"], new_name) | |
| self.assertEqual(updated_config["description"], new_description) | |
| self.assertEqual(updated_config["parameters"], new_parameters) | |
| # Test updating non-existent config | |
| success = self.manager.update_config("nonexistent", "New Name", "New Description", {}) | |
| self.assertFalse(success) | |
| def test_delete_config(self): | |
| """Test delete_config method.""" | |
| # Create test config | |
| config = { | |
| "name": "Test Config", | |
| "description": "Test description", | |
| "parameters": {"temperature": 0.7} | |
| } | |
| config_path = os.path.join(self.config_dir, "test.json") | |
| with open(config_path, "w") as f: | |
| json.dump(config, f) | |
| # Delete config | |
| success = self.manager.delete_config("test") | |
| # Verify result | |
| self.assertTrue(success) | |
| self.assertFalse(os.path.exists(config_path)) | |
| # Test deleting non-existent config | |
| success = self.manager.delete_config("nonexistent") | |
| self.assertFalse(success) | |
| # Test deleting default config | |
| for model_type in self.manager.default_configs: | |
| config_path = os.path.join(self.config_dir, f"{model_type}.json") | |
| with open(config_path, "w") as f: | |
| json.dump(self.manager.default_configs[model_type], f) | |
| success = self.manager.delete_config(model_type) | |
| self.assertFalse(success) | |
| self.assertTrue(os.path.exists(config_path)) | |
| def test_apply_config_to_model_params(self): | |
| """Test apply_config_to_model_params method.""" | |
| # Create test config | |
| config = { | |
| "name": "Test Config", | |
| "description": "Test description", | |
| "parameters": { | |
| "temperature": 0.7, | |
| "top_k": 50, | |
| "top_p": 0.9 | |
| } | |
| } | |
| config_path = os.path.join(self.config_dir, "test.json") | |
| with open(config_path, "w") as f: | |
| json.dump(config, f) | |
| # Apply config | |
| model_params = { | |
| "temperature": 0.5, | |
| "max_length": 100 | |
| } | |
| result = self.manager.apply_config_to_model_params(model_params, "test") | |
| # Verify result | |
| self.assertEqual(result["temperature"], 0.7) | |
| self.assertEqual(result["top_k"], 50) | |
| self.assertEqual(result["top_p"], 0.9) | |
| self.assertEqual(result["max_length"], 100) | |
| # Test with non-existent config | |
| result = self.manager.apply_config_to_model_params(model_params, "nonexistent") | |
| self.assertEqual(result, model_params) | |
| if __name__ == '__main__': | |
| unittest.main() | |