Spaces:
Running
Running
update app.py
Browse files
app.py
CHANGED
|
@@ -16,7 +16,10 @@ from visionts import VisionTSpp, freq_to_seasonality_list
|
|
| 16 |
# ========================
|
| 17 |
# 1. Configuration
|
| 18 |
# ========================
|
|
|
|
| 19 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
| 20 |
REPO_ID = "Lefei/VisionTSpp"
|
| 21 |
LOCAL_DIR = "./hf_models/VisionTSpp"
|
| 22 |
CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
|
|
@@ -121,9 +124,15 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
|
|
| 121 |
|
| 122 |
print(f"{len(pred_quantiles_list) = }")
|
| 123 |
print(f"{len(model_quantiles) = }")
|
|
|
|
| 124 |
print(f"{pred_quantiles_list[0].shape = }")
|
| 125 |
|
| 126 |
-
sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list + [pred_median]), key=lambda x: x[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5]
|
| 128 |
quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5]
|
| 129 |
|
|
@@ -149,6 +158,7 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
|
|
| 149 |
handles, labels = axes[0].get_legend_handles_labels()
|
| 150 |
unique_labels = dict(zip(labels, handles))
|
| 151 |
fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2)
|
|
|
|
| 152 |
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
| 153 |
plt.close(fig)
|
| 154 |
|
|
@@ -227,12 +237,9 @@ def predict_at_index(df, index, context_len, pred_len):
|
|
| 227 |
print(f"{reconstructed_image.shape = }")
|
| 228 |
print(f"{len(y_pred_quantile_list) = }")
|
| 229 |
|
| 230 |
-
# print(f"{input_image = }")
|
| 231 |
-
print(f"{input_image[0,0,0,
|
| 232 |
-
print(f"{input_image[0,0,0,
|
| 233 |
-
print(f"{input_image[0,0,0, 100:120, 0] = }")
|
| 234 |
-
# print(f"{input_image[0] = }")
|
| 235 |
-
# print(f"{reconstructed_image = }")
|
| 236 |
|
| 237 |
all_y_pred_list = copy.deepcopy(y_pred_quantile_list)
|
| 238 |
|
|
|
|
| 16 |
# ========================
|
| 17 |
# 1. Configuration
|
| 18 |
# ========================
|
| 19 |
+
|
| 20 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 21 |
+
# DEVICE = 'cpu'
|
| 22 |
+
|
| 23 |
REPO_ID = "Lefei/VisionTSpp"
|
| 24 |
LOCAL_DIR = "./hf_models/VisionTSpp"
|
| 25 |
CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
|
|
|
|
| 124 |
|
| 125 |
print(f"{len(pred_quantiles_list) = }")
|
| 126 |
print(f"{len(model_quantiles) = }")
|
| 127 |
+
print(f"{model_quantiles = }")
|
| 128 |
print(f"{pred_quantiles_list[0].shape = }")
|
| 129 |
|
| 130 |
+
# sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list + [pred_median]), key=lambda x: x[0])
|
| 131 |
+
# sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list), key=lambda x: x[0])
|
| 132 |
+
|
| 133 |
+
pred_quantiles_list.insert(len(QUANTILES)//2, pred_median)
|
| 134 |
+
sorted_quantiles = sorted(zip(QUANTILES, pred_quantiles_list), key=lambda x: x[0])
|
| 135 |
+
|
| 136 |
quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5]
|
| 137 |
quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5]
|
| 138 |
|
|
|
|
| 158 |
handles, labels = axes[0].get_legend_handles_labels()
|
| 159 |
unique_labels = dict(zip(labels, handles))
|
| 160 |
fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2)
|
| 161 |
+
|
| 162 |
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
| 163 |
plt.close(fig)
|
| 164 |
|
|
|
|
| 237 |
print(f"{reconstructed_image.shape = }")
|
| 238 |
print(f"{len(y_pred_quantile_list) = }")
|
| 239 |
|
| 240 |
+
# print(f"{input_image[0,0,0, :, 0] = }")
|
| 241 |
+
# print(f"{input_image[0,0,0, 50:70, 0] = }")
|
| 242 |
+
# print(f"{input_image[0,0,0, 100:120, 0] = }")
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
all_y_pred_list = copy.deepcopy(y_pred_quantile_list)
|
| 245 |
|