|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio |
|
|
from enum import Enum |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
from camel.tasks import Task |
|
|
|
|
|
|
|
|
class PacketStatus(Enum): |
|
|
r"""The status of a packet. The packet can be in one of the following |
|
|
states: |
|
|
|
|
|
- ``SENT``: The packet has been sent to a worker. |
|
|
- ``RETURNED``: The packet has been returned by the worker, meaning that |
|
|
the status of the task inside has been updated. |
|
|
- ``ARCHIVED``: The packet has been archived, meaning that the content of |
|
|
the task inside will not be changed. The task is considered |
|
|
as a dependency. |
|
|
""" |
|
|
|
|
|
SENT = "SENT" |
|
|
RETURNED = "RETURNED" |
|
|
ARCHIVED = "ARCHIVED" |
|
|
|
|
|
|
|
|
class Packet: |
|
|
r"""The basic element inside the channel. A task is wrapped inside a |
|
|
packet. The packet will contain the task, along with the task's assignee, |
|
|
and the task's status. |
|
|
|
|
|
Args: |
|
|
task (Task): The task that is wrapped inside the packet. |
|
|
publisher_id (str): The ID of the workforce that published the task. |
|
|
assignee_id (str): The ID of the workforce that is assigned |
|
|
to the task. Defaults to None, meaning that the task is posted as |
|
|
a dependency in the channel. |
|
|
|
|
|
Attributes: |
|
|
task (Task): The task that is wrapped inside the packet. |
|
|
publisher_id (str): The ID of the workforce that published the task. |
|
|
assignee_id (Optional[str], optional): The ID of the workforce that is |
|
|
assigned to the task. Would be None if the task is a dependency. |
|
|
Defaults to None. |
|
|
status (PacketStatus): The status of the task. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
task: Task, |
|
|
publisher_id: str, |
|
|
assignee_id: Optional[str] = None, |
|
|
status: PacketStatus = PacketStatus.SENT, |
|
|
) -> None: |
|
|
self.task = task |
|
|
self.publisher_id = publisher_id |
|
|
self.assignee_id = assignee_id |
|
|
self.status = status |
|
|
|
|
|
def __repr__(self): |
|
|
return ( |
|
|
f"Packet(publisher_id={self.publisher_id}, assignee_id=" |
|
|
f"{self.assignee_id}, status={self.status})" |
|
|
) |
|
|
|
|
|
|
|
|
class TaskChannel: |
|
|
r"""An internal class used by Workforce to manage tasks.""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
self._task_id_list: List[str] = [] |
|
|
self._condition = asyncio.Condition() |
|
|
self._task_dict: Dict[str, Packet] = {} |
|
|
|
|
|
async def get_returned_task_by_publisher(self, publisher_id: str) -> Task: |
|
|
r"""Get a task from the channel that has been returned by the |
|
|
publisher. |
|
|
""" |
|
|
async with self._condition: |
|
|
while True: |
|
|
for task_id in self._task_id_list: |
|
|
packet = self._task_dict[task_id] |
|
|
if packet.publisher_id != publisher_id: |
|
|
continue |
|
|
if packet.status != PacketStatus.RETURNED: |
|
|
continue |
|
|
return packet.task |
|
|
await self._condition.wait() |
|
|
|
|
|
async def get_assigned_task_by_assignee(self, assignee_id: str) -> Task: |
|
|
r"""Get a task from the channel that has been assigned to the |
|
|
assignee. |
|
|
""" |
|
|
async with self._condition: |
|
|
while True: |
|
|
for task_id in self._task_id_list: |
|
|
packet = self._task_dict[task_id] |
|
|
if ( |
|
|
packet.status == PacketStatus.SENT |
|
|
and packet.assignee_id == assignee_id |
|
|
): |
|
|
return packet.task |
|
|
await self._condition.wait() |
|
|
|
|
|
async def post_task( |
|
|
self, task: Task, publisher_id: str, assignee_id: str |
|
|
) -> None: |
|
|
r"""Send a task to the channel with specified publisher and assignee, |
|
|
along with the dependency of the task.""" |
|
|
async with self._condition: |
|
|
self._task_id_list.append(task.id) |
|
|
packet = Packet(task, publisher_id, assignee_id) |
|
|
self._task_dict[packet.task.id] = packet |
|
|
self._condition.notify_all() |
|
|
|
|
|
async def post_dependency( |
|
|
self, dependency: Task, publisher_id: str |
|
|
) -> None: |
|
|
r"""Post a dependency to the channel. A dependency is a task that is |
|
|
archived, and will be referenced by other tasks.""" |
|
|
async with self._condition: |
|
|
self._task_id_list.append(dependency.id) |
|
|
packet = Packet( |
|
|
dependency, publisher_id, status=PacketStatus.ARCHIVED |
|
|
) |
|
|
self._task_dict[packet.task.id] = packet |
|
|
self._condition.notify_all() |
|
|
|
|
|
async def return_task(self, task_id: str) -> None: |
|
|
r"""Return a task to the sender, indicating that the task has been |
|
|
processed by the worker.""" |
|
|
async with self._condition: |
|
|
packet = self._task_dict[task_id] |
|
|
packet.status = PacketStatus.RETURNED |
|
|
self._condition.notify_all() |
|
|
|
|
|
async def archive_task(self, task_id: str) -> None: |
|
|
r"""Archive a task in channel, making it to become a dependency.""" |
|
|
async with self._condition: |
|
|
packet = self._task_dict[task_id] |
|
|
packet.status = PacketStatus.ARCHIVED |
|
|
self._condition.notify_all() |
|
|
|
|
|
async def remove_task(self, task_id: str) -> None: |
|
|
r"""Remove a task from the channel.""" |
|
|
async with self._condition: |
|
|
self._task_id_list.remove(task_id) |
|
|
self._task_dict.pop(task_id) |
|
|
self._condition.notify_all() |
|
|
|
|
|
async def get_dependency_ids(self) -> List[str]: |
|
|
r"""Get the IDs of all dependencies in the channel.""" |
|
|
async with self._condition: |
|
|
dependency_ids = [] |
|
|
for task_id in self._task_id_list: |
|
|
packet = self._task_dict[task_id] |
|
|
if packet.status == PacketStatus.ARCHIVED: |
|
|
dependency_ids.append(task_id) |
|
|
return dependency_ids |
|
|
|
|
|
async def get_task_by_id(self, task_id: str) -> Task: |
|
|
r"""Get a task from the channel by its ID.""" |
|
|
async with self._condition: |
|
|
if task_id not in self._task_id_list: |
|
|
raise ValueError(f"Task {task_id} not found.") |
|
|
return self._task_dict[task_id].task |
|
|
|
|
|
async def get_channel_debug_info(self) -> str: |
|
|
r"""Get the debug information of the channel.""" |
|
|
async with self._condition: |
|
|
return str(self._task_dict) + '\n' + str(self._task_id_list) |
|
|
|