Spaces:
Runtime error
Runtime error
| # Construct pairs of text and image | |
| from configs import CFG | |
| from costum_datasets import make_pairs | |
| from text_image_audio import OneEncoder | |
| import torch | |
| import gradio as gr | |
| import torchaudio | |
| # Construct pairs of text and image | |
| training_pairs = make_pairs(CFG.image_dir, CFG.image_dir, 5) # 413.915 -> 82.783 images | |
| # Sorted according images | |
| training_pairs = sorted(training_pairs, key=lambda x: x[0]) | |
| coco_images, coco_captions = zip(*training_pairs) | |
| # Take unique images | |
| unique_images = set() | |
| unique_pairs = [(item[0], item[1]) for item in training_pairs if item[0] not in unique_images | |
| and not unique_images.add(item[0])] | |
| coco_images, _ = zip(*unique_pairs) | |
| # Load model (update) | |
| model = OneEncoder.from_pretrained("bilalfaye/OneEncoder-text-image-audio") | |
| # Load coco image features | |
| coco_image_features = torch.load("image_embeddings_best.pt", map_location=CFG.device) | |
| coco_image_features = coco_image_features[:3000] | |
| def text_image(query): | |
| model.text_image_encoder.image_retrieval(query, | |
| image_paths=coco_images, | |
| image_embeddings=coco_image_features, | |
| n=9, | |
| plot=True, | |
| temperature=0.0 | |
| ) | |
| return "img.png" | |
| def audio_image(query): | |
| # Load the audio with torchaudio (returns tensor and sample rate) | |
| waveform, sample_rate = torchaudio.load(query) | |
| # Check if audio is stereo | |
| if waveform.shape[0] > 1: # Stereo (2 channels) | |
| # Convert stereo to mono: sum the left and right channels and divide by 2 | |
| mono_audio = waveform.mean(dim=0, keepdim=True) | |
| else: | |
| # Audio is already mono | |
| mono_audio = waveform | |
| # Resample to 16000 Hz if not already | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) | |
| mono_audio = resampler(mono_audio) | |
| # Convert to numpy array for pipeline processing (if required) | |
| mono_audio = mono_audio.squeeze(0).numpy() | |
| audio_encoding = model.process_audio([mono_audio]) | |
| model.image_retrieval(audio_encoding, | |
| image_paths=coco_images, | |
| image_embeddings=coco_image_features, | |
| n=9, | |
| plot=True, | |
| temperature=0.0, | |
| display_audio=False) | |
| return "img.png" | |
| # Updated Gradio Interface | |
| iface = gr.TabbedInterface( | |
| [ | |
| gr.Interface( | |
| fn=text_image, | |
| inputs=gr.Textbox(label="Text Query"), | |
| outputs="image", | |
| title="Retrieve images using text as query", | |
| description="Implementation of OneEncoder using one layer on UP for light demo, Only coco train dataset is used in this example (3000 images)." | |
| ), | |
| gr.Interface( | |
| fn=audio_image, | |
| inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Provide Audio Query"), | |
| outputs="image", | |
| title="Retrieve images using audio as query", | |
| description="Implementation of OneEncoder using one layer on UP for light demo, Only coco train dataset is used in this example (3000 images)." | |
| ) | |
| ], | |
| tab_names=["Text - Image", "Audio - Image"] | |
| ) | |
| iface.launch(debug=True, share=True) | |