hello10000 commited on
Commit
034aa20
·
1 Parent(s): de21991

update files

Browse files
.gitignore CHANGED
@@ -217,13 +217,15 @@ outputs
217
  yolo11n.pt
218
  yolo11s.pt
219
  best.pt
 
220
 
221
  .claude
222
- best.pt
223
 
224
  runs
225
  .idea
226
  working_dir
227
  data
228
  upload_to_huggingface.py
229
- resources/best.pt
 
 
 
217
  yolo11n.pt
218
  yolo11s.pt
219
  best.pt
220
+ **/best.pt
221
 
222
  .claude
 
223
 
224
  runs
225
  .idea
226
  working_dir
227
  data
228
  upload_to_huggingface.py
229
+ resources/best.pt
230
+ resources/model_version.json
231
+ .web
.streamlit/config.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [server]
2
+ maxUploadSize=4096
3
+ # 4GB as maximum
app.py CHANGED
@@ -3,6 +3,7 @@ os.environ["HOME"] = "/tmp"
3
  os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"
4
  os.makedirs("/tmp/.streamlit", exist_ok=True)
5
 
 
6
  import shutil
7
  import tempfile
8
  from pathlib import Path
@@ -17,9 +18,75 @@ def main():
17
  page_title="Sora Watermark Cleaner", page_icon="🎬", layout="centered"
18
  )
19
 
20
- st.title("🎬 Sora Watermark Cleaner")
21
- st.markdown("Remove watermarks from Sora-generated videos with ease")
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Initialize SoraWM
24
  if "sora_wm" not in st.session_state:
25
  with st.spinner("Loading AI models..."):
@@ -27,84 +94,232 @@ def main():
27
 
28
  st.markdown("---")
29
 
30
- # File uploader
31
- uploaded_file = st.file_uploader(
32
- "Upload your video",
33
- type=["mp4", "avi", "mov", "mkv"],
34
- help="Select a video file to remove watermarks",
35
  )
36
 
