File size: 7,140 Bytes
fcaa164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import asyncio
from enum import Enum
from typing import Dict, List, Optional

from camel.tasks import Task


class PacketStatus(Enum):
    r"""The status of a packet. The packet can be in one of the following
    states:

    - ``SENT``: The packet has been sent to a worker.
    - ``RETURNED``: The packet has been returned by the worker, meaning that
      the status of the task inside has been updated.
    - ``ARCHIVED``: The packet has been archived, meaning that the content of
      the task inside will not be changed. The task is considered
      as a dependency.
    """

    SENT = "SENT"
    RETURNED = "RETURNED"
    ARCHIVED = "ARCHIVED"


class Packet:
    r"""The basic element inside the channel. A task is wrapped inside a
    packet. The packet will contain the task, along with the task's assignee,
    and the task's status.

    Args:
        task (Task): The task that is wrapped inside the packet.
        publisher_id (str): The ID of the workforce that published the task.
        assignee_id (str): The ID of the workforce that is assigned
            to the task. Defaults to None, meaning that the task is posted as
            a dependency in the channel.

    Attributes:
        task (Task): The task that is wrapped inside the packet.
        publisher_id (str): The ID of the workforce that published the task.
        assignee_id (Optional[str], optional): The ID of the workforce that is
            assigned to the task. Would be None if the task is a dependency.
            Defaults to None.
        status (PacketStatus): The status of the task.
    """

    def __init__(
        self,
        task: Task,
        publisher_id: str,
        assignee_id: Optional[str] = None,
        status: PacketStatus = PacketStatus.SENT,
    ) -> None:
        self.task = task
        self.publisher_id = publisher_id
        self.assignee_id = assignee_id
        self.status = status

    def __repr__(self):
        return (
            f"Packet(publisher_id={self.publisher_id}, assignee_id="
            f"{self.assignee_id}, status={self.status})"
        )


class TaskChannel:
    r"""An internal class used by Workforce to manage tasks."""

    def __init__(self) -> None:
        self._task_id_list: List[str] = []
        self._condition = asyncio.Condition()
        self._task_dict: Dict[str, Packet] = {}

    async def get_returned_task_by_publisher(self, publisher_id: str) -> Task:
        r"""Get a task from the channel that has been returned by the
        publisher.
        """
        async with self._condition:
            while True:
                for task_id in self._task_id_list:
                    packet = self._task_dict[task_id]
                    if packet.publisher_id != publisher_id:
                        continue
                    if packet.status != PacketStatus.RETURNED:
                        continue
                    return packet.task
                await self._condition.wait()

    async def get_assigned_task_by_assignee(self, assignee_id: str) -> Task:
        r"""Get a task from the channel that has been assigned to the
        assignee.
        """
        async with self._condition:
            while True:
                for task_id in self._task_id_list:
                    packet = self._task_dict[task_id]
                    if (
                        packet.status == PacketStatus.SENT
                        and packet.assignee_id == assignee_id
                    ):
                        return packet.task
                await self._condition.wait()

    async def post_task(
        self, task: Task, publisher_id: str, assignee_id: str
    ) -> None:
        r"""Send a task to the channel with specified publisher and assignee,
        along with the dependency of the task."""
        async with self._condition:
            self._task_id_list.append(task.id)
            packet = Packet(task, publisher_id, assignee_id)
            self._task_dict[packet.task.id] = packet
            self._condition.notify_all()

    async def post_dependency(
        self, dependency: Task, publisher_id: str
    ) -> None:
        r"""Post a dependency to the channel. A dependency is a task that is
        archived, and will be referenced by other tasks."""
        async with self._condition:
            self._task_id_list.append(dependency.id)
            packet = Packet(
                dependency, publisher_id, status=PacketStatus.ARCHIVED
            )
            self._task_dict[packet.task.id] = packet
            self._condition.notify_all()

    async def return_task(self, task_id: str) -> None:
        r"""Return a task to the sender, indicating that the task has been
        processed by the worker."""
        async with self._condition:
            packet = self._task_dict[task_id]
            packet.status = PacketStatus.RETURNED
            self._condition.notify_all()

    async def archive_task(self, task_id: str) -> None:
        r"""Archive a task in channel, making it to become a dependency."""
        async with self._condition:
            packet = self._task_dict[task_id]
            packet.status = PacketStatus.ARCHIVED
            self._condition.notify_all()

    async def remove_task(self, task_id: str) -> None:
        r"""Remove a task from the channel."""
        async with self._condition:
            self._task_id_list.remove(task_id)
            self._task_dict.pop(task_id)
            self._condition.notify_all()

    async def get_dependency_ids(self) -> List[str]:
        r"""Get the IDs of all dependencies in the channel."""
        async with self._condition:
            dependency_ids = []
            for task_id in self._task_id_list:
                packet = self._task_dict[task_id]
                if packet.status == PacketStatus.ARCHIVED:
                    dependency_ids.append(task_id)
            return dependency_ids

    async def get_task_by_id(self, task_id: str) -> Task:
        r"""Get a task from the channel by its ID."""
        async with self._condition:
            if task_id not in self._task_id_list:
                raise ValueError(f"Task {task_id} not found.")
            return self._task_dict[task_id].task

    async def get_channel_debug_info(self) -> str:
        r"""Get the debug information of the channel."""
        async with self._condition:
            return str(self._task_dict) + '\n' + str(self._task_id_list)