Lefei commited on
Commit
34eb624
·
verified ·
1 Parent(s): b8584bb

update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -16,7 +16,10 @@ from visionts import VisionTSpp, freq_to_seasonality_list
16
  # ========================
17
  # 1. Configuration
18
  # ========================
 
19
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
20
  REPO_ID = "Lefei/VisionTSpp"
21
  LOCAL_DIR = "./hf_models/VisionTSpp"
22
  CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
@@ -121,9 +124,15 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
121
 
122
  print(f"{len(pred_quantiles_list) = }")
123
  print(f"{len(model_quantiles) = }")
 
124
  print(f"{pred_quantiles_list[0].shape = }")
125
 
126
- sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list + [pred_median]), key=lambda x: x[0])
 
 
 
 
 
127
  quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5]
128
  quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5]
129
 
@@ -149,6 +158,7 @@ def visual_ts_with_quantiles(true_data, pred_median, pred_quantiles_list, model_
149
  handles, labels = axes[0].get_legend_handles_labels()
150
  unique_labels = dict(zip(labels, handles))
151
  fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2)
 
152
  plt.tight_layout(rect=[0, 0, 1, 0.95])
153
  plt.close(fig)
154
 
@@ -227,12 +237,9 @@ def predict_at_index(df, index, context_len, pred_len):
227
  print(f"{reconstructed_image.shape = }")
228
  print(f"{len(y_pred_quantile_list) = }")
229
 
230
- # print(f"{input_image = }")
231
- print(f"{input_image[0,0,0, :, 0] = }")
232
- print(f"{input_image[0,0,0, 50:70, 0] = }")
233
- print(f"{input_image[0,0,0, 100:120, 0] = }")
234
- # print(f"{input_image[0] = }")
235
- # print(f"{reconstructed_image = }")
236
 
237
  all_y_pred_list = copy.deepcopy(y_pred_quantile_list)
238
 
 
16
  # ========================
17
  # 1. Configuration
18
  # ========================
19
+
20
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+ # DEVICE = 'cpu'
22
+
23
  REPO_ID = "Lefei/VisionTSpp"
24
  LOCAL_DIR = "./hf_models/VisionTSpp"
25
  CKPT_PATH = os.path.join(LOCAL_DIR, "visiontspp_model.ckpt")
 
124
 
125
  print(f"{len(pred_quantiles_list) = }")
126
  print(f"{len(model_quantiles) = }")
127
+ print(f"{model_quantiles = }")
128
  print(f"{pred_quantiles_list[0].shape = }")
129
 
130
+ # sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list + [pred_median]), key=lambda x: x[0])
131
+ # sorted_quantiles = sorted(zip(model_quantiles, pred_quantiles_list), key=lambda x: x[0])
132
+
133
+ pred_quantiles_list.insert(len(QUANTILES)//2, pred_median)
134
+ sorted_quantiles = sorted(zip(QUANTILES, pred_quantiles_list), key=lambda x: x[0])
135
+
136
  quantile_preds = [item[1] for item in sorted_quantiles if item[0] != 0.5]
137
  quantile_vals = [item[0] for item in sorted_quantiles if item[0] != 0.5]
138
 
 
158
  handles, labels = axes[0].get_legend_handles_labels()
159
  unique_labels = dict(zip(labels, handles))
160
  fig.legend(unique_labels.values(), unique_labels.keys(), loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=num_bands + 2)
161
+
162
  plt.tight_layout(rect=[0, 0, 1, 0.95])
163
  plt.close(fig)
164
 
 
237
  print(f"{reconstructed_image.shape = }")
238
  print(f"{len(y_pred_quantile_list) = }")
239
 
240
+ # print(f"{input_image[0,0,0, :, 0] = }")
241
+ # print(f"{input_image[0,0,0, 50:70, 0] = }")
242
+ # print(f"{input_image[0,0,0, 100:120, 0] = }")
 
 
 
243
 
244
  all_y_pred_list = copy.deepcopy(y_pred_quantile_list)
245