37
- if uploaded_file is not None:
38
- # Display video info
39
- st.success(f"✅ Uploaded: {uploaded_file.name}")
40
- st.video(uploaded_file)
41
-
42
- # Process button
43
- if st.button("🚀 Remove Watermark", type="primary", use_container_width=True):
44
- with tempfile.TemporaryDirectory() as tmp_dir:
45
- tmp_path = Path(tmp_dir)
46
-
47
- # Save uploaded file
48
- input_path = tmp_path / uploaded_file.name
49
- with open(input_path, "wb") as f:
50
- f.write(uploaded_file.read())
51
-
52
- # Process video
53
- output_path = tmp_path / f"cleaned_{uploaded_file.name}"
54
-
55
- try:
56
- # Create progress bar and status text
57
- progress_bar = st.progress(0)
58
- status_text = st.empty()
59
-
60
- def update_progress(progress: int):
61
- progress_bar.progress(progress / 100)
62
- if progress < 50:
63
- status_text.text(f"🔍 Detecting watermarks... {progress}%")
64
- elif progress < 95:
65
- status_text.text(f"🧹 Removing watermarks... {progress}%")
66
- else:
67
- status_text.text(f"🎵 Merging audio... {progress}%")
68
-
69
- # Run the watermark removal with progress callback
70
- st.session_state.sora_wm.run(
71
- input_path, output_path, progress_callback=update_progress
72
- )
73
-
74
- # Complete the progress bar
75
- progress_bar.progress(100)
76
- status_text.text(" Processing complete!")
77
-
78
- st.success("✅ Watermark removed successfully!")
79
-
80
- # Display result
81
- st.markdown("### Result")
82
- st.video(str(output_path))
83
-
84
- # Download button
85
- with open(output_path, "rb") as f:
86
- st.download_button(
87
- label="⬇️ Download Cleaned Video",
88
- data=f,
89
- file_name=f"cleaned_{uploaded_file.name}",
90
- mime="video/mp4",
91
- use_container_width=True,
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
 
94
- except Exception as e:
95
- st.error(f" Error processing video: {str(e)}")
 
96
 
97
- # Footer
98
- st.markdown("---")
99
- st.markdown(
100
- """
101
- <div style='text-align: center'>
102
- <p>Built with ❤️ using Streamlit and AI</p>
103
- <p><a href='https://github.com/linkedlist771/SoraWatermarkCleaner'>GitHub Repository</a></p>
104
- </div>
105
- """,
106
- unsafe_allow_html=True,
107
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  if __name__ == "__main__":
 
3
  os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"
4
  os.makedirs("/tmp/.streamlit", exist_ok=True)
5
 
6
+
7
  import shutil
8
  import tempfile
9
  from pathlib import Path
 
18
  page_title="Sora Watermark Cleaner", page_icon="🎬", layout="centered"
19
  )
20
 
21
+ # Header section with improved layout
22
+ st.markdown(
23
+ """
24
+ <div style='text-align: center; padding: 1rem 0;'>
25
+ <h1 style='margin-bottom: 0.5rem;'>
26
+ 🎬 Sora Watermark Cleaner
27
+ </h1>
28
+ <p style='font-size: 1.2rem; color: #666; margin-bottom: 1rem;'>
29
+ Remove watermarks from Sora-generated videos with AI-powered precision
30
+ </p>
31
+ </div>
32
+ """,
33
+ unsafe_allow_html=True,
34
+ )
35
+
36
+ # # Feature badges
37
+ # col1, col2, col3 = st.columns(3)
38
+ # with col1:
39
+ # st.markdown(
40
+ # """
41
+ # <div style='text-align: center; padding: 0.8rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
42
+ # border-radius: 10px; color: white;'>
43
+ # <div style='font-size: 1.5rem;'>⚡</div>
44
+ # <div style='font-weight: bold;'>Fast Processing</div>
45
+ # <div style='font-size: 0.85rem; opacity: 0.9;'>GPU Accelerated</div>
46
+ # </div>
47
+ # """,
48
+ # unsafe_allow_html=True,
49
+ # )
50
+ # with col2:
51
+ # st.markdown(
52
+ # """
53
+ # <div style='text-align: center; padding: 0.8rem; background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
54
+ # border-radius: 10px; color: white;'>
55
+ # <div style='font-size: 1.5rem;'>🎯</div>
56
+ # <div style='font-weight: bold;'>High Precision</div>
57
+ # <div style='font-size: 0.85rem; opacity: 0.9;'>AI-Powered</div>
58
+ # </div>
59
+ # """,
60
+ # unsafe_allow_html=True,
61
+ # )
62
+ # with col3:
63
+ # st.markdown(
64
+ # """
65
+ # <div style='text-align: center; padding: 0.8rem; background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%);
66
+ # border-radius: 10px; color: white;'>
67
+ # <div style='font-size: 1.5rem;'>📦</div>
68
+ # <div style='font-weight: bold;'>Batch Support</div>
69
+ # <div style='font-size: 0.85rem; opacity: 0.9;'>Process Multiple</div>
70
+ # </div>
71
+ # """,
72
+ # unsafe_allow_html=True,
73
+ # )
74
+
75
+ # Footer info
76
+ st.markdown(
77
+ """
78
+ <div style='text-align: center; padding: 1rem 0; margin-top: 1rem;'>
79
+ <p style='color: #888; font-size: 0.9rem;'>
80
+ Built with ❤️ using Streamlit and AI |
81
+ <a href='https://github.com/linkedlist771/SoraWatermarkCleaner'
82
+ target='_blank' style='color: #667eea; text-decoration: none;'>
83
+ ⭐ Star on GitHub
84
+ </a>
85
+ </p>
86
+ </div>
87
+ """,
88
+ unsafe_allow_html=True,
89
+ )
90
  # Initialize SoraWM
91
  if "sora_wm" not in st.session_state:
92
  with st.spinner("Loading AI models..."):
 
94
 
95
  st.markdown("---")
96
 
97
+ # Mode selection
98
+ mode = st.radio(
99
+ "Select input mode:",
100
+ ["📁 Upload Video File", "🗂️ Process Folder"],
101
+ horizontal=True,
102
  )
103
 
104
+ if mode == "📁 Upload Video File":
105
+ # File uploader
106
+ uploaded_file = st.file_uploader(
107
+ "Upload your video",
108
+ type=["mp4", "avi", "mov", "mkv"],
109
+ accept_multiple_files=False,
110
+ help="Select a video file to remove watermark",
111
+ )
112
+
113
+ if uploaded_file:
114
+ # Clear previous processed video if a new file is uploaded
115
+ if "current_file_name" not in st.session_state or st.session_state.current_file_name != uploaded_file.name:
116
+ st.session_state.current_file_name = uploaded_file.name
117
+ if "processed_video_data" in st.session_state:
118
+ del st.session_state.processed_video_data
119
+ if "processed_video_path" in st.session_state:
120
+ del st.session_state.processed_video_path
121
+ if "processed_video_name" in st.session_state:
122
+ del st.session_state.processed_video_name
123
+
124
+ # Display video info
125
+ st.success(f"✅ Uploaded: {uploaded_file.name}")
126
+
127
+ # Create two columns for before/after comparison
128
+ col_left, col_right = st.columns(2)
129
+
130
+ with col_left:
131
+ st.markdown("### 📥 Original Video")
132
+ st.video(uploaded_file)
133
+
134
+ with col_right:
135
+ st.markdown("### 🎬 Processed Video")
136
+ # Placeholder for processed video
137
+ if "processed_video_data" not in st.session_state:
138
+ st.info("Click 'Remove Watermark' to process the video")
139
+ else:
140
+ st.video(st.session_state.processed_video_data)
141
+
142
+ # Process button
143
+ if st.button("🚀 Remove Watermark", type="primary", use_container_width=True):
144
+ with tempfile.TemporaryDirectory() as tmp_dir:
145
+ tmp_path = Path(tmp_dir)
146
+
147
+ try:
148
+ # Create progress bar and status text
149
+ progress_bar = st.progress(0)
150
+ status_text = st.empty()
151
+
152
+ def update_progress(progress: int):
153
+ progress_bar.progress(progress / 100)
154
+ if progress < 50:
155
+ status_text.text(f"🔍 Detecting watermarks... {progress}%")
156
+ elif progress < 95:
157
+ status_text.text(f"🧹 Removing watermarks... {progress}%")
158
+ else:
159
+ status_text.text(f"🎵 Merging audio... {progress}%")
160
+
161
+ # Single file processing
162
+ input_path = tmp_path / uploaded_file.name
163
+ with open(input_path, "wb") as f:
164
+ f.write(uploaded_file.read())
165
+
166
+ output_path = tmp_path / f"cleaned_{uploaded_file.name}"
167
+
168
+ st.session_state.sora_wm.run(
169
+ input_path, output_path, progress_callback=update_progress
170
  )
171
 
172
+ progress_bar.progress(100)
173
+ status_text.text(" Processing complete!")
174
+ st.success("✅ Watermark removed successfully!")
175
 
176
+ # Store processed video path and read video data
177
+ with open(output_path, "rb") as f:
178
+ video_data = f.read()
179
+
180
+ st.session_state.processed_video_path = output_path
181
+ st.session_state.processed_video_data = video_data
182
+ st.session_state.processed_video_name = f"cleaned_{uploaded_file.name}"
183
+
184
+ # Rerun to show the video in the right column
185
+ st.rerun()
186
+
187
+ except Exception as e:
188
+ st.error(f"❌ Error processing video: {str(e)}")
189
+
190
+ # Download button (show only if video is processed)
191
+ if "processed_video_data" in st.session_state:
192
+ st.download_button(
193
+ label="⬇️ Download Cleaned Video",
194
+ data=st.session_state.processed_video_data,
195
+ file_name=st.session_state.processed_video_name,
196
+ mime="video/mp4",
197
+ use_container_width=True,
198
+ )
199
+
200
+ else: # Folder mode
201
+ st.info("💡 Drag and drop your video folder here, or click to browse and select multiple video files")
202
+
203
+ # File uploader for multiple files (supports folder drag & drop)
204
+ uploaded_files = st.file_uploader(
205
+ "Upload videos from folder",
206
+ type=["mp4", "avi", "mov", "mkv"],
207
+ accept_multiple_files=True,
208
+ help="You can drag & drop an entire folder here, or select multiple video files",
209
+ key="folder_uploader"
210
+ )
211
+
212
+ if uploaded_files:
213
+ # Display uploaded files info
214
+ video_count = len(uploaded_files)
215
+ st.success(f"✅ {video_count} video file(s) uploaded")
216
+
217
+ # Show file list in an expander
218
+ with st.expander("📋 View uploaded files", expanded=False):
219
+ for i, file in enumerate(uploaded_files, 1):
220
+ file_size_mb = file.size / (1024 * 1024)
221
+ st.text(f"{i}. {file.name} ({file_size_mb:.2f} MB)")
222
+
223
+ # Process button
224
+ if st.button("🚀 Process All Videos", type="primary", use_container_width=True):
225
+ with tempfile.TemporaryDirectory() as tmp_dir:
226
+ tmp_path = Path(tmp_dir)
227
+ input_folder = tmp_path / "input"
228
+ output_folder = tmp_path / "output"
229
+ input_folder.mkdir(exist_ok=True)
230
+ output_folder.mkdir(exist_ok=True)
231
+
232
+ try:
233
+ # Save all uploaded files to temp folder
234
+ status_text = st.empty()
235
+ status_text.text("📥 Saving uploaded files...")
236
+
237
+ for uploaded_file in uploaded_files:
238
+ # Preserve folder structure if file.name contains subdirectories
239
+ file_path = input_folder / uploaded_file.name
240
+ file_path.parent.mkdir(parents=True, exist_ok=True)
241
+ with open(file_path, "wb") as f:
242
+ f.write(uploaded_file.read())
243
+
244
+ # Create progress tracking
245
+ progress_bar = st.progress(0)
246
+ current_file_text = st.empty()
247
+ processed_count = 0
248
+
249
+ def update_progress(progress: int):
250
+ # Calculate overall progress
251
+ overall_progress = (processed_count * 100 + progress) / video_count / 100
252
+ progress_bar.progress(overall_progress)
253
+
254
+ if progress < 50:
255
+ current_file_text.text(f"🔍 Processing file {processed_count + 1}/{video_count}: Detecting watermarks... {progress}%")
256
+ elif progress < 95:
257
+ current_file_text.text(f"🧹 Processing file {processed_count + 1}/{video_count}: Removing watermarks... {progress}%")
258
+ else:
259
+ current_file_text.text(f"🎵 Processing file {processed_count + 1}/{video_count}: Merging audio... {progress}%")
260
+
261
+ # Process each video file
262
+ for video_file in input_folder.rglob("*"):
263
+ if video_file.is_file() and video_file.suffix.lower() in [".mp4", ".avi", ".mov", ".mkv"]:
264
+ # Determine output path maintaining folder structure
265
+ rel_path = video_file.relative_to(input_folder)
266
+ output_path = output_folder / rel_path.parent / f"cleaned_{rel_path.name}"
267
+ output_path.parent.mkdir(parents=True, exist_ok=True)
268
+
269
+ # Process the video
270
+ st.session_state.sora_wm.run(
271
+ video_file, output_path, progress_callback=update_progress
272
+ )
273
+ processed_count += 1
274
+
275
+ progress_bar.progress(100)
276
+ current_file_text.text("✅ All videos processed!")
277
+ st.success(f"✅ {video_count} video(s) processed successfully!")
278
+
279
+ # Create download option for processed videos
280
+ st.markdown("### 📦 Download Processed Videos")
281
+
282
+ # Store processed files info in session state
283
+ if "batch_processed_files" not in st.session_state:
284
+ st.session_state.batch_processed_files = []
285
+
286
+ st.session_state.batch_processed_files.clear()
287
+
288
+ for processed_file in output_folder.rglob("*"):
289
+ if processed_file.is_file():
290
+ with open(processed_file, "rb") as f:
291
+ video_data = f.read()
292
+ rel_path = processed_file.relative_to(output_folder)
293
+ st.session_state.batch_processed_files.append({
294
+ "name": str(rel_path),
295
+ "data": video_data
296
+ })
297
+
298
+ st.rerun()
299
+
300
+ except Exception as e:
301
+ st.error(f"❌ Error processing videos: {str(e)}")
302
+ import traceback
303
+ st.error(f"Details: {traceback.format_exc()}")
304
+
305
+ # Show download buttons for processed files
306
+ if "batch_processed_files" in st.session_state and st.session_state.batch_processed_files:
307
+ st.markdown("---")
308
+ st.markdown("### ⬇️ Download Processed Videos")
309
+
310
+ for file_info in st.session_state.batch_processed_files:
311
+ col1, col2 = st.columns([3, 1])
312
+ with col1:
313
+ st.text(f"📹 {file_info['name']}")
314
+ with col2:
315
+ st.download_button(
316
+ label="⬇️ Download",
317
+ data=file_info['data'],
318
+ file_name=file_info['name'],
319
+ mime="video/mp4",
320
+ key=f"download_{file_info['name']}",
321
+ use_container_width=True
322
+ )
323
 
324
 
325
  if __name__ == "__main__":
datasets/make_yolo_images.py CHANGED
@@ -2,48 +2,58 @@ from pathlib import Path
2
 
3
  import cv2
4
  from tqdm import tqdm
5
-
6
  from sorawm.configs import ROOT
7
 
8
  videos_dir = ROOT / "videos"
9
  datasets_dir = ROOT / "datasets"
10
  images_dir = datasets_dir / "images"
11
  images_dir.mkdir(exist_ok=True, parents=True)
 
 
12
 
13
  if __name__ == "__main__":
14
- fps_save_interval = 1 # Save every 1th frame
15
 
16
- idx = 0
 
 
 
17
  for video_path in tqdm(list(videos_dir.rglob("*.mp4"))):
18
  # Open the video file
19
  cap = cv2.VideoCapture(str(video_path))
20
-
21
  if not cap.isOpened():
22
  print(f"Error opening video: {video_path}")
23
  continue
24
 
25
  frame_count = 0
26
 
27
- while True:
28
- ret, frame = cap.read()
29
-
30
- # Break if no more frames
31
- if not ret:
32
- break
33
-
34
- # Save frame at the specified interval
35
- if frame_count % fps_save_interval == 0:
36
- # Create filename: image_idx_framecount.jpg
37
- image_filename = f"image_{idx:06d}_frame_{frame_count:06d}.jpg"
38
- image_path = images_dir / image_filename
39
-
40
- # Save the frame
41
- cv2.imwrite(str(image_path), frame)
42
-
43
- frame_count += 1
44
-
45
- # Release the video capture object
46
- cap.release()
47
- idx += 1
48
-
49
- print(f"Processed {idx} videos, extracted frames saved to {images_dir}")
 
 
 
 
 
 
2
 
3
  import cv2
4
  from tqdm import tqdm
5
+ from sorawm.watermark_detector import SoraWaterMarkDetector
6
  from sorawm.configs import ROOT
7
 
8
  videos_dir = ROOT / "videos"
9
  datasets_dir = ROOT / "datasets"
10
  images_dir = datasets_dir / "images"
11
  images_dir.mkdir(exist_ok=True, parents=True)
12
+ detector = SoraWaterMarkDetector()
13
+
14
 
15
  if __name__ == "__main__":
16
+ fps_save_interval = 1 # Save every 5th frame
17
 
18
+ video_idx = 0
19
+ image_idx = 0 # 全局图片索引
20
+ total_failed = 0 # 检测失败的总数
21
+
22
  for video_path in tqdm(list(videos_dir.rglob("*.mp4"))):
23
  # Open the video file
24
  cap = cv2.VideoCapture(str(video_path))
25
+ video_name = video_path.name
26
  if not cap.isOpened():
27
  print(f"Error opening video: {video_path}")
28
  continue
29
 
30
  frame_count = 0
31
 
32
+ try:
33
+ while True:
34
+ ret, frame = cap.read()
35
+
36
+ # Break if no more frames
37
+ if not ret:
38
+ break
39
+
40
+ # Save frame at the specified interval
41
+ if frame_count % fps_save_interval == 0:
42
+ if not detector.detect(frame)["detected"]:
43
+ # Create filename: image_idx_framecount.jpg
44
+ image_filename = f"{video_name}_failed_image_frame_{frame_count:06d}.jpg"
45
+ image_path = images_dir / image_filename
46
+ # Save the frame
47
+ cv2.imwrite(str(image_path), frame)
48
+ image_idx += 1
49
+ total_failed += 1
50
+
51
+ frame_count += 1
52
+
53
+ finally:
54
+ # Release the video capture object
55
+ cap.release()
56
+
57
+ video_idx += 1
58
+
59
+ print(f"Processed {video_idx} videos, extracted {total_failed} failed detection frames to {images_dir}")
mds/reward.md ADDED
@@ -0,0 +1 @@
 
 
1
+ ![](../assests/wechat_reward.jpg)
model_version.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"sha256": "79b44170111bd206d4964966b3b35adef1b3b15e7acf6427a95d35a2c715f987"}
pyproject.toml CHANGED
@@ -12,6 +12,7 @@ dependencies = [
12
  "fastapi==0.108.0",
13
  "ffmpeg-python>=0.2.0",
14
  "fire>=0.7.1",
 
15
  "httpx>=0.28.1",
16
  "huggingface-hub>=0.35.3",
17
  "jupyter>=1.1.1",
 
12
  "fastapi==0.108.0",
13
  "ffmpeg-python>=0.2.0",
14
  "fire>=0.7.1",
15
+ "greenlet>=3.2.4",
16
  "httpx>=0.28.1",
17
  "huggingface-hub>=0.35.3",
18
  "jupyter>=1.1.1",
resources/19700121_1645_68e0a027836c8191a50bea3717ea7485.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b7dffeb423765b5fa4ff4dc3571c5c83d45baee0c6abb86d512733fa75b27a72
3
- size 4336615
 
 
 
 
resources/53abf3fd-11a9-4dd7-a348-34920775f8ad.png DELETED

Git LFS Details

  • SHA256: 33eeb15039b0a661b1d589b155ec30bd79afe7fcd416704197f7371b77bb75a0
  • Pointer size: 131 Bytes
  • Size of remote file: 231 kB
resources/app.png DELETED

Git LFS Details

  • SHA256: 110fa5105f4e9eaa43fc63ef4b1b26ba392e77d3f2827ed3e9005a874c73f9e0
  • Pointer size: 132 Bytes
  • Size of remote file: 4.43 MB
resources/dog_vs_sam.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bb22d5365c17e2bb73a9bed0598e057fc836ef60aa997982fea885e53e73283b
3
- size 5603713
 
 
 
 
resources/first_frame.json DELETED
The diff for this file is too large to render. See raw diff
 
resources/first_frame.png DELETED

Git LFS Details

  • SHA256: 56336473668766f3cbd52056669a8cd0142b001a6be63108cbd162724fb59962
  • Pointer size: 131 Bytes
  • Size of remote file: 599 kB
resources/puppies.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7dd8a9d43d43cd5b901946c7ba14e8018e18d47e75f0f774157327ba7e9181c5
3
- size 10866692
 
 
 
 
resources/sora_watermark_removed.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a7c1776fdb29087a57c73a7d66e02684f7c32f5245aa442cb2b76c64c5f9f7a1
3
- size 1570689
 
 
 
 
resources/watermark_template.png DELETED
Binary file (54.1 kB)
 
sorawm/configs.py CHANGED
@@ -8,6 +8,9 @@ WATER_MARK_TEMPLATE_IMAGE_PATH = RESOURCES_DIR / "watermark_template.png"
8
 
9
  WATER_MARK_DETECT_YOLO_WEIGHTS = RESOURCES_DIR / "best.pt"
10
 
 
 
 
11
 
12
  OUTPUT_DIR = Path("/tmp/output")
13
  OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
 
8
 
9
  WATER_MARK_DETECT_YOLO_WEIGHTS = RESOURCES_DIR / "best.pt"
10
 
11
+ WATER_MARK_DETECT_YOLO_WEIGHTS_HASH_JSON = RESOURCES_DIR / "model_version.json"
12
+
13
+
14
 
15
  OUTPUT_DIR = Path("/tmp/output")
16
  OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
sorawm/core.py CHANGED
@@ -15,12 +15,39 @@ from sorawm.utils.imputation_utils import (
15
  find_idxs_interval,
16
  )
17
 
 
18
 
19
  class SoraWM:
20
  def __init__(self):
21
  self.detector = SoraWaterMarkDetector()
22
  self.cleaner = WaterMarkCleaner()
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def run(
25
  self,
26
  input_video_path: Path,
@@ -62,7 +89,7 @@ class SoraWM:
62
  .run_async(pipe_stdin=True)
63
  )
64
 
65
- frame_and_mask = {}
66
  detect_missed = []
67
  bbox_centers = []
68
  bboxes = []
@@ -75,13 +102,13 @@ class SoraWM:
75
  ):
