FlexSED / api.py
OpenSound's picture
Update api.py
33db348 verified
import torch
import librosa
import os
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, ClapTextModelWithProjection
from src.models.transformer import Dasheng_Encoder
from src.models.sed_decoder import Decoder, TSED_Wrapper
from src.utils import load_yaml_with_includes
class FlexSED:
def __init__(
self,
config_path='src/configs/model.yml',
ckpt_path='ckpts/flexsed_as.pt',
ckpt_url='https://huggingface.co/Higobeatz/FlexSED/resolve/main/ckpts/flexsed_as.pt',
device='cuda'
):
"""
Initialize FlexSED with model, CLAP, and tokenizer loaded once.
If the checkpoint is not available locally, it will be downloaded automatically.
"""
self.device = device
params = load_yaml_with_includes(config_path)
# Ensure checkpoint exists
if not os.path.exists(ckpt_path):
print(f"[FlexSED] Downloading checkpoint from {ckpt_url} ...")
state_dict = torch.hub.load_state_dict_from_url(ckpt_url, map_location="cpu")
else:
state_dict = torch.load(ckpt_path, map_location="cpu")
# Encoder + Decoder
encoder = Dasheng_Encoder(**params['encoder']).to(self.device)
decoder = Decoder(**params['decoder']).to(self.device)
self.model = TSED_Wrapper(encoder, decoder, params['ft_blocks'], params['frozen_encoder'])
self.model.load_state_dict(state_dict['model'])
self.model.eval()
# CLAP text model
self.clap = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused")
self.clap.eval()
self.tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused")
def run_inference(self, audio_path, events, norm_audio=True):
"""
Run inference on audio for given events.
"""
audio, sr = librosa.load(audio_path, sr=16000)
audio = torch.tensor([audio]).to(self.device)
if norm_audio:
eps = 1e-9
max_val = torch.max(torch.abs(audio))
audio = audio / (max_val + eps)
clap_embeds = []
with torch.no_grad():
for event in events:
text = f"The sound of {event.replace('_', ' ').capitalize()}"
inputs = self.tokenizer([text], padding=True, return_tensors="pt")
outputs = self.clap(**inputs)
text_embeds = outputs.text_embeds.unsqueeze(1)
clap_embeds.append(text_embeds)
query = torch.cat(clap_embeds, dim=1).to(self.device)
mel = self.model.forward_to_spec(audio)
preds = self.model(mel, query)
preds = torch.sigmoid(preds).cpu()
return preds # shape: [num_events, 1, T]
# ---------- Multi-event plotting ----------
@staticmethod
def plot_and_save_multi(preds, events, sr=25, out_dir="./plots", fname="all_events"):
os.makedirs(out_dir, exist_ok=True)
preds_np = preds.squeeze(1).numpy() # [num_events, T]
T = preds_np.shape[1]
plt.figure(figsize=(12, len(events) * 0.6 + 2))
plt.imshow(
preds_np,
aspect="auto",
cmap="Blues",
extent=[0, T/sr, 0, len(events)],
vmin=0, vmax=1, origin="lower"
)
plt.colorbar(label="Probability")
plt.yticks(np.arange(len(events)) + 0.5, events)
plt.xlabel("Time (s)")
plt.ylabel("Events")
plt.title("Event Predictions")
save_path = os.path.join(out_dir, f"{fname}.png")
plt.savefig(save_path, dpi=200, bbox_inches="tight")
plt.close()
return save_path
def to_multi_plot(self, preds, events, out_dir="./plots", fname="all_events"):
return self.plot_and_save_multi(preds, events, out_dir=out_dir, fname=fname)
# ---------- Multi-event video ----------
@staticmethod
def make_multi_event_video(preds, events, sr=25, out_dir="./videos",
audio_path=None, fps=25, highlight=True, fname="all_events"):
from moviepy.editor import ImageSequenceClip, AudioFileClip
from tqdm import tqdm
os.makedirs(out_dir, exist_ok=True)
preds_np = preds.squeeze(1).numpy() # [num_events, T]
T = preds_np.shape[1]
duration = T / sr
frames = []
n_frames = int(duration * fps)
for i in tqdm(range(n_frames)):
t = int(i * T / n_frames)
plt.figure(figsize=(12, len(events) * 0.6 + 2))
if highlight:
mask = np.zeros_like(preds_np)
mask[:, :t+1] = preds_np[:, :t+1]
plt.imshow(
mask,
aspect="auto",
cmap="Blues",
extent=[0, T/sr, 0, len(events)],
vmin=0, vmax=1, origin="lower"
)
else:
plt.imshow(
preds_np[:, :t+1],
aspect="auto",
cmap="Blues",
extent=[0, (t+1)/sr, 0, len(events)],
vmin=0, vmax=1, origin="lower"
)
plt.colorbar(label="Probability")
plt.yticks(np.arange(len(events)) + 0.5, events)
plt.xlabel("Time (s)")
plt.ylabel("Events")
plt.title("Event Predictions")
frame_path = f"/tmp/frame_{i:04d}.png"
plt.savefig(frame_path, dpi=150, bbox_inches="tight")
plt.close()
frames.append(frame_path)
clip = ImageSequenceClip(frames, fps=fps)
if audio_path is not None:
audio = AudioFileClip(audio_path).subclip(0, duration)
clip = clip.set_audio(audio)
save_path = os.path.join(out_dir, f"{fname}.mp4")
clip.write_videofile(
save_path,
fps=fps,
codec="mpeg4",
audio_codec="aac"
)
for f in frames:
os.remove(f)
return save_path
def to_multi_video(self, preds, events, audio_path, out_dir="./videos", fname="all_events"):
return self.make_multi_event_video(
preds, events, audio_path=audio_path, out_dir=out_dir, fname=fname
)
if __name__ == "__main__":
flexsed = FlexSED(device='cuda')
events = ["Door", "Laughter", "Dog"]
preds = flexsed.run_inference("example2.wav", events)
# Combined plot & video
flexsed.to_multi_plot(preds, events, fname="example2")
# flexsed.to_multi_video(preds, events, audio_path="example2.wav", fname="example2")