iBrokeTheCode's picture
chore: Add source code for training
9470ff7
raw
history blame
7.66 kB
import os
import warnings
from io import BytesIO
import numpy as np
import pandas as pd
import requests
from PIL import Image
from sklearn.model_selection import train_test_split
# 💬 NOTE: Suppress all warnings
warnings.filterwarnings("ignore")
def process_embeddings(df, col_name):
"""
Process embeddings in a DataFrame column.
Args:
- df (pd.DataFrame): The DataFrame containing the embeddings column.
- col_name (str): The name of the column containing the embeddings.
Returns:
pd.DataFrame: The DataFrame with processed embeddings.
Steps:
1. Convert the values in the specified column to lists.
2. Extract values from lists and create new columns for each element.
3. Remove the original embeddings column.
Example:
df_processed = process_embeddings(df, 'embeddings')
"""
# Convert the values (eg. "[-0.123, 0.456, ...]") in the column to lists
df[col_name] = df[col_name].apply(eval)
# Extract values from lists and create new columns
""" 🔎 Example
text_1 text_2 text_3
0 -0.123 0.456 0.789
1 0.321 -0.654 0.987
"""
embeddings_df = pd.DataFrame(
df[col_name].to_list(),
columns=[f"text_{i + 1}" for i in range(df[col_name].str.len().max())],
)
df = pd.concat([df, embeddings_df], axis=1)
# Remove the original "embeddings" column
df = df.drop(columns=[col_name])
return df
def rename_image_embeddings(df):
"""
Rename columns in a DataFrame for image embeddings.
Args:
- df (pd.DataFrame): The DataFrame containing columns to be renamed.
Returns:
pd.DataFrame: The DataFrame with renamed columns.
Example:
df_renamed = rename_image_embeddings(df)
"""
# From 0 1 2 label ➡️ image_0 image_1 image_2 label
df.columns = [f"image_{int(col)}" if col.isdigit() else col for col in df.columns]
return df
def preprocess_data(
text_data,
image_data,
text_id="image_id",
image_id="ImageName",
embeddings_col="embeddings",
):
"""
Preprocess and merge text and image dataframes.
Args:
- text_data (pd.DataFrame): DataFrame containing text data.
- image_data (pd.DataFrame): DataFrame containing image data.
- text_id (str): Column name for text data identifier.
- image_id (str): Column name for image data identifier.
- embeddings_col (str): Column name for embeddings data.
Returns:
pd.DataFrame: Merged and preprocessed DataFrame.
This function:
Process text and image embeddings.
Convert image_id and text_id values to integers.
Merge dataframes using id.
Drop unnecessary columns.
Example:
merged_df = preprocess_data(text_df, image_df)
"""
# Call previous functions to tune the text and image dataframes
text_data = process_embeddings(text_data, embeddings_col)
image_data = rename_image_embeddings(image_data)
# Drop missing values in image id - Removes rows where the ID (used to join text ↔ image) is missing.
image_data = image_data.dropna(subset=[image_id])
text_data = text_data.dropna(subset=[text_id])
# Cleans up text IDs: if the column contains file paths (like "data/images/123.jpg"), it extracts just the file name ("123.jpg").
text_data[text_id] = text_data[text_id].apply(lambda x: x.split("/")[-1])
# Merge dataframes using image_id - Joins text and image embeddings using the IDs (text_id vs image_id).
df = pd.merge(text_data, image_data, left_on=text_id, right_on=image_id)
# Drop unnecessary columns - Removes the original ID columns since they’re no longer needed after the merge.
df.drop([image_id, text_id], axis=1, inplace=True)
return df
class ImageDownloader:
"""
Image downloader class to download images from URLs.
Args:
- image_dir (str): Directory to save images.
- image_size (tuple): Size of the images to be saved.
- override (bool): Whether to override existing images.
Methods:
- download_images(df, print_every=1000): Download images from URLs in a DataFrame.
Args:
- df (pd.DataFrame): DataFrame containing image URLs.
- print_every (int): Print progress every n images.
Returns:
pd.DataFrame: DataFrame with image paths added.
Example:
downloader = ImageDownloader()
df = downloader.download_images(df)
"""
def __init__(
self, image_dir="data/images/", image_size=(224, 224), overwrite=False
):
self.image_dir = image_dir
self.image_size = image_size
self.overwrite = overwrite
# Create the directory if it doesn't exist
if not os.path.exists(self.image_dir):
os.makedirs(self.image_dir)
def download_images(self, df, print_every=1000):
# Bulk download images from a DataFrame of URLs, resize them to a standard format, and add their local paths back to the DataFrame.
image_paths = []
i = 0
for index, row in df.iterrows():
if i % print_every == 0:
print(f"Downloading image {i}/{len(df)}")
i += 1
sku = row["sku"]
image_url = row["image"]
image_path = os.path.join(self.image_dir, f"{sku}.jpg")
if os.path.exists(image_path) and not self.overwrite:
print(f"Image {sku} is already in the path.")
image_paths.append(image_path)
continue
try:
response = requests.get(image_url)
response.raise_for_status()
img = Image.open(BytesIO(response.content))
img = img.resize(self.image_size, Image.Resampling.LANCZOS)
img.save(image_path)
# print(f"Downloaded image for SKU: {sku}")
image_paths.append(image_path)
except Exception as e:
print(f"Could not download image for SKU: {sku}. Error: {e}")
image_paths.append(np.nan)
df["image_path"] = image_paths
return df
def train_test_split_and_feature_extraction(df, test_size=0.3, random_state=42):
"""
Split the data into train and test sets and extract features and labels.
Args:
- df (pd.DataFrame): DataFrame containing the data.
Keyword Args:
- test_size (float): Size of the test set.
- random_state (int): Random state for reproducibility
Returns:
pd.DataFrame: Train DataFrame.
pd.DataFrame: Test DataFrame.
list: List of columns with text embeddings.
list: List of columns with image embeddings.
list: List of columns with class labels.
Example:
train_df, test_df, text_columns, image_columns, label_columns = train_test_split_and_feature_extraction(df)
"""
# Split the data into train and test sets setting using the test_size and random_state parameters
train_df, test_df = train_test_split(
df, test_size=test_size, random_state=random_state
)
# Select the name of the columns with the text embeddings and return it as a list (Even if there is only one column)
text_columns = [col for col in df.columns if col.startswith("text_")]
# Select the name of the columns with the image embeddings and return it as a list (Even if there is only one column)
image_columns = [col for col in df.columns if col.startswith("image_")]
# Select the name of the column with the class labels and return it as a list (Even if there is only one column)
label_columns = ["class_id"]
return train_df, test_df, text_columns, image_columns, label_columns