Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import pipeline, ImageClassificationPipeline | |
| class MultiClassLabel(ImageClassificationPipeline): | |
| def postprocess(self, model_outputs, top_k=5): | |
| if top_k > self.model.config.num_labels: | |
| top_k = self.model.config.num_labels | |
| if self.framework == "pt": | |
| probs = model_outputs.logits.sigmoid()[0] | |
| scores, ids = probs.topk(top_k) | |
| elif self.framework == "tf": | |
| probs = stable_softmax(model_outputs.logits, axis=-1)[0] | |
| topk = tf.math.top_k(probs, k=top_k) | |
| scores, ids = topk.values.numpy(), topk.indices.numpy() | |
| else: | |
| raise ValueError(f"Unsupported framework: {self.framework}") | |
| scores = scores.tolist() | |
| ids = ids.tolist() | |
| return [{"score": score, "label": self.model.config.id2label[_id]} for score, _id in zip(scores, ids)] | |
| pipe_aesthetic = pipeline("image-classification", "./sonic", pipeline_class=MultiClassLabel) | |
| def aesthetic(input_img): | |
| data = pipe_aesthetic(input_img, top_k=5) | |
| final = {} | |
| for d in data: | |
| final[d["label"]] = d["score"] | |
| return final | |
| demo_aesthetic = gr.Interface(fn=aesthetic, inputs=gr.Image(type="pil"), outputs=gr.Label(label="characters")) | |
| gr.Parallel(demo_aesthetic).launch() |