76
  detection_result = self.detector.detect(frame)
77
  if detection_result["detected"]:
78
- frame_and_mask[idx] = {"frame": frame, "bbox": detection_result["bbox"]}
79
  x1, y1, x2, y2 = detection_result["bbox"]
80
  bbox_centers.append((int((x1 + x2) / 2), int((y1 + y2) / 2)))
81
  bboxes.append((x1, y1, x2, y2))
82
 
83
  else:
84
- frame_and_mask[idx] = {"frame": frame, "bbox": None}
85
  detect_missed.append(idx)
86
  bbox_centers.append(None)
87
  bboxes.append(None)
@@ -115,28 +142,30 @@ class SoraWM:
115
  interval_idx < len(interval_bboxes)
116
  and interval_bboxes[interval_idx] is not None
117
  ):
118
- frame_and_mask[missed_idx]["bbox"] = interval_bboxes[interval_idx]
119
  logger.debug(f"Filled missed frame {missed_idx} with bbox:\n"
120
  f" {interval_bboxes[interval_idx]}")
121
  else:
122
  # if the interval has no valid bbox, use the previous and next frame to complete (fallback strategy)
123
  before = max(missed_idx - 1, 0)
124
  after = min(missed_idx + 1, total_frames - 1)
125
- before_box = frame_and_mask[before]["bbox"]
126
- after_box = frame_and_mask[after]["bbox"]
127
  if before_box:
