import os
os.environ["HOME"] = "/tmp"
os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp/.streamlit"
os.makedirs("/tmp/.streamlit", exist_ok=True)
import shutil
import tempfile
from pathlib import Path
import streamlit as st
from sorawm.core import SoraWM
def main():
st.set_page_config(
page_title="Sora Watermark Cleaner", page_icon="๐ฌ", layout="centered"
)
# Header section with improved layout
st.markdown(
"""
๐ฌ Sora Watermark Cleaner
Remove watermarks from Sora-generated videos with AI-powered precision
""",
unsafe_allow_html=True,
)
# # Feature badges
# col1, col2, col3 = st.columns(3)
# with col1:
# st.markdown(
# """
#
#
โก
#
Fast Processing
#
GPU Accelerated
#
# """,
# unsafe_allow_html=True,
# )
# with col2:
# st.markdown(
# """
#
#
๐ฏ
#
High Precision
#
AI-Powered
#
# """,
# unsafe_allow_html=True,
# )
# with col3:
# st.markdown(
# """
#
#
๐ฆ
#
Batch Support
#
Process Multiple
#
# """,
# unsafe_allow_html=True,
# )
# Footer info
st.markdown(
"""
""",
unsafe_allow_html=True,
)
# Initialize SoraWM
if "sora_wm" not in st.session_state:
with st.spinner("Loading AI models..."):
st.session_state.sora_wm = SoraWM()
st.markdown("---")
# Mode selection
mode = st.radio(
"Select input mode:",
["๐ Upload Video File", "๐๏ธ Process Folder"],
horizontal=True,
)
if mode == "๐ Upload Video File":
# File uploader
uploaded_file = st.file_uploader(
"Upload your video",
type=["mp4", "avi", "mov", "mkv"],
accept_multiple_files=False,
help="Select a video file to remove watermark",
)
if uploaded_file:
# Clear previous processed video if a new file is uploaded
if "current_file_name" not in st.session_state or st.session_state.current_file_name != uploaded_file.name:
st.session_state.current_file_name = uploaded_file.name
if "processed_video_data" in st.session_state:
del st.session_state.processed_video_data
if "processed_video_path" in st.session_state:
del st.session_state.processed_video_path
if "processed_video_name" in st.session_state:
del st.session_state.processed_video_name
# Display video info
st.success(f"โ
Uploaded: {uploaded_file.name}")
# Create two columns for before/after comparison
col_left, col_right = st.columns(2)
with col_left:
st.markdown("### ๐ฅ Original Video")
st.video(uploaded_file)
with col_right:
st.markdown("### ๐ฌ Processed Video")
# Placeholder for processed video
if "processed_video_data" not in st.session_state:
st.info("Click 'Remove Watermark' to process the video")
else:
st.video(st.session_state.processed_video_data)
# Process button
if st.button("๐ Remove Watermark", type="primary", use_container_width=True):
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
try:
# Create progress bar and status text
progress_bar = st.progress(0)
status_text = st.empty()
def update_progress(progress: int):
progress_bar.progress(progress / 100)
if progress < 50:
status_text.text(f"๐ Detecting watermarks... {progress}%")
elif progress < 95:
status_text.text(f"๐งน Removing watermarks... {progress}%")
else:
status_text.text(f"๐ต Merging audio... {progress}%")
# Single file processing
input_path = tmp_path / uploaded_file.name
with open(input_path, "wb") as f:
f.write(uploaded_file.read())
output_path = tmp_path / f"cleaned_{uploaded_file.name}"
st.session_state.sora_wm.run(
input_path, output_path, progress_callback=update_progress
)
progress_bar.progress(100)
status_text.text("โ
Processing complete!")
st.success("โ
Watermark removed successfully!")
# Store processed video path and read video data
with open(output_path, "rb") as f:
video_data = f.read()
st.session_state.processed_video_path = output_path
st.session_state.processed_video_data = video_data
st.session_state.processed_video_name = f"cleaned_{uploaded_file.name}"
# Rerun to show the video in the right column
st.rerun()
except Exception as e:
st.error(f"โ Error processing video: {str(e)}")
# Download button (show only if video is processed)
if "processed_video_data" in st.session_state:
st.download_button(
label="โฌ๏ธ Download Cleaned Video",
data=st.session_state.processed_video_data,
file_name=st.session_state.processed_video_name,
mime="video/mp4",
use_container_width=True,
)
else: # Folder mode
st.info("๐ก Drag and drop your video folder here, or click to browse and select multiple video files")
# File uploader for multiple files (supports folder drag & drop)
uploaded_files = st.file_uploader(
"Upload videos from folder",
type=["mp4", "avi", "mov", "mkv"],
accept_multiple_files=True,
help="You can drag & drop an entire folder here, or select multiple video files",
key="folder_uploader"
)
if uploaded_files:
# Display uploaded files info
video_count = len(uploaded_files)
st.success(f"โ
{video_count} video file(s) uploaded")
# Show file list in an expander
with st.expander("๐ View uploaded files", expanded=False):
for i, file in enumerate(uploaded_files, 1):
file_size_mb = file.size / (1024 * 1024)
st.text(f"{i}. {file.name} ({file_size_mb:.2f} MB)")
# Process button
if st.button("๐ Process All Videos", type="primary", use_container_width=True):
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
input_folder = tmp_path / "input"
output_folder = tmp_path / "output"
input_folder.mkdir(exist_ok=True)
output_folder.mkdir(exist_ok=True)
try:
# Save all uploaded files to temp folder
status_text = st.empty()
status_text.text("๐ฅ Saving uploaded files...")
for uploaded_file in uploaded_files:
# Preserve folder structure if file.name contains subdirectories
file_path = input_folder / uploaded_file.name
file_path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "wb") as f:
f.write(uploaded_file.read())
# Create progress tracking
progress_bar = st.progress(0)
current_file_text = st.empty()
processed_count = 0
def update_progress(progress: int):
# Calculate overall progress
overall_progress = (processed_count * 100 + progress) / video_count / 100
progress_bar.progress(overall_progress)
if progress < 50:
current_file_text.text(f"๐ Processing file {processed_count + 1}/{video_count}: Detecting watermarks... {progress}%")
elif progress < 95:
current_file_text.text(f"๐งน Processing file {processed_count + 1}/{video_count}: Removing watermarks... {progress}%")
else:
current_file_text.text(f"๐ต Processing file {processed_count + 1}/{video_count}: Merging audio... {progress}%")
# Process each video file
for video_file in input_folder.rglob("*"):
if video_file.is_file() and video_file.suffix.lower() in [".mp4", ".avi", ".mov", ".mkv"]:
# Determine output path maintaining folder structure
rel_path = video_file.relative_to(input_folder)
output_path = output_folder / rel_path.parent / f"cleaned_{rel_path.name}"
output_path.parent.mkdir(parents=True, exist_ok=True)
# Process the video
st.session_state.sora_wm.run(
video_file, output_path, progress_callback=update_progress
)
processed_count += 1
progress_bar.progress(100)
current_file_text.text("โ
All videos processed!")
st.success(f"โ
{video_count} video(s) processed successfully!")
# Create download option for processed videos
st.markdown("### ๐ฆ Download Processed Videos")
# Store processed files info in session state
if "batch_processed_files" not in st.session_state:
st.session_state.batch_processed_files = []
st.session_state.batch_processed_files.clear()
for processed_file in output_folder.rglob("*"):
if processed_file.is_file():
with open(processed_file, "rb") as f:
video_data = f.read()
rel_path = processed_file.relative_to(output_folder)
st.session_state.batch_processed_files.append({
"name": str(rel_path),
"data": video_data
})
st.rerun()
except Exception as e:
st.error(f"โ Error processing videos: {str(e)}")
import traceback
st.error(f"Details: {traceback.format_exc()}")
# Show download buttons for processed files
if "batch_processed_files" in st.session_state and st.session_state.batch_processed_files:
st.markdown("---")
st.markdown("### โฌ๏ธ Download Processed Videos")
for file_info in st.session_state.batch_processed_files:
col1, col2 = st.columns([3, 1])
with col1:
st.text(f"๐น {file_info['name']}")
with col2:
st.download_button(
label="โฌ๏ธ Download",
data=file_info['data'],
file_name=file_info['name'],
mime="video/mp4",
key=f"download_{file_info['name']}",
use_container_width=True
)
if __name__ == "__main__":
main()