MogensR commited on
Commit
a41fc30
Β·
1 Parent(s): ac26af7

Create test_smoke.py

Browse files
Files changed (1) hide show
  1. tests/test_smoke.py +267 -0
tests/test_smoke.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Smoke test for two-stage video processing
4
+ THIS IS A NEW FILE - Basic end-to-end test
5
+ Tests quality profiles, frame count preservation, and basic functionality
6
+ """
7
+ import os
8
+ import sys
9
+ import cv2
10
+ import numpy as np
11
+ import tempfile
12
+ import logging
13
+ import time
14
+ from pathlib import Path
15
+
16
+ # Add project root to path
17
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18
+
19
+ from processing.two_stage.two_stage_processor import TwoStageProcessor
20
+ from models.loaders.matanyone_loader import MatAnyoneLoader
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ def create_test_video(path: str, frames: int = 30, fps: int = 30):
26
+ """Create a simple test video with a moving circle (simulating a person)"""
27
+ width, height = 640, 480
28
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
29
+ out = cv2.VideoWriter(path, fourcc, fps, (width, height))
30
+
31
+ if not out.isOpened():
32
+ raise RuntimeError(f"Failed to create test video at {path}")
33
+
34
+ for i in range(frames):
35
+ # Create frame with moving white circle on dark background
36
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
37
+ frame[:] = (30, 30, 30) # Dark gray background
38
+
39
+ # Draw a moving circle (simulating a person)
40
+ x = int(width/2 + 100 * np.sin(i * 0.2))
41
+ y = int(height/2 + 50 * np.cos(i * 0.15))
42
+ cv2.circle(frame, (x, y), 60, (255, 255, 255), -1)
43
+
44
+ # Add some variation to simulate clothing
45
+ cv2.circle(frame, (x, y-20), 20, (200, 100, 100), -1) # "shirt"
46
+
47
+ out.write(frame)
48
+
49
+ out.release()
50
+
51
+ # Verify the video was created
52
+ cap = cv2.VideoCapture(path)
53
+ actual_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
54
+ cap.release()
55
+
56
+ logger.info(f"Created test video: {path} ({actual_frames} frames)")
57
+ return actual_frames
58
+
59
+
60
+ def verify_output_video(path: str, expected_frames: int) -> bool:
61
+ """Verify output video exists and has correct frame count"""
62
+ if not os.path.exists(path):
63
+ logger.error(f"Output video not found: {path}")
64
+ return False
65
+
66
+ file_size = os.path.getsize(path)
67
+ if file_size < 1000:
68
+ logger.error(f"Output video too small: {file_size} bytes")
69
+ return False
70
+
71
+ cap = cv2.VideoCapture(path)
72
+ if not cap.isOpened():
73
+ logger.error(f"Cannot open output video: {path}")
74
+ return False
75
+
76
+ actual_frames = 0
77
+ while True:
78
+ ret, frame = cap.read()
79
+ if not ret:
80
+ break
81
+ actual_frames += 1
82
+
83
+ cap.release()
84
+
85
+ if actual_frames != expected_frames:
86
+ logger.error(f"Frame count mismatch: got {actual_frames}, expected {expected_frames}")
87
+ return False
88
+
89
+ logger.info(f"Output verified: {actual_frames} frames, {file_size:,} bytes")
90
+ return True
91
+
92
+
93
+ def test_quality_profiles():
94
+ """Test that different quality profiles produce different results"""
95
+ logger.info("="*60)
96
+ logger.info("Testing Quality Profiles")
97
+ logger.info("="*60)
98
+
99
+ with tempfile.TemporaryDirectory() as tmpdir:
100
+ tmpdir = Path(tmpdir)
101
+
102
+ # Create test video
103
+ test_video = tmpdir / "test_input.mp4"
104
+ expected_frames = create_test_video(str(test_video), frames=30, fps=30)
105
+
106
+ # Create a simple background
107
+ background = np.ones((480, 640, 3), dtype=np.uint8) * 128 # Gray
108
+ background[:240, :] = (100, 150, 200) # Blue top half
109
+
110
+ results = {}
111
+
112
+ for quality in ["speed", "balanced", "max"]:
113
+ logger.info(f"\nTesting quality mode: {quality}")
114
+ logger.info("-" * 40)
115
+
116
+ # Set quality environment variable
117
+ os.environ["BFX_QUALITY"] = quality
118
+
119
+ # Initialize processor (without models for basic test)
120
+ processor = TwoStageProcessor(
121
+ sam2_predictor=None, # Will use fallback
122
+ matanyone_model=None # Will use fallback
123
+ )
124
+
125
+ # Process video
126
+ output_path = tmpdir / f"output_{quality}.mp4"
127
+ start_time = time.time()
128
+
129
+ result, message = processor.process_full_pipeline(
130
+ video_path=str(test_video),
131
+ background=background,
132
+ output_path=str(output_path),
133
+ key_color_mode="auto",
134
+ chroma_settings=None,
135
+ progress_callback=lambda p, d: logger.debug(f"{p:.1%}: {d}"),
136
+ stop_event=None
137
+ )
138
+
139
+ process_time = time.time() - start_time
140
+
141
+ if result is None:
142
+ logger.error(f"Processing failed for {quality}: {message}")
143
+ continue
144
+
145
+ # Verify output
146
+ if verify_output_video(result, expected_frames):
147
+ results[quality] = {
148
+ "success": True,
149
+ "time": process_time,
150
+ "frames_refined": processor.frames_refined,
151
+ "total_frames": processor.total_frames_processed
152
+ }
153
+ logger.info(f"βœ“ {quality}: {process_time:.2f}s, "
154
+ f"{processor.frames_refined}/{processor.total_frames_processed} refined")
155
+ else:
156
+ results[quality] = {"success": False}
157
+ logger.error(f"βœ— {quality}: verification failed")
158
+
159
+ # Summary
160
+ logger.info("\n" + "="*60)
161
+ logger.info("SUMMARY")
162
+ logger.info("="*60)
163
+
164
+ all_passed = all(r.get("success", False) for r in results.values())
165
+
166
+ if all_passed:
167
+ # Check that quality modes are actually different
168
+ if len(results) >= 2:
169
+ times = [r["time"] for r in results.values() if "time" in r]
170
+ refined_counts = [r["frames_refined"] for r in results.values() if "frames_refined" in r]
171
+
172
+ if len(set(refined_counts)) > 1:
173
+ logger.info("βœ“ Quality profiles show different refinement counts")
174
+ else:
175
+ logger.warning("⚠ All quality profiles refined same number of frames")
176
+
177
+ if max(times) - min(times) > 0.1:
178
+ logger.info("βœ“ Quality profiles show different processing times")
179
+ else:
180
+ logger.warning("⚠ Quality profiles have similar processing times")
181
+
182
+ for quality, result in results.items():
183
+ if result.get("success"):
184
+ logger.info(f"βœ“ {quality:8s}: {result['time']:.2f}s, "
185
+ f"{result['frames_refined']}/{result['total_frames']} frames refined")
186
+ else:
187
+ logger.info(f"βœ— {quality:8s}: FAILED")
188
+
189
+ return all_passed
190
+
191
+
192
+ def test_frame_preservation():
193
+ """Test that no frames are lost during processing"""
194
+ logger.info("\n" + "="*60)
195
+ logger.info("Testing Frame Preservation")
196
+ logger.info("="*60)
197
+
198
+ with tempfile.TemporaryDirectory() as tmpdir:
199
+ tmpdir = Path(tmpdir)
200
+
201
+ # Test different frame counts
202
+ test_cases = [10, 25, 30, 60]
203
+
204
+ for frame_count in test_cases:
205
+ logger.info(f"\nTesting with {frame_count} frames...")
206
+
207
+ test_video = tmpdir / f"test_{frame_count}.mp4"
208
+ expected = create_test_video(str(test_video), frames=frame_count, fps=30)
209
+
210
+ os.environ["BFX_QUALITY"] = "speed" # Fast for this test
211
+
212
+ processor = TwoStageProcessor()
213
+ output_path = tmpdir / f"output_{frame_count}.mp4"
214
+
215
+ result, message = processor.process_full_pipeline(
216
+ video_path=str(test_video),
217
+ background=np.ones((480, 640, 3), dtype=np.uint8) * 100,
218
+ output_path=str(output_path),
219
+ key_color_mode="green",
220
+ )
221
+
222
+ if result and verify_output_video(result, expected):
223
+ logger.info(f"βœ“ {frame_count} frames: preserved correctly")
224
+ else:
225
+ logger.error(f"βœ— {frame_count} frames: FAILED")
226
+ return False
227
+
228
+ logger.info("\nβœ“ All frame preservation tests passed!")
229
+ return True
230
+
231
+
232
+ def main():
233
+ """Run all smoke tests"""
234
+ logger.info("\n" + "πŸ”₯"*20)
235
+ logger.info("BACKGROUNDFX PRO SMOKE TESTS")
236
+ logger.info("πŸ”₯"*20)
237
+
238
+ tests_passed = []
239
+
240
+ # Test 1: Quality profiles
241
+ try:
242
+ tests_passed.append(test_quality_profiles())
243
+ except Exception as e:
244
+ logger.error(f"Quality profile test crashed: {e}")
245
+ tests_passed.append(False)
246
+
247
+ # Test 2: Frame preservation
248
+ try:
249
+ tests_passed.append(test_frame_preservation())
250
+ except Exception as e:
251
+ logger.error(f"Frame preservation test crashed: {e}")
252
+ tests_passed.append(False)
253
+
254
+ # Final result
255
+ logger.info("\n" + "="*60)
256
+ if all(tests_passed):
257
+ logger.info("βœ… ALL SMOKE TESTS PASSED!")
258
+ logger.info("="*60)
259
+ return 0
260
+ else:
261
+ logger.error("❌ SOME TESTS FAILED")
262
+ logger.info("="*60)
263
+ return 1
264
+
265
+
266
+ if __name__ == "__main__":
267
+ exit(main())