128
- frame_and_mask[missed_idx]["bbox"] = before_box
129
  elif after_box:
130
- frame_and_mask[missed_idx]["bbox"] = after_box
131
  else:
132
  del bboxes
133
  del bbox_centers
134
  del detect_missed
 
 
135
 
136
- for idx in tqdm(range(total_frames), desc="Remove watermarks"):
137
- frame_info = frame_and_mask[idx]
138
- frame = frame_info["frame"]
139
- bbox = frame_info["bbox"]
140
  if bbox is not None:
141
  x1, y1, x2, y2 = bbox
142
  mask = np.zeros((height, width), dtype=np.uint8)
 
15
  find_idxs_interval,
16
  )
17
 
18
+ VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm"]
19
 
20
  class SoraWM:
21
  def __init__(self):
22
  self.detector = SoraWaterMarkDetector()
23
  self.cleaner = WaterMarkCleaner()
24
 
25
+ def run_batch(self, input_video_dir_path: Path,
26
+ output_video_dir_path: Path | None = None,
27
+ progress_callback: Callable[[int], None] | None = None,
28
+ ):
29
+ if output_video_dir_path is None:
30
+ output_video_dir_path = input_video_dir_path.parent / "watermark_removed"
31
+ logger.warning(f"output_video_dir_path is not set, using {output_video_dir_path} as output_video_dir_path")
32
+ output_video_dir_path.mkdir(parents=True, exist_ok=True)
33
+ input_video_paths = []
34
+ for ext in VIDEO_EXTENSIONS:
35
+ input_video_paths.extend(input_video_dir_path.rglob(f"*{ext}"))
36
+
37
+ video_lengths = len(input_video_paths)
38
+ logger.info(f"Found {video_lengths} video(s) to process")
39
+
40
+ for idx, input_video_path in enumerate(tqdm(input_video_paths, desc="Processing videos")):
41
+ output_video_path = output_video_dir_path / input_video_path.name
42
+ if progress_callback:
43
+ def batch_progress_callback(single_video_progress: int):
44
+ overall_progress = int((idx / video_lengths) * 100 + (single_video_progress / video_lengths))
45
+ progress_callback(min(overall_progress, 100))
46
+
47
+ self.run(input_video_path, output_video_path, progress_callback=batch_progress_callback)
48
+ else:
49
+ self.run(input_video_path, output_video_path, progress_callback=None)
50
+
51
  def run(
52
  self,
53
  input_video_path: Path,
 
89
  .run_async(pipe_stdin=True)
90
  )
