Spaces:
Configuration error
Configuration error
| import pytest | |
| import aiohttp | |
| from aiohttp import ClientResponse | |
| import itertools | |
| import os | |
| from unittest.mock import AsyncMock, patch, MagicMock | |
| from model_filemanager import download_model, validate_model_subdirectory, track_download_progress, create_model_path, check_file_exists, DownloadStatusType, DownloadModelStatus, validate_filename | |
| class AsyncIteratorMock: | |
| """ | |
| A mock class that simulates an asynchronous iterator. | |
| This is used to mimic the behavior of aiohttp's content iterator. | |
| """ | |
| def __init__(self, seq): | |
| # Convert the input sequence into an iterator | |
| self.iter = iter(seq) | |
| def __aiter__(self): | |
| # This method is called when 'async for' is used | |
| return self | |
| async def __anext__(self): | |
| # This method is called for each iteration in an 'async for' loop | |
| try: | |
| return next(self.iter) | |
| except StopIteration: | |
| # This is the asynchronous equivalent of StopIteration | |
| raise StopAsyncIteration | |
| class ContentMock: | |
| """ | |
| A mock class that simulates the content attribute of an aiohttp ClientResponse. | |
| This class provides the iter_chunked method which returns an async iterator of chunks. | |
| """ | |
| def __init__(self, chunks): | |
| # Store the chunks that will be returned by the iterator | |
| self.chunks = chunks | |
| def iter_chunked(self, chunk_size): | |
| # This method mimics aiohttp's content.iter_chunked() | |
| # For simplicity in testing, we ignore chunk_size and just return our predefined chunks | |
| return AsyncIteratorMock(self.chunks) | |
| async def test_download_model_success(): | |
| mock_response = AsyncMock(spec=aiohttp.ClientResponse) | |
| mock_response.status = 200 | |
| mock_response.headers = {'Content-Length': '1000'} | |
| # Create a mock for content that returns an async iterator directly | |
| chunks = [b'a' * 500, b'b' * 300, b'c' * 200] | |
| mock_response.content = ContentMock(chunks) | |
| mock_make_request = AsyncMock(return_value=mock_response) | |
| mock_progress_callback = AsyncMock() | |
| # Mock file operations | |
| mock_open = MagicMock() | |
| mock_file = MagicMock() | |
| mock_open.return_value.__enter__.return_value = mock_file | |
| time_values = itertools.count(0, 0.1) | |
| with patch('model_filemanager.create_model_path', return_value=('models/checkpoints/model.sft', 'checkpoints/model.sft')), \ | |
| patch('model_filemanager.check_file_exists', return_value=None), \ | |
| patch('builtins.open', mock_open), \ | |
| patch('time.time', side_effect=time_values): # Simulate time passing | |
| result = await download_model( | |
| mock_make_request, | |
| 'model.sft', | |
| 'http://example.com/model.sft', | |
| 'checkpoints', | |
| mock_progress_callback | |
| ) | |
| # Assert the result | |
| assert isinstance(result, DownloadModelStatus) | |
| assert result.message == 'Successfully downloaded model.sft' | |
| assert result.status == 'completed' | |
| assert result.already_existed is False | |
| # Check progress callback calls | |
| assert mock_progress_callback.call_count >= 3 # At least start, one progress update, and completion | |
| # Check initial call | |
| mock_progress_callback.assert_any_call( | |
| 'checkpoints/model.sft', | |
| DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.sft", False) | |
| ) | |
| # Check final call | |
| mock_progress_callback.assert_any_call( | |
| 'checkpoints/model.sft', | |
| DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.sft", False) | |
| ) | |
| # Verify file writing | |
| mock_file.write.assert_any_call(b'a' * 500) | |
| mock_file.write.assert_any_call(b'b' * 300) | |
| mock_file.write.assert_any_call(b'c' * 200) | |
| # Verify request was made | |
| mock_make_request.assert_called_once_with('http://example.com/model.sft') | |
| async def test_download_model_url_request_failure(): | |
| # Mock dependencies | |
| mock_response = AsyncMock(spec=ClientResponse) | |
| mock_response.status = 404 # Simulate a "Not Found" error | |
| mock_get = AsyncMock(return_value=mock_response) | |
| mock_progress_callback = AsyncMock() | |
| # Mock the create_model_path function | |
| with patch('model_filemanager.create_model_path', return_value=('/mock/path/model.safetensors', 'mock/path/model.safetensors')): | |
| # Mock the check_file_exists function to return None (file doesn't exist) | |
| with patch('model_filemanager.check_file_exists', return_value=None): | |
| # Call the function | |
| result = await download_model( | |
| mock_get, | |
| 'model.safetensors', | |
| 'http://example.com/model.safetensors', | |
| 'mock_directory', | |
| mock_progress_callback | |
| ) | |
| # Assert the expected behavior | |
| assert isinstance(result, DownloadModelStatus) | |
| assert result.status == 'error' | |
| assert result.message == 'Failed to download model.safetensors. Status code: 404' | |
| assert result.already_existed is False | |
| # Check that progress_callback was called with the correct arguments | |
| mock_progress_callback.assert_any_call( | |
| 'mock_directory/model.safetensors', | |
| DownloadModelStatus( | |
| status=DownloadStatusType.PENDING, | |
| progress_percentage=0, | |
| message='Starting download of model.safetensors', | |
| already_existed=False | |
| ) | |
| ) | |
| mock_progress_callback.assert_called_with( | |
| 'mock_directory/model.safetensors', | |
| DownloadModelStatus( | |
| status=DownloadStatusType.ERROR, | |
| progress_percentage=0, | |
| message='Failed to download model.safetensors. Status code: 404', | |
| already_existed=False | |
| ) | |
| ) | |
| # Verify that the get method was called with the correct URL | |
| mock_get.assert_called_once_with('http://example.com/model.safetensors') | |
| async def test_download_model_invalid_model_subdirectory(): | |
| mock_make_request = AsyncMock() | |
| mock_progress_callback = AsyncMock() | |
| result = await download_model( | |
| mock_make_request, | |
| 'model.sft', | |
| 'http://example.com/model.sft', | |
| '../bad_path', | |
| mock_progress_callback | |
| ) | |
| # Assert the result | |
| assert isinstance(result, DownloadModelStatus) | |
| assert result.message == 'Invalid model subdirectory' | |
| assert result.status == 'error' | |
| assert result.already_existed is False | |
| # For create_model_path function | |
| def test_create_model_path(tmp_path, monkeypatch): | |
| mock_models_dir = tmp_path / "models" | |
| monkeypatch.setattr('folder_paths.models_dir', str(mock_models_dir)) | |
| model_name = "test_model.sft" | |
| model_directory = "test_dir" | |
| file_path, relative_path = create_model_path(model_name, model_directory, mock_models_dir) | |
| assert file_path == str(mock_models_dir / model_directory / model_name) | |
| assert relative_path == f"{model_directory}/{model_name}" | |
| assert os.path.exists(os.path.dirname(file_path)) | |
| async def test_check_file_exists_when_file_exists(tmp_path): | |
| file_path = tmp_path / "existing_model.sft" | |
| file_path.touch() # Create an empty file | |
| mock_callback = AsyncMock() | |
| result = await check_file_exists(str(file_path), "existing_model.sft", mock_callback, "test/existing_model.sft") | |
| assert result is not None | |
| assert result.status == "completed" | |
| assert result.message == "existing_model.sft already exists" | |
| assert result.already_existed is True | |
| mock_callback.assert_called_once_with( | |
| "test/existing_model.sft", | |
| DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.sft already exists", already_existed=True) | |
| ) | |
| async def test_check_file_exists_when_file_does_not_exist(tmp_path): | |
| file_path = tmp_path / "non_existing_model.sft" | |
| mock_callback = AsyncMock() | |
| result = await check_file_exists(str(file_path), "non_existing_model.sft", mock_callback, "test/non_existing_model.sft") | |
| assert result is None | |
| mock_callback.assert_not_called() | |
| async def test_track_download_progress_no_content_length(): | |
| mock_response = AsyncMock(spec=aiohttp.ClientResponse) | |
| mock_response.headers = {} # No Content-Length header | |
| mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 500, b'b' * 500]) | |
| mock_callback = AsyncMock() | |
| mock_open = MagicMock(return_value=MagicMock()) | |
| with patch('builtins.open', mock_open): | |
| result = await track_download_progress( | |
| mock_response, '/mock/path/model.sft', 'model.sft', | |
| mock_callback, 'models/model.sft', interval=0.1 | |
| ) | |
| assert result.status == "completed" | |
| # Check that progress was reported even without knowing the total size | |
| mock_callback.assert_any_call( | |
| 'models/model.sft', | |
| DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.sft", already_existed=False) | |
| ) | |
| async def test_track_download_progress_interval(): | |
| mock_response = AsyncMock(spec=aiohttp.ClientResponse) | |
| mock_response.headers = {'Content-Length': '1000'} | |
| mock_response.content.iter_chunked.return_value = AsyncIteratorMock([b'a' * 100] * 10) | |
| mock_callback = AsyncMock() | |
| mock_open = MagicMock(return_value=MagicMock()) | |
| # Create a mock time function that returns incremental float values | |
| mock_time = MagicMock() | |
| mock_time.side_effect = [i * 0.5 for i in range(30)] # This should be enough for 10 chunks | |
| with patch('builtins.open', mock_open), \ | |
| patch('time.time', mock_time): | |
| await track_download_progress( | |
| mock_response, '/mock/path/model.sft', 'model.sft', | |
| mock_callback, 'models/model.sft', interval=1.0 | |
| ) | |
| # Print out the actual call count and the arguments of each call for debugging | |
| print(f"mock_callback was called {mock_callback.call_count} times") | |
| for i, call in enumerate(mock_callback.call_args_list): | |
| args, kwargs = call | |
| print(f"Call {i + 1}: {args[1].status}, Progress: {args[1].progress_percentage:.2f}%") | |
| # Assert that progress was updated at least 3 times (start, at least one interval, and end) | |
| assert mock_callback.call_count >= 3, f"Expected at least 3 calls, but got {mock_callback.call_count}" | |
| # Verify the first and last calls | |
| first_call = mock_callback.call_args_list[0] | |
| assert first_call[0][1].status == "in_progress" | |
| # Allow for some initial progress, but it should be less than 50% | |
| assert 0 <= first_call[0][1].progress_percentage < 50, f"First call progress was {first_call[0][1].progress_percentage}%" | |
| last_call = mock_callback.call_args_list[-1] | |
| assert last_call[0][1].status == "completed" | |
| assert last_call[0][1].progress_percentage == 100 | |
| def test_valid_subdirectory(): | |
| assert validate_model_subdirectory("valid-model123") is True | |
| def test_subdirectory_too_long(): | |
| assert validate_model_subdirectory("a" * 51) is False | |
| def test_subdirectory_with_double_dots(): | |
| assert validate_model_subdirectory("model/../unsafe") is False | |
| def test_subdirectory_with_slash(): | |
| assert validate_model_subdirectory("model/unsafe") is False | |
| def test_subdirectory_with_special_characters(): | |
| assert validate_model_subdirectory("model@unsafe") is False | |
| def test_subdirectory_with_underscore_and_dash(): | |
| assert validate_model_subdirectory("valid_model-name") is True | |
| def test_empty_subdirectory(): | |
| assert validate_model_subdirectory("") is False | |
| def test_validate_filename(filename, expected): | |
| assert validate_filename(filename) == expected | |