Fahimeh Orvati Nia commited on
Commit
b989947
·
1 Parent(s): 8b0c81c
Files changed (1) hide show
  1. sorghum_pipeline/pipeline.py +16 -2
sorghum_pipeline/pipeline.py CHANGED
@@ -99,7 +99,14 @@ class SorghumPipeline:
99
  green_band = None
100
  spectral = pdata.get('spectral_stack', {})
101
  if 'green' in spectral:
102
- green_band = spectral['green'].squeeze(-1).astype(np.float64)
 
 
 
 
 
 
 
103
  if mask is not None:
104
  valid = np.where(mask > 0, green_band, np.nan)
105
  else:
@@ -138,7 +145,14 @@ class SorghumPipeline:
138
  if not all(b in spectral for b in bands):
139
  continue
140
 
141
- arrays = [np.asarray(spectral[b].squeeze(-1), dtype=np.float64) for b in bands]
 
 
 
 
 
 
 
142
  values = self.vegetation_extractor.index_formulas[name](*arrays).astype(np.float64)
143
  binary_mask = (mask > 0)
144
  masked_values = np.where(binary_mask, values, np.nan)
 
99
  green_band = None
100
  spectral = pdata.get('spectral_stack', {})
101
  if 'green' in spectral:
102
+ gb = spectral['green']
103
+ gb = np.asarray(gb)
104
+ # robustly collapse to 2D if it arrived as 3-channel
105
+ if gb.ndim == 3 and gb.shape[2] > 1:
106
+ gb = gb[..., 0]
107
+ elif gb.ndim == 3 and gb.shape[2] == 1:
108
+ gb = gb.squeeze(-1)
109
+ green_band = gb.astype(np.float64)
110
  if mask is not None:
111
  valid = np.where(mask > 0, green_band, np.nan)
112
  else:
 
145
  if not all(b in spectral for b in bands):
146
  continue
147
 
148
+ arrays = []
149
+ for b in bands:
150
+ arr = np.asarray(spectral[b])
151
+ if arr.ndim == 3 and arr.shape[2] > 1:
152
+ arr = arr[..., 0]
153
+ elif arr.ndim == 3 and arr.shape[2] == 1:
154
+ arr = arr.squeeze(-1)
155
+ arrays.append(arr.astype(np.float64))
156
  values = self.vegetation_extractor.index_formulas[name](*arrays).astype(np.float64)
157
  binary_mask = (mask > 0)
158
  masked_values = np.where(binary_mask, values, np.nan)