91
 
92
+ frame_bboxes = {}
93
  detect_missed = []
94
  bbox_centers = []
95
  bboxes = []
 
102
  ):
103
  detection_result = self.detector.detect(frame)
104
  if detection_result["detected"]:
105
+ frame_bboxes[idx] = { "bbox": detection_result["bbox"]}
106
  x1, y1, x2, y2 = detection_result["bbox"]
107
  bbox_centers.append((int((x1 + x2) / 2), int((y1 + y2) / 2)))
108
  bboxes.append((x1, y1, x2, y2))
109
 
110
  else:
111
+ frame_bboxes[idx] = {"bbox": None}
112
  detect_missed.append(idx)
113
  bbox_centers.append(None)
114
  bboxes.append(None)
 
142
  interval_idx < len(interval_bboxes)
143
  and interval_bboxes[interval_idx] is not None
144
  ):
145
+ frame_bboxes[missed_idx]["bbox"] = interval_bboxes[interval_idx]
146
  logger.debug(f"Filled missed frame {missed_idx} with bbox:\n"
147
  f" {interval_bboxes[interval_idx]}")
148
  else:
149
  # if the interval has no valid bbox, use the previous and next frame to complete (fallback strategy)
150
  before = max(missed_idx - 1, 0)
151
  after = min(missed_idx + 1, total_frames - 1)
152
+ before_box = frame_bboxes[before]["bbox"]
153
+ after_box = frame_bboxes[after]["bbox"]
154
  if before_box:
155
+ frame_bboxes[missed_idx]["bbox"] = before_box
156
  elif after_box:
157
+ frame_bboxes[missed_idx]["bbox"] = after_box
158
  else:
159
  del bboxes
160
  del bbox_centers
161
  del detect_missed
162
+
163
+ input_video_loader = VideoLoader(input_video_path)
164
 
165
+ for idx, frame in enumerate(tqdm(input_video_loader, total=total_frames, desc="Remove watermarks")):
166
+ # for idx in tqdm(range(total_frames), desc="Remove watermarks"):
167
+ # frame_info =
168
+ bbox = frame_bboxes[idx]["bbox"]
169
  if bbox is not None:
170
  x1, y1, x2, y2 = bbox
171
  mask = np.zeros((height, width), dtype=np.uint8)
sorawm/utils/download_utils.py CHANGED
@@ -3,33 +3,90 @@ from pathlib import Path
3
  import requests
4
  from loguru import logger
5
  from tqdm import tqdm
6
-
7
- from sorawm.configs import WATER_MARK_DETECT_YOLO_WEIGHTS
 
8
 
9
  DETECTOR_URL = "https://github.com/linkedlist771/SoraWatermarkCleaner/releases/download/V0.0.1/best.pt"
 
10
 
 
 
 
11
 
12
- def download_detector_weights():
13
- if not WATER_MARK_DETECT_YOLO_WEIGHTS.exists():
14
- logger.debug(f"llama weights not found, downloading from {DETECTOR_URL}")
15
- WATER_MARK_DETECT_YOLO_WEIGHTS.parent.mkdir(parents=True, exist_ok=True)
16
-
 
 
17
  try:
18
  response = requests.get(DETECTOR_URL, stream=True, timeout=300)
19
  response.raise_for_status()
20
- total_size = int(response.headers.get("content-length", 0))
21
- with open(WATER_MARK_DETECT_YOLO_WEIGHTS, "wb") as f:
22
  with tqdm(
23
  total=total_size, unit="B", unit_scale=True, desc="Downloading"
24
  ) as pbar:
25
  for chunk in response.iter_content(chunk_size=8192):
26
  if chunk:
27
  f.write(chunk)
28
- pbar.update(len(chunk))
29
-
 
 
 
30
  logger.success(f"✓ Weights downloaded: {WATER_MARK_DETECT_YOLO_WEIGHTS}")
31
-
 
 
 
 
 
 
32
  except requests.exceptions.RequestException as e:
