Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from typing import Callable, List | |
| import json | |
| from pathlib import Path | |
| from collections import OrderedDict | |
| class Action: | |
| """The action represent an assignment. | |
| `output = tool_name(**inputs)` | |
| Examples: | |
| >>> mask = segmentation_by_mask(image=image, prompt_mask=prompt_mask) | |
| >>> image = image_inpainting(image=image, mask=mask) | |
| """ | |
| tool_name: str = (None,) | |
| inputs: dict = (None,) | |
| outputs: List[str] = (None,) | |
| def __str__(self) -> str: | |
| args = ", ".join([f"{k}={v}" for k, v in self.inputs.items()]) | |
| return "{} = {}(".format(", ".join(self.outputs), self.tool_name) + args + ")" | |
| def dict(self): | |
| args = {str(k): str(v) for k, v in self.inputs.items()} | |
| # args = {str(item["name"]): str(item["value"]) for item in self.inputs} | |
| rets = [o if isinstance(o, str) else str(o) for o in self.outputs] | |
| return { | |
| "tool": self.tool_name, | |
| "inputs": args, | |
| "outputs": rets, | |
| } | |
| class DataType(Enum): | |
| TEXT = "text" | |
| TAGS = "tags" | |
| TITLE = "title" | |
| # HTML = "text.html" | |
| HTML = "html" | |
| LOCATION = "location" | |
| WEATHER = "weather" | |
| TIME = "time" | |
| IMAGE = "image" | |
| VIDEO = "video" | |
| AUDIO = "audio" | |
| ANY = "any" | |
| NONE = "none" | |
| SEGMENTATION = "image.segmentation" | |
| EDGE = "image.edge" | |
| LINE = "image.line" | |
| HED = "image.hed" | |
| CANNY = "image.canny" | |
| SCRIBBLE = "image.scribble" | |
| POSE = "image.pose" | |
| DEPTH = "image.depth" | |
| NORMAL = "image.normal" | |
| MASK = "image.mask" # SAM mask | |
| POINT = "point" | |
| BBOX = "bbox" # {'label': 'dog', 'box': [1,2,3,4], 'score': 0.9} | |
| CATEGORY = "category" | |
| LIST = "list" | |
| def __str__(self): | |
| return self.value | |
| def __eq__(self, other): | |
| if isinstance(other, str): | |
| return self.value == other | |
| elif isinstance(other, self.__class__): | |
| return self.value == other.value | |
| else: | |
| return False | |
| class Resource: | |
| name: str | |
| type: DataType | |
| value: None | |
| # description: str = None | |
| def dict(self): | |
| return { | |
| "name": self.name, | |
| "type": str(self.type), | |
| "value": str(self.value), | |
| # "description": self.description, | |
| } | |
| class Tool: | |
| class Domain(Enum): | |
| IMAGE_PERCEPTION = "image-perception" | |
| IMAGE_GENERATION = "image-generation" | |
| IMAGE_EDITING = "image-editing" | |
| IMAGE_PROCESSING = "image-processing" | |
| AUDIO_PERCEPTION = "audio-perception" | |
| AUDIO_GENERATION = "audio-generation" | |
| VIDEO_PERCEPTION = "video-perception" | |
| VIDEO_GENERATION = "video-generation" | |
| VIDEO_PROCESSING = "video-processing" | |
| VIDEO_EDITING = "video-editing" | |
| VIDEO_CUTTING = "video-cutting" | |
| NATURAL_LANGUAGE_PROCESSING = "natural-language-processing" | |
| CODE_GENERATION = "code-generation" | |
| VISUAL_QUESTION_ANSWERING = "visual-question-answering" | |
| QUESTION_ANSWERING = "question-answering" | |
| GENERAL = "general" | |
| def __str__(self): | |
| return self.value | |
| class Argument: | |
| name: str | |
| type: DataType | |
| description: str | |
| def dict(self): | |
| return { | |
| "name": self.name, | |
| "type": str(self.type), | |
| "description": self.description, | |
| } | |
| name: str | |
| description: str | |
| domain: Domain | |
| model: Callable | |
| usages: List[str] = field(default_factory=lambda: []) | |
| args: List[Argument] = field(default_factory=lambda: []) | |
| returns: List[Argument] = field(default_factory=lambda: []) | |
| def dict(self): | |
| return { | |
| "name": self.name, | |
| "description": self.description, | |
| "domain": str(self.domain), | |
| "args": [a.dict() for a in self.args], | |
| "returns": [r.dict() for r in self.returns], | |
| } | |
| NON_FILE_TYPES = [ | |
| DataType.TAGS, | |
| DataType.TEXT, | |
| DataType.TITLE, | |
| DataType.BBOX, | |
| DataType.CATEGORY, | |
| DataType.LIST, | |
| DataType.LOCATION, | |
| DataType.POINT, | |
| DataType.WEATHER, | |
| DataType.TIME, | |
| ] | |
| if __name__ == "__main__": | |
| s = [ | |
| [Action("a", {"aa": [Path("/a/d/e/t.txt")]}, [Path("/a/aa.txt")])], | |
| Action("b", {"bb": "bbb"}, ["bbb"]), | |
| ] | |
| print(json.dumps(s, indent=4, default=lambda o: o.dict())) | |