Update pipeline.py
Browse files- pipeline.py +12 -9
pipeline.py
CHANGED
|
@@ -18,16 +18,17 @@ class PreTrainedPipeline(Pipeline):
|
|
| 18 |
|
| 19 |
|
| 20 |
# Reload Keras SavedModel
|
| 21 |
-
self.model =
|
| 22 |
|
| 23 |
# Number of labels
|
| 24 |
self.num_labels = self.model.output_shape[1]
|
| 25 |
|
| 26 |
# Config is required to know the mapping to label.
|
| 27 |
-
config_file = hf_hub_download(model_id, filename=CONFIG_FILENAME)
|
| 28 |
-
with open(config_file) as config:
|
| 29 |
-
|
| 30 |
-
|
|
|
|
| 31 |
self.id2label = config.get(
|
| 32 |
"id2label", {str(i): f"LABEL_{i}" for i in range(self.num_labels)}
|
| 33 |
)
|
|
@@ -59,12 +60,14 @@ class PreTrainedPipeline(Pipeline):
|
|
| 59 |
self.single_output_unit = (
|
| 60 |
self.model.output_shape[1] == 1
|
| 61 |
) # if there are two classes
|
| 62 |
-
|
|
|
|
| 63 |
if self.single_output_unit:
|
| 64 |
score = predictions[0][0]
|
| 65 |
-
labels = [
|
| 66 |
-
|
| 67 |
-
|
|
|
|
| 68 |
]
|
| 69 |
else:
|
| 70 |
labels = [
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
# Reload Keras SavedModel
|
| 21 |
+
self.model = keras.models.load_model('./model.h5')
|
| 22 |
|
| 23 |
# Number of labels
|
| 24 |
self.num_labels = self.model.output_shape[1]
|
| 25 |
|
| 26 |
# Config is required to know the mapping to label.
|
| 27 |
+
#config_file = hf_hub_download(model_id, filename=CONFIG_FILENAME)
|
| 28 |
+
#with open(config_file) as config:
|
| 29 |
+
# config = json.load(config)
|
| 30 |
+
|
| 31 |
+
self.num_labels = 3
|
| 32 |
self.id2label = config.get(
|
| 33 |
"id2label", {str(i): f"LABEL_{i}" for i in range(self.num_labels)}
|
| 34 |
)
|
|
|
|
| 60 |
self.single_output_unit = (
|
| 61 |
self.model.output_shape[1] == 1
|
| 62 |
) # if there are two classes
|
| 63 |
+
|
| 64 |
+
|
| 65 |
if self.single_output_unit:
|
| 66 |
score = predictions[0][0]
|
| 67 |
+
labels = [{"label":"pet", "score":1.0}, {"label":"other", "score":1.0}]
|
| 68 |
+
#labels = [
|
| 69 |
+
# {"label": str(self.id2label["1"]), "score": float(score)},
|
| 70 |
+
# {"label": str(self.id2label["0"]), "score": float(1 - score)},
|
| 71 |
]
|
| 72 |
else:
|
| 73 |
labels = [
|