33
- if WATER_MARK_DETECT_YOLO_WEIGHTS.exists():
34
- WATER_MARK_DETECT_YOLO_WEIGHTS.unlink()
35
- raise RuntimeError(f"Downing failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import requests
4
  from loguru import logger
5
  from tqdm import tqdm
6
+ import hashlib
7
+ import json
8
+ from sorawm.configs import WATER_MARK_DETECT_YOLO_WEIGHTS, WATER_MARK_DETECT_YOLO_WEIGHTS_HASH_JSON
9
 
10
  DETECTOR_URL = "https://github.com/linkedlist771/SoraWatermarkCleaner/releases/download/V0.0.1/best.pt"
11
+ REMOTE_MODEL_VERSION_URL = "https://raw.githubusercontent.com/linkedlist771/SoraWatermarkCleaner/refs/heads/main/model_version.json"
12
 
13
+ def generate_sha256_hash(file_path: Path) -> str:
14
+ with open(file_path, "rb") as f:
15
+ return hashlib.sha256(f.read()).hexdigest()
16
 
17
+ def download_detector_weights(force_download: bool = False):
18
+ ## 1. check if model exists and if we need to download
19
+ if not WATER_MARK_DETECT_YOLO_WEIGHTS.exists() or force_download:
20
+ logger.debug(f"Downloading weights from {DETECTOR_URL}")
21
+ WATER_MARK_DETECT_YOLO_WEIGHTS.parent.mkdir(parents=True, exist_ok=True)
22
+ temp_file = WATER_MARK_DETECT_YOLO_WEIGHTS.with_suffix(".tmp")
23
+
24
  try:
25
  response = requests.get(DETECTOR_URL, stream=True, timeout=300)
26
  response.raise_for_status()
27
+ total_size = int(response.headers.get("content-length", 0))
28
+ with open(temp_file, "wb") as f:
29
  with tqdm(
30
  total=total_size, unit="B", unit_scale=True, desc="Downloading"
31
  ) as pbar:
32
  for chunk in response.iter_content(chunk_size=8192):
33
  if chunk:
34
  f.write(chunk)
35
+ pbar.update(len(chunk))
36
+ if WATER_MARK_DETECT_YOLO_WEIGHTS.exists():
37
+ WATER_MARK_DETECT_YOLO_WEIGHTS.unlink()
38
+ temp_file.rename(WATER_MARK_DETECT_YOLO_WEIGHTS)
39
+
40
  logger.success(f"✓ Weights downloaded: {WATER_MARK_DETECT_YOLO_WEIGHTS}")
41
+ new_hash = generate_sha256_hash(WATER_MARK_DETECT_YOLO_WEIGHTS)
42
+ WATER_MARK_DETECT_YOLO_WEIGHTS_HASH_JSON.parent.mkdir(parents=True, exist_ok=True)
43
+ with WATER_MARK_DETECT_YOLO_WEIGHTS_HASH_JSON.open("w") as f:
44
+ json.dump({"sha256": new_hash}, f)
45
+ logger.debug(f"Hash updated: {new_hash[:8]}...")
46
+ return
47
+
48
  except requests.exceptions.RequestException as e:
49
+ if temp_file.exists():
50
+ temp_file.unlink()
51
+ raise RuntimeError(f"Download failed: {e}")
52
+
53
+ ## 2. check the local hash, if it exits, compare it with the remote one(with timeout)
54
+ ## if not, generate it then compare.
55
+ local_sha256_hash = None
56
+ # WATER_MARK_DETECT_YOLO_WEIGHTS_HASH_JSON
57
+ if not WATER_MARK_DETECT_YOLO_WEIGHTS_HASH_JSON.exists():
58
+ pass
59
+ else:
60
+ # if it has not hash,
61
+ with WATER_MARK_DETECT_YOLO_WEIGHTS_HASH_JSON.open("r") as f:
62
+ hash_data = json.load(f)
63
+ local_sha256_hash = hash_data.get("sha256", None)
64
+
65
+ if local_sha256_hash is None:
66
+
67
+ # generate the hash and update the config
68
+ logger.info(f"Generating sha256 hash for {WATER_MARK_DETECT_YOLO_WEIGHTS}")
69
+ local_sha256_hash = generate_sha256_hash(WATER_MARK_DETECT_YOLO_WEIGHTS)
70
+ WATER_MARK_DETECT_YOLO_WEIGHTS_HASH_JSON.parent.mkdir(parents=True, exist_ok=True)
71
+ with WATER_MARK_DETECT_YOLO_WEIGHTS_HASH_JSON.open("w") as f:
72
+ json.dump({"sha256": local_sha256_hash}, f)
73
+ remote_sha256_hash = None
74
+ try:
75
+ response = requests.get(REMOTE_MODEL_VERSION_URL, timeout=10)
76
+ response.raise_for_status()
77
+ remote_sha256_hash = response.json().get("sha256", None)
78
+ except requests.exceptions.RequestException as e:
79
+ logger.error(f"Failed to get remote sha256 hash: {e}")
80
+ remote_sha256_hash = None
81
+
82
+ ## 3. after the compare, if there is a new version, download it and replace the local and
83
+ ## update the hash
84
+ logger.debug(f"Local hash: {local_sha256_hash}, Remote hash: {remote_sha256_hash}")
85
+ if remote_sha256_hash is None:
86
+ pass
87
+ else:
88
+ if local_sha256_hash != remote_sha256_hash:
89
+ logger.info(f"Hash mismatch detected, updating model...")
90
+ download_detector_weights(force_download=True)
91
+ else:
92
+ logger.debug("Model is up-to-date")
sorawm/watermark_cleaner.py CHANGED
@@ -36,294 +36,3 @@ class WaterMarkCleaner:
36
  inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
37
  return inpaint_result
38
 
39
-
40
- if __name__ == "__main__":
41
- from pathlib import Path
42
-
43
- import cv2
44
- import numpy as np
45
- from tqdm import tqdm
46
-
47
- # ========= 配置 =========
48
- video_path = Path("resources/puppies.mp4")
49
- save_video = True
50
- out_path = Path("outputs/dog_vs_sam_detected.mp4")
51
- window = "Sora watermark (threshold+morph+shape + tracking)"
52
-
53
- # 追踪/回退策略参数
54
- PREV_ROI_EXPAND = 2.2 # 上一框宽高的膨胀倍数(>1)
55
- AREA1 = (1000, 2000) # 主检测面积范围
56
- AREA2 = (600, 4000) # 回退阶段面积范围
57
- # =======================
58
-
59
- cleaner = SoraWaterMarkCleaner(video_path, video_path)
60
-
61
- # 预取一帧确定尺寸/FPS
62
- first_frame = None
63
- for first_frame in cleaner.input_video_loader:
64
- break
65
- assert first_frame is not None, "无法读取视频帧"
66
- H, W = first_frame.shape[:2]
67
- fps = getattr(cleaner.input_video_loader, "fps", 30)
68
-
69
- # 输出视频(原 | bw | all-contours | vis 四联画)
70
- writer = None
71
- if save_video:
72
- out_path.parent.mkdir(parents=True, exist_ok=True)
73
- fourcc = cv2.VideoWriter_fourcc(*"avc1")
74
- writer = cv2.VideoWriter(str(out_path), fourcc, fps, (W * 4, H))
75
- if not writer.isOpened():
76
- fourcc = cv2.VideoWriter_fourcc(*"MJPG")
77
- writer = cv2.VideoWriter(str(out_path), fourcc, fps, (W * 4, H))
78
- assert writer.isOpened(), "无法创建输出视频文件"
79
-
80
- cv2.namedWindow(window, cv2.WINDOW_NORMAL)
81
-
82
- # ---- 工具函数 ----
83
- def _clip_rect(x0, y0, x1, y1, w_img, h_img):
84
- x0 = max(0, min(x0, w_img - 1))
85
- x1 = max(0, min(x1, w_img))
86
- y0 = max(0, min(y0, h_img - 1))
87
- y1 = max(0, min(y1, h_img))
88
- if x1 <= x0:
89
- x1 = x0 + 1
90
- if y1 <= y0:
91
- y1 = y0 + 1
92
- return x0, y0, x1, y1
93
-
94
- def _cnt_bbox(cnt):
95
- x, y, w, h = cv2.boundingRect(cnt)
96
- return (x, y, x + w, y + h)
97
-
98
- def _bbox_center(b):
99
- x0, y0, x1, y1 = b
100
- return ((x0 + x1) // 2, (y0 + y1) // 2)
101
-
102
- def detect_flower_like(image, prev_bbox=None):
103
- """
104
- 识别流程:
105
- 灰度范围 → 自适应阈值 → 仅在 3 个区域 + (可选)上一帧膨胀ROI 内找轮廓
106
- 三个区域:1) 左上20% 2) 左下20% 3) 中间水平带 y∈[0.4H, 0.6H], x∈[0,W]
107
- 返回: bw_region, best_cnt, contours_region, region_boxes, prev_roi_box
108
- """
109
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
110
-
111
- # 208 ± 20% 亮度范围
112
- low, high = int(round(208 * 0.9)), int(round(208 * 1.1))
113
- mask = ((gray >= low) & (gray <= high)).astype(np.uint8) * 255
114
-
115
- # 自适应阈值并限制到亮度范围
116
- bw = cv2.adaptiveThreshold(
117
- gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 31, -5
118
- )
119
- bw = cv2.bitwise_and(bw, mask)
120
-
121
- # -------- 三个区域:左上/左下/中间带 --------
122
- h_img, w_img = gray.shape[:2]
123
- r_top_left = (0, 0, int(0.2 * w_img), int(0.2 * h_img))
124
- r_bot_left = (0, int(0.8 * h_img), int(0.2 * w_img), h_img)
125
- y0, y1 = int(0.40 * h_img), int(0.60 * h_img) # 中间带
126
- r_mid_band = (0, y0, w_img, y1)
127
-
128
- region_mask = np.zeros_like(bw, dtype=np.uint8)
129
- for x0, ys, x1, ye in (r_top_left, r_bot_left):
130
- region_mask[ys:ye, x0:x1] = 255
131
- region_mask[y0:y1, :] = 255
132
-
133
- # -------- 追加:上一帧膨胀ROI --------
134
- prev_roi_box = None
135
- if prev_bbox is not None:
136
- px0, py0, px1, py1 = prev_bbox
137
- pw, ph = (px1 - px0), (py1 - py0)
138
- cx, cy = _bbox_center(prev_bbox)
139
- rw = int(pw * PREV_ROI_EXPAND)
140
- rh = int(ph * PREV_ROI_EXPAND)
141
- rx0, ry0 = cx - rw // 2, cy - rh // 2
142
- rx1, ry1 = cx + rw // 2, cy + rh // 2
143
- rx0, ry0, rx1, ry1 = _clip_rect(rx0, ry0, rx1, ry1, w_img, h_img)
144
- region_mask[ry0:ry1, rx0:rx1] = 255
145
- prev_roi_box = (rx0, ry0, rx1, ry1)
146
-
147
- bw_region = cv2.bitwise_and(bw, region_mask)
148
-
149
- # -------- 轮廓 + 形状筛选 --------
150
- def select_candidates(bw_bin, area_rng):
151
- contours, _ = cv2.findContours(
152
- bw_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
153
- )
154
- cand = []
155
- for cnt in contours:
156
- area = cv2.contourArea(cnt)
157
- if area < area_rng[0] or area > area_rng[1]:
158
- continue
159
- peri = cv2.arcLength(cnt, True)
160
- if peri == 0:
161
- continue
162
- circularity = 4.0 * np.pi * area / (peri * peri)
163
- if 0.55 <= circularity <= 0.95:
164
- cand.append(cnt)
165
- return contours, cand
166
-
167
- contours_region, cand1 = select_candidates(bw_region, AREA1)
168
-
169
- best_cnt = None
170
- if cand1:
171
- # 若有上一帧,用“离上一帧中心最近”优先;否则取面积最大
172
- if prev_bbox is None:
173
- best_cnt = max(cand1, key=lambda c: cv2.contourArea(c))
174
- else:
175
- pcx, pcy = _bbox_center(prev_bbox)
176
- best_cnt = max(
177
- cand1,
178
- key=lambda c: -(
179
- (((_cnt_bbox(c)[0] + _cnt_bbox(c)[2]) // 2 - pcx) ** 2)
180
- + (((_cnt_bbox(c)[1] + _cnt_bbox(c)[3]) // 2 - pcy) ** 2)
181
- ),
182
- )
183
- else:
184
- # 回退1:仅在上一帧 ROI 内放宽面积
185
- if prev_roi_box is not None:
186
- rx0, ry0, rx1, ry1 = prev_roi_box
187
- roi = np.zeros_like(bw_region)
188
- roi[ry0:ry1, rx0:rx1] = bw_region[ry0:ry1, rx0:rx1]
189
- _, cand2 = select_candidates(roi, AREA2)
190
- if cand2:
191
- if prev_bbox is None:
192
- best_cnt = max(cand2, key=lambda c: cv2.contourArea(c))
193
- else:
194
- pcx, pcy = _bbox_center(prev_bbox)
195
- best_cnt = max(
196
- cand2,
197
- key=lambda c: -(
198
- (((_cnt_bbox(c)[0] + _cnt_bbox(c)[2]) // 2 - pcx) ** 2)
199
- + (
200
- ((_cnt_bbox(c)[1] + _cnt_bbox(c)[3]) // 2 - pcy)
201
- ** 2
202
- )
203
- ),
204
- )
205
- else:
206
- # 回退2:全区域 cand,选最近中心
207
- if prev_bbox is not None:
208
- _, cand3 = select_candidates(bw_region, AREA2)
209
- if cand3:
210
- pcx, pcy = _bbox_center(prev_bbox)
211
- best_cnt = max(
212
- cand3,
213
- key=lambda c: -(
214
- (
215
- ((_cnt_bbox(c)[0] + _cnt_bbox(c)[2]) // 2 - pcx)
216
- ** 2
217
- )
218
- + (
219
- ((_cnt_bbox(c)[1] + _cnt_bbox(c)[3]) // 2 - pcy)
220
- ** 2
221
- )
222
- ),
223
- )
224
-
225
- region_boxes = (r_top_left, r_bot_left, r_mid_band, (y0, y1))
226
- return bw_region, best_cnt, contours_region, region_boxes, prev_roi_box
227
-
228
- # ---- 时序追踪状态(用字典避免 nonlocal/global) ----
229
- state = {"bbox": None} # 保存上一帧外接框 (x0,y0,x1,y1)
230
-
231
- def process_and_show(frame, idx):
232
- img = frame.copy()
233
- bw, best, contours, region_boxes, prev_roi_box = detect_flower_like(
234
- img, state["bbox"]
235
- )
236
- r_top_left, r_bot_left, r_mid_band, (y0, y1) = region_boxes
237
-
238
- # 所有轮廓(黄)
239
- allc = img.copy()
240
- if contours:
241
- cv2.drawContours(allc, contours, -1, (0, 255, 255), 1)
242
-
243
- # 画三个区域:红框 + 中间带上下红线
244
- def draw_rect(im, rect, color=(0, 0, 255), th=2):
245
- x0, y0r, x1, y1r = rect
246
- cv2.rectangle(im, (x0, y0r), (x1, y1r), color, th)
247
-
248
- draw_rect(allc, r_top_left)
249
- draw_rect(allc, r_bot_left)
250
- draw_rect(allc, (r_mid_band[0], r_mid_band[1], r_mid_band[2], r_mid_band[3]))
251
- cv2.line(allc, (0, y0), (img.shape[1], y0), (0, 0, 255), 2)
252
- cv2.line(allc, (0, y1), (img.shape[1], y1), (0, 0, 255), 2)
253
-
254
- # 画上一帧的膨胀 ROI(青色)
255
- if prev_roi_box is not None:
256
- x0, y0r, x1, y1r = prev_roi_box
257
- cv2.rectangle(allc, (x0, y0r), (x1, y1r), (255, 255, 0), 2)
258
-
259
- # 最终检测
260
- vis = img.copy()
261
- title = "no-detect"
262
- if best is not None:
263
- cv2.drawContours(vis, [best], -1, (0, 255, 0), 2)
264
- x0, y0r, x1, y1r = _cnt_bbox(best)
265
- state["bbox"] = (x0, y0r, x1, y1r) # 更新追踪状态
266
- M = cv2.moments(best)
267
- if M["m00"] > 0:
268
- cx, cy = int(M["m10"] / M["m00"]), int(M["m01"] / M["m00"])
269
- cv2.circle(vis, (cx, cy), 4, (0, 0, 255), -1)
270
- title = "detected"
271
- else:
272
- # 若仍未检测,维持上一状态
273
- cv2.putText(
274
- vis,
275
- "No detection (kept last state)",
276
- (12, 28),
277
- cv2.FONT_HERSHEY_SIMPLEX,
278
- 0.8,
279
- (0, 0, 255),
280
- 2,
281
- )
282
- if state["bbox"] is not None:
283
- x0, y0r, x1, y1r = state["bbox"]
284
- cv2.rectangle(vis, (x0, y0r), (x1, y1r), (255, 255, 0), 2)
285
-
286
- # 四联画:原图 | 区域内bw | 所有轮廓 | 最终检测
287
- panel = np.hstack([img, cv2.cvtColor(bw, cv2.COLOR_GRAY2BGR), allc, vis])
288
- cv2.putText(
289
- panel,
290
- f"Frame {idx} | {title}",
291
- (12, 28),
292
- cv2.FONT_HERSHEY_SIMPLEX,
293
- 0.9,
294
- (255, 255, 255),
295
- 2,
296
- )
297
-
298
- cv2.imshow(window, panel)
299
- if writer is not None:
300
- if panel.shape[:2] != (H, W * 4):
301
- panel = cv2.resize(panel, (W * 4, H), interpolation=cv2.INTER_AREA)
302
- writer.write(panel)
303
-
304
- # 先处理已取出的第一帧
305
- process_and_show(first_frame, 0)
306
-
307
- # 按你的遍历方式继续
308
- for idx, frame in enumerate(
309
- tqdm(cleaner.input_video_loader, desc="Processing frames", initial=1, unit="f")
310
- ):
311
- process_and_show(frame, idx)
312
- key = cv2.waitKey(max(1, int(1000 / max(1, int(fps))))) & 0xFF
313
- if key == ord("q"):
314
- break
315
- elif key == ord(" "):
316
- while True:
317
- k = cv2.waitKey(50) & 0xFF
318
- if k in (ord(" "), ord("q")):
319
- if k == ord("q"):
320
- idx = 10**9
321
- break
322
- if idx >= 10**9:
323
- break
324
-
325
- if writer is not None:
326
- writer.release()
327
- print(f"[OK] 可视化视频已保存: {out_path}")
328
-
329
- cv2.destroyAllWindows()
 
36
  inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
37
  return inpaint_result
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
uv.lock CHANGED
@@ -2661,6 +2661,7 @@ dependencies = [
2661
  { name = "fastapi" },
2662
  { name = "ffmpeg-python" },
2663
  { name = "fire" },
 
2664
  { name = "httpx" },
2665
  { name = "huggingface-hub" },
2666
  { name = "jupyter" },
@@ -2695,6 +2696,7 @@ requires-dist = [
2695
  { name = "fastapi", specifier = "==0.108.0" },
2696
  { name = "ffmpeg-python", specifier = ">=0.2.0" },
2697
  { name = "fire", specifier = ">=0.7.1" },
 
2698
  { name = "httpx", specifier = ">=0.28.1" },
2699
  { name = "huggingface-hub", specifier = ">=0.35.3" },
2700
  { name = "jupyter", specifier = ">=1.1.1" },
 
2661
  { name = "fastapi" },
2662
  { name = "ffmpeg-python" },
2663
  { name = "fire" },
2664
+ { name = "greenlet" },
2665
  { name = "httpx" },
2666
  { name = "huggingface-hub" },
2667
  { name = "jupyter" },
 
2696
  { name = "fastapi", specifier = "==0.108.0" },
2697
  { name = "ffmpeg-python", specifier = ">=0.2.0" },
2698
  { name = "fire", specifier = ">=0.7.1" },
2699
+ { name = "greenlet", specifier = ">=3.2.4" },
2700
  { name = "httpx", specifier = ">=0.28.1" },
2701
  { name = "huggingface-hub", specifier = ">=0.35.3" },
2702
  { name = "jupyter", specifier = ">=1.1.1" },
version/model_version.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}