Commit
·
9470ff7
1
Parent(s):
d7c8166
chore: Add source code for training
Browse files- requirements_train.txt +15 -0
- src/__init__.py +0 -0
- src/classifiers_classic_ml.py +298 -0
- src/classifiers_mlp.py +522 -0
- src/nlp_models.py +242 -0
- src/utils.py +227 -0
- src/vision_embeddings_tf.py +470 -0
requirements_train.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas~=1.5.0
|
| 2 |
+
numpy~=1.23.3
|
| 3 |
+
pillow==10.4.0
|
| 4 |
+
requests==2.26.0
|
| 5 |
+
matplotlib==3.4.2
|
| 6 |
+
seaborn==0.13.2
|
| 7 |
+
plotly==5.23.0
|
| 8 |
+
pytest==8.3.3
|
| 9 |
+
scikit-learn==0.24.2
|
| 10 |
+
torch==2.0.0
|
| 11 |
+
tensorflow==2.10.0
|
| 12 |
+
transformers==4.44.2
|
| 13 |
+
openai==1.37.0
|
| 14 |
+
python-dotenv==1.0.1
|
| 15 |
+
tensorflow-gpu==2.10.0
|
src/__init__.py
ADDED
|
File without changes
|
src/classifiers_classic_ml.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from itertools import cycle
|
| 3 |
+
|
| 4 |
+
import matplotlib
|
| 5 |
+
|
| 6 |
+
# 💬 NOTE: Handle plots issues when running tests or displaying in notebooks
|
| 7 |
+
try:
|
| 8 |
+
get_ipython # Only exists in Jupyter
|
| 9 |
+
matplotlib.use("module://matplotlib_inline.backend_inline")
|
| 10 |
+
except Exception:
|
| 11 |
+
matplotlib.use("Agg") # Fix error with tests
|
| 12 |
+
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import plotly.express as px
|
| 16 |
+
import seaborn as sns
|
| 17 |
+
from sklearn.decomposition import PCA
|
| 18 |
+
from sklearn.ensemble import RandomForestClassifier
|
| 19 |
+
from sklearn.linear_model import LogisticRegression
|
| 20 |
+
from sklearn.manifold import TSNE
|
| 21 |
+
from sklearn.metrics import (
|
| 22 |
+
accuracy_score,
|
| 23 |
+
auc,
|
| 24 |
+
classification_report,
|
| 25 |
+
confusion_matrix,
|
| 26 |
+
f1_score,
|
| 27 |
+
precision_score,
|
| 28 |
+
recall_score,
|
| 29 |
+
roc_curve,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
warnings.filterwarnings("ignore")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def visualize_embeddings(
|
| 36 |
+
X_train, X_test, y_train, y_test, plot_type="2D", method="PCA"
|
| 37 |
+
):
|
| 38 |
+
"""
|
| 39 |
+
Visualizes high-dimensional embeddings (e.g., text or image embeddings) using dimensionality reduction techniques (PCA or t-SNE)
|
| 40 |
+
and plots the results in 2D or 3D using Plotly for interactive visualizations.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
X_train (np.ndarray): Training data embeddings of shape (n_samples, n_features).
|
| 44 |
+
X_test (np.ndarray): Test data embeddings of shape (n_samples, n_features).
|
| 45 |
+
y_train (np.ndarray): True labels for the training data.
|
| 46 |
+
y_test (np.ndarray): True labels for the test data.
|
| 47 |
+
plot_type (str, optional): Type of plot to generate, either '2D' or '3D'. Default is '2D'.
|
| 48 |
+
method (str, optional): Dimensionality reduction method to use, either 'PCA' or 't-SNE'. Default is 'PCA'.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
None
|
| 52 |
+
|
| 53 |
+
Side Effects:
|
| 54 |
+
- Displays an interactive 2D or 3D scatter plot of the reduced embeddings, with points colored by their class labels.
|
| 55 |
+
|
| 56 |
+
Notes:
|
| 57 |
+
- PCA is a linear dimensionality reduction method, while t-SNE is non-linear and captures more complex relationships.
|
| 58 |
+
- Perplexity is set to 10 for t-SNE. It can be tuned if necessary for better visualization of data clusters.
|
| 59 |
+
- The function raises a `ValueError` if an invalid method is specified.
|
| 60 |
+
- The function uses Plotly to display interactive plots.
|
| 61 |
+
|
| 62 |
+
Example:
|
| 63 |
+
visualize_embeddings(X_train, X_test, y_train, y_test, plot_type='3D', method='t-SNE')
|
| 64 |
+
|
| 65 |
+
Visualization Details:
|
| 66 |
+
- For 3D visualization, the reduced embeddings are plotted in a 3D scatter plot, with axes labeled as 'col1', 'col2', and 'col3'.
|
| 67 |
+
- For 2D visualization, the embeddings are plotted in a 2D scatter plot, with axes labeled as 'col1' and 'col2'.
|
| 68 |
+
- Class labels are represented by different colors in the scatter plots.
|
| 69 |
+
"""
|
| 70 |
+
perplexity = 10
|
| 71 |
+
|
| 72 |
+
if plot_type == "3D":
|
| 73 |
+
if method == "PCA":
|
| 74 |
+
# Create an instance of PCA for 3D visualization and fit it on the training data
|
| 75 |
+
red = PCA(n_components=3)
|
| 76 |
+
red.fit(X_train)
|
| 77 |
+
|
| 78 |
+
# Use the trained model to transform the test data
|
| 79 |
+
reduced_embeddings = red.transform(X_test)
|
| 80 |
+
elif method == "t-SNE":
|
| 81 |
+
# Implement t-SNE for 3D visualization
|
| 82 |
+
red = TSNE(
|
| 83 |
+
n_components=3, perplexity=perplexity, random_state=42, init="pca"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Use the model to train and transform the test data
|
| 87 |
+
reduced_embeddings = red.fit_transform(X_test)
|
| 88 |
+
else:
|
| 89 |
+
raise ValueError("Invalid method. Please choose either 'PCA' or 't-SNE'.")
|
| 90 |
+
|
| 91 |
+
df_reduced = pd.DataFrame(reduced_embeddings, columns=["col1", "col2", "col3"])
|
| 92 |
+
df_reduced["Class"] = y_test
|
| 93 |
+
|
| 94 |
+
# 3D scatter plot
|
| 95 |
+
fig = px.scatter_3d(
|
| 96 |
+
df_reduced, x="col1", y="col2", z="col3", color="Class", title="3D"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
else: # 2D
|
| 100 |
+
if method == "PCA":
|
| 101 |
+
# Create an instance of PCA for 2D visualization and fit it on the training data
|
| 102 |
+
red = PCA(n_components=2)
|
| 103 |
+
red.fit(X_train)
|
| 104 |
+
|
| 105 |
+
# Use the trained model to transform the test data
|
| 106 |
+
reduced_embeddings = red.transform(X_test)
|
| 107 |
+
elif method == "t-SNE":
|
| 108 |
+
# Implement t-SNE for 2D visualization
|
| 109 |
+
red = TSNE(
|
| 110 |
+
n_components=2, perplexity=perplexity, random_state=42, init="pca"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Use the model to train and transform the test data
|
| 114 |
+
reduced_embeddings = red.fit_transform(X_test)
|
| 115 |
+
else:
|
| 116 |
+
raise ValueError("Invalid method. Please choose either 'PCA' or 't-SNE'.")
|
| 117 |
+
|
| 118 |
+
df_reduced = pd.DataFrame(reduced_embeddings, columns=["col1", "col2"])
|
| 119 |
+
df_reduced["Class"] = y_test
|
| 120 |
+
|
| 121 |
+
# 2D scatter plot
|
| 122 |
+
fig = px.scatter(df_reduced, x="col1", y="col2", color="Class", title="2D")
|
| 123 |
+
|
| 124 |
+
fig.update_layout(
|
| 125 |
+
title=f"Embeddings - {method} {plot_type} Visualization", scene=dict()
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
fig.show()
|
| 129 |
+
|
| 130 |
+
return red
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def test_model(X_test, y_test, model):
|
| 134 |
+
"""
|
| 135 |
+
Evaluates a trained model on a test set by computing key performance metrics and visualizing the results.
|
| 136 |
+
|
| 137 |
+
The function generates a confusion matrix, plots ROC curves (for binary or multi-class classification),
|
| 138 |
+
and prints the classification report. It also computes overall accuracy, weighted precision, weighted recall,
|
| 139 |
+
and weighted F1-score for the test data.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
X_test (np.ndarray): Test set feature data.
|
| 143 |
+
y_test (np.ndarray): True labels for the test set.
|
| 144 |
+
model (sklearn-like model): A trained machine learning model with `predict` and `predict_proba` methods.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
tuple:
|
| 148 |
+
- accuracy (float): Overall accuracy of the model on the test set.
|
| 149 |
+
- precision (float): Weighted precision score across all classes.
|
| 150 |
+
- recall (float): Weighted recall score across all classes.
|
| 151 |
+
- f1 (float): Weighted F1-score across all classes.
|
| 152 |
+
|
| 153 |
+
Side Effects:
|
| 154 |
+
- Displays a confusion matrix as a heatmap.
|
| 155 |
+
- Plots ROC curves for binary or multi-class classification.
|
| 156 |
+
- Prints the classification report with precision, recall, F1-score, and support for each class.
|
| 157 |
+
|
| 158 |
+
Example:
|
| 159 |
+
accuracy, precision, recall, f1 = test_model(X_test, y_test, trained_model)
|
| 160 |
+
|
| 161 |
+
Notes:
|
| 162 |
+
- If `y_test` is multi-dimensional (e.g., one-hot encoded), it will be squeezed to 1D.
|
| 163 |
+
- For binary classification, a single ROC curve is plotted. For multi-class classification,
|
| 164 |
+
an ROC curve is plotted for each class with a unique color.
|
| 165 |
+
- Weighted precision, recall, and F1-score are computed to handle class imbalance in multi-class classification.
|
| 166 |
+
|
| 167 |
+
"""
|
| 168 |
+
y_pred = model.predict(X_test)
|
| 169 |
+
y_pred_proba = model.predict_proba(X_test)
|
| 170 |
+
y_test = y_test.squeeze() if y_test.ndim > 1 else y_test
|
| 171 |
+
|
| 172 |
+
# Confusion matrix
|
| 173 |
+
cm = confusion_matrix(y_test, y_pred)
|
| 174 |
+
plt.figure(figsize=(10, 5))
|
| 175 |
+
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
|
| 176 |
+
plt.xlabel("Predicted")
|
| 177 |
+
plt.ylabel("True")
|
| 178 |
+
plt.title("Confusion Matrix")
|
| 179 |
+
plt.show()
|
| 180 |
+
|
| 181 |
+
# ROC curve
|
| 182 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
| 183 |
+
|
| 184 |
+
# Binary classification
|
| 185 |
+
if y_pred_proba.shape[1] == 2:
|
| 186 |
+
fpr, tpr, _ = roc_curve(y_test, y_pred_proba[:, 1])
|
| 187 |
+
ax.plot(
|
| 188 |
+
fpr,
|
| 189 |
+
tpr,
|
| 190 |
+
color="aqua",
|
| 191 |
+
lw=2,
|
| 192 |
+
label=f"ROC curve (area = {auc(fpr, tpr):.2f})",
|
| 193 |
+
)
|
| 194 |
+
ax.plot([0, 1], [0, 1], "k--", label="Chance level (AUC = 0.5)")
|
| 195 |
+
# Multiclass classification
|
| 196 |
+
else:
|
| 197 |
+
y_onehot_test = pd.get_dummies(y_test).values
|
| 198 |
+
colors = cycle(
|
| 199 |
+
[
|
| 200 |
+
"aqua",
|
| 201 |
+
"darkorange",
|
| 202 |
+
"cornflowerblue",
|
| 203 |
+
"red",
|
| 204 |
+
"green",
|
| 205 |
+
"yellow",
|
| 206 |
+
"purple",
|
| 207 |
+
"pink",
|
| 208 |
+
"brown",
|
| 209 |
+
"black",
|
| 210 |
+
]
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
for class_id, color in zip(range(y_onehot_test.shape[1]), colors):
|
| 214 |
+
fpr, tpr, _ = roc_curve(
|
| 215 |
+
y_onehot_test[:, class_id], y_pred_proba[:, class_id]
|
| 216 |
+
)
|
| 217 |
+
ax.plot(
|
| 218 |
+
fpr,
|
| 219 |
+
tpr,
|
| 220 |
+
color=color,
|
| 221 |
+
lw=2,
|
| 222 |
+
label=f"ROC curve for class {class_id} (area = {auc(fpr, tpr):.2f})",
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
ax.plot([0, 1], [0, 1], "k--", label="Chance level (AUC = 0.5)")
|
| 226 |
+
ax.set_axisbelow(True)
|
| 227 |
+
ax.set_xlabel("False Positive Rate")
|
| 228 |
+
ax.set_ylabel("True Positive Rate")
|
| 229 |
+
ax.set_title("ROC Curve")
|
| 230 |
+
ax.legend(loc="lower right")
|
| 231 |
+
plt.show()
|
| 232 |
+
|
| 233 |
+
cr = classification_report(y_test, y_pred)
|
| 234 |
+
print(cr)
|
| 235 |
+
|
| 236 |
+
accuracy = accuracy_score(y_test, y_pred)
|
| 237 |
+
precision = precision_score(y_test, y_pred, average="weighted")
|
| 238 |
+
recall = recall_score(y_test, y_pred, average="weighted")
|
| 239 |
+
f1 = f1_score(y_test, y_pred, average="weighted")
|
| 240 |
+
|
| 241 |
+
return accuracy, precision, recall, f1
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def train_and_evaluate_model(X_train, X_test, y_train, y_test, models=None, test=True):
|
| 245 |
+
"""
|
| 246 |
+
Trains and evaluates multiple machine learning models on a given dataset, then visualizes the data embeddings
|
| 247 |
+
using PCA before training. This function trains each model on the training data, evaluates them on the test data,
|
| 248 |
+
and computes performance metrics (accuracy, precision, recall, and F1-score).
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
X_train (np.ndarray): Feature matrix for the training data.
|
| 252 |
+
X_test (np.ndarray): Feature matrix for the test data.
|
| 253 |
+
y_train (np.ndarray): True labels for the training data.
|
| 254 |
+
y_test (np.ndarray): True labels for the test data.
|
| 255 |
+
models (list of tuples, optional): A list of tuples, where each tuple contains the model name as a string and
|
| 256 |
+
the corresponding scikit-learn model instance.
|
| 257 |
+
If None, default models include Random Forest, Decision Tree, and Logistic Regression.
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
list: A list of trained model tuples, where each tuple contains the model name and the trained model instance.
|
| 261 |
+
|
| 262 |
+
Side Effects:
|
| 263 |
+
- Displays a PCA 2D visualization of the embeddings using the `visualize_embeddings` function.
|
| 264 |
+
- Trains each model on the training set.
|
| 265 |
+
- Prints evaluation metrics (accuracy, precision, recall, F1-score) for each model on the test set.
|
| 266 |
+
- Displays confusion matrix and ROC curve for each model using the `test_model` function.
|
| 267 |
+
|
| 268 |
+
Example:
|
| 269 |
+
models = train_and_evaluate_model(X_train, X_test, y_train, y_test)
|
| 270 |
+
|
| 271 |
+
Notes:
|
| 272 |
+
- The `models` argument can be customized to include any classification models from scikit-learn.
|
| 273 |
+
- The function uses PCA for the embedding visualization. You can modify the `visualize_embeddings` function call for other visualization methods or dimensionality reduction techniques.
|
| 274 |
+
- Default models include Random Forest, Decision Tree, and Logistic Regression.
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
visualize_embeddings(X_train, X_test, y_train, y_test, plot_type="2D", method="PCA")
|
| 278 |
+
|
| 279 |
+
if not (models):
|
| 280 |
+
# Implement the ML models
|
| 281 |
+
models = [
|
| 282 |
+
(
|
| 283 |
+
"Random Forest",
|
| 284 |
+
RandomForestClassifier(n_estimators=100, random_state=42),
|
| 285 |
+
),
|
| 286 |
+
("Logistic Regression", LogisticRegression(max_iter=1000, random_state=42)),
|
| 287 |
+
]
|
| 288 |
+
|
| 289 |
+
for name, model in models:
|
| 290 |
+
print("#" * 20, f" {name} ", "#" * 20)
|
| 291 |
+
# Train the model on the training
|
| 292 |
+
model.fit(X_train, y_train)
|
| 293 |
+
|
| 294 |
+
# Evaluate the model on the test set using the test_model function
|
| 295 |
+
if test:
|
| 296 |
+
accuracy, precision, recall, f1 = test_model(X_test, y_test, model)
|
| 297 |
+
|
| 298 |
+
return models
|
src/classifiers_mlp.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from itertools import cycle
|
| 3 |
+
|
| 4 |
+
import matplotlib
|
| 5 |
+
import tensorflow as tf
|
| 6 |
+
|
| 7 |
+
# 💬 NOTE: Handle plots issues when running tests or displaying in notebooks
|
| 8 |
+
try:
|
| 9 |
+
get_ipython # Only exists in Jupyter
|
| 10 |
+
matplotlib.use("module://matplotlib_inline.backend_inline")
|
| 11 |
+
except Exception:
|
| 12 |
+
matplotlib.use("Agg") # Fix error with tests
|
| 13 |
+
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import seaborn as sns
|
| 18 |
+
from sklearn.metrics import (
|
| 19 |
+
accuracy_score,
|
| 20 |
+
classification_report,
|
| 21 |
+
confusion_matrix,
|
| 22 |
+
f1_score,
|
| 23 |
+
precision_score,
|
| 24 |
+
recall_score,
|
| 25 |
+
roc_auc_score,
|
| 26 |
+
roc_curve,
|
| 27 |
+
)
|
| 28 |
+
from sklearn.preprocessing import LabelEncoder
|
| 29 |
+
from sklearn.utils.class_weight import compute_class_weight
|
| 30 |
+
from tensorflow.keras import Input, Model
|
| 31 |
+
from tensorflow.keras.callbacks import EarlyStopping
|
| 32 |
+
from tensorflow.keras.layers import BatchNormalization, Concatenate, Dense, Dropout
|
| 33 |
+
from tensorflow.keras.losses import CategoricalCrossentropy
|
| 34 |
+
from tensorflow.keras.optimizers import SGD, Adam
|
| 35 |
+
from tensorflow.keras.utils import Sequence
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class MultimodalDataset(Sequence):
|
| 39 |
+
"""
|
| 40 |
+
Custom Keras Dataset class for multimodal data handling, designed for models that
|
| 41 |
+
take both text and image data as inputs. It facilitates batching and shuffling
|
| 42 |
+
of data for efficient training in Keras models.
|
| 43 |
+
|
| 44 |
+
This class supports loading and batching multimodal data (text and images), as well as handling
|
| 45 |
+
label encoding. It is compatible with Keras and can be used to train models that require both
|
| 46 |
+
text and image inputs. It also supports optional shuffling at the end of each epoch for better
|
| 47 |
+
training performance.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
df (pd.DataFrame): The DataFrame containing the dataset with text, image, and label columns.
|
| 51 |
+
text_cols (list): List of column names corresponding to text data. Can be a single column or multiple columns.
|
| 52 |
+
image_cols (list): List of column names corresponding to image data (usually file paths or image pixel data).
|
| 53 |
+
label_col (str): Column name corresponding to the target labels.
|
| 54 |
+
encoder (LabelEncoder, optional): A pre-fitted LabelEncoder instance for encoding the labels.
|
| 55 |
+
If None, a new LabelEncoder is fitted based on the provided data.
|
| 56 |
+
batch_size (int, optional): Number of samples per batch. Default is 32.
|
| 57 |
+
shuffle (bool, optional): Whether to shuffle the dataset at the end of each epoch. Default is True.
|
| 58 |
+
|
| 59 |
+
Attributes:
|
| 60 |
+
text_data (np.ndarray): Array of text data from the DataFrame. None if `text_cols` is not provided.
|
| 61 |
+
image_data (np.ndarray): Array of image data from the DataFrame. None if `image_cols` is not provided.
|
| 62 |
+
labels (np.ndarray): One-hot encoded labels corresponding to the dataset's classes.
|
| 63 |
+
encoder (LabelEncoder): Fitted LabelEncoder used to encode target labels.
|
| 64 |
+
batch_size (int): Number of samples per batch.
|
| 65 |
+
shuffle (bool): Flag indicating whether to shuffle the data after each epoch.
|
| 66 |
+
indices (np.ndarray): Array of indices representing the dataset. Used for shuffling batches.
|
| 67 |
+
|
| 68 |
+
Methods:
|
| 69 |
+
-------
|
| 70 |
+
__len__():
|
| 71 |
+
Returns the number of batches per epoch based on the dataset size and batch size.
|
| 72 |
+
|
| 73 |
+
__getitem__(idx):
|
| 74 |
+
Retrieves a single batch of data, including both text and image inputs and the corresponding labels.
|
| 75 |
+
The method returns a tuple in the format ({'text': text_batch, 'image': image_batch}, label_batch),
|
| 76 |
+
where 'text' and 'image' are only included if their respective columns were provided.
|
| 77 |
+
|
| 78 |
+
on_epoch_end():
|
| 79 |
+
Updates the index order after each epoch, shuffling if needed.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
df,
|
| 85 |
+
text_cols,
|
| 86 |
+
image_cols,
|
| 87 |
+
label_col,
|
| 88 |
+
encoder=None,
|
| 89 |
+
batch_size=32,
|
| 90 |
+
shuffle=True,
|
| 91 |
+
):
|
| 92 |
+
"""
|
| 93 |
+
Initializes the MultimodalDataset object.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
df (pd.DataFrame): The dataset as a DataFrame, containing text, image, and label data.
|
| 97 |
+
text_cols (list): List of column names representing text features.
|
| 98 |
+
image_cols (list): List of column names representing image features (e.g., file paths or pixel data).
|
| 99 |
+
label_col (str): Column name corresponding to the target labels.
|
| 100 |
+
encoder (LabelEncoder, optional): LabelEncoder for encoding the target labels. If None, a new LabelEncoder will be created.
|
| 101 |
+
batch_size (int, optional): Batch size for loading data. Default is 32.
|
| 102 |
+
shuffle (bool, optional): Whether to shuffle the data at the end of each epoch. Default is True.
|
| 103 |
+
|
| 104 |
+
Raises:
|
| 105 |
+
ValueError: If both text_cols and image_cols are None or empty.
|
| 106 |
+
"""
|
| 107 |
+
if text_cols:
|
| 108 |
+
# Get the text data from the DataFrame as a NumPy array
|
| 109 |
+
self.text_data = df[text_cols].astype(np.float32).values
|
| 110 |
+
else:
|
| 111 |
+
# Else, set text data to None
|
| 112 |
+
self.text_data = None
|
| 113 |
+
|
| 114 |
+
if image_cols:
|
| 115 |
+
# Get the image data from the DataFrame as a NumPy array
|
| 116 |
+
self.image_data = df[image_cols].astype(np.float32).values
|
| 117 |
+
else:
|
| 118 |
+
# Else, set image data to None
|
| 119 |
+
self.image_data = None
|
| 120 |
+
|
| 121 |
+
if not text_cols and not image_cols:
|
| 122 |
+
raise ValueError(
|
| 123 |
+
"At least one of text_cols or image_cols must be provided."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Get the labels from the DataFrame and encode them
|
| 127 |
+
self.labels = df[label_col].values
|
| 128 |
+
|
| 129 |
+
# Use provided encoder or fit a new one
|
| 130 |
+
if encoder is None:
|
| 131 |
+
self.encoder = LabelEncoder()
|
| 132 |
+
self.labels = self.encoder.fit_transform(self.labels)
|
| 133 |
+
else:
|
| 134 |
+
self.encoder = encoder
|
| 135 |
+
self.labels = self.encoder.transform(self.labels)
|
| 136 |
+
|
| 137 |
+
# One-hot encode labels for multi-class classification
|
| 138 |
+
num_classes = len(self.encoder.classes_)
|
| 139 |
+
self.labels = np.eye(num_classes)[self.labels]
|
| 140 |
+
|
| 141 |
+
self.batch_size = batch_size
|
| 142 |
+
self.shuffle = shuffle
|
| 143 |
+
self.on_epoch_end()
|
| 144 |
+
|
| 145 |
+
def __len__(self):
|
| 146 |
+
"""
|
| 147 |
+
Returns the number of batches per epoch based on the dataset size and batch size.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
-------
|
| 151 |
+
int:
|
| 152 |
+
The number of batches per epoch.
|
| 153 |
+
"""
|
| 154 |
+
return int(np.floor(len(self.labels) / self.batch_size))
|
| 155 |
+
|
| 156 |
+
def __getitem__(self, idx):
|
| 157 |
+
"""
|
| 158 |
+
Retrieves a single batch of data (text and/or image) and the corresponding labels.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
idx (int): Index of the batch to retrieve.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
-------
|
| 165 |
+
tuple:
|
| 166 |
+
A tuple containing the batch of text and/or image inputs and the corresponding labels.
|
| 167 |
+
The input data is returned as a dictionary with keys 'text' and 'image', depending on the provided columns.
|
| 168 |
+
If no text or image columns were provided, only the other is returned.
|
| 169 |
+
"""
|
| 170 |
+
indices = self.indices[idx * self.batch_size : (idx + 1) * self.batch_size]
|
| 171 |
+
|
| 172 |
+
if self.text_data is not None:
|
| 173 |
+
text_batch = self.text_data[indices]
|
| 174 |
+
if self.image_data is not None:
|
| 175 |
+
image_batch = self.image_data[indices]
|
| 176 |
+
label_batch = self.labels[indices]
|
| 177 |
+
|
| 178 |
+
if self.text_data is None:
|
| 179 |
+
return {"image": image_batch}, label_batch
|
| 180 |
+
if self.image_data is None:
|
| 181 |
+
return {"text": text_batch}, label_batch
|
| 182 |
+
else:
|
| 183 |
+
return {"text": text_batch, "image": image_batch}, label_batch
|
| 184 |
+
|
| 185 |
+
def on_epoch_end(self):
|
| 186 |
+
"""
|
| 187 |
+
Updates the index order after each epoch, shuffling the data if needed.
|
| 188 |
+
|
| 189 |
+
This method is called at the end of each epoch and will shuffle the data if the `shuffle` flag is set to True.
|
| 190 |
+
"""
|
| 191 |
+
self.indices = np.arange(len(self.labels))
|
| 192 |
+
if self.shuffle:
|
| 193 |
+
np.random.shuffle(self.indices)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# Early Fusion Model
|
| 197 |
+
def create_early_fusion_model(
|
| 198 |
+
text_input_size, image_input_size, output_size, hidden=[128], p=0.2
|
| 199 |
+
):
|
| 200 |
+
"""
|
| 201 |
+
Creates a multimodal early fusion model combining text and image inputs. The model concatenates the text and
|
| 202 |
+
image features, passes them through fully connected layers with optional dropout and batch normalization,
|
| 203 |
+
and produces a multi-class classification output.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
text_input_size (int): Size of the input vector for the text data.
|
| 207 |
+
image_input_size (int): Size of the input vector for the image data.
|
| 208 |
+
output_size (int): Number of classes for the output layer (i.e., size of the softmax output).
|
| 209 |
+
hidden (int or list, optional): Specifies the number of hidden units in the dense layers.
|
| 210 |
+
If an integer, a single dense layer with the specified units is created.
|
| 211 |
+
If a list, multiple dense layers are created with the respective units. Default is [128].
|
| 212 |
+
p (float, optional): Dropout rate to apply after each dense layer. Default is 0.2.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
Model (keras.Model): A compiled Keras model with text and image inputs and a softmax output for classification.
|
| 216 |
+
|
| 217 |
+
Model Architecture:
|
| 218 |
+
- The model accepts two inputs: one for text features and one for image features.
|
| 219 |
+
- The features are concatenated into a single vector.
|
| 220 |
+
- Dense layers with ReLU activation are applied, followed by dropout and batch normalization (if multiple hidden layers are specified).
|
| 221 |
+
- The output layer uses a softmax activation for multi-class classification.
|
| 222 |
+
|
| 223 |
+
Example:
|
| 224 |
+
model = create_early_fusion_model(text_input_size=300, image_input_size=2048, output_size=10, hidden=[128, 64], p=0.3)
|
| 225 |
+
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
if text_input_size is None and image_input_size is None:
|
| 229 |
+
raise ValueError(
|
| 230 |
+
"At least one of text_input_size and image_input_size must be provided."
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Define inputs
|
| 234 |
+
if text_input_size is not None:
|
| 235 |
+
# Define text input layer for only text data
|
| 236 |
+
text_input = Input(shape=(text_input_size,), name="text")
|
| 237 |
+
if image_input_size is not None:
|
| 238 |
+
# Define image input layer for only image data
|
| 239 |
+
image_input = Input(shape=(image_input_size,), name="image")
|
| 240 |
+
|
| 241 |
+
# Merge or select inputs
|
| 242 |
+
if text_input_size is not None and image_input_size is not None:
|
| 243 |
+
# Concatenate text and image inputs if both are provided
|
| 244 |
+
x = Concatenate(name="fusion_layer")([text_input, image_input])
|
| 245 |
+
elif text_input_size is not None:
|
| 246 |
+
x = text_input
|
| 247 |
+
elif image_input_size is not None:
|
| 248 |
+
x = image_input
|
| 249 |
+
|
| 250 |
+
# Hidden layers
|
| 251 |
+
if isinstance(hidden, int):
|
| 252 |
+
# Add a single dense layer, activation, dropout and normalization
|
| 253 |
+
x = Dense(hidden, activation="relu")(x)
|
| 254 |
+
x = Dropout(p)(x)
|
| 255 |
+
x = BatchNormalization()(x)
|
| 256 |
+
elif isinstance(hidden, list):
|
| 257 |
+
for h in hidden:
|
| 258 |
+
# Add multiple dense layers based on the hidden list, activation, dropout and normalization
|
| 259 |
+
x = Dense(h, activation="relu")(x)
|
| 260 |
+
x = Dropout(p)(x)
|
| 261 |
+
x = BatchNormalization()(x)
|
| 262 |
+
|
| 263 |
+
# Output layer
|
| 264 |
+
# Add the output layer with softmax activation
|
| 265 |
+
output = Dense(output_size, activation="softmax", name="output")(x)
|
| 266 |
+
|
| 267 |
+
# Create the model
|
| 268 |
+
if text_input_size is not None and image_input_size is not None:
|
| 269 |
+
# Define the model with both text and image inputs
|
| 270 |
+
model = Model(inputs=[text_input, image_input], outputs=output)
|
| 271 |
+
elif text_input_size is not None:
|
| 272 |
+
# Define the model with only text input
|
| 273 |
+
model = Model(inputs=text_input, outputs=output)
|
| 274 |
+
elif image_input_size is not None:
|
| 275 |
+
# Define the model with only image input
|
| 276 |
+
model = Model(inputs=image_input, outputs=output)
|
| 277 |
+
else:
|
| 278 |
+
raise ValueError(
|
| 279 |
+
"At least one of text_input_size and image_input_size must be provided."
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
return model
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def test_model(y_test, y_pred, y_prob=None, encoder=None):
|
| 286 |
+
"""
|
| 287 |
+
Evaluates a trained model's performance using various metrics such as accuracy, precision, recall, F1-score,
|
| 288 |
+
and visualizations including a confusion matrix and ROC curves.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
y_test (np.ndarray): Ground truth one-hot encoded labels for the test data.
|
| 292 |
+
y_pred (np.ndarray): Predicted class labels by the model for the test data (after argmax transformation).
|
| 293 |
+
y_prob (np.ndarray, optional): Predicted probabilities for each class from the model. Required for ROC curves. Default is None.
|
| 294 |
+
encoder (LabelEncoder, optional): A fitted LabelEncoder instance used to inverse transform one-hot encoded and predicted labels to their original categorical form.
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
accuracy (float): Accuracy score of the model on the test data.
|
| 298 |
+
precision (float): Weighted precision score of the model on the test data.
|
| 299 |
+
recall (float): Weighted recall score of the model on the test data.
|
| 300 |
+
f1 (float): Weighted F1 score of the model on the test data.
|
| 301 |
+
|
| 302 |
+
This function performs the following steps:
|
| 303 |
+
- Inverse transforms the one-hot encoded `y_test` and predicted `y_pred` values to their original labels using the provided LabelEncoder.
|
| 304 |
+
- Computes the confusion matrix and plots it as a heatmap using Seaborn.
|
| 305 |
+
- If `y_prob` is provided, computes and plots the ROC curves for each class.
|
| 306 |
+
- Prints the classification report, which includes precision, recall, F1-score, and support for each class.
|
| 307 |
+
- Returns the overall accuracy, weighted precision, recall, and F1-score of the model.
|
| 308 |
+
|
| 309 |
+
Visualizations:
|
| 310 |
+
- Confusion Matrix: A heatmap of the confusion matrix comparing the true labels with the predicted labels.
|
| 311 |
+
- ROC Curves: Plots ROC curves for each class if predicted probabilities are provided (`y_prob`).
|
| 312 |
+
|
| 313 |
+
Example:
|
| 314 |
+
accuracy, precision, recall, f1 = test_model(y_test, y_pred, y_prob, encoder)
|
| 315 |
+
"""
|
| 316 |
+
# Handle label decoding
|
| 317 |
+
y_test_binarized = y_test
|
| 318 |
+
y_test = encoder.inverse_transform(np.argmax(y_test, axis=1))
|
| 319 |
+
y_pred = encoder.inverse_transform(y_pred)
|
| 320 |
+
|
| 321 |
+
cm = confusion_matrix(y_test, y_pred)
|
| 322 |
+
fig, ax = plt.subplots(figsize=(15, 15))
|
| 323 |
+
sns.heatmap(cm, annot=True, cmap="Blues", fmt="g", ax=ax)
|
| 324 |
+
plt.xlabel("Predicted")
|
| 325 |
+
plt.ylabel("True")
|
| 326 |
+
plt.title("Confusion Matrix")
|
| 327 |
+
plt.show()
|
| 328 |
+
|
| 329 |
+
if y_prob is not None:
|
| 330 |
+
fig, ax = plt.subplots(figsize=(15, 15))
|
| 331 |
+
|
| 332 |
+
colors = cycle(["aqua", "darkorange", "cornflowerblue"])
|
| 333 |
+
|
| 334 |
+
for i, color in zip(range(y_prob.shape[1]), colors):
|
| 335 |
+
fpr, tpr, _ = roc_curve(y_test_binarized[:, i], y_prob[:, i])
|
| 336 |
+
ax.plot(fpr, tpr, color=color, lw=2, label=f"Class {i}")
|
| 337 |
+
|
| 338 |
+
ax.plot([0, 1], [0, 1], "k--")
|
| 339 |
+
plt.title("ROC Curve")
|
| 340 |
+
plt.ylabel("True Positive Rate")
|
| 341 |
+
plt.xlabel("False Positive Rate")
|
| 342 |
+
plt.legend()
|
| 343 |
+
plt.show()
|
| 344 |
+
|
| 345 |
+
cr = classification_report(y_test, y_pred)
|
| 346 |
+
print(cr)
|
| 347 |
+
|
| 348 |
+
accuracy = accuracy_score(y_test, y_pred)
|
| 349 |
+
precision = precision_score(y_test, y_pred, average="weighted")
|
| 350 |
+
recall = recall_score(y_test, y_pred, average="weighted")
|
| 351 |
+
f1 = f1_score(y_test, y_pred, average="weighted")
|
| 352 |
+
|
| 353 |
+
return accuracy, precision, recall, f1
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def train_mlp(
|
| 357 |
+
train_loader,
|
| 358 |
+
test_loader,
|
| 359 |
+
text_input_size,
|
| 360 |
+
image_input_size,
|
| 361 |
+
output_size,
|
| 362 |
+
num_epochs=50,
|
| 363 |
+
report=False,
|
| 364 |
+
lr=0.001,
|
| 365 |
+
set_weights=True,
|
| 366 |
+
adam=False,
|
| 367 |
+
p=0.0,
|
| 368 |
+
seed=1,
|
| 369 |
+
patience=40,
|
| 370 |
+
save_results=True,
|
| 371 |
+
train_model=True,
|
| 372 |
+
test_mlp_model=True,
|
| 373 |
+
):
|
| 374 |
+
"""
|
| 375 |
+
Trains a multimodal early fusion model using both text and image data.
|
| 376 |
+
|
| 377 |
+
The function handles the training process of the model by combining text and image features,
|
| 378 |
+
computes class weights if needed, applies an optimizer (SGD or Adam), and implements early stopping
|
| 379 |
+
to prevent overfitting. The model is evaluated on the test set, and key performance metrics are computed.
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
train_loader (MultimodalDataset): Keras-compatible data loader for the training set with both text and image data.
|
| 383 |
+
test_loader (MultimodalDataset): Keras-compatible data loader for the test set with both text and image data.
|
| 384 |
+
text_input_size (int): The size of the input vector for the text data.
|
| 385 |
+
image_input_size (int): The size of the input vector for the image data.
|
| 386 |
+
output_size (int): Number of output classes for the softmax layer.
|
| 387 |
+
num_epochs (int, optional): Number of training epochs. Default is 50.
|
| 388 |
+
report (bool, optional): Whether to generate a detailed classification report and display metrics. Default is False.
|
| 389 |
+
lr (float, optional): Learning rate for the optimizer. Default is 0.001.
|
| 390 |
+
set_weights (bool, optional): Whether to compute and apply class weights to handle imbalanced datasets. Default is True.
|
| 391 |
+
adam (bool, optional): Whether to use the Adam optimizer instead of SGD. Default is False.
|
| 392 |
+
p (float, optional): Dropout rate for regularization in the model. Default is 0.0.
|
| 393 |
+
seed (int, optional): Seed for random number generators to ensure reproducibility. Default is 1.
|
| 394 |
+
patience (int, optional): Number of epochs with no improvement on validation loss before early stopping. Default is 40.
|
| 395 |
+
|
| 396 |
+
Returns:
|
| 397 |
+
None
|
| 398 |
+
|
| 399 |
+
Side Effects:
|
| 400 |
+
- Trains the early fusion model and saves the best weights based on validation loss.
|
| 401 |
+
- Generates plots showing the training and validation accuracy over epochs.
|
| 402 |
+
- If `report` is True, calls `test_model` to print detailed evaluation metrics and plots.
|
| 403 |
+
|
| 404 |
+
Training Process:
|
| 405 |
+
- The function creates a fusion model combining text and image inputs.
|
| 406 |
+
- Class weights are computed to balance the dataset if `set_weights` is True.
|
| 407 |
+
- The model is trained using categorical cross-entropy loss and the chosen optimizer (Adam or SGD).
|
| 408 |
+
- Early stopping is applied based on validation loss to prevent overfitting.
|
| 409 |
+
- After training, the model is evaluated on the test set, and accuracy, F1-score, and AUC are calculated.
|
| 410 |
+
|
| 411 |
+
Example:
|
| 412 |
+
train_mlp(train_loader, test_loader, text_input_size=300, image_input_size=2048, output_size=10, num_epochs=30, lr=0.001, adam=True, report=True)
|
| 413 |
+
|
| 414 |
+
Notes:
|
| 415 |
+
- `train_loader` and `test_loader` should be instances of `MultimodalDataset` or compatible Keras data loaders.
|
| 416 |
+
- If the dataset is imbalanced, setting `set_weights=True` is recommended to ensure better model performance on minority classes.
|
| 417 |
+
"""
|
| 418 |
+
|
| 419 |
+
if seed is not None:
|
| 420 |
+
np.random.seed(seed)
|
| 421 |
+
tf.random.set_seed(seed)
|
| 422 |
+
|
| 423 |
+
# Create an early fusion model using the provided input sizes and output size
|
| 424 |
+
model = create_early_fusion_model(text_input_size, image_input_size, output_size)
|
| 425 |
+
|
| 426 |
+
# Compute class weights for imbalanced datasets
|
| 427 |
+
class_weights = None
|
| 428 |
+
if set_weights:
|
| 429 |
+
class_indices = np.argmax(train_loader.labels, axis=1)
|
| 430 |
+
# Compute class weights using the training labels
|
| 431 |
+
weights = compute_class_weight(
|
| 432 |
+
class_weight="balanced",
|
| 433 |
+
classes=np.unique(class_indices),
|
| 434 |
+
y=class_indices,
|
| 435 |
+
)
|
| 436 |
+
class_weights = {i: w for i, w in enumerate(weights)}
|
| 437 |
+
|
| 438 |
+
# Choose the loss function for multi-class classification
|
| 439 |
+
loss = CategoricalCrossentropy()
|
| 440 |
+
|
| 441 |
+
# Choose the optimizer
|
| 442 |
+
if adam:
|
| 443 |
+
# Use the Adam optimizer with the specified learning rate
|
| 444 |
+
optimizer = Adam(learning_rate=lr)
|
| 445 |
+
else:
|
| 446 |
+
# Use the SGD optimizer with the specified learning rate
|
| 447 |
+
optimizer = SGD(learning_rate=lr)
|
| 448 |
+
|
| 449 |
+
# Compile the model with the chosen optimizer and loss function
|
| 450 |
+
model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])
|
| 451 |
+
|
| 452 |
+
# Define an early stopping callback with the specified patience
|
| 453 |
+
early_stopping = EarlyStopping(
|
| 454 |
+
monitor="val_loss",
|
| 455 |
+
patience=patience,
|
| 456 |
+
restore_best_weights=True,
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# Train the model using the training data and validation data
|
| 460 |
+
history = None
|
| 461 |
+
if train_model:
|
| 462 |
+
history = model.fit(
|
| 463 |
+
train_loader,
|
| 464 |
+
validation_data=test_loader,
|
| 465 |
+
epochs=num_epochs,
|
| 466 |
+
class_weight=class_weights,
|
| 467 |
+
callbacks=[early_stopping],
|
| 468 |
+
verbose="1",
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
if test_mlp_model:
|
| 472 |
+
# Test the model on the test set
|
| 473 |
+
y_true, y_pred, y_prob = [], [], []
|
| 474 |
+
for batch in test_loader:
|
| 475 |
+
features, labels = batch
|
| 476 |
+
if len(features) == 1:
|
| 477 |
+
text = features["text"] if "text" in features else features["image"]
|
| 478 |
+
preds = model.predict(text)
|
| 479 |
+
else:
|
| 480 |
+
text, image = features["text"], features["image"]
|
| 481 |
+
preds = model.predict([text, image])
|
| 482 |
+
y_true.extend(labels)
|
| 483 |
+
y_pred.extend(np.argmax(preds, axis=1))
|
| 484 |
+
y_prob.extend(preds)
|
| 485 |
+
|
| 486 |
+
y_true, y_pred, y_prob = np.array(y_true), np.array(y_pred), np.array(y_prob)
|
| 487 |
+
|
| 488 |
+
test_accuracy = accuracy_score(np.argmax(y_true, axis=1), y_pred)
|
| 489 |
+
f1 = f1_score(np.argmax(y_true, axis=1), y_pred, average="macro")
|
| 490 |
+
|
| 491 |
+
auc_scores = roc_auc_score(y_true, y_prob, average="macro", multi_class="ovr")
|
| 492 |
+
macro_auc = auc_scores
|
| 493 |
+
|
| 494 |
+
plt.plot(history.history["accuracy"], label="Train Accuracy")
|
| 495 |
+
plt.plot(history.history["val_accuracy"], label="Validation Accuracy")
|
| 496 |
+
plt.xlabel("Epoch")
|
| 497 |
+
plt.ylabel("Accuracy")
|
| 498 |
+
plt.legend()
|
| 499 |
+
plt.show()
|
| 500 |
+
|
| 501 |
+
if report:
|
| 502 |
+
test_model(y_true, y_pred, y_prob, encoder=train_loader.encoder)
|
| 503 |
+
|
| 504 |
+
# Store results in a dataframe and save in the results folder
|
| 505 |
+
if text_input_size is not None and image_input_size is not None:
|
| 506 |
+
model_type = "multimodal"
|
| 507 |
+
elif text_input_size is not None:
|
| 508 |
+
model_type = "text"
|
| 509 |
+
elif image_input_size is not None:
|
| 510 |
+
model_type = "image"
|
| 511 |
+
|
| 512 |
+
if save_results:
|
| 513 |
+
results = pd.DataFrame(
|
| 514 |
+
{"Predictions": y_pred, "True Labels": np.argmax(y_true, axis=1)}
|
| 515 |
+
)
|
| 516 |
+
# create results folder if it does not exist
|
| 517 |
+
os.makedirs("results", exist_ok=True)
|
| 518 |
+
results.to_csv(f"results/{model_type}_results.csv", index=False)
|
| 519 |
+
else:
|
| 520 |
+
test_accuracy, f1, macro_auc = None, None, None
|
| 521 |
+
|
| 522 |
+
return model, test_accuracy, f1, macro_auc
|
src/nlp_models.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import AutoModel, AutoTokenizer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HuggingFaceEmbeddings:
|
| 11 |
+
"""
|
| 12 |
+
A class to handle text embedding generation using a Hugging Face pre-trained transformer model.
|
| 13 |
+
This class loads the model, tokenizes the input text, generates embeddings, and provides an option
|
| 14 |
+
to save the embeddings to a CSV file.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
model_name (str, optional): The name of the Hugging Face pre-trained model to use for generating embeddings.
|
| 18 |
+
Default is 'sentence-transformers/all-MiniLM-L6-v2'.
|
| 19 |
+
path (str, optional): The path to the CSV file containing the text data. Default is 'data/file.csv'.
|
| 20 |
+
save_path (str, optional): The directory path where the embeddings will be saved. Default is 'Models'.
|
| 21 |
+
device (str, optional): The device to run the model on ('cpu' or 'cuda'). If None, it will automatically detect
|
| 22 |
+
a GPU if available; otherwise, it defaults to CPU.
|
| 23 |
+
|
| 24 |
+
Attributes:
|
| 25 |
+
model_name (str): The name of the Hugging Face model used for embedding generation.
|
| 26 |
+
tokenizer (transformers.AutoTokenizer): The tokenizer corresponding to the chosen model.
|
| 27 |
+
model (transformers.AutoModel): The pre-trained model loaded for embedding generation.
|
| 28 |
+
path (str): Path to the input CSV file.
|
| 29 |
+
save_path (str): Directory where the embeddings CSV will be saved.
|
| 30 |
+
device (torch.device): The device on which the model and data are processed (CPU or GPU).
|
| 31 |
+
|
| 32 |
+
Methods:
|
| 33 |
+
get_embedding(text):
|
| 34 |
+
Generates embeddings for a given text input using the pre-trained model.
|
| 35 |
+
|
| 36 |
+
get_embedding_df(column, directory, file):
|
| 37 |
+
Reads a CSV file, computes embeddings for a specified text column, and saves the resulting DataFrame
|
| 38 |
+
with embeddings to a new CSV file in the specified directory.
|
| 39 |
+
|
| 40 |
+
Example:
|
| 41 |
+
embedding_instance = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2',
|
| 42 |
+
path='data/products.csv', save_path='output')
|
| 43 |
+
text_embedding = embedding_instance.get_embedding("Sample product description.")
|
| 44 |
+
embedding_instance.get_embedding_df(column='description', directory='output', file='product_embeddings.csv')
|
| 45 |
+
|
| 46 |
+
Notes:
|
| 47 |
+
- The Hugging Face model and tokenizer are downloaded from the Hugging Face hub.
|
| 48 |
+
- The function supports large models and can run on either GPU or CPU, depending on device availability.
|
| 49 |
+
- The input text will be truncated and padded to a maximum length of 512 tokens to fit into the model.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 55 |
+
path="data/file.csv",
|
| 56 |
+
save_path=None,
|
| 57 |
+
device=None,
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
Initializes the HuggingFaceEmbeddings class with the specified model and paths.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
model_name (str, optional): The name of the Hugging Face pre-trained model. Default is 'sentence-transformers/all-MiniLM-L6-v2'.
|
| 64 |
+
path (str, optional): The path to the CSV file containing text data. Default is 'data/file.csv'.
|
| 65 |
+
save_path (str, optional): Directory path where the embeddings will be saved. Default is 'Models'.
|
| 66 |
+
device (str, optional): Device to use for model processing. Defaults to 'cuda' if available, otherwise 'cpu'.
|
| 67 |
+
"""
|
| 68 |
+
self.model_name = model_name
|
| 69 |
+
# Load the Hugging Face tokenizer from a pre-trained model
|
| 70 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 71 |
+
|
| 72 |
+
# Load the model from the Hugging Face model hub from the specified model name
|
| 73 |
+
self.model = AutoModel.from_pretrained(model_name)
|
| 74 |
+
self.path = path
|
| 75 |
+
self.save_path = save_path or "Models"
|
| 76 |
+
|
| 77 |
+
# Define device
|
| 78 |
+
if device is None:
|
| 79 |
+
# Note: If you have a mac, you may want to change 'cuda' to 'mps' to use GPU
|
| 80 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
+
else:
|
| 82 |
+
self.device = torch.device(device)
|
| 83 |
+
print(f"Using device: {self.device}")
|
| 84 |
+
|
| 85 |
+
# Move model to the specified device
|
| 86 |
+
self.model.to(self.device)
|
| 87 |
+
print(f"Model moved to device: {self.device}")
|
| 88 |
+
print(f"Model: {model_name}")
|
| 89 |
+
|
| 90 |
+
def get_embedding(self, text):
|
| 91 |
+
"""
|
| 92 |
+
Generates embeddings for a given text using the Hugging Face model.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
text (str): The input text for which embeddings will be generated.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
np.ndarray: A numpy array containing the embedding vector for the input text.
|
| 99 |
+
"""
|
| 100 |
+
# Tokenize the input text using the Hugging Face tokenizer
|
| 101 |
+
inputs = self.tokenizer(
|
| 102 |
+
text, return_tensors="pt", truncation=True, padding=True, max_length=512
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Move the inputs to the device
|
| 106 |
+
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
| 107 |
+
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
# Generate the embeddings using the Hugging Face model from the tokenized input
|
| 110 |
+
outputs = self.model(**inputs)
|
| 111 |
+
|
| 112 |
+
# Extract the embeddings from the model output, send to cpu and return the numpy array
|
| 113 |
+
last_hidden_state = outputs.last_hidden_state
|
| 114 |
+
|
| 115 |
+
embeddings = last_hidden_state.mean(dim=1)
|
| 116 |
+
embeddings = embeddings.cpu().numpy()
|
| 117 |
+
|
| 118 |
+
return embeddings[0]
|
| 119 |
+
|
| 120 |
+
def get_embedding_df(self, column, directory, file):
|
| 121 |
+
# Load the CSV file
|
| 122 |
+
df = pd.read_csv(self.path)
|
| 123 |
+
# Generate embeddings for the specified column using the `get_embedding` method
|
| 124 |
+
df["embeddings"] = df[column].apply(
|
| 125 |
+
lambda x: self.get_embedding(str(x)).tolist() if pd.notnull(x) else None
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
os.makedirs(directory, exist_ok=True)
|
| 129 |
+
|
| 130 |
+
# Save the DataFrame with the embeddings to a new CSV file in the specified directory
|
| 131 |
+
output_path = os.path.join(directory, file)
|
| 132 |
+
df.to_csv(output_path, index=False)
|
| 133 |
+
|
| 134 |
+
print(f"✅ Embeddings saved to {output_path}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class GPT:
|
| 138 |
+
"""
|
| 139 |
+
A class to interact with the OpenAI GPT API for generating text embeddings from a given dataset.
|
| 140 |
+
This class provides methods to retrieve embeddings for text data and save them to a CSV file.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
path (str, optional): The path to the CSV file containing the text data. Default is 'data/file.csv'.
|
| 144 |
+
embedding_model (str, optional): The embedding model to use for generating text embeddings.
|
| 145 |
+
Default is 'text-embedding-3-small'.
|
| 146 |
+
|
| 147 |
+
Attributes:
|
| 148 |
+
path (str): Path to the CSV file.
|
| 149 |
+
embedding_model (str): The embedding model used for generating text embeddings.
|
| 150 |
+
|
| 151 |
+
Methods:
|
| 152 |
+
get_embedding(text):
|
| 153 |
+
Generates and returns the embedding vector for the given text using the OpenAI API.
|
| 154 |
+
|
| 155 |
+
get_embedding_df(column, directory, file):
|
| 156 |
+
Reads a CSV file, computes the embeddings for a specified text column, and saves the embeddings
|
| 157 |
+
to a new CSV file in the specified directory.
|
| 158 |
+
|
| 159 |
+
Example:
|
| 160 |
+
gpt_instance = GPT(path='data/products.csv', embedding_model='text-embedding-ada-002')
|
| 161 |
+
text_embedding = gpt_instance.get_embedding("Sample product description.")
|
| 162 |
+
gpt_instance.get_embedding_df(column='description', directory='output', file='product_embeddings.csv')
|
| 163 |
+
|
| 164 |
+
Notes:
|
| 165 |
+
- The OpenAI API key must be stored in a `.env` file with the variable name `OPENAI_API_KEY`.
|
| 166 |
+
- The OpenAI Python package should be installed (`pip install openai`), and an active OpenAI API key is required.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
def __init__(self, path="data/file.csv", embedding_model="text-embedding-3-small"):
|
| 170 |
+
"""
|
| 171 |
+
Initializes the GPT class with the provided CSV file path and embedding model.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
path (str, optional): The path to the CSV file containing the text data. Default is 'data/file.csv'.
|
| 175 |
+
embedding_model (str, optional): The embedding model to use for generating text embeddings.
|
| 176 |
+
Default is 'text-embedding-3-small'.
|
| 177 |
+
"""
|
| 178 |
+
import openai
|
| 179 |
+
from dotenv import find_dotenv, load_dotenv
|
| 180 |
+
|
| 181 |
+
# Load the OpenAI API key from the .env file
|
| 182 |
+
_ = load_dotenv(find_dotenv()) # read local .env file
|
| 183 |
+
# Set the OpenAI API key
|
| 184 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
| 185 |
+
|
| 186 |
+
self.path = path
|
| 187 |
+
self.embedding_model = embedding_model
|
| 188 |
+
|
| 189 |
+
def get_embedding(self, text):
|
| 190 |
+
"""
|
| 191 |
+
Generates and returns the embedding vector for the given text using the OpenAI API.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
text (str): The input text to generate the embedding for.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
list: A list containing the embedding vector for the input text.
|
| 198 |
+
"""
|
| 199 |
+
from openai import OpenAI
|
| 200 |
+
|
| 201 |
+
# Instantiate the OpenAI client
|
| 202 |
+
client = OpenAI()
|
| 203 |
+
|
| 204 |
+
# Optional. Do text preprocessing if needed (e.g., removing newlines)
|
| 205 |
+
text = text.replace("\n", " ").strip()
|
| 206 |
+
|
| 207 |
+
# Call the OpenAI API to generate the embeddings and return only the embedding data
|
| 208 |
+
response = client.embeddings.create(model=self.embedding_model, input=text)
|
| 209 |
+
|
| 210 |
+
embeddings_np = np.array(response.data[0].embedding, dtype=np.float32)
|
| 211 |
+
return embeddings_np
|
| 212 |
+
|
| 213 |
+
def get_embedding_df(self, column, directory, file):
|
| 214 |
+
"""
|
| 215 |
+
Reads a CSV file, computes the embeddings for a specified text column, and saves the results in a new CSV file.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
column (str): The name of the column in the CSV file that contains the text data.
|
| 219 |
+
directory (str): The directory where the output CSV file will be saved.
|
| 220 |
+
file (str): The name of the output CSV file.
|
| 221 |
+
|
| 222 |
+
Side Effects:
|
| 223 |
+
- Saves a new CSV file containing the original data along with the computed embeddings to the specified directory.
|
| 224 |
+
"""
|
| 225 |
+
# Load the CSV file
|
| 226 |
+
df = pd.read_csv(self.path)
|
| 227 |
+
|
| 228 |
+
if column not in df.columns:
|
| 229 |
+
raise ValueError(f"Column '{column}' not found in CSV")
|
| 230 |
+
|
| 231 |
+
# Generate embeddings in a new column 'embeddings', for the specified column using the `get_embedding` method
|
| 232 |
+
df["embeddings"] = df[column].apply(
|
| 233 |
+
lambda x: json.dumps(self.get_embedding(str(x)).tolist())
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
os.makedirs(directory, exist_ok=True)
|
| 237 |
+
|
| 238 |
+
# Save the DataFrame with the embeddings to a new CSV file in the specified directory
|
| 239 |
+
output_path = os.path.join(directory, file)
|
| 240 |
+
df.to_csv(output_path, index=False)
|
| 241 |
+
|
| 242 |
+
print(f"✅ Embeddings saved to {output_path}")
|
src/utils.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import warnings
|
| 3 |
+
from io import BytesIO
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import requests
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from sklearn.model_selection import train_test_split
|
| 10 |
+
|
| 11 |
+
# 💬 NOTE: Suppress all warnings
|
| 12 |
+
warnings.filterwarnings("ignore")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def process_embeddings(df, col_name):
|
| 16 |
+
"""
|
| 17 |
+
Process embeddings in a DataFrame column.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
- df (pd.DataFrame): The DataFrame containing the embeddings column.
|
| 21 |
+
- col_name (str): The name of the column containing the embeddings.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
pd.DataFrame: The DataFrame with processed embeddings.
|
| 25 |
+
|
| 26 |
+
Steps:
|
| 27 |
+
1. Convert the values in the specified column to lists.
|
| 28 |
+
2. Extract values from lists and create new columns for each element.
|
| 29 |
+
3. Remove the original embeddings column.
|
| 30 |
+
|
| 31 |
+
Example:
|
| 32 |
+
df_processed = process_embeddings(df, 'embeddings')
|
| 33 |
+
"""
|
| 34 |
+
# Convert the values (eg. "[-0.123, 0.456, ...]") in the column to lists
|
| 35 |
+
df[col_name] = df[col_name].apply(eval)
|
| 36 |
+
|
| 37 |
+
# Extract values from lists and create new columns
|
| 38 |
+
""" 🔎 Example
|
| 39 |
+
text_1 text_2 text_3
|
| 40 |
+
0 -0.123 0.456 0.789
|
| 41 |
+
1 0.321 -0.654 0.987
|
| 42 |
+
"""
|
| 43 |
+
embeddings_df = pd.DataFrame(
|
| 44 |
+
df[col_name].to_list(),
|
| 45 |
+
columns=[f"text_{i + 1}" for i in range(df[col_name].str.len().max())],
|
| 46 |
+
)
|
| 47 |
+
df = pd.concat([df, embeddings_df], axis=1)
|
| 48 |
+
|
| 49 |
+
# Remove the original "embeddings" column
|
| 50 |
+
df = df.drop(columns=[col_name])
|
| 51 |
+
|
| 52 |
+
return df
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def rename_image_embeddings(df):
|
| 56 |
+
"""
|
| 57 |
+
Rename columns in a DataFrame for image embeddings.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
- df (pd.DataFrame): The DataFrame containing columns to be renamed.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
pd.DataFrame: The DataFrame with renamed columns.
|
| 64 |
+
|
| 65 |
+
Example:
|
| 66 |
+
df_renamed = rename_image_embeddings(df)
|
| 67 |
+
"""
|
| 68 |
+
# From 0 1 2 label ➡️ image_0 image_1 image_2 label
|
| 69 |
+
df.columns = [f"image_{int(col)}" if col.isdigit() else col for col in df.columns]
|
| 70 |
+
|
| 71 |
+
return df
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def preprocess_data(
|
| 75 |
+
text_data,
|
| 76 |
+
image_data,
|
| 77 |
+
text_id="image_id",
|
| 78 |
+
image_id="ImageName",
|
| 79 |
+
embeddings_col="embeddings",
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
Preprocess and merge text and image dataframes.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
- text_data (pd.DataFrame): DataFrame containing text data.
|
| 86 |
+
- image_data (pd.DataFrame): DataFrame containing image data.
|
| 87 |
+
- text_id (str): Column name for text data identifier.
|
| 88 |
+
- image_id (str): Column name for image data identifier.
|
| 89 |
+
- embeddings_col (str): Column name for embeddings data.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
pd.DataFrame: Merged and preprocessed DataFrame.
|
| 93 |
+
|
| 94 |
+
This function:
|
| 95 |
+
Process text and image embeddings.
|
| 96 |
+
Convert image_id and text_id values to integers.
|
| 97 |
+
Merge dataframes using id.
|
| 98 |
+
Drop unnecessary columns.
|
| 99 |
+
|
| 100 |
+
Example:
|
| 101 |
+
merged_df = preprocess_data(text_df, image_df)
|
| 102 |
+
"""
|
| 103 |
+
# Call previous functions to tune the text and image dataframes
|
| 104 |
+
text_data = process_embeddings(text_data, embeddings_col)
|
| 105 |
+
image_data = rename_image_embeddings(image_data)
|
| 106 |
+
|
| 107 |
+
# Drop missing values in image id - Removes rows where the ID (used to join text ↔ image) is missing.
|
| 108 |
+
image_data = image_data.dropna(subset=[image_id])
|
| 109 |
+
text_data = text_data.dropna(subset=[text_id])
|
| 110 |
+
|
| 111 |
+
# Cleans up text IDs: if the column contains file paths (like "data/images/123.jpg"), it extracts just the file name ("123.jpg").
|
| 112 |
+
text_data[text_id] = text_data[text_id].apply(lambda x: x.split("/")[-1])
|
| 113 |
+
|
| 114 |
+
# Merge dataframes using image_id - Joins text and image embeddings using the IDs (text_id vs image_id).
|
| 115 |
+
df = pd.merge(text_data, image_data, left_on=text_id, right_on=image_id)
|
| 116 |
+
|
| 117 |
+
# Drop unnecessary columns - Removes the original ID columns since they’re no longer needed after the merge.
|
| 118 |
+
df.drop([image_id, text_id], axis=1, inplace=True)
|
| 119 |
+
|
| 120 |
+
return df
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class ImageDownloader:
|
| 124 |
+
"""
|
| 125 |
+
Image downloader class to download images from URLs.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
- image_dir (str): Directory to save images.
|
| 129 |
+
- image_size (tuple): Size of the images to be saved.
|
| 130 |
+
- override (bool): Whether to override existing images.
|
| 131 |
+
|
| 132 |
+
Methods:
|
| 133 |
+
- download_images(df, print_every=1000): Download images from URLs in a DataFrame.
|
| 134 |
+
Args:
|
| 135 |
+
- df (pd.DataFrame): DataFrame containing image URLs.
|
| 136 |
+
- print_every (int): Print progress every n images.
|
| 137 |
+
Returns:
|
| 138 |
+
pd.DataFrame: DataFrame with image paths added.
|
| 139 |
+
|
| 140 |
+
Example:
|
| 141 |
+
downloader = ImageDownloader()
|
| 142 |
+
df = downloader.download_images(df)
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(
|
| 146 |
+
self, image_dir="data/images/", image_size=(224, 224), overwrite=False
|
| 147 |
+
):
|
| 148 |
+
self.image_dir = image_dir
|
| 149 |
+
self.image_size = image_size
|
| 150 |
+
self.overwrite = overwrite
|
| 151 |
+
|
| 152 |
+
# Create the directory if it doesn't exist
|
| 153 |
+
if not os.path.exists(self.image_dir):
|
| 154 |
+
os.makedirs(self.image_dir)
|
| 155 |
+
|
| 156 |
+
def download_images(self, df, print_every=1000):
|
| 157 |
+
# Bulk download images from a DataFrame of URLs, resize them to a standard format, and add their local paths back to the DataFrame.
|
| 158 |
+
image_paths = []
|
| 159 |
+
|
| 160 |
+
i = 0
|
| 161 |
+
for index, row in df.iterrows():
|
| 162 |
+
if i % print_every == 0:
|
| 163 |
+
print(f"Downloading image {i}/{len(df)}")
|
| 164 |
+
i += 1
|
| 165 |
+
|
| 166 |
+
sku = row["sku"]
|
| 167 |
+
image_url = row["image"]
|
| 168 |
+
image_path = os.path.join(self.image_dir, f"{sku}.jpg")
|
| 169 |
+
|
| 170 |
+
if os.path.exists(image_path) and not self.overwrite:
|
| 171 |
+
print(f"Image {sku} is already in the path.")
|
| 172 |
+
image_paths.append(image_path)
|
| 173 |
+
continue
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
response = requests.get(image_url)
|
| 177 |
+
response.raise_for_status()
|
| 178 |
+
img = Image.open(BytesIO(response.content))
|
| 179 |
+
img = img.resize(self.image_size, Image.Resampling.LANCZOS)
|
| 180 |
+
img.save(image_path)
|
| 181 |
+
# print(f"Downloaded image for SKU: {sku}")
|
| 182 |
+
image_paths.append(image_path)
|
| 183 |
+
except Exception as e:
|
| 184 |
+
print(f"Could not download image for SKU: {sku}. Error: {e}")
|
| 185 |
+
image_paths.append(np.nan)
|
| 186 |
+
|
| 187 |
+
df["image_path"] = image_paths
|
| 188 |
+
return df
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def train_test_split_and_feature_extraction(df, test_size=0.3, random_state=42):
|
| 192 |
+
"""
|
| 193 |
+
Split the data into train and test sets and extract features and labels.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
- df (pd.DataFrame): DataFrame containing the data.
|
| 197 |
+
|
| 198 |
+
Keyword Args:
|
| 199 |
+
- test_size (float): Size of the test set.
|
| 200 |
+
- random_state (int): Random state for reproducibility
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
pd.DataFrame: Train DataFrame.
|
| 204 |
+
pd.DataFrame: Test DataFrame.
|
| 205 |
+
list: List of columns with text embeddings.
|
| 206 |
+
list: List of columns with image embeddings.
|
| 207 |
+
list: List of columns with class labels.
|
| 208 |
+
|
| 209 |
+
Example:
|
| 210 |
+
train_df, test_df, text_columns, image_columns, label_columns = train_test_split_and_feature_extraction(df)
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
# Split the data into train and test sets setting using the test_size and random_state parameters
|
| 214 |
+
train_df, test_df = train_test_split(
|
| 215 |
+
df, test_size=test_size, random_state=random_state
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Select the name of the columns with the text embeddings and return it as a list (Even if there is only one column)
|
| 219 |
+
text_columns = [col for col in df.columns if col.startswith("text_")]
|
| 220 |
+
|
| 221 |
+
# Select the name of the columns with the image embeddings and return it as a list (Even if there is only one column)
|
| 222 |
+
image_columns = [col for col in df.columns if col.startswith("image_")]
|
| 223 |
+
|
| 224 |
+
# Select the name of the column with the class labels and return it as a list (Even if there is only one column)
|
| 225 |
+
label_columns = ["class_id"]
|
| 226 |
+
|
| 227 |
+
return train_df, test_df, text_columns, image_columns, label_columns
|
src/vision_embeddings_tf.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import warnings
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import tensorflow as tf
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from tensorflow.keras.applications import (
|
| 9 |
+
DenseNet121,
|
| 10 |
+
DenseNet169,
|
| 11 |
+
InceptionV3,
|
| 12 |
+
ResNet50,
|
| 13 |
+
ResNet101,
|
| 14 |
+
)
|
| 15 |
+
from tensorflow.keras.layers import GlobalAveragePooling2D, Input
|
| 16 |
+
from tensorflow.keras.models import Model
|
| 17 |
+
from transformers import TFConvNextV2Model, TFSwinModel, TFViTModel
|
| 18 |
+
|
| 19 |
+
# 💬 NOTE: Suppress TensorFlow warnings
|
| 20 |
+
warnings.filterwarnings("ignore")
|
| 21 |
+
tf.get_logger().setLevel("ERROR")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_and_preprocess_image(image_path, target_size=(224, 224)):
|
| 25 |
+
"""
|
| 26 |
+
Load and preprocess an image.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
- image_path (str): Path to the image file.
|
| 30 |
+
- target_size (tuple): Desired image size.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
- np.array: Preprocessed image.
|
| 34 |
+
"""
|
| 35 |
+
# Open the image using PIL Image.open and convert it to RGB format
|
| 36 |
+
img = Image.open(image_path).convert("RGB")
|
| 37 |
+
|
| 38 |
+
# Resize the image to the target size
|
| 39 |
+
img = img.resize(target_size)
|
| 40 |
+
|
| 41 |
+
# Convert the image to a numpy array and scale the pixel values to [0, 1]
|
| 42 |
+
img = np.array(img, dtype=np.float32) / 255.0
|
| 43 |
+
|
| 44 |
+
return img
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class FoundationalCVModel:
|
| 48 |
+
"""
|
| 49 |
+
A Keras module for loading and using foundational computer vision models.
|
| 50 |
+
|
| 51 |
+
This class allows you to load and use various foundational computer vision models for tasks like image classification
|
| 52 |
+
or feature extraction. The user can choose between evaluation mode (non-trainable model) and fine-tuning mode (trainable model).
|
| 53 |
+
|
| 54 |
+
Attributes:
|
| 55 |
+
----------
|
| 56 |
+
backbone_name : str
|
| 57 |
+
The name of the foundational CV model to load (e.g., 'resnet50', 'vit_base').
|
| 58 |
+
model : keras.Model
|
| 59 |
+
The compiled Keras model with the selected backbone.
|
| 60 |
+
|
| 61 |
+
Parameters:
|
| 62 |
+
----------
|
| 63 |
+
backbone : str
|
| 64 |
+
The name of the foundational CV model to load. The available backbones can include:
|
| 65 |
+
- ResNet variants: 'resnet50', 'resnet101'
|
| 66 |
+
- DenseNet variants: 'densenet121', 'densenet169'
|
| 67 |
+
- InceptionV3: 'inception_v3'
|
| 68 |
+
- ConvNextV2 variants: 'convnextv2_tiny', 'convnextv2_base', 'convnextv2_large'
|
| 69 |
+
- Swin Transformer variants: 'swin_tiny', 'swin_small', 'swin_base'
|
| 70 |
+
- Vision Transformer (ViT) variants: 'vit_base', 'vit_large'
|
| 71 |
+
|
| 72 |
+
mode : str, optional
|
| 73 |
+
The mode of the model, either 'eval' for evaluation or 'fine_tune' for fine-tuning. Default is 'eval'.
|
| 74 |
+
|
| 75 |
+
Methods:
|
| 76 |
+
-------
|
| 77 |
+
__init__(self, backbone, mode='eval'):
|
| 78 |
+
Initializes the model with the specified backbone and mode.
|
| 79 |
+
|
| 80 |
+
predict(self, images):
|
| 81 |
+
Given a batch of images, performs a forward pass through the model and returns predictions.
|
| 82 |
+
Parameters:
|
| 83 |
+
----------
|
| 84 |
+
images : numpy.ndarray
|
| 85 |
+
A batch of images to perform prediction on, with shape (batch_size, 224, 224, 3).
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
-------
|
| 89 |
+
numpy.ndarray
|
| 90 |
+
Model predictions or extracted features for the provided images.
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, backbone, mode="eval", input_shape=(224, 224, 3)):
|
| 94 |
+
self.backbone_name = backbone
|
| 95 |
+
|
| 96 |
+
# Select the backbone from the possible foundational models
|
| 97 |
+
input_layer = Input(shape=input_shape)
|
| 98 |
+
|
| 99 |
+
if backbone == "resnet50":
|
| 100 |
+
# Load the ResNet50 model from tensorflow.keras.applications
|
| 101 |
+
self.base_model = ResNet50(
|
| 102 |
+
include_top=False, weights="imagenet", input_tensor=input_layer
|
| 103 |
+
)
|
| 104 |
+
elif backbone == "resnet101":
|
| 105 |
+
# Load the ResNet101 model from tensorflow.keras.applications
|
| 106 |
+
self.base_model = ResNet101(
|
| 107 |
+
include_top=False, weights="imagenet", input_tensor=input_layer
|
| 108 |
+
)
|
| 109 |
+
elif backbone == "densenet121":
|
| 110 |
+
# Load the DenseNet121 model from tensorflow.keras.applications
|
| 111 |
+
self.base_model = DenseNet121(
|
| 112 |
+
include_top=False, weights="imagenet", input_tensor=input_layer
|
| 113 |
+
)
|
| 114 |
+
elif backbone == "densenet169":
|
| 115 |
+
# Load the DenseNet169 model from tensorflow.keras.applications
|
| 116 |
+
self.base_model = DenseNet169(
|
| 117 |
+
include_top=False, weights="imagenet", input_tensor=input_layer
|
| 118 |
+
)
|
| 119 |
+
elif backbone == "inception_v3":
|
| 120 |
+
# Load the InceptionV3 model from tensorflow.keras.applications
|
| 121 |
+
self.base_model = InceptionV3(
|
| 122 |
+
include_top=False, weights="imagenet", input_tensor=input_layer
|
| 123 |
+
)
|
| 124 |
+
elif backbone == "convnextv2_tiny":
|
| 125 |
+
# Load the ConvNeXtV2 Tiny model from transformers
|
| 126 |
+
self.base_model = TFConvNextV2Model.from_pretrained(
|
| 127 |
+
"facebook/convnextv2-tiny-22k-224"
|
| 128 |
+
)
|
| 129 |
+
elif backbone == "convnextv2_base":
|
| 130 |
+
# Load the ConvNeXtV2 Base model from transformers
|
| 131 |
+
self.base_model = TFConvNextV2Model.from_pretrained(
|
| 132 |
+
"facebook/convnextv2-base-22k-224"
|
| 133 |
+
)
|
| 134 |
+
elif backbone == "convnextv2_large":
|
| 135 |
+
# Load the ConvNeXtV2 Large model from transformers
|
| 136 |
+
self.base_model = TFConvNextV2Model.from_pretrained(
|
| 137 |
+
"facebook/convnextv2-large-22k-224"
|
| 138 |
+
)
|
| 139 |
+
elif backbone == "swin_tiny":
|
| 140 |
+
# Load the Swin Transformer Tiny model from transformers
|
| 141 |
+
self.base_model = TFSwinModel.from_pretrained(
|
| 142 |
+
"microsoft/swin-tiny-patch4-window7-224"
|
| 143 |
+
)
|
| 144 |
+
elif backbone == "swin_small":
|
| 145 |
+
# Load the Swin Transformer Small model from transformers
|
| 146 |
+
self.base_model = TFSwinModel.from_pretrained(
|
| 147 |
+
"microsoft/swin-small-patch4-window7-224"
|
| 148 |
+
)
|
| 149 |
+
elif backbone == "swin_base":
|
| 150 |
+
# Load the Swin Transformer Base model from transformers
|
| 151 |
+
self.base_model = TFSwinModel.from_pretrained(
|
| 152 |
+
"microsoft/swin-base-patch4-window7-224"
|
| 153 |
+
)
|
| 154 |
+
elif backbone in ["vit_base", "vit_large"]:
|
| 155 |
+
# Load the Vision Transformer (ViT) model from transformers
|
| 156 |
+
backbone_path = {
|
| 157 |
+
"vit_base": "google/vit-base-patch16-224",
|
| 158 |
+
"vit_large": "google/vit-large-patch16-224",
|
| 159 |
+
}
|
| 160 |
+
self.base_model = TFViTModel.from_pretrained(backbone_path[backbone])
|
| 161 |
+
else:
|
| 162 |
+
raise ValueError(f"Unsupported backbone model: {backbone}")
|
| 163 |
+
|
| 164 |
+
if mode == "eval":
|
| 165 |
+
# Set the model to evaluation mode (non-trainable)
|
| 166 |
+
self.base_model.trainable = False
|
| 167 |
+
elif mode == "fine_tune":
|
| 168 |
+
self.base_model.trainable = True
|
| 169 |
+
|
| 170 |
+
# 💬 NOTE: Take into account the model's input requirements. In models from transformers, the input is channels first, but in models from keras.applications, the input is channels last.
|
| 171 |
+
# Additionally, the output of the model is different in both cases, we need to get the pooling of the output layer.
|
| 172 |
+
|
| 173 |
+
# If is a model from transformers:
|
| 174 |
+
if backbone in [
|
| 175 |
+
"vit_base",
|
| 176 |
+
"vit_large",
|
| 177 |
+
"convnextv2_tiny",
|
| 178 |
+
"convnextv2_base",
|
| 179 |
+
"convnextv2_large",
|
| 180 |
+
"swin_tiny",
|
| 181 |
+
"swin_small",
|
| 182 |
+
"swin_base",
|
| 183 |
+
]:
|
| 184 |
+
# Adjust the input for channels first models within the model
|
| 185 |
+
input_layer_transposed = tf.transpose(input_layer, perm=[0, 3, 1, 2])
|
| 186 |
+
hf_outputs = self.base_model(input_layer_transposed)
|
| 187 |
+
|
| 188 |
+
# Get the pooling output of the model "pooler_output"
|
| 189 |
+
outputs = hf_outputs.pooler_output # shape (batch_size, hidden_size)
|
| 190 |
+
# If is a model from keras.applications
|
| 191 |
+
else:
|
| 192 |
+
# Get the pooling output of the model
|
| 193 |
+
# In this case the pooling layer is not included in the model, we can use a pooling layer such as GlobalAveragePooling2D
|
| 194 |
+
x = self.base_model.output
|
| 195 |
+
outputs = GlobalAveragePooling2D()(x)
|
| 196 |
+
|
| 197 |
+
# Create the final model with the input layer and the pooling output
|
| 198 |
+
self.model = Model(inputs=input_layer, outputs=outputs)
|
| 199 |
+
|
| 200 |
+
def get_output_shape(self):
|
| 201 |
+
"""
|
| 202 |
+
Get the output shape of the model.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
-------
|
| 206 |
+
tuple
|
| 207 |
+
The shape of the model's output tensor.
|
| 208 |
+
"""
|
| 209 |
+
return self.model.output_shape
|
| 210 |
+
|
| 211 |
+
def predict(self, images):
|
| 212 |
+
"""
|
| 213 |
+
Predict on a batch of images.
|
| 214 |
+
|
| 215 |
+
Parameters:
|
| 216 |
+
----------
|
| 217 |
+
images : numpy.ndarray
|
| 218 |
+
A batch of images of shape (batch_size, 224, 224, 3).
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
-------
|
| 222 |
+
numpy.ndarray
|
| 223 |
+
Predictions or features from the model for the given images.
|
| 224 |
+
"""
|
| 225 |
+
# Perform a forward pass through the model and return the predictions
|
| 226 |
+
images = tf.convert_to_tensor(images, dtype=tf.float32)
|
| 227 |
+
|
| 228 |
+
# Forward pass (no training)
|
| 229 |
+
predictions = self.model(images, training=False)
|
| 230 |
+
|
| 231 |
+
# Convert back to numpy for usability
|
| 232 |
+
return predictions.numpy()
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class ImageFolderDataset:
|
| 236 |
+
"""
|
| 237 |
+
A custom dataset class for loading and preprocessing images from a folder.
|
| 238 |
+
|
| 239 |
+
This class helps in loading images from a given folder, automatically filtering valid image files and
|
| 240 |
+
preprocessing them to a specified shape. It also handles any unreadable or corrupted images by excluding them.
|
| 241 |
+
|
| 242 |
+
Attributes:
|
| 243 |
+
----------
|
| 244 |
+
folder_path : str
|
| 245 |
+
The path to the folder containing the images.
|
| 246 |
+
shape : tuple
|
| 247 |
+
The desired shape (width, height) to which the images will be resized.
|
| 248 |
+
image_files : list
|
| 249 |
+
A list of valid image file names that can be processed.
|
| 250 |
+
|
| 251 |
+
Parameters:
|
| 252 |
+
----------
|
| 253 |
+
folder_path : str
|
| 254 |
+
The path to the folder containing image files.
|
| 255 |
+
shape : tuple, optional
|
| 256 |
+
The target shape to resize the images to. The default value is (224, 224).
|
| 257 |
+
image_files : list, optional
|
| 258 |
+
A pre-provided list of image file names. If not provided, it will automatically detect valid image files
|
| 259 |
+
(with extensions '.jpg', '.jpeg', '.png', '.gif') in the specified folder.
|
| 260 |
+
|
| 261 |
+
Methods:
|
| 262 |
+
-------
|
| 263 |
+
clean_unidentified_images():
|
| 264 |
+
Cleans the dataset by removing images that cause an `UnidentifiedImageError` during loading. This helps ensure
|
| 265 |
+
that only valid, readable images are kept in the dataset.
|
| 266 |
+
|
| 267 |
+
__len__():
|
| 268 |
+
Returns the number of valid images in the dataset after cleaning.
|
| 269 |
+
|
| 270 |
+
__getitem__(idx):
|
| 271 |
+
Given an index `idx`, retrieves the image file at that index, loads and preprocesses it, and returns the image
|
| 272 |
+
along with its filename.
|
| 273 |
+
|
| 274 |
+
"""
|
| 275 |
+
|
| 276 |
+
def __init__(self, folder_path, shape=(224, 224), image_files=None):
|
| 277 |
+
"""
|
| 278 |
+
Initializes the dataset object by setting the folder path and target image shape.
|
| 279 |
+
It also optionally accepts a list of image files to be processed, otherwise detects valid images in the folder.
|
| 280 |
+
|
| 281 |
+
Parameters:
|
| 282 |
+
----------
|
| 283 |
+
folder_path : str
|
| 284 |
+
The directory containing the images.
|
| 285 |
+
shape : tuple, optional
|
| 286 |
+
The target shape to resize the images to. Default is (224, 224).
|
| 287 |
+
image_files : list, optional
|
| 288 |
+
A list of image files to load. If not provided, it will auto-detect valid images from the folder.
|
| 289 |
+
"""
|
| 290 |
+
self.folder_path = folder_path
|
| 291 |
+
self.shape = shape
|
| 292 |
+
|
| 293 |
+
# If image files are provided, use them; otherwise, detect image files in the folder
|
| 294 |
+
if image_files:
|
| 295 |
+
self.image_files = image_files
|
| 296 |
+
else:
|
| 297 |
+
# List all files in the folder and filter only image files
|
| 298 |
+
self.image_files = [
|
| 299 |
+
f
|
| 300 |
+
for f in os.listdir(folder_path)
|
| 301 |
+
if f.lower().endswith(("jpg", "jpeg", "png", "gif"))
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
# Clean the dataset by removing images that cause errors during loading
|
| 305 |
+
self.clean_unidentified_images()
|
| 306 |
+
|
| 307 |
+
def clean_unidentified_images(self):
|
| 308 |
+
"""
|
| 309 |
+
Clean the dataset by removing images that cannot be opened due to errors (e.g., `UnidentifiedImageError`).
|
| 310 |
+
|
| 311 |
+
This method iterates over the list of detected image files and attempts to open and convert each image to RGB.
|
| 312 |
+
If an image cannot be opened (e.g., due to corruption or unsupported format), it is excluded from the dataset.
|
| 313 |
+
|
| 314 |
+
Any image that causes an error will be skipped, and a message will be printed to indicate which file was skipped.
|
| 315 |
+
"""
|
| 316 |
+
cleaned_files = []
|
| 317 |
+
# Iterate over the image files and check if they can be opened
|
| 318 |
+
for img_name in self.image_files:
|
| 319 |
+
img_path = os.path.join(self.folder_path, img_name)
|
| 320 |
+
try:
|
| 321 |
+
# Try to open the image and convert it to RGB format
|
| 322 |
+
Image.open(img_path).convert("RGB")
|
| 323 |
+
# If successful, add the image to the cleaned list
|
| 324 |
+
cleaned_files.append(img_name)
|
| 325 |
+
except Exception as e:
|
| 326 |
+
print(f"Skipping {img_name} due to error: {e}")
|
| 327 |
+
|
| 328 |
+
# Update the list of image files with only the cleaned files
|
| 329 |
+
self.image_files = cleaned_files
|
| 330 |
+
|
| 331 |
+
def __len__(self):
|
| 332 |
+
"""
|
| 333 |
+
Returns the number of valid images in the dataset after cleaning.
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
-------
|
| 337 |
+
int
|
| 338 |
+
The number of images in the cleaned dataset.
|
| 339 |
+
"""
|
| 340 |
+
return len(self.image_files)
|
| 341 |
+
|
| 342 |
+
def __getitem__(self, idx):
|
| 343 |
+
"""
|
| 344 |
+
Retrieves the image and its filename at the specified index.
|
| 345 |
+
|
| 346 |
+
Parameters:
|
| 347 |
+
----------
|
| 348 |
+
idx : int
|
| 349 |
+
The index of the image to retrieve.
|
| 350 |
+
|
| 351 |
+
Returns:
|
| 352 |
+
-------
|
| 353 |
+
tuple
|
| 354 |
+
A tuple containing the image filename and the preprocessed image as a NumPy array or Tensor.
|
| 355 |
+
|
| 356 |
+
Raises:
|
| 357 |
+
------
|
| 358 |
+
IndexError
|
| 359 |
+
If the index is out of bounds for the dataset.
|
| 360 |
+
"""
|
| 361 |
+
# Get an item from the list of image files
|
| 362 |
+
img_name = self.image_files[idx]
|
| 363 |
+
# Load and preprocess the image:
|
| 364 |
+
img_path = os.path.join(self.folder_path, img_name)
|
| 365 |
+
img = load_and_preprocess_image(img_path, self.shape)
|
| 366 |
+
# Return the image filename and the preprocessed image
|
| 367 |
+
return img_name, img
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def get_embeddings_df(
|
| 371 |
+
batch_size=32,
|
| 372 |
+
path="data/images",
|
| 373 |
+
dataset_name="",
|
| 374 |
+
backbone="resnet50",
|
| 375 |
+
directory="Embeddings",
|
| 376 |
+
image_files=None,
|
| 377 |
+
):
|
| 378 |
+
"""
|
| 379 |
+
Generates embeddings for images in a dataset using a specified backbone model and saves them to a CSV file.
|
| 380 |
+
|
| 381 |
+
This function processes images from a given folder in batches, extracts features (embeddings) using a specified
|
| 382 |
+
pre-trained computer vision model, and stores the results in a CSV file. The embeddings can be used for
|
| 383 |
+
downstream tasks such as image retrieval or clustering.
|
| 384 |
+
|
| 385 |
+
Parameters:
|
| 386 |
+
----------
|
| 387 |
+
batch_size : int, optional
|
| 388 |
+
The number of images to process in each batch. Default is 32.
|
| 389 |
+
path : str, optional
|
| 390 |
+
The folder path containing the images. Default is "data/images".
|
| 391 |
+
dataset_name : str, optional
|
| 392 |
+
The name of the dataset to create subdirectories for saving embeddings. Default is an empty string.
|
| 393 |
+
backbone : str, optional
|
| 394 |
+
The name of the backbone model to use for generating embeddings. The default is 'resnet50'.
|
| 395 |
+
Other possible options include models like 'convnext_tiny', 'vit_base', etc.
|
| 396 |
+
directory : str, optional
|
| 397 |
+
The root directory where the embeddings CSV file will be saved. Default is 'Embeddings'.
|
| 398 |
+
image_files : list, optional
|
| 399 |
+
A pre-defined list of image file names to process. If not provided, the function will automatically detect
|
| 400 |
+
image files in the `path` directory.
|
| 401 |
+
|
| 402 |
+
Returns:
|
| 403 |
+
-------
|
| 404 |
+
None
|
| 405 |
+
The function does not return any value. It saves a CSV file containing image names and their embeddings.
|
| 406 |
+
|
| 407 |
+
Side Effects:
|
| 408 |
+
------------
|
| 409 |
+
- Saves a CSV file in the specified directory containing image file names and their corresponding embeddings.
|
| 410 |
+
|
| 411 |
+
Notes:
|
| 412 |
+
------
|
| 413 |
+
- The images are loaded and preprocessed using the `ImageFolderDataset` class.
|
| 414 |
+
- The embeddings are generated using a pre-trained model from the `FoundationalCVModel` class.
|
| 415 |
+
- The embeddings are saved as a CSV file with the following structure:
|
| 416 |
+
- `ImageName`: The name of the image file.
|
| 417 |
+
- Columns corresponding to the embedding vector (one column per feature).
|
| 418 |
+
|
| 419 |
+
Example:
|
| 420 |
+
--------
|
| 421 |
+
>>> get_embeddings_df(batch_size=16, path="data/images", dataset_name='sample_dataset', backbone="resnet50")
|
| 422 |
+
|
| 423 |
+
This would generate a CSV file with image embeddings from the 'resnet50' backbone model for images in the "data/images" directory.
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
# Create an instance of the ImageFolderDataset class
|
| 427 |
+
dataset = ImageFolderDataset(folder_path=path, image_files=image_files)
|
| 428 |
+
# Create an instance of the FoundationalCVModel class
|
| 429 |
+
model = FoundationalCVModel(backbone)
|
| 430 |
+
|
| 431 |
+
img_names = []
|
| 432 |
+
features = []
|
| 433 |
+
# Calculate the number of batches based on the dataset size and batch size
|
| 434 |
+
num_batches = len(dataset) // batch_size + (
|
| 435 |
+
1 if len(dataset) % batch_size != 0 else 0
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# Process images in batches and extract features
|
| 439 |
+
for i in range(0, len(dataset), batch_size):
|
| 440 |
+
# Get the image files and images for the current batch
|
| 441 |
+
batch_files = dataset.image_files[i : i + batch_size]
|
| 442 |
+
batch_imgs = np.array(
|
| 443 |
+
[dataset[j][1] for j in range(i, min(i + batch_size, len(dataset)))]
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Generate embeddings for the batch of images
|
| 447 |
+
batch_features = model.predict(batch_imgs)
|
| 448 |
+
|
| 449 |
+
# Append the image names and features to the lists
|
| 450 |
+
img_names.extend(batch_files)
|
| 451 |
+
features.extend(batch_features)
|
| 452 |
+
|
| 453 |
+
if (i // batch_size + 1) % 10 == 0:
|
| 454 |
+
print(f"Batch {i // batch_size + 1}/{num_batches} done")
|
| 455 |
+
|
| 456 |
+
# Create a DataFrame with the image names and embeddings
|
| 457 |
+
df = pd.DataFrame({"ImageName": img_names, "Embeddings": features})
|
| 458 |
+
|
| 459 |
+
# Split the embeddings into separate columns
|
| 460 |
+
df_aux = pd.DataFrame(df["Embeddings"].tolist())
|
| 461 |
+
df = pd.concat([df["ImageName"], df_aux], axis=1)
|
| 462 |
+
|
| 463 |
+
# Save the DataFrame to a CSV file
|
| 464 |
+
if not os.path.exists(directory):
|
| 465 |
+
os.makedirs(directory)
|
| 466 |
+
|
| 467 |
+
if not os.path.exists(f"{directory}/{dataset_name}"):
|
| 468 |
+
os.makedirs(f"{directory}/{dataset_name}")
|
| 469 |
+
|
| 470 |
+
df.to_csv(f"{directory}/{dataset_name}/Embeddings_{backbone}.csv", index=False)
|