diff --git a/metagpt/.DS_Store b/metagpt/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8efc112b351b99a8d3d2a88dd3b772d166495061 Binary files /dev/null and b/metagpt/.DS_Store differ diff --git a/metagpt/__init__.py b/metagpt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..71ddd1affca7590934bc054ccf47b325f6d29e2d --- /dev/null +++ b/metagpt/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/4/24 22:26 +# @Author : alexanderwu +# @File : __init__.py + +from metagpt import _compat as _ # noqa: F401 diff --git a/metagpt/__pycache__/__init__.cpython-310.pyc b/metagpt/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a236886b85788118220d411be3b22afdd08b352 Binary files /dev/null and b/metagpt/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/__pycache__/__init__.cpython-39.pyc b/metagpt/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3da1f4c12aa91a4c8138fe28cfa8fb0bb02034b7 Binary files /dev/null and b/metagpt/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/__pycache__/_compat.cpython-310.pyc b/metagpt/__pycache__/_compat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10c1503c170029d99f47f1a0a1d21b16ed1f59ef Binary files /dev/null and b/metagpt/__pycache__/_compat.cpython-310.pyc differ diff --git a/metagpt/__pycache__/_compat.cpython-39.pyc b/metagpt/__pycache__/_compat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d0473b3161ab7fe88f01e71d3c1215250237849 Binary files /dev/null and b/metagpt/__pycache__/_compat.cpython-39.pyc differ diff --git a/metagpt/__pycache__/config2.cpython-310.pyc b/metagpt/__pycache__/config2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d986d5dcfc793b63581193117366825e6605e7c Binary files /dev/null and b/metagpt/__pycache__/config2.cpython-310.pyc differ diff --git a/metagpt/__pycache__/config2.cpython-39.pyc b/metagpt/__pycache__/config2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf34f4f4b66cbf31e8814e76b66de09a0e6f95cf Binary files /dev/null and b/metagpt/__pycache__/config2.cpython-39.pyc differ diff --git a/metagpt/__pycache__/const.cpython-310.pyc b/metagpt/__pycache__/const.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02cecbbcca1de49df772d7ec14093460f487cb3b Binary files /dev/null and b/metagpt/__pycache__/const.cpython-310.pyc differ diff --git a/metagpt/__pycache__/const.cpython-39.pyc b/metagpt/__pycache__/const.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..134877fe4717178ec144333f2c050dcd1a5df622 Binary files /dev/null and b/metagpt/__pycache__/const.cpython-39.pyc differ diff --git a/metagpt/__pycache__/context.cpython-310.pyc b/metagpt/__pycache__/context.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89a733a1301280447eb79422a03f6ea36b966526 Binary files /dev/null and b/metagpt/__pycache__/context.cpython-310.pyc differ diff --git a/metagpt/__pycache__/context.cpython-39.pyc b/metagpt/__pycache__/context.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..944c148936bc035e59f304d7e865c2f1bda7a4b4 Binary files /dev/null and b/metagpt/__pycache__/context.cpython-39.pyc differ diff --git a/metagpt/__pycache__/context_mixin.cpython-310.pyc b/metagpt/__pycache__/context_mixin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e7a8138bc84f6efaa4b6adadb62829b2e214a49 Binary files /dev/null and b/metagpt/__pycache__/context_mixin.cpython-310.pyc differ diff --git a/metagpt/__pycache__/context_mixin.cpython-39.pyc b/metagpt/__pycache__/context_mixin.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cb99339503aa50949a0a2e3c0098d1cbad9af67 Binary files /dev/null and b/metagpt/__pycache__/context_mixin.cpython-39.pyc differ diff --git a/metagpt/__pycache__/llm.cpython-310.pyc b/metagpt/__pycache__/llm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9a7c01373d1d6005d3786dc577e4851ee7f97c0 Binary files /dev/null and b/metagpt/__pycache__/llm.cpython-310.pyc differ diff --git a/metagpt/__pycache__/llm.cpython-39.pyc b/metagpt/__pycache__/llm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2f44c6b7926394f3a3d9ac9756accbbb32a0328 Binary files /dev/null and b/metagpt/__pycache__/llm.cpython-39.pyc differ diff --git a/metagpt/__pycache__/logs.cpython-310.pyc b/metagpt/__pycache__/logs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e1e93a8e8032bd5d2e152d1daf39a109835570b Binary files /dev/null and b/metagpt/__pycache__/logs.cpython-310.pyc differ diff --git a/metagpt/__pycache__/logs.cpython-39.pyc b/metagpt/__pycache__/logs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f81955b11b76c67482e69a8452b82f7405d89ff Binary files /dev/null and b/metagpt/__pycache__/logs.cpython-39.pyc differ diff --git a/metagpt/__pycache__/repo_parser.cpython-310.pyc b/metagpt/__pycache__/repo_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..978a714cf56a0db5d7fe8b672d6143fb2aa462ae Binary files /dev/null and b/metagpt/__pycache__/repo_parser.cpython-310.pyc differ diff --git a/metagpt/__pycache__/repo_parser.cpython-39.pyc b/metagpt/__pycache__/repo_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d3ee342bedc235e60c4218ed88206f32a74520e Binary files /dev/null and b/metagpt/__pycache__/repo_parser.cpython-39.pyc differ diff --git a/metagpt/__pycache__/schema.cpython-310.pyc b/metagpt/__pycache__/schema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6777ab1c6c1bd1da36977cdbdfaf3f118012952a Binary files /dev/null and b/metagpt/__pycache__/schema.cpython-310.pyc differ diff --git a/metagpt/__pycache__/schema.cpython-39.pyc b/metagpt/__pycache__/schema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..623411bfb8b236f7bd439a310bed351271819ac4 Binary files /dev/null and b/metagpt/__pycache__/schema.cpython-39.pyc differ diff --git a/metagpt/_compat.py b/metagpt/_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..c442bd7ded67f56c5b76d27e0828702d5c6ced5b --- /dev/null +++ b/metagpt/_compat.py @@ -0,0 +1,23 @@ +import platform +import sys +import warnings + +if sys.implementation.name == "cpython" and platform.system() == "Windows": + import asyncio + + if sys.version_info[:2] == (3, 9): + from asyncio.proactor_events import _ProactorBasePipeTransport + + # https://github.com/python/cpython/pull/92842 + def pacth_del(self, _warn=warnings.warn): + if self._sock is not None: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self._sock.close() + + _ProactorBasePipeTransport.__del__ = pacth_del + + if sys.version_info >= (3, 9, 0): + from semantic_kernel.orchestration import sk_function as _ # noqa: F401 + + # caused by https://github.com/microsoft/semantic-kernel/pull/1416 + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) diff --git a/metagpt/actions/.DS_Store b/metagpt/actions/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..015c928662e1f574c6194286b7e4f49aec3e120c Binary files /dev/null and b/metagpt/actions/.DS_Store differ diff --git a/metagpt/actions/__init__.py b/metagpt/actions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..495ed403133200363b6b13b7736c7399441acf90 --- /dev/null +++ b/metagpt/actions/__init__.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:44 +@Author : alexanderwu +@File : __init__.py +""" +from enum import Enum + +from metagpt.actions.action import Action +from metagpt.actions.action_output import ActionOutput +from metagpt.actions.add_requirement import UserRequirement +from metagpt.actions.debug_error import DebugError +from metagpt.actions.design_api import WriteDesign +from metagpt.actions.design_api_review import DesignReview +from metagpt.actions.project_management import WriteTasks +from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch +from metagpt.actions.run_code import RunCode +from metagpt.actions.search_and_summarize import SearchAndSummarize +from metagpt.actions.write_code import WriteCode +from metagpt.actions.write_code_review import WriteCodeReview +from metagpt.actions.write_prd import WritePRD +from metagpt.actions.write_prd_review import WritePRDReview +from metagpt.actions.write_test import WriteTest +from metagpt.actions.di.execute_nb_code import ExecuteNbCode +from metagpt.actions.di.write_analysis_code import WriteAnalysisCode +from metagpt.actions.di.write_plan import WritePlan + + +class ActionType(Enum): + """All types of Actions, used for indexing.""" + + ADD_REQUIREMENT = UserRequirement + WRITE_PRD = WritePRD + WRITE_PRD_REVIEW = WritePRDReview + WRITE_DESIGN = WriteDesign + DESIGN_REVIEW = DesignReview + WRTIE_CODE = WriteCode + WRITE_CODE_REVIEW = WriteCodeReview + WRITE_TEST = WriteTest + RUN_CODE = RunCode + DEBUG_ERROR = DebugError + WRITE_TASKS = WriteTasks + SEARCH_AND_SUMMARIZE = SearchAndSummarize + COLLECT_LINKS = CollectLinks + WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize + CONDUCT_RESEARCH = ConductResearch + EXECUTE_NB_CODE = ExecuteNbCode + WRITE_ANALYSIS_CODE = WriteAnalysisCode + WRITE_PLAN = WritePlan + + +__all__ = [ + "ActionType", + "Action", + "ActionOutput", +] diff --git a/metagpt/actions/__pycache__/__init__.cpython-310.pyc b/metagpt/actions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8476f8c2b9ca79aacddd66f0a7fe3b68d6f84bb0 Binary files /dev/null and b/metagpt/actions/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/__init__.cpython-39.pyc b/metagpt/actions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27113f1b766570419e6c747575f21e7a19688473 Binary files /dev/null and b/metagpt/actions/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/action.cpython-310.pyc b/metagpt/actions/__pycache__/action.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a67e0e4beb555185bb0e8ee799e859e8ec1eb9b7 Binary files /dev/null and b/metagpt/actions/__pycache__/action.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/action.cpython-39.pyc b/metagpt/actions/__pycache__/action.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df8459aa23ef642bf09582950465ae43fbdd12ee Binary files /dev/null and b/metagpt/actions/__pycache__/action.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/action_node.cpython-310.pyc b/metagpt/actions/__pycache__/action_node.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..653aa3da23db21fd3385ef035d0efc7973fd76af Binary files /dev/null and b/metagpt/actions/__pycache__/action_node.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/action_node.cpython-39.pyc b/metagpt/actions/__pycache__/action_node.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95f47fe96b6d076403756bf9d053eced90d75aed Binary files /dev/null and b/metagpt/actions/__pycache__/action_node.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/action_outcls_registry.cpython-310.pyc b/metagpt/actions/__pycache__/action_outcls_registry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..049530d00e44ea39051de4a442617dd172caff90 Binary files /dev/null and b/metagpt/actions/__pycache__/action_outcls_registry.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/action_outcls_registry.cpython-39.pyc b/metagpt/actions/__pycache__/action_outcls_registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d01ebf09c8eb4aae77db82a1556628ac393a464 Binary files /dev/null and b/metagpt/actions/__pycache__/action_outcls_registry.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/action_output.cpython-310.pyc b/metagpt/actions/__pycache__/action_output.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31738a30a8a73da70887e37bca3f9274b312d751 Binary files /dev/null and b/metagpt/actions/__pycache__/action_output.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/action_output.cpython-39.pyc b/metagpt/actions/__pycache__/action_output.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01a1cfc05004417855da0fbcacb331147b2cf5d2 Binary files /dev/null and b/metagpt/actions/__pycache__/action_output.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/add_requirement.cpython-310.pyc b/metagpt/actions/__pycache__/add_requirement.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30542cc692013c9cd38cfc76ff08eff80666406d Binary files /dev/null and b/metagpt/actions/__pycache__/add_requirement.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/add_requirement.cpython-39.pyc b/metagpt/actions/__pycache__/add_requirement.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4fa66e756e112c858728eb22b797833d9dff202 Binary files /dev/null and b/metagpt/actions/__pycache__/add_requirement.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/debug_error.cpython-310.pyc b/metagpt/actions/__pycache__/debug_error.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8ce43f210cc04eb7f2f4c312e150e43d9168076 Binary files /dev/null and b/metagpt/actions/__pycache__/debug_error.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/debug_error.cpython-39.pyc b/metagpt/actions/__pycache__/debug_error.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa0844d0893e1966361869e1004d8690a2652cc1 Binary files /dev/null and b/metagpt/actions/__pycache__/debug_error.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/design_api.cpython-310.pyc b/metagpt/actions/__pycache__/design_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8eafc3432e1b14730557b2adaf87846f4ce0ec75 Binary files /dev/null and b/metagpt/actions/__pycache__/design_api.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/design_api.cpython-39.pyc b/metagpt/actions/__pycache__/design_api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba90ba852d216f3b088ce50fe1056da7de7adaff Binary files /dev/null and b/metagpt/actions/__pycache__/design_api.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/design_api_an.cpython-310.pyc b/metagpt/actions/__pycache__/design_api_an.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..295eca16feea1a96dbe761479d6e5c25c3ddd3b4 Binary files /dev/null and b/metagpt/actions/__pycache__/design_api_an.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/design_api_an.cpython-39.pyc b/metagpt/actions/__pycache__/design_api_an.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5586aff7bce1c43be8487b2754bd7a1836e60b39 Binary files /dev/null and b/metagpt/actions/__pycache__/design_api_an.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/design_api_review.cpython-310.pyc b/metagpt/actions/__pycache__/design_api_review.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05ec71db09dbd41dad5b9f76539dc60dc50fa1fd Binary files /dev/null and b/metagpt/actions/__pycache__/design_api_review.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/design_api_review.cpython-39.pyc b/metagpt/actions/__pycache__/design_api_review.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94486289eb98954ab332454d8190666056cc7cec Binary files /dev/null and b/metagpt/actions/__pycache__/design_api_review.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/fix_bug.cpython-310.pyc b/metagpt/actions/__pycache__/fix_bug.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..980dc572306d1742c2e3f28a582828c7d2091b1f Binary files /dev/null and b/metagpt/actions/__pycache__/fix_bug.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/fix_bug.cpython-39.pyc b/metagpt/actions/__pycache__/fix_bug.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59bd46c4499ce360f82f255c332d42399c50740c Binary files /dev/null and b/metagpt/actions/__pycache__/fix_bug.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/prepare_documents.cpython-310.pyc b/metagpt/actions/__pycache__/prepare_documents.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed880080802077c19b12a49e4f5b91a0ffa67735 Binary files /dev/null and b/metagpt/actions/__pycache__/prepare_documents.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/prepare_documents.cpython-39.pyc b/metagpt/actions/__pycache__/prepare_documents.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3cdcf1e809c29b096c6d591b20ed2f53afff1b3 Binary files /dev/null and b/metagpt/actions/__pycache__/prepare_documents.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/project_management.cpython-310.pyc b/metagpt/actions/__pycache__/project_management.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5333a558f82656506c7af92101fa55597448d75 Binary files /dev/null and b/metagpt/actions/__pycache__/project_management.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/project_management.cpython-39.pyc b/metagpt/actions/__pycache__/project_management.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80ed079bdad0dee251fe42b7354321f55f48b651 Binary files /dev/null and b/metagpt/actions/__pycache__/project_management.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/project_management_an.cpython-310.pyc b/metagpt/actions/__pycache__/project_management_an.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5205b482c73330d25c51898181fc8b8014b0912f Binary files /dev/null and b/metagpt/actions/__pycache__/project_management_an.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/project_management_an.cpython-39.pyc b/metagpt/actions/__pycache__/project_management_an.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81788f13eec6a6486b8d82634eef1b84dc19440f Binary files /dev/null and b/metagpt/actions/__pycache__/project_management_an.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/research.cpython-310.pyc b/metagpt/actions/__pycache__/research.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a03c9e200f0dee584aff2be9d2d848308fb9cef Binary files /dev/null and b/metagpt/actions/__pycache__/research.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/research.cpython-39.pyc b/metagpt/actions/__pycache__/research.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8e1033599208fe9c5e8f52bb7e3b7f5c2d61714 Binary files /dev/null and b/metagpt/actions/__pycache__/research.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/run_code.cpython-310.pyc b/metagpt/actions/__pycache__/run_code.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9d95bc49c753dae18a81850d069d7b44cb892a9 Binary files /dev/null and b/metagpt/actions/__pycache__/run_code.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/run_code.cpython-39.pyc b/metagpt/actions/__pycache__/run_code.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26997921801eff7a15affd30a59d913c0102d178 Binary files /dev/null and b/metagpt/actions/__pycache__/run_code.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/search_and_summarize.cpython-310.pyc b/metagpt/actions/__pycache__/search_and_summarize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51b3202cf1ef96b04f917261d4138cf41360ea3a Binary files /dev/null and b/metagpt/actions/__pycache__/search_and_summarize.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/search_and_summarize.cpython-39.pyc b/metagpt/actions/__pycache__/search_and_summarize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47344f5e6dc4eafb2dece063c2609472a958aef6 Binary files /dev/null and b/metagpt/actions/__pycache__/search_and_summarize.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/search_enhanced_qa.cpython-310.pyc b/metagpt/actions/__pycache__/search_enhanced_qa.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea3c90fd15aaa20f66f0803ff8402af26722cd10 Binary files /dev/null and b/metagpt/actions/__pycache__/search_enhanced_qa.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/search_enhanced_qa.cpython-39.pyc b/metagpt/actions/__pycache__/search_enhanced_qa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93f8b4c4d3c40684d78b40a44ea960d7b8827410 Binary files /dev/null and b/metagpt/actions/__pycache__/search_enhanced_qa.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/summarize_code.cpython-310.pyc b/metagpt/actions/__pycache__/summarize_code.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba8d0ddf9c8e58c154b223d875eb719f0be6e04e Binary files /dev/null and b/metagpt/actions/__pycache__/summarize_code.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/summarize_code.cpython-39.pyc b/metagpt/actions/__pycache__/summarize_code.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2347997477e68191b4ab498fdfae90c58af16b1 Binary files /dev/null and b/metagpt/actions/__pycache__/summarize_code.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/write_code.cpython-310.pyc b/metagpt/actions/__pycache__/write_code.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..167c67170917c161c36936263f79b5104a726295 Binary files /dev/null and b/metagpt/actions/__pycache__/write_code.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/write_code.cpython-39.pyc b/metagpt/actions/__pycache__/write_code.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..235d2bd3145a35d5969b8061bcd19ec0968cf142 Binary files /dev/null and b/metagpt/actions/__pycache__/write_code.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/write_code_plan_and_change_an.cpython-310.pyc b/metagpt/actions/__pycache__/write_code_plan_and_change_an.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eb2480150eaa0f34ade58ac13e29e8ebe2dfadb Binary files /dev/null and b/metagpt/actions/__pycache__/write_code_plan_and_change_an.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/write_code_plan_and_change_an.cpython-39.pyc b/metagpt/actions/__pycache__/write_code_plan_and_change_an.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a726ba39abdda9413ea574a5c99011d69ada3c14 Binary files /dev/null and b/metagpt/actions/__pycache__/write_code_plan_and_change_an.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/write_code_review.cpython-310.pyc b/metagpt/actions/__pycache__/write_code_review.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b336c81e691b8f4c2f0f5e786acf47eda3005367 Binary files /dev/null and b/metagpt/actions/__pycache__/write_code_review.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/write_code_review.cpython-39.pyc b/metagpt/actions/__pycache__/write_code_review.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c7894a59d1025114a8a29aea2589d8b5f5516f2 Binary files /dev/null and b/metagpt/actions/__pycache__/write_code_review.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/write_prd.cpython-310.pyc b/metagpt/actions/__pycache__/write_prd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78e847195ab2fb2ba0fafc352d7049eb9ddc0349 Binary files /dev/null and b/metagpt/actions/__pycache__/write_prd.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/write_prd.cpython-39.pyc b/metagpt/actions/__pycache__/write_prd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81840b1bf53a40298814df9d8c864a065fec08c8 Binary files /dev/null and b/metagpt/actions/__pycache__/write_prd.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/write_prd_an.cpython-310.pyc b/metagpt/actions/__pycache__/write_prd_an.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5be9ad80ead81a78e52a7f0b922af37b66515ae Binary files /dev/null and b/metagpt/actions/__pycache__/write_prd_an.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/write_prd_an.cpython-39.pyc b/metagpt/actions/__pycache__/write_prd_an.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..128f9a98bff34f42c9c3be4bc58c8096ff1cb2c6 Binary files /dev/null and b/metagpt/actions/__pycache__/write_prd_an.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/write_prd_review.cpython-310.pyc b/metagpt/actions/__pycache__/write_prd_review.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfd357d25e8492d0e4d025cba1186e252f4bbcbc Binary files /dev/null and b/metagpt/actions/__pycache__/write_prd_review.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/write_prd_review.cpython-39.pyc b/metagpt/actions/__pycache__/write_prd_review.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95c386c1f67e7b74b773ca527bcb71030c4a0aba Binary files /dev/null and b/metagpt/actions/__pycache__/write_prd_review.cpython-39.pyc differ diff --git a/metagpt/actions/__pycache__/write_test.cpython-310.pyc b/metagpt/actions/__pycache__/write_test.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a95044326028b932bbbcd0200306f99b755639a Binary files /dev/null and b/metagpt/actions/__pycache__/write_test.cpython-310.pyc differ diff --git a/metagpt/actions/__pycache__/write_test.cpython-39.pyc b/metagpt/actions/__pycache__/write_test.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab8eb61bb9509bf1291eacbf0bab9b84bb22dab3 Binary files /dev/null and b/metagpt/actions/__pycache__/write_test.cpython-39.pyc differ diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8a94803a84598b467bc51201b8890d9ac991fc --- /dev/null +++ b/metagpt/actions/action.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 14:43 +@Author : alexanderwu +@File : action.py +""" + +from __future__ import annotations + +from typing import Any, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from metagpt.actions.action_node import ActionNode +from metagpt.configs.models_config import ModelsConfig +from metagpt.context_mixin import ContextMixin +from metagpt.provider.llm_provider_registry import create_llm_instance +from metagpt.schema import ( + CodePlanAndChangeContext, + CodeSummarizeContext, + CodingContext, + RunCodeContext, + SerializationMixin, + TestingContext, +) + + +class Action(SerializationMixin, ContextMixin, BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str = "" + i_context: Union[ + dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, CodePlanAndChangeContext, str, None + ] = "" + prefix: str = "" # aask*时会加上prefix,作为system_message + desc: str = "" # for skill manager + node: ActionNode = Field(default=None, exclude=True) + # The model name or API type of LLM of the `models` in the `config2.yaml`; + # Using `None` to use the `llm` configuration in the `config2.yaml`. + llm_name_or_type: Optional[str] = None + + @model_validator(mode="after") + @classmethod + def _update_private_llm(cls, data: Any) -> Any: + config = ModelsConfig.default().get(data.llm_name_or_type) + if config: + llm = create_llm_instance(config) + llm.cost_manager = data.llm.cost_manager + data.llm = llm + return data + + @property + def prompt_schema(self): + return self.config.prompt_schema + + @property + def project_name(self): + return self.config.project_name + + @project_name.setter + def project_name(self, value): + self.config.project_name = value + + @property + def project_path(self): + return self.config.project_path + + @model_validator(mode="before") + @classmethod + def set_name_if_empty(cls, values): + if "name" not in values or not values["name"]: + values["name"] = cls.__name__ + return values + + @model_validator(mode="before") + @classmethod + def _init_with_instruction(cls, values): + if "instruction" in values: + name = values["name"] + i = values.pop("instruction") + values["node"] = ActionNode(key=name, expected_type=str, instruction=i, example="", schema="raw") + return values + + def set_prefix(self, prefix): + """Set prefix for later usage""" + self.prefix = prefix + self.llm.system_prompt = prefix + if self.node: + self.node.llm = self.llm + return self + + def __str__(self): + return self.__class__.__name__ + + def __repr__(self): + return self.__str__() + + async def _aask(self, prompt: str, system_msgs: Optional[list[str]] = None) -> str: + """Append default prefix""" + return await self.llm.aask(prompt, system_msgs) + + async def _run_action_node(self, *args, **kwargs): + """Run action node""" + msgs = args[0] + context = "## History Messages\n" + context += "\n".join([f"{idx}: {i}" for idx, i in enumerate(reversed(msgs))]) + return await self.node.fill(req=context, llm=self.llm) + + async def run(self, *args, **kwargs): + """Run action""" + if self.node: + return await self._run_action_node(*args, **kwargs) + raise NotImplementedError("The run method should be implemented in a subclass.") + + def override_context(self): + """Set `private_context` and `context` to the same `Context` object.""" + if not self.private_context: + self.private_context = self.context diff --git a/metagpt/actions/action_graph.py b/metagpt/actions/action_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..893bc6d4c27c5b619b8a86797bba8bc4927879ea --- /dev/null +++ b/metagpt/actions/action_graph.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/30 13:52 +@Author : alexanderwu +@File : action_graph.py +""" +from __future__ import annotations + +# from metagpt.actions.action_node import ActionNode + + +class ActionGraph: + """ActionGraph: a directed graph to represent the dependency between actions.""" + + def __init__(self): + self.nodes = {} + self.edges = {} + self.execution_order = [] + + def add_node(self, node): + """Add a node to the graph""" + self.nodes[node.key] = node + + def add_edge(self, from_node: "ActionNode", to_node: "ActionNode"): + """Add an edge to the graph""" + if from_node.key not in self.edges: + self.edges[from_node.key] = [] + self.edges[from_node.key].append(to_node.key) + from_node.add_next(to_node) + to_node.add_prev(from_node) + + def topological_sort(self): + """Topological sort the graph""" + visited = set() + stack = [] + + def visit(k): + if k not in visited: + visited.add(k) + if k in self.edges: + for next_node in self.edges[k]: + visit(next_node) + stack.insert(0, k) + + for key in self.nodes: + visit(key) + + self.execution_order = stack diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py new file mode 100644 index 0000000000000000000000000000000000000000..7109f287e29bb12e2ac3539b4cb8f62a38c0d18c --- /dev/null +++ b/metagpt/actions/action_node.py @@ -0,0 +1,876 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/11 18:45 +@Author : alexanderwu +@File : action_node.py + +NOTE: You should use typing.List instead of list to do type annotation. Because in the markdown extraction process, + we can use typing to extract the type of the node, but we cannot use built-in list to extract. +""" +import json +import re +import typing +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from pydantic import BaseModel, Field, create_model, model_validator +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions.action_outcls_registry import register_action_outcls +from metagpt.const import MARKDOWN_TITLE_PREFIX, USE_CONFIG_TIMEOUT +from metagpt.exp_pool import exp_cache +from metagpt.exp_pool.serializers import ActionNodeSerializer +from metagpt.llm import BaseLLM +from metagpt.logs import logger +from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess +from metagpt.utils.common import OutputParser, general_after_log +from metagpt.utils.human_interaction import HumanInteraction +from metagpt.utils.sanitize import sanitize + + +class ReviewMode(Enum): + HUMAN = "human" + AUTO = "auto" + + +class ReviseMode(Enum): + HUMAN = "human" # human revise + HUMAN_REVIEW = "human_review" # human-review and auto-revise + AUTO = "auto" # auto-review and auto-revise + + +TAG = "CONTENT" + + +class FillMode(Enum): + CODE_FILL = "code_fill" + XML_FILL = "xml_fill" + SINGLE_FILL = "single_fill" + + +LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT." +FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else." + + +SIMPLE_TEMPLATE = """ +## context +{context} + +----- + +## format example +{example} + +## nodes: ": # " +{instruction} + +## constraint +{constraint} + +## action +Follow instructions of nodes, generate output and make sure it follows the format example. +""" + +REVIEW_TEMPLATE = """ +## context +Compare the key's value of nodes_output and the corresponding requirements one by one. If a key's value that does not match the requirement is found, provide the comment content on how to modify it. No output is required for matching keys. + +### nodes_output +{nodes_output} + +----- + +## format example +[{tag}] +{{ + "key1": "comment1", + "key2": "comment2", + "keyn": "commentn" +}} +[/{tag}] + +## nodes: ": # " +- key1: # the first key name of mismatch key +- key2: # the second key name of mismatch key +- keyn: # the last key name of mismatch key + +## constraint +{constraint} + +## action +Follow format example's {prompt_schema} format, generate output and make sure it follows the format example. +""" + +REVISE_TEMPLATE = """ +## context +change the nodes_output key's value to meet its comment and no need to add extra comment. + +### nodes_output +{nodes_output} + +----- + +## format example +{example} + +## nodes: ": # " +{instruction} + +## constraint +{constraint} + +## action +Follow format example's {prompt_schema} format, generate output and make sure it follows the format example. +""" + + +def dict_to_markdown(d, prefix=MARKDOWN_TITLE_PREFIX, kv_sep="\n", postfix="\n"): + markdown_str = "" + for key, value in d.items(): + markdown_str += f"{prefix}{key}{kv_sep}{value}{postfix}" + return markdown_str + + +class ActionNode: + """ActionNode is a tree of nodes.""" + + schema: str # raw/json/markdown, default: "" + + # Action Context + context: str # all the context, including all necessary info + llm: BaseLLM # LLM with aask interface + children: dict[str, "ActionNode"] + + # Action Input + key: str # Product Requirement / File list / Code + func: typing.Callable # 与节点相关联的函数或LLM调用 + params: Dict[str, Type] # 输入参数的字典,键为参数名,值为参数类型 + expected_type: Type # such as str / int / float etc. + # context: str # everything in the history. + instruction: str # the instructions should be followed. + example: Any # example for In Context-Learning. + + # Action Output + content: str + instruct_content: BaseModel + + # For ActionGraph + prevs: List["ActionNode"] # previous nodes + nexts: List["ActionNode"] # next nodes + + def __init__( + self, + key: str, + expected_type: Type, + instruction: str, + example: Any, + content: str = "", + children: dict[str, "ActionNode"] = None, + schema: str = "", + ): + self.key = key + self.expected_type = expected_type + self.instruction = instruction + self.example = example + self.content = content + self.children = children if children is not None else {} + self.schema = schema + self.prevs = [] + self.nexts = [] + + def __str__(self): + return ( + f"{self.key}, {repr(self.expected_type)}, {self.instruction}, {self.example}" + f", {self.content}, {self.children}" + ) + + def __repr__(self): + return self.__str__() + + def add_prev(self, node: "ActionNode"): + """增加前置ActionNode""" + self.prevs.append(node) + + def add_next(self, node: "ActionNode"): + """增加后置ActionNode""" + self.nexts.append(node) + + def add_child(self, node: "ActionNode"): + """增加子ActionNode""" + self.children[node.key] = node + + def get_child(self, key: str) -> Union["ActionNode", None]: + return self.children.get(key, None) + + def add_children(self, nodes: List["ActionNode"]): + """批量增加子ActionNode""" + for node in nodes: + self.add_child(node) + + @classmethod + def from_children(cls, key, nodes: List["ActionNode"]): + """直接从一系列的子nodes初始化""" + obj = cls(key, str, "", "") + obj.add_children(nodes) + return obj + + def _get_children_mapping(self, exclude=None) -> Dict[str, Any]: + """获得子ActionNode的字典,以key索引,支持多级结构。""" + exclude = exclude or [] + + def _get_mapping(node: "ActionNode") -> Dict[str, Any]: + mapping = {} + for key, child in node.children.items(): + if key in exclude: + continue + # 对于嵌套的子节点,递归调用 _get_mapping + if child.children: + mapping[key] = _get_mapping(child) + else: + mapping[key] = (child.expected_type, Field(default=child.example, description=child.instruction)) + return mapping + + return _get_mapping(self) + + def _get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: + """get self key: type mapping""" + return {self.key: (self.expected_type, ...)} + + def get_mapping(self, mode="children", exclude=None) -> Dict[str, Tuple[Type, Any]]: + """get key: type mapping under mode""" + if mode == "children" or (mode == "auto" and self.children): + return self._get_children_mapping(exclude=exclude) + return {} if exclude and self.key in exclude else self._get_self_mapping() + + @classmethod + @register_action_outcls + def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]): + """基于pydantic v2的模型动态生成,用来检验结果类型正确性""" + + def check_fields(cls, values): + all_fields = set(mapping.keys()) + required_fields = set() + for k, v in mapping.items(): + type_v, field_info = v + if ActionNode.is_optional_type(type_v): + continue + required_fields.add(k) + + missing_fields = required_fields - set(values.keys()) + if missing_fields: + raise ValueError(f"Missing fields: {missing_fields}") + + unrecognized_fields = set(values.keys()) - all_fields + if unrecognized_fields: + logger.warning(f"Unrecognized fields: {unrecognized_fields}") + return values + + validators = {"check_missing_fields_validator": model_validator(mode="before")(check_fields)} + + new_fields = {} + for field_name, field_value in mapping.items(): + if isinstance(field_value, dict): + # 对于嵌套结构,递归创建模型类 + nested_class_name = f"{class_name}_{field_name}" + nested_class = cls.create_model_class(nested_class_name, field_value) + new_fields[field_name] = (nested_class, ...) + else: + new_fields[field_name] = field_value + + new_class = create_model(class_name, __validators__=validators, **new_fields) + return new_class + + def create_class(self, mode: str = "auto", class_name: str = None, exclude=None): + class_name = class_name if class_name else f"{self.key}_AN" + mapping = self.get_mapping(mode=mode, exclude=exclude) + return self.create_model_class(class_name, mapping) + + def _create_children_class(self, exclude=None): + """使用object内有的字段直接生成model_class""" + class_name = f"{self.key}_AN" + mapping = self._get_children_mapping(exclude=exclude) + return self.create_model_class(class_name, mapping) + + def to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict: + """将当前节点与子节点都按照node: format的格式组织成字典""" + nodes = self._to_dict(format_func=format_func, mode=mode, exclude=exclude) + if not isinstance(nodes, dict): + nodes = {self.key: nodes} + return nodes + + def _to_dict(self, format_func=None, mode="auto", exclude=None) -> Dict: + """将当前节点与子节点都按照node: format的格式组织成字典""" + + # 如果没有提供格式化函数,则使用默认的格式化函数 + if format_func is None: + format_func = lambda node: node.instruction + + # 使用提供的格式化函数来格式化当前节点的值 + formatted_value = format_func(self) + + # 创建当前节点的键值对 + if (mode == "children" or mode == "auto") and self.children: + node_value = {} + else: + node_value = formatted_value + + if mode == "root": + return {self.key: node_value} + + # 递归处理子节点 + exclude = exclude or [] + for child_key, child_node in self.children.items(): + if child_key in exclude: + continue + # 递归调用 to_dict 方法并更新节点字典 + child_dict = child_node._to_dict(format_func, mode, exclude) + node_value[child_key] = child_dict + + return node_value + + def update_instruct_content(self, incre_data: dict[str, Any]): + assert self.instruct_content + origin_sc_dict = self.instruct_content.model_dump() + origin_sc_dict.update(incre_data) + output_class = self.create_class() + self.instruct_content = output_class(**origin_sc_dict) + + def keys(self, mode: str = "auto") -> list: + if mode == "children" or (mode == "auto" and self.children): + keys = [] + else: + keys = [self.key] + if mode == "root": + return keys + + for _, child_node in self.children.items(): + keys.append(child_node.key) + return keys + + def compile_to(self, i: Dict, schema, kv_sep) -> str: + if schema == "json": + return json.dumps(i, indent=4, ensure_ascii=False) + elif schema == "markdown": + return dict_to_markdown(i, kv_sep=kv_sep) + else: + return str(i) + + def tagging(self, text, schema, tag="") -> str: + if not tag: + return text + return f"[{tag}]\n{text}\n[/{tag}]" + + def _compile_f(self, schema, mode, tag, format_func, kv_sep, exclude=None) -> str: + nodes = self.to_dict(format_func=format_func, mode=mode, exclude=exclude) + text = self.compile_to(nodes, schema, kv_sep) + return self.tagging(text, schema, tag) + + def compile_instruction(self, schema="markdown", mode="children", tag="", exclude=None) -> str: + """compile to raw/json/markdown template with all/root/children nodes""" + format_func = lambda i: f"{i.expected_type} # {i.instruction}" + return self._compile_f(schema, mode, tag, format_func, kv_sep=": ", exclude=exclude) + + def compile_example(self, schema="json", mode="children", tag="", exclude=None) -> str: + """compile to raw/json/markdown examples with all/root/children nodes""" + + # 这里不能使用f-string,因为转译为str后再json.dumps会额外加上引号,无法作为有效的example + # 错误示例:"File list": "['main.py', 'const.py', 'game.py']", 注意这里值不是list,而是str + format_func = lambda i: i.example + return self._compile_f(schema, mode, tag, format_func, kv_sep="\n", exclude=exclude) + + def compile(self, context, schema="json", mode="children", template=SIMPLE_TEMPLATE, exclude=[]) -> str: + """ + mode: all/root/children + mode="children": 编译所有子节点为一个统一模板,包括instruction与example + mode="all": NotImplemented + mode="root": NotImplemented + schmea: raw/json/markdown + schema="raw": 不编译,context, lang_constaint, instruction + schema="json":编译context, example(json), instruction(markdown), constraint, action + schema="markdown": 编译context, example(markdown), instruction(markdown), constraint, action + """ + if schema == "raw": + return f"{context}\n\n## Actions\n{LANGUAGE_CONSTRAINT}\n{self.instruction}" + + ### 直接使用 pydantic BaseModel 生成 instruction 与 example,仅限 JSON + # child_class = self._create_children_class() + # node_schema = child_class.model_json_schema() + # defaults = { + # k: str(v) + # for k, v in child_class.model_fields.items() + # if k not in exclude + # } + # instruction = node_schema + # example = json.dumps(defaults, indent=4) + + # FIXME: json instruction会带来格式问题,如:"Project name": "web_2048 # 项目名称使用下划线", + # compile example暂时不支持markdown + instruction = self.compile_instruction(schema="markdown", mode=mode, exclude=exclude) + example = self.compile_example(schema=schema, tag=TAG, mode=mode, exclude=exclude) + # nodes = ", ".join(self.to_dict(mode=mode).keys()) + constraints = [LANGUAGE_CONSTRAINT, FORMAT_CONSTRAINT] + constraint = "\n".join(constraints) + + prompt = template.format( + context=context, + example=example, + instruction=instruction, + constraint=constraint, + ) + return prompt + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _aask_v1( + self, + prompt: str, + output_class_name: str, + output_data_mapping: dict, + images: Optional[Union[str, list[str]]] = None, + system_msgs: Optional[list[str]] = None, + schema="markdown", # compatible to original format + timeout=USE_CONFIG_TIMEOUT, + ) -> (str, BaseModel): + """Use ActionOutput to wrap the output of aask""" + content = await self.llm.aask(prompt, system_msgs, images=images, timeout=timeout) + logger.debug(f"llm raw output:\n{content}") + output_class = self.create_model_class(output_class_name, output_data_mapping) + + if schema == "json": + parsed_data = llm_output_postprocess( + output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" + ) + else: # using markdown parser + parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping) + + logger.debug(f"parsed_data:\n{parsed_data}") + instruct_content = output_class(**parsed_data) + return content, instruct_content + + def get(self, key): + return self.instruct_content.model_dump()[key] + + def set_recursive(self, name, value): + setattr(self, name, value) + for _, i in self.children.items(): + i.set_recursive(name, value) + + def set_llm(self, llm): + self.set_recursive("llm", llm) + + def set_context(self, context): + self.set_recursive("context", context) + + async def simple_fill( + self, schema, mode, images: Optional[Union[str, list[str]]] = None, timeout=USE_CONFIG_TIMEOUT, exclude=None + ): + prompt = self.compile(context=self.context, schema=schema, mode=mode, exclude=exclude) + if schema != "raw": + mapping = self.get_mapping(mode, exclude=exclude) + class_name = f"{self.key}_AN" + content, scontent = await self._aask_v1( + prompt, class_name, mapping, images=images, schema=schema, timeout=timeout + ) + self.content = content + self.instruct_content = scontent + else: + self.content = await self.llm.aask(prompt) + self.instruct_content = None + + return self + + def get_field_name(self): + """ + Get the field name from the Pydantic model associated with this ActionNode. + """ + model_class = self.create_class() + fields = model_class.model_fields + + # Assuming there's only one field in the model + if len(fields) == 1: + return next(iter(fields)) + + # If there are multiple fields, we might want to use self.key to find the right one + return self.key + + def get_field_names(self): + """ + Get the field names associated with this ActionNode's Pydantic model. + """ + model_class = self.create_class() + return model_class.model_fields.keys() + + def get_field_types(self): + """ + Get the field types associated with this ActionNode's Pydantic model. + """ + model_class = self.create_class() + return {field_name: field.annotation for field_name, field in model_class.model_fields.items()} + + def xml_compile(self, context): + """ + Compile the prompt to make it easier for the model to understand the xml format. + """ + field_names = self.get_field_names() + # Construct the example using the field names + examples = [] + for field_name in field_names: + examples.append(f"<{field_name}>content") + + # Join all examples into a single string + example_str = "\n".join(examples) + # Add the example to the context + context += f""" +### Response format (must be strictly followed): All content must be enclosed in the given XML tags, ensuring each opening has a corresponding closing , with no incomplete or self-closing tags allowed.\n +{example_str} +""" + return context + + async def code_fill( + self, context: str, function_name: Optional[str] = None, timeout: int = USE_CONFIG_TIMEOUT + ) -> Dict[str, str]: + """ + Fill CodeBlock Using ``` ``` + """ + field_name = self.get_field_name() + prompt = context + content = await self.llm.aask(prompt, timeout=timeout) + extracted_code = sanitize(code=content, entrypoint=function_name) + result = {field_name: extracted_code} + return result + + async def single_fill(self, context: str, images: Optional[Union[str, list[str]]] = None) -> Dict[str, str]: + field_name = self.get_field_name() + prompt = context + content = await self.llm.aask(prompt, images=images) + result = {field_name: content} + return result + + async def xml_fill(self, context: str, images: Optional[Union[str, list[str]]] = None) -> Dict[str, Any]: + """ + Fill context with XML tags and convert according to field types, including string, integer, boolean, list and dict types + """ + field_names = self.get_field_names() + field_types = self.get_field_types() + + extracted_data: Dict[str, Any] = {} + content = await self.llm.aask(context, images=images) + + for field_name in field_names: + pattern = rf"<{field_name}>(.*?)" + match = re.search(pattern, content, re.DOTALL) + if match: + raw_value = match.group(1).strip() + field_type = field_types.get(field_name) + + if field_type == str: + extracted_data[field_name] = raw_value + elif field_type == int: + try: + extracted_data[field_name] = int(raw_value) + except ValueError: + extracted_data[field_name] = 0 # 或者其他默认值 + elif field_type == bool: + extracted_data[field_name] = raw_value.lower() in ("true", "yes", "1", "on", "True") + elif field_type == list: + try: + extracted_data[field_name] = eval(raw_value) + if not isinstance(extracted_data[field_name], list): + raise ValueError + except: + extracted_data[field_name] = [] # 默认空列表 + elif field_type == dict: + try: + extracted_data[field_name] = eval(raw_value) + if not isinstance(extracted_data[field_name], dict): + raise ValueError + except: + extracted_data[field_name] = {} # 默认空字典 + + return extracted_data + + @exp_cache(serializer=ActionNodeSerializer()) + async def fill( + self, + *, + req, + llm, + schema="json", + mode="auto", + strgy="simple", + images: Optional[Union[str, list[str]]] = None, + timeout=USE_CONFIG_TIMEOUT, + exclude=[], + function_name: str = None, + ): + """Fill the node(s) with mode. + + :param req: Everything we should know when filling node. + :param llm: Large Language Model with pre-defined system message. + :param schema: json/markdown, determine example and output format. + - raw: free form text + - json: it's easy to open source LLM with json format + - markdown: when generating code, markdown is always better + :param mode: auto/children/root + - auto: automated fill children's nodes and gather outputs, if no children, fill itself + - children: fill children's nodes and gather outputs + - root: fill root's node and gather output + :param strgy: simple/complex + - simple: run only once + - complex: run each node + :param images: the list of image url or base64 for gpt4-v + :param timeout: Timeout for llm invocation. + :param exclude: The keys of ActionNode to exclude. + :return: self + """ + self.set_llm(llm) + self.set_context(req) + if self.schema: + schema = self.schema + + if mode == FillMode.CODE_FILL.value: + result = await self.code_fill(context, function_name, timeout) + self.instruct_content = self.create_class()(**result) + return self + + elif mode == FillMode.XML_FILL.value: + context = self.xml_compile(context=self.context) + result = await self.xml_fill(context, images=images) + self.instruct_content = self.create_class()(**result) + return self + + elif mode == FillMode.SINGLE_FILL.value: + result = await self.single_fill(context, images=images) + self.instruct_content = self.create_class()(**result) + return self + + if strgy == "simple": + return await self.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude) + elif strgy == "complex": + # 这里隐式假设了拥有children + tmp = {} + for _, i in self.children.items(): + if exclude and i.key in exclude: + continue + child = await i.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude) + tmp.update(child.instruct_content.model_dump()) + cls = self._create_children_class() + self.instruct_content = cls(**tmp) + return self + + async def human_review(self) -> dict[str, str]: + review_comments = HumanInteraction().interact_with_instruct_content( + instruct_content=self.instruct_content, interact_type="review" + ) + + return review_comments + + def _makeup_nodes_output_with_req(self) -> dict[str, str]: + instruct_content_dict = self.instruct_content.model_dump() + nodes_output = {} + for key, value in instruct_content_dict.items(): + child = self.get_child(key) + nodes_output[key] = {"value": value, "requirement": child.instruction if child else self.instruction} + return nodes_output + + async def auto_review(self, template: str = REVIEW_TEMPLATE) -> dict[str, str]: + """use key's output value and its instruction to review the modification comment""" + nodes_output = self._makeup_nodes_output_with_req() + """nodes_output format: + { + "key": {"value": "output value", "requirement": "key instruction"} + } + """ + if not nodes_output: + return dict() + + prompt = template.format( + nodes_output=json.dumps(nodes_output, ensure_ascii=False), + tag=TAG, + constraint=FORMAT_CONSTRAINT, + prompt_schema="json", + ) + + content = await self.llm.aask(prompt) + # Extract the dict of mismatch key and its comment. Due to the mismatch keys are unknown, here use the keys + # of ActionNode to judge if exist in `content` and then follow the `data_mapping` method to create model class. + keys = self.keys() + include_keys = [] + for key in keys: + if f'"{key}":' in content: + include_keys.append(key) + if not include_keys: + return dict() + + exclude_keys = list(set(keys).difference(include_keys)) + output_class_name = f"{self.key}_AN_REVIEW" + output_class = self.create_class(class_name=output_class_name, exclude=exclude_keys) + parsed_data = llm_output_postprocess( + output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" + ) + instruct_content = output_class(**parsed_data) + return instruct_content.model_dump() + + async def simple_review(self, review_mode: ReviewMode = ReviewMode.AUTO): + # generate review comments + if review_mode == ReviewMode.HUMAN: + review_comments = await self.human_review() + else: + review_comments = await self.auto_review() + + if not review_comments: + logger.warning("There are no review comments") + return review_comments + + async def review(self, strgy: str = "simple", review_mode: ReviewMode = ReviewMode.AUTO): + """only give the review comment of each exist and mismatch key + + :param strgy: simple/complex + - simple: run only once + - complex: run each node + """ + if not hasattr(self, "llm"): + raise RuntimeError("use `review` after `fill`") + assert review_mode in ReviewMode + assert self.instruct_content, 'review only support with `schema != "raw"`' + + if strgy == "simple": + review_comments = await self.simple_review(review_mode) + elif strgy == "complex": + # review each child node one-by-one + review_comments = {} + for _, child in self.children.items(): + child_review_comment = await child.simple_review(review_mode) + review_comments.update(child_review_comment) + + return review_comments + + async def human_revise(self) -> dict[str, str]: + review_contents = HumanInteraction().interact_with_instruct_content( + instruct_content=self.instruct_content, mapping=self.get_mapping(mode="auto"), interact_type="revise" + ) + # re-fill the ActionNode + self.update_instruct_content(review_contents) + return review_contents + + def _makeup_nodes_output_with_comment(self, review_comments: dict[str, str]) -> dict[str, str]: + instruct_content_dict = self.instruct_content.model_dump() + nodes_output = {} + for key, value in instruct_content_dict.items(): + if key in review_comments: + nodes_output[key] = {"value": value, "comment": review_comments[key]} + return nodes_output + + async def auto_revise( + self, revise_mode: ReviseMode = ReviseMode.AUTO, template: str = REVISE_TEMPLATE + ) -> dict[str, str]: + """revise the value of incorrect keys""" + # generate review comments + if revise_mode == ReviseMode.AUTO: + review_comments: dict = await self.auto_review() + elif revise_mode == ReviseMode.HUMAN_REVIEW: + review_comments: dict = await self.human_review() + + include_keys = list(review_comments.keys()) + + # generate revise content, two-steps + # step1, find the needed revise keys from review comments to makeup prompt template + nodes_output = self._makeup_nodes_output_with_comment(review_comments) + keys = self.keys() + exclude_keys = list(set(keys).difference(include_keys)) + example = self.compile_example(schema="json", mode="auto", tag=TAG, exclude=exclude_keys) + instruction = self.compile_instruction(schema="markdown", mode="auto", exclude=exclude_keys) + + prompt = template.format( + nodes_output=json.dumps(nodes_output, ensure_ascii=False), + example=example, + instruction=instruction, + constraint=FORMAT_CONSTRAINT, + prompt_schema="json", + ) + + # step2, use `_aask_v1` to get revise structure result + output_mapping = self.get_mapping(mode="auto", exclude=exclude_keys) + output_class_name = f"{self.key}_AN_REVISE" + content, scontent = await self._aask_v1( + prompt=prompt, output_class_name=output_class_name, output_data_mapping=output_mapping, schema="json" + ) + + # re-fill the ActionNode + sc_dict = scontent.model_dump() + self.update_instruct_content(sc_dict) + return sc_dict + + async def simple_revise(self, revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]: + if revise_mode == ReviseMode.HUMAN: + revise_contents = await self.human_revise() + else: + revise_contents = await self.auto_revise(revise_mode) + + return revise_contents + + async def revise(self, strgy: str = "simple", revise_mode: ReviseMode = ReviseMode.AUTO) -> dict[str, str]: + """revise the content of ActionNode and update the instruct_content + + :param strgy: simple/complex + - simple: run only once + - complex: run each node + """ + if not hasattr(self, "llm"): + raise RuntimeError("use `revise` after `fill`") + assert revise_mode in ReviseMode + assert self.instruct_content, 'revise only support with `schema != "raw"`' + + if strgy == "simple": + revise_contents = await self.simple_revise(revise_mode) + elif strgy == "complex": + # revise each child node one-by-one + revise_contents = {} + for _, child in self.children.items(): + child_revise_content = await child.simple_revise(revise_mode) + revise_contents.update(child_revise_content) + self.update_instruct_content(revise_contents) + + return revise_contents + + @classmethod + def from_pydantic(cls, model: Type[BaseModel], key: str = None): + """ + Creates an ActionNode tree from a Pydantic model. + + Args: + model (Type[BaseModel]): The Pydantic model to convert. + + Returns: + ActionNode: The root node of the created ActionNode tree. + """ + key = key or model.__name__ + root_node = cls(key=key, expected_type=Type[model], instruction="", example="") + + for field_name, field_info in model.model_fields.items(): + field_type = field_info.annotation + description = field_info.description + default = field_info.default + + # Recursively handle nested models if needed + if not isinstance(field_type, typing._GenericAlias) and issubclass(field_type, BaseModel): + child_node = cls.from_pydantic(field_type, key=field_name) + else: + child_node = cls(key=field_name, expected_type=field_type, instruction=description, example=default) + + root_node.add_child(child_node) + + return root_node + + @staticmethod + def is_optional_type(tp) -> bool: + """Return True if `tp` is `typing.Optional[...]`""" + if typing.get_origin(tp) is Union: + args = typing.get_args(tp) + non_none_types = [arg for arg in args if arg is not type(None)] + return len(non_none_types) == 1 and len(args) == 2 + return False diff --git a/metagpt/actions/action_outcls_registry.py b/metagpt/actions/action_outcls_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..6baa4cea926a80251ace6ddfc28f745482bddcdf --- /dev/null +++ b/metagpt/actions/action_outcls_registry.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : registry to store Dynamic Model from ActionNode.create_model_class to keep it as same Class +# with same class name and mapping + +from functools import wraps + +action_outcls_registry = dict() + + +def register_action_outcls(func): + """ + Due to `create_model` return different Class even they have same class name and mapping. + In order to do a comparison, use outcls_id to identify same Class with same class name and field definition + """ + + @wraps(func) + def decorater(*args, **kwargs): + """ + arr example + [, 'test', {'field': (str, Ellipsis)}] + """ + arr = list(args) + list(kwargs.values()) + """ + outcls_id example + "_test_{'field': (str, Ellipsis)}" + """ + for idx, item in enumerate(arr): + if isinstance(item, dict): + arr[idx] = dict(sorted(item.items())) + outcls_id = "_".join([str(i) for i in arr]) + # eliminate typing influence + outcls_id = outcls_id.replace("typing.List", "list").replace("typing.Dict", "dict") + + if outcls_id in action_outcls_registry: + return action_outcls_registry[outcls_id] + + out_cls = func(*args, **kwargs) + action_outcls_registry[outcls_id] = out_cls + return out_cls + + return decorater diff --git a/metagpt/actions/action_output.py b/metagpt/actions/action_output.py new file mode 100644 index 0000000000000000000000000000000000000000..6be8dac50e4fd6f6bfeb37aab23b3405d6b18814 --- /dev/null +++ b/metagpt/actions/action_output.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# coding: utf-8 +""" +@Time : 2023/7/11 10:03 +@Author : chengmaoyu +@File : action_output +""" + +from pydantic import BaseModel + + +class ActionOutput: + content: str + instruct_content: BaseModel + + def __init__(self, content: str, instruct_content: BaseModel): + self.content = content + self.instruct_content = instruct_content diff --git a/metagpt/actions/add_requirement.py b/metagpt/actions/add_requirement.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a489b2c10dd0ad9bab1ad90aa591ad42356a7 --- /dev/null +++ b/metagpt/actions/add_requirement.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/20 17:46 +@Author : alexanderwu +@File : add_requirement.py +""" +from metagpt.actions import Action + + +class UserRequirement(Action): + """User Requirement without any implementation details""" diff --git a/metagpt/actions/analyze_requirements.py b/metagpt/actions/analyze_requirements.py new file mode 100644 index 0000000000000000000000000000000000000000..86088d824eb251d915d4dc9ecf337d7ed2046a47 --- /dev/null +++ b/metagpt/actions/analyze_requirements.py @@ -0,0 +1,76 @@ +from metagpt.actions import Action + +ANALYZE_REQUIREMENTS = """ +# Example +{examples} + +---------------- + +# Requirements +{requirements} + +# Instructions +{instructions} + +# Output Format +{output_format} + +Follow the instructions and output format. Do not include any additional content. +""" + +EXAMPLES = """ +Example 1 +Requirements: +创建一个贪吃蛇,只需要给出设计文档和代码 +Outputs: +[User Restrictions] : 只需要给出设计文档和代码. +[Language Restrictions] : The response, message and instruction must be in Chinese. +[Programming Language] : HTML (*.html), CSS (*.css), and JavaScript (*.js) + +Example 2 +Requirements: +Create 2048 game using Python. Do not write PRD. +Outputs: +[User Restrictions] : Do not write PRD. +[Language Restrictions] : The response, message and instruction must be in English. +[Programming Language] : Python + +Example 3 +Requirements: +You must ignore create PRD and TRD. Help me write a schedule display program for the Paris Olympics. +Outputs: +[User Restrictions] : You must ignore create PRD and TRD. +[Language Restrictions] : The response, message and instruction must be in English. +[Programming Language] : HTML (*.html), CSS (*.css), and JavaScript (*.js) +""" + +INSTRUCTIONS = """ +You must output in the same language as the Requirements. +First, This language should be consistent with the language used in the requirement description. determine the natural language you must respond in. If the requirements specify a special language, follow those instructions. The default language for responses is English. +Second, extract the restrictions in the requirements, specifically the steps. Do not include detailed demand descriptions; focus only on the restrictions. +Third, if the requirements is a software development, extract the program language. If no specific programming language is required, Use HTML (*.html), CSS (*.css), and JavaScript (*.js) + +Note: +1. if there is not restrictions, requirements_restrictions must be "" +2. if the requirements is a not software development, programming language must be "" +""" + +OUTPUT_FORMAT = """ +[User Restrictions] : the restrictions in the requirements +[Language Restrictions] : The response, message and instruction must be in {{language}} +[Programming Language] : Your program must use ... +""" + + +class AnalyzeRequirementsRestrictions(Action): + """Write a review for the given context.""" + + name: str = "AnalyzeRequirementsRestrictions" + + async def run(self, requirements, isinstance=INSTRUCTIONS, output_format=OUTPUT_FORMAT): + """Analyze the constraints and the language used in the requirements.""" + prompt = ANALYZE_REQUIREMENTS.format( + examples=EXAMPLES, requirements=requirements, instructions=isinstance, output_format=output_format + ) + rsp = await self.llm.aask(prompt) + return rsp diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py new file mode 100644 index 0000000000000000000000000000000000000000..8f0f52266fd3e8c89b805799fcdc72429397c1c9 --- /dev/null +++ b/metagpt/actions/debug_error.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:46 +@Author : alexanderwu +@File : debug_error.py +@Modified By: mashenquan, 2023/11/27. + 1. Divide the context into three components: legacy code, unit test code, and console log. + 2. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. +""" +import re +from typing import Optional + +from pydantic import BaseModel, Field + +from metagpt.actions.action import Action +from metagpt.logs import logger +from metagpt.schema import RunCodeContext, RunCodeResult +from metagpt.utils.common import CodeParser +from metagpt.utils.project_repo import ProjectRepo + +PROMPT_TEMPLATE = """ +NOTICE +1. Role: You are a Development Engineer or QA engineer; +2. Task: You received this message from another Development Engineer or QA engineer who ran or tested your code. +Based on the message, first, figure out your own role, i.e. Engineer or QaEngineer, +then rewrite the development code or the test code based on your role, the error, and the summary, such that all bugs are fixed and the code performs well. +Attention: Use '##' to split sections, not '#', and '## ' SHOULD WRITE BEFORE the test case or script and triple quotes. +The message is as follows: +# Legacy Code +```python +{code} +``` +--- +# Unit Test Code +```python +{test_code} +``` +--- +# Console logs +```text +{logs} +``` +--- +Now you should start rewriting the code: +## file name of the code to rewrite: Write code with triple quote. Do your best to implement THIS IN ONLY ONE FILE. +""" + + +class DebugError(Action): + i_context: RunCodeContext = Field(default_factory=RunCodeContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + + async def run(self, *args, **kwargs) -> str: + output_doc = await self.repo.test_outputs.get(filename=self.i_context.output_filename) + if not output_doc: + return "" + output_detail = RunCodeResult.loads(output_doc.content) + pattern = r"Ran (\d+) tests in ([\d.]+)s\n\nOK" + matches = re.search(pattern, output_detail.stderr) + if matches: + return "" + + logger.info(f"Debug and rewrite {self.i_context.test_filename}") + code_doc = await self.repo.srcs.get(filename=self.i_context.code_filename) + if not code_doc: + return "" + test_doc = await self.repo.tests.get(filename=self.i_context.test_filename) + if not test_doc: + return "" + prompt = PROMPT_TEMPLATE.format(code=code_doc.content, test_code=test_doc.content, logs=output_detail.stderr) + + rsp = await self._aask(prompt) + code = CodeParser.parse_code(text=rsp) + + return code diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py new file mode 100644 index 0000000000000000000000000000000000000000..68a66d5a490ceea55b3973c3e9a747e3d05dee4b --- /dev/null +++ b/metagpt/actions/design_api.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 19:26 +@Author : alexanderwu +@File : design_api.py +@Modified By: mashenquan, 2023/11/27. + 1. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. + 2. According to the design in Section 2.2.3.5.3 of RFC 135, add incremental iteration functionality. +@Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. +""" +import json +from pathlib import Path +from typing import List, Optional, Union + +from pydantic import BaseModel, Field + +from metagpt.actions import Action +from metagpt.actions.design_api_an import ( + DATA_STRUCTURES_AND_INTERFACES, + DESIGN_API_NODE, + PROGRAM_CALL_FLOW, + REFINED_DATA_STRUCTURES_AND_INTERFACES, + REFINED_DESIGN_NODE, + REFINED_PROGRAM_CALL_FLOW, +) +from metagpt.const import DATA_API_DESIGN_FILE_REPO, SEQ_FLOW_FILE_REPO +from metagpt.logs import logger +from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import ( + aread, + awrite, + rectify_pathname, + save_json_to_markdown, + to_markdown_code_block, +) +from metagpt.utils.mermaid import mermaid_to_file +from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.report import DocsReporter, GalleryReporter + +NEW_REQ_TEMPLATE = """ +### Legacy Content +{old_design} + +### New Requirements +{context} +""" + + +@register_tool(include_functions=["run"]) +class WriteDesign(Action): + name: str = "" + i_context: Optional[str] = None + desc: str = ( + "Based on the PRD, think about the system design, and design the corresponding APIs, " + "data structures, library tables, processes, and paths. Please provide your design, feedback " + "clearly and in detail." + ) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + + async def run( + self, + with_messages: List[Message] = None, + *, + user_requirement: str = "", + prd_filename: str = "", + legacy_design_filename: str = "", + extra_info: str = "", + output_pathname: str = "", + **kwargs, + ) -> Union[AIMessage, str]: + """ + Write a system design. + + Args: + user_requirement (str): The user's requirements for the system design. + prd_filename (str, optional): The filename of the Product Requirement Document (PRD). + legacy_design_filename (str, optional): The filename of the legacy design document. + extra_info (str, optional): Additional information to be included in the system design. + output_pathname (str, optional): The output file path of the document. + + Returns: + str: The file path of the generated system design. + + Example: + # Write a new system design and save to the path name. + >>> user_requirement = "Write system design for a snake game" + >>> extra_info = "Your extra information" + >>> output_pathname = "snake_game/docs/system_design.json" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, output_pathname=output_pathname) + >>> print(result) + System Design filename: "/absolute/path/to/snake_game/docs/system_design.json" + + # Rewrite an existing system design and save to the path name. + >>> user_requirement = "Write system design for a snake game, include new features such as a web UI" + >>> extra_info = "Your extra information" + >>> legacy_design_filename = "/absolute/path/to/snake_game/docs/system_design.json" + >>> output_pathname = "/absolute/path/to/snake_game/docs/system_design_new.json" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, legacy_design_filename=legacy_design_filename, output_pathname=output_pathname) + >>> print(result) + System Design filename: "/absolute/path/to/snake_game/docs/system_design_new.json" + + # Write a new system design with the given PRD(Product Requirement Document) and save to the path name. + >>> user_requirement = "Write system design for a snake game based on the PRD at /absolute/path/to/snake_game/docs/prd.json" + >>> extra_info = "Your extra information" + >>> prd_filename = "/absolute/path/to/snake_game/docs/prd.json" + >>> output_pathname = "/absolute/path/to/snake_game/docs/sytem_design.json" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename, output_pathname=output_pathname) + >>> print(result) + System Design filename: "/absolute/path/to/snake_game/docs/sytem_design.json" + + # Rewrite an existing system design with the given PRD(Product Requirement Document) and save to the path name. + >>> user_requirement = "Write system design for a snake game, include new features such as a web UI" + >>> extra_info = "Your extra information" + >>> prd_filename = "/absolute/path/to/snake_game/docs/prd.json" + >>> legacy_design_filename = "/absolute/path/to/snake_game/docs/system_design.json" + >>> output_pathname = "/absolute/path/to/snake_game/docs/system_design_new.json" + >>> action = WriteDesign() + >>> result = await action.run(user_requirement=user_requirement, extra_info=extra_info, prd_filename=prd_filename, legacy_design_filename=legacy_design_filename, output_pathname=output_pathname) + >>> print(result) + System Design filename: "/absolute/path/to/snake_game/docs/system_design_new.json" + """ + if not with_messages: + return await self._execute_api( + user_requirement=user_requirement, + prd_filename=prd_filename, + legacy_design_filename=legacy_design_filename, + extra_info=extra_info, + output_pathname=output_pathname, + ) + + self.input_args = with_messages[-1].instruct_content + self.repo = ProjectRepo(self.input_args.project_path) + changed_prds = self.input_args.changed_prd_filenames + changed_system_designs = [ + str(self.repo.docs.system_design.workdir / i) + for i in list(self.repo.docs.system_design.changed_files.keys()) + ] + + # For those PRDs and design documents that have undergone changes, regenerate the design content. + changed_files = Documents() + for filename in changed_prds: + doc = await self._update_system_design(filename=filename) + changed_files.docs[filename] = doc + + for filename in changed_system_designs: + if filename in changed_files.docs: + continue + doc = await self._update_system_design(filename=filename) + changed_files.docs[filename] = doc + if not changed_files.docs: + logger.info("Nothing has changed.") + # Wait until all files under `docs/system_designs/` are processed before sending the publish message, + # leaving room for global optimization in subsequent steps. + kvs = self.input_args.model_dump() + kvs["changed_system_design_filenames"] = [ + str(self.repo.docs.system_design.workdir / i) + for i in list(self.repo.docs.system_design.changed_files.keys()) + ] + return AIMessage( + content="Designing is complete. " + + "\n".join( + list(self.repo.docs.system_design.changed_files.keys()) + + list(self.repo.resources.data_api_design.changed_files.keys()) + + list(self.repo.resources.seq_flow.changed_files.keys()) + ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteDesignOutput"), + cause_by=self, + ) + + async def _new_system_design(self, context): + node = await DESIGN_API_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema) + return node + + async def _merge(self, prd_doc, system_design_doc): + context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content) + node = await REFINED_DESIGN_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema) + system_design_doc.content = node.instruct_content.model_dump_json() + return system_design_doc + + async def _update_system_design(self, filename) -> Document: + root_relative_path = Path(filename).relative_to(self.repo.workdir) + prd = await Document.load(filename=filename, project_path=self.repo.workdir) + old_system_design_doc = await self.repo.docs.system_design.get(root_relative_path.name) + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "design"}, "meta") + if not old_system_design_doc: + system_design = await self._new_system_design(context=prd.content) + doc = await self.repo.docs.system_design.save( + filename=prd.filename, + content=system_design.instruct_content.model_dump_json(), + dependencies={prd.root_relative_path}, + ) + else: + doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc) + await self.repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path}) + await self._save_data_api_design(doc) + await self._save_seq_flow(doc) + md = await self.repo.resources.system_design.save_pdf(doc=doc) + await reporter.async_report(self.repo.workdir / md.root_relative_path, "path") + return doc + + async def _save_data_api_design(self, design_doc, output_filename: Path = None): + m = json.loads(design_doc.content) + data_api_design = m.get(DATA_STRUCTURES_AND_INTERFACES.key) or m.get(REFINED_DATA_STRUCTURES_AND_INTERFACES.key) + if not data_api_design: + return + pathname = output_filename or self.repo.workdir / DATA_API_DESIGN_FILE_REPO / Path( + design_doc.filename + ).with_suffix("") + await self._save_mermaid_file(data_api_design, pathname) + logger.info(f"Save class view to {str(pathname)}") + + async def _save_seq_flow(self, design_doc, output_filename: Path = None): + m = json.loads(design_doc.content) + seq_flow = m.get(PROGRAM_CALL_FLOW.key) or m.get(REFINED_PROGRAM_CALL_FLOW.key) + if not seq_flow: + return + pathname = output_filename or self.repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path( + design_doc.filename + ).with_suffix("") + await self._save_mermaid_file(seq_flow, pathname) + logger.info(f"Saving sequence flow to {str(pathname)}") + + async def _save_mermaid_file(self, data: str, pathname: Path): + pathname.parent.mkdir(parents=True, exist_ok=True) + await mermaid_to_file(self.config.mermaid.engine, data, pathname) + image_path = pathname.parent / f"{pathname.name}.svg" + if image_path.exists(): + await GalleryReporter().async_report(image_path, "path") + + async def _execute_api( + self, + user_requirement: str = "", + prd_filename: str = "", + legacy_design_filename: str = "", + extra_info: str = "", + output_pathname: str = "", + ) -> str: + prd_content = "" + if prd_filename: + prd_filename = rectify_pathname(path=prd_filename, default_filename="prd.json") + prd_content = await aread(filename=prd_filename) + context = "### User Requirements\n{user_requirement}\n### Extra_info\n{extra_info}\n### PRD\n{prd}\n".format( + user_requirement=to_markdown_code_block(user_requirement), + extra_info=to_markdown_code_block(extra_info), + prd=to_markdown_code_block(prd_content), + ) + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "design"}, "meta") + if not legacy_design_filename: + node = await self._new_system_design(context=context) + design = Document(content=node.instruct_content.model_dump_json()) + else: + old_design_content = await aread(filename=legacy_design_filename) + design = await self._merge( + prd_doc=Document(content=context), system_design_doc=Document(content=old_design_content) + ) + + if not output_pathname: + output_pathname = Path(output_pathname) / "docs" / "system_design.json" + elif not Path(output_pathname).is_absolute(): + output_pathname = self.config.workspace.path / output_pathname + output_pathname = rectify_pathname(path=output_pathname, default_filename="system_design.json") + await awrite(filename=output_pathname, data=design.content) + output_filename = output_pathname.parent / f"{output_pathname.stem}-class-diagram" + await self._save_data_api_design(design_doc=design, output_filename=output_filename) + output_filename = output_pathname.parent / f"{output_pathname.stem}-sequence-diagram" + await self._save_seq_flow(design_doc=design, output_filename=output_filename) + md_output_filename = output_pathname.with_suffix(".md") + await save_json_to_markdown(content=design.content, output_filename=md_output_filename) + await reporter.async_report(md_output_filename, "path") + return f'System Design filename: "{str(output_pathname)}". \n The System Design has been completed.' diff --git a/metagpt/actions/design_api_an.py b/metagpt/actions/design_api_an.py new file mode 100644 index 0000000000000000000000000000000000000000..4707b53536ac57a2c4d5504468d9844d8a9a06ca --- /dev/null +++ b/metagpt/actions/design_api_an.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/12 22:24 +@Author : alexanderwu +@File : design_api_an.py +""" +from typing import List, Optional + +from metagpt.actions.action_node import ActionNode +from metagpt.utils.mermaid import MMC1, MMC2 + +IMPLEMENTATION_APPROACH = ActionNode( + key="Implementation approach", + expected_type=str, + instruction="Analyze the difficult points of the requirements, select the appropriate open-source framework.", + example="We will ...", +) + +REFINED_IMPLEMENTATION_APPROACH = ActionNode( + key="Refined Implementation Approach", + expected_type=str, + instruction="Update and extend the original implementation approach to reflect the evolving challenges and " + "requirements due to incremental development. Outline the steps involved in the implementation process with the " + "detailed strategies.", + example="We will refine ...", +) + +PROJECT_NAME = ActionNode( + key="Project name", expected_type=str, instruction="The project name with underline", example="game_2048" +) + +FILE_LIST = ActionNode( + key="File list", + expected_type=List[str], + instruction="Only need relative paths. Succinctly designate the correct entry file for your project based on the programming language: use main.js for JavaScript, main.py for Python, and so on for other languages.", + example=["a.js", "b.py", "c.css", "d.html"], +) + +REFINED_FILE_LIST = ActionNode( + key="Refined File list", + expected_type=List[str], + instruction="Update and expand the original file list including only relative paths. Up to 2 files can be added." + "Ensure that the refined file list reflects the evolving structure of the project.", + example=["main.py", "game.py", "new_feature.py"], +) + +# optional,because low success reproduction of class diagram in non py project. +DATA_STRUCTURES_AND_INTERFACES = ActionNode( + key="Data structures and interfaces", + expected_type=Optional[str], + instruction="Use mermaid classDiagram code syntax, including classes, method(__init__ etc.) and functions with type" + " annotations, CLEARLY MARK the RELATIONSHIPS between classes, and comply with PEP8 standards. " + "The data structures SHOULD BE VERY DETAILED and the API should be comprehensive with a complete design.", + example=MMC1, +) + +REFINED_DATA_STRUCTURES_AND_INTERFACES = ActionNode( + key="Refined Data structures and interfaces", + expected_type=str, + instruction="Update and extend the existing mermaid classDiagram code syntax to incorporate new classes, " + "methods (including __init__), and functions with precise type annotations. Delineate additional " + "relationships between classes, ensuring clarity and adherence to PEP8 standards." + "Retain content that is not related to incremental development but important for consistency and clarity.", + example=MMC1, +) + +PROGRAM_CALL_FLOW = ActionNode( + key="Program call flow", + expected_type=Optional[str], + instruction="Use sequenceDiagram code syntax, COMPLETE and VERY DETAILED, using CLASSES AND API DEFINED ABOVE " + "accurately, covering the CRUD AND INIT of each object, SYNTAX MUST BE CORRECT.", + example=MMC2, +) + +REFINED_PROGRAM_CALL_FLOW = ActionNode( + key="Refined Program call flow", + expected_type=str, + instruction="Extend the existing sequenceDiagram code syntax with detailed information, accurately covering the" + "CRUD and initialization of each object. Ensure correct syntax usage and reflect the incremental changes introduced" + "in the classes and API defined above. " + "Retain content that is not related to incremental development but important for consistency and clarity.", + example=MMC2, +) + +ANYTHING_UNCLEAR = ActionNode( + key="Anything UNCLEAR", + expected_type=str, + instruction="Mention unclear project aspects, then try to clarify it.", + example="Clarification needed on third-party API integration, ...", +) + +NODES = [ + IMPLEMENTATION_APPROACH, + # PROJECT_NAME, + FILE_LIST, + DATA_STRUCTURES_AND_INTERFACES, + PROGRAM_CALL_FLOW, + ANYTHING_UNCLEAR, +] + +REFINED_NODES = [ + REFINED_IMPLEMENTATION_APPROACH, + REFINED_FILE_LIST, + REFINED_DATA_STRUCTURES_AND_INTERFACES, + REFINED_PROGRAM_CALL_FLOW, + ANYTHING_UNCLEAR, +] + +DESIGN_API_NODE = ActionNode.from_children("DesignAPI", NODES) +REFINED_DESIGN_NODE = ActionNode.from_children("RefinedDesignAPI", REFINED_NODES) diff --git a/metagpt/actions/design_api_review.py b/metagpt/actions/design_api_review.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd01a4c321f381286c1b95f0d153245d23c7e66 --- /dev/null +++ b/metagpt/actions/design_api_review.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 19:31 +@Author : alexanderwu +@File : design_api_review.py +""" + +from typing import Optional + +from metagpt.actions.action import Action + + +class DesignReview(Action): + name: str = "DesignReview" + i_context: Optional[str] = None + + async def run(self, prd, api_design): + prompt = ( + f"Here is the Product Requirement Document (PRD):\n\n{prd}\n\nHere is the list of APIs designed " + f"based on this PRD:\n\n{api_design}\n\nPlease review whether this API design meets the requirements" + f" of the PRD, and whether it complies with good design practices." + ) + + api_review = await self._aask(prompt) + return api_review diff --git a/metagpt/actions/di/__init__.py b/metagpt/actions/di/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/metagpt/actions/di/__pycache__/__init__.cpython-310.pyc b/metagpt/actions/di/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..208fac966604e5b5fcdcb03c2a04b448cd40efe1 Binary files /dev/null and b/metagpt/actions/di/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/actions/di/__pycache__/__init__.cpython-39.pyc b/metagpt/actions/di/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bebaa1313ebc897211ec25a61982412bfbebe5e Binary files /dev/null and b/metagpt/actions/di/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/actions/di/__pycache__/ask_review.cpython-310.pyc b/metagpt/actions/di/__pycache__/ask_review.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd52c441984b0356de7ca77e0110493c46959c94 Binary files /dev/null and b/metagpt/actions/di/__pycache__/ask_review.cpython-310.pyc differ diff --git a/metagpt/actions/di/__pycache__/ask_review.cpython-39.pyc b/metagpt/actions/di/__pycache__/ask_review.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d40598441b94d4853d3518295c477a87c185def6 Binary files /dev/null and b/metagpt/actions/di/__pycache__/ask_review.cpython-39.pyc differ diff --git a/metagpt/actions/di/__pycache__/execute_nb_code.cpython-310.pyc b/metagpt/actions/di/__pycache__/execute_nb_code.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9eddd7b4686b4ac290a1d3b5682f3fd5139e3532 Binary files /dev/null and b/metagpt/actions/di/__pycache__/execute_nb_code.cpython-310.pyc differ diff --git a/metagpt/actions/di/__pycache__/execute_nb_code.cpython-39.pyc b/metagpt/actions/di/__pycache__/execute_nb_code.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c766f48d35dfc7e2488a9419993e3219eff78e9 Binary files /dev/null and b/metagpt/actions/di/__pycache__/execute_nb_code.cpython-39.pyc differ diff --git a/metagpt/actions/di/__pycache__/run_command.cpython-310.pyc b/metagpt/actions/di/__pycache__/run_command.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b45a98d727054e4db7fb8aa28b9ab6a72d895c0 Binary files /dev/null and b/metagpt/actions/di/__pycache__/run_command.cpython-310.pyc differ diff --git a/metagpt/actions/di/__pycache__/run_command.cpython-39.pyc b/metagpt/actions/di/__pycache__/run_command.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c09532d15a6887a2b00609c3992fec7d926101d Binary files /dev/null and b/metagpt/actions/di/__pycache__/run_command.cpython-39.pyc differ diff --git a/metagpt/actions/di/__pycache__/write_analysis_code.cpython-310.pyc b/metagpt/actions/di/__pycache__/write_analysis_code.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1db1289c1f16fc4a61214fb6c42a36b310cc446 Binary files /dev/null and b/metagpt/actions/di/__pycache__/write_analysis_code.cpython-310.pyc differ diff --git a/metagpt/actions/di/__pycache__/write_analysis_code.cpython-39.pyc b/metagpt/actions/di/__pycache__/write_analysis_code.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..685709846de2c8c1634c4a2cb396b01d7da478ea Binary files /dev/null and b/metagpt/actions/di/__pycache__/write_analysis_code.cpython-39.pyc differ diff --git a/metagpt/actions/di/__pycache__/write_plan.cpython-310.pyc b/metagpt/actions/di/__pycache__/write_plan.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..748a74a300f5774f44aa0a036aeb46f006874575 Binary files /dev/null and b/metagpt/actions/di/__pycache__/write_plan.cpython-310.pyc differ diff --git a/metagpt/actions/di/__pycache__/write_plan.cpython-39.pyc b/metagpt/actions/di/__pycache__/write_plan.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8a2fd597baa074e82cd02d93a31394ad735f73d Binary files /dev/null and b/metagpt/actions/di/__pycache__/write_plan.cpython-39.pyc differ diff --git a/metagpt/actions/di/ask_review.py b/metagpt/actions/di/ask_review.py new file mode 100644 index 0000000000000000000000000000000000000000..ecbbd992ea3a12a417d57d9da121a71202076ce4 --- /dev/null +++ b/metagpt/actions/di/ask_review.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import Tuple + +from metagpt.actions import Action +from metagpt.logs import get_human_input, logger +from metagpt.schema import Message, Plan + + +class ReviewConst: + TASK_REVIEW_TRIGGER = "task" + CODE_REVIEW_TRIGGER = "code" + CONTINUE_WORDS = ["confirm", "continue", "c", "yes", "y"] + CHANGE_WORDS = ["change"] + EXIT_WORDS = ["exit"] + TASK_REVIEW_INSTRUCTION = ( + f"If you want to change, add, delete a task or merge tasks in the plan, say '{CHANGE_WORDS[0]} task task_id or current task, ... (things to change)' " + f"If you confirm the output from the current task and wish to continue, type: {CONTINUE_WORDS[0]}" + ) + CODE_REVIEW_INSTRUCTION = ( + f"If you want the codes to be rewritten, say '{CHANGE_WORDS[0]} ... (your change advice)' " + f"If you want to leave it as is, type: {CONTINUE_WORDS[0]} or {CONTINUE_WORDS[1]}" + ) + EXIT_INSTRUCTION = f"If you want to terminate the process, type: {EXIT_WORDS[0]}" + + +class AskReview(Action): + async def run( + self, context: list[Message] = [], plan: Plan = None, trigger: str = ReviewConst.TASK_REVIEW_TRIGGER + ) -> Tuple[str, bool]: + if plan: + logger.info("Current overall plan:") + logger.info( + "\n".join( + [f"{task.task_id}: {task.instruction}, is_finished: {task.is_finished}" for task in plan.tasks] + ) + ) + + logger.info("Most recent context:") + latest_action = context[-1].cause_by if context and context[-1].cause_by else "" + review_instruction = ( + ReviewConst.TASK_REVIEW_INSTRUCTION + if trigger == ReviewConst.TASK_REVIEW_TRIGGER + else ReviewConst.CODE_REVIEW_INSTRUCTION + ) + prompt = ( + f"This is a <{trigger}> review. Please review output from {latest_action}\n" + f"{review_instruction}\n" + f"{ReviewConst.EXIT_INSTRUCTION}\n" + "Please type your review below:\n" + ) + + rsp = await get_human_input(prompt) + + if rsp.lower() in ReviewConst.EXIT_WORDS: + exit() + + # Confirmation can be one of "confirm", "continue", "c", "yes", "y" exactly, or sentences containing "confirm". + # One could say "confirm this task, but change the next task to ..." + confirmed = rsp.lower() in ReviewConst.CONTINUE_WORDS or ReviewConst.CONTINUE_WORDS[0] in rsp.lower() + + return rsp, confirmed diff --git a/metagpt/actions/di/execute_nb_code.py b/metagpt/actions/di/execute_nb_code.py new file mode 100644 index 0000000000000000000000000000000000000000..01019b49312d161dd08066d4faa5dd8b2e9733b5 --- /dev/null +++ b/metagpt/actions/di/execute_nb_code.py @@ -0,0 +1,328 @@ +# -*- encoding: utf-8 -*- +""" +@Date : 2023/11/17 14:22:15 +@Author : orange-crow +@File : execute_nb_code.py +""" +from __future__ import annotations + +import asyncio +import base64 +import re +from typing import Literal, Tuple + +import nbformat +from nbclient import NotebookClient +from nbclient.exceptions import CellExecutionComplete, CellTimeoutError, DeadKernelError +from nbclient.util import ensure_async +from nbformat import NotebookNode +from nbformat.v4 import new_code_cell, new_markdown_cell, new_output, output_from_msg +from rich.box import MINIMAL +from rich.console import Console, Group +from rich.live import Live +from rich.markdown import Markdown +from rich.panel import Panel +from rich.syntax import Syntax + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.utils.report import NotebookReporter + +INSTALL_KEEPLEN = 500 +INI_CODE = """import warnings +import logging + +root_logger = logging.getLogger() +root_logger.setLevel(logging.ERROR) +warnings.filterwarnings('ignore')""" + + +class RealtimeOutputNotebookClient(NotebookClient): + """Realtime output of Notebook execution.""" + + def __init__(self, *args, notebook_reporter=None, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.notebook_reporter = notebook_reporter or NotebookReporter() + + async def _async_poll_output_msg(self, parent_msg_id: str, cell: NotebookNode, cell_index: int) -> None: + """Implement a feature to enable sending messages.""" + assert self.kc is not None + while True: + msg = await ensure_async(self.kc.iopub_channel.get_msg(timeout=None)) + await self._send_msg(msg) + + if msg["parent_header"].get("msg_id") == parent_msg_id: + try: + # Will raise CellExecutionComplete when completed + self.process_message(msg, cell, cell_index) + except CellExecutionComplete: + return + + async def _send_msg(self, msg: dict): + msg_type = msg.get("header", {}).get("msg_type") + if msg_type not in ["stream", "error", "execute_result"]: + return + + await self.notebook_reporter.async_report(output_from_msg(msg), "content") + + +class ExecuteNbCode(Action): + """execute notebook code block, return result to llm, and display it.""" + + nb: NotebookNode + nb_client: RealtimeOutputNotebookClient = None + console: Console + interaction: str + timeout: int = 600 + + def __init__(self, nb=nbformat.v4.new_notebook(), timeout=600): + super().__init__( + nb=nb, + timeout=timeout, + console=Console(), + interaction=("ipython" if self.is_ipython() else "terminal"), + ) + self.reporter = NotebookReporter() + self.set_nb_client() + self.init_called = False + + async def init_code(self): + if not self.init_called: + await self.run(INI_CODE) + self.init_called = True + + def set_nb_client(self): + self.nb_client = RealtimeOutputNotebookClient( + self.nb, + timeout=self.timeout, + resources={"metadata": {"path": self.config.workspace.path}}, + notebook_reporter=self.reporter, + coalesce_streams=True, + ) + + async def build(self): + if self.nb_client.kc is None or not await self.nb_client.kc.is_alive(): + self.nb_client.create_kernel_manager() + self.nb_client.start_new_kernel() + self.nb_client.start_new_kernel_client() + + async def terminate(self): + """kill NotebookClient""" + if self.nb_client.km is not None and await self.nb_client.km.is_alive(): + await self.nb_client.km.shutdown_kernel(now=True) + await self.nb_client.km.cleanup_resources() + + channels = [ + self.nb_client.kc.stdin_channel, # The channel for handling standard input to the kernel. + self.nb_client.kc.hb_channel, # The channel for heartbeat communication between the kernel and client. + self.nb_client.kc.control_channel, # The channel for controlling the kernel. + ] + + # Stops all the running channels for this kernel + for channel in channels: + if channel.is_alive(): + channel.stop() + + self.nb_client.kc = None + self.nb_client.km = None + + async def reset(self): + """reset NotebookClient""" + await self.terminate() + + # sleep 1s to wait for the kernel to be cleaned up completely + await asyncio.sleep(1) + await self.build() + self.set_nb_client() + + def add_code_cell(self, code: str): + self.nb.cells.append(new_code_cell(source=code)) + + def add_markdown_cell(self, markdown: str): + self.nb.cells.append(new_markdown_cell(source=markdown)) + + def _display(self, code: str, language: Literal["python", "markdown"] = "python"): + if language == "python": + code = Syntax(code, "python", theme="paraiso-dark", line_numbers=True) + self.console.print(code) + elif language == "markdown": + display_markdown(code) + else: + raise ValueError(f"Only support for python, markdown, but got {language}") + + def add_output_to_cell(self, cell: NotebookNode, output: str): + """add outputs of code execution to notebook cell.""" + if "outputs" not in cell: + cell["outputs"] = [] + else: + cell["outputs"].append(new_output(output_type="stream", name="stdout", text=str(output))) + + def parse_outputs(self, outputs: list[str], keep_len: int = 5000) -> Tuple[bool, str]: + """Parses the outputs received from notebook execution.""" + assert isinstance(outputs, list) + parsed_output, is_success = [], True + for i, output in enumerate(outputs): + output_text = "" + if output["output_type"] == "stream" and not any( + tag in output["text"] + for tag in ["| INFO | metagpt", "| ERROR | metagpt", "| WARNING | metagpt", "DEBUG"] + ): + output_text = output["text"] + elif output["output_type"] == "display_data": + if "image/png" in output["data"]: + self.show_bytes_figure(output["data"]["image/png"], self.interaction) + else: + logger.info( + f"{i}th output['data'] from nbclient outputs dont have image/png, continue next output ..." + ) + elif output["output_type"] == "execute_result": + output_text = output["data"]["text/plain"] + elif output["output_type"] == "error": + output_text, is_success = "\n".join(output["traceback"]), False + + # handle coroutines that are not executed asynchronously + if output_text.strip().startswith("" not in output_text: + output_text = output_text[:keep_len] if is_success else output_text[-keep_len:] + + parsed_output.append(output_text) + return is_success, ",".join(parsed_output) + + def show_bytes_figure(self, image_base64: str, interaction_type: Literal["ipython", None]): + image_bytes = base64.b64decode(image_base64) + if interaction_type == "ipython": + from IPython.display import Image, display + + display(Image(data=image_bytes)) + else: + import io + + from PIL import Image + + image = Image.open(io.BytesIO(image_bytes)) + image.show() + + def is_ipython(self) -> bool: + try: + # 如果在Jupyter Notebook中运行,__file__ 变量不存在 + from IPython import get_ipython + + if get_ipython() is not None and "IPKernelApp" in get_ipython().config: + return True + else: + return False + except NameError: + return False + + async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str]: + """set timeout for run code. + returns the success or failure of the cell execution, and an optional error message. + """ + await self.reporter.async_report(cell, "content") + + try: + await self.nb_client.async_execute_cell(cell, cell_index) + return self.parse_outputs(self.nb.cells[-1].outputs) + except CellTimeoutError: + assert self.nb_client.km is not None + await self.nb_client.km.interrupt_kernel() + await asyncio.sleep(1) + error_msg = "Cell execution timed out: Execution exceeded the time limit and was stopped; consider optimizing your code for better performance." + return False, error_msg + except DeadKernelError: + await self.reset() + return False, "DeadKernelError" + except Exception: + return self.parse_outputs(self.nb.cells[-1].outputs) + + async def run(self, code: str, language: Literal["python", "markdown"] = "python") -> Tuple[str, bool]: + """ + return the output of code execution, and a success indicator (bool) of code execution. + """ + self._display(code, language) + + async with self.reporter: + if language == "python": + # add code to the notebook + self.add_code_cell(code=code) + + # build code executor + await self.build() + + # run code + cell_index = len(self.nb.cells) - 1 + success, outputs = await self.run_cell(self.nb.cells[-1], cell_index) + + if "!pip" in code: + success = False + outputs = outputs[-INSTALL_KEEPLEN:] + elif "git clone" in code: + outputs = outputs[:INSTALL_KEEPLEN] + "..." + outputs[-INSTALL_KEEPLEN:] + + elif language == "markdown": + # add markdown content to markdown cell in a notebook. + self.add_markdown_cell(code) + # return True, beacuse there is no execution failure for markdown cell. + outputs, success = code, True + else: + raise ValueError(f"Only support for language: python, markdown, but got {language}, ") + + file_path = self.config.workspace.path / "code.ipynb" + nbformat.write(self.nb, file_path) + await self.reporter.async_report(file_path, "path") + + return outputs, success + + +def remove_log_and_warning_lines(input_str: str) -> str: + delete_lines = ["[warning]", "warning:", "[cv]", "[info]"] + result = "\n".join( + [line for line in input_str.split("\n") if not any(dl in line.lower() for dl in delete_lines)] + ).strip() + return result + + +def remove_escape_and_color_codes(input_str: str): + # 使用正则表达式去除jupyter notebook输出结果中的转义字符和颜色代码 + # Use regular expressions to get rid of escape characters and color codes in jupyter notebook output. + pattern = re.compile(r"\x1b\[[0-9;]*[mK]") + result = pattern.sub("", input_str) + return result + + +def display_markdown(content: str): + # Use regular expressions to match blocks of code one by one. + matches = re.finditer(r"```(.+?)```", content, re.DOTALL) + start_index = 0 + content_panels = [] + # Set the text background color and text color. + style = "black on white" + # Print the matching text and code one by one. + for match in matches: + text_content = content[start_index : match.start()].strip() + code_content = match.group(0).strip()[3:-3] # Remove triple backticks + + if text_content: + content_panels.append(Panel(Markdown(text_content), style=style, box=MINIMAL)) + + if code_content: + content_panels.append(Panel(Markdown(f"```{code_content}"), style=style, box=MINIMAL)) + start_index = match.end() + + # Print remaining text (if any). + remaining_text = content[start_index:].strip() + if remaining_text: + content_panels.append(Panel(Markdown(remaining_text), style=style, box=MINIMAL)) + + # Display all panels in Live mode. + with Live(auto_refresh=False, console=Console(), vertical_overflow="visible") as live: + live.update(Group(*content_panels)) + live.refresh() diff --git a/metagpt/actions/di/run_command.py b/metagpt/actions/di/run_command.py new file mode 100644 index 0000000000000000000000000000000000000000..510bb5d9201fcb6ae3b88786f9c57ad53d5ec7e8 --- /dev/null +++ b/metagpt/actions/di/run_command.py @@ -0,0 +1,5 @@ +from metagpt.actions import Action + + +class RunCommand(Action): + """A dummy RunCommand action used as a symbol only""" diff --git a/metagpt/actions/di/write_analysis_code.py b/metagpt/actions/di/write_analysis_code.py new file mode 100644 index 0000000000000000000000000000000000000000..80e2c5ddce3c0871c052e69f9acb705d930bd5f1 --- /dev/null +++ b/metagpt/actions/di/write_analysis_code.py @@ -0,0 +1,74 @@ +# -*- encoding: utf-8 -*- +""" +@Date : 2023/11/20 13:19:39 +@Author : orange-crow +@File : write_analysis_code.py +""" +from __future__ import annotations + +from metagpt.actions import Action +from metagpt.prompts.di.write_analysis_code import ( + CHECK_DATA_PROMPT, + DEBUG_REFLECTION_EXAMPLE, + INTERPRETER_SYSTEM_MSG, + REFLECTION_PROMPT, + REFLECTION_SYSTEM_MSG, + STRUCTUAL_PROMPT, +) +from metagpt.schema import Message, Plan +from metagpt.utils.common import CodeParser, remove_comments + + +class WriteAnalysisCode(Action): + async def _debug_with_reflection(self, context: list[Message], working_memory: list[Message]): + reflection_prompt = REFLECTION_PROMPT.format( + debug_example=DEBUG_REFLECTION_EXAMPLE, + context=context, + previous_impl=working_memory, + ) + + rsp = await self._aask(reflection_prompt, system_msgs=[REFLECTION_SYSTEM_MSG]) + # reflection = json.loads(CodeParser.parse_code(block=None, text=rsp)) + # return reflection["improved_impl"] + reflection = CodeParser.parse_code(block=None, text=rsp) + return reflection + + async def run( + self, + user_requirement: str, + plan_status: str = "", + tool_info: str = "", + working_memory: list[Message] = None, + use_reflection: bool = False, + memory: list[Message] = None, + **kwargs, + ) -> str: + structual_prompt = STRUCTUAL_PROMPT.format( + user_requirement=user_requirement, + plan_status=plan_status, + tool_info=tool_info, + ) + + working_memory = working_memory or [] + memory = memory or [] + context = self.llm.format_msg(memory + [Message(content=structual_prompt, role="user")] + working_memory) + + # LLM call + if use_reflection: + code = await self._debug_with_reflection(context=context, working_memory=working_memory) + else: + rsp = await self.llm.aask(context, system_msgs=[INTERPRETER_SYSTEM_MSG], **kwargs) + code = CodeParser.parse_code(text=rsp, lang="python") + + return code + + +class CheckData(Action): + async def run(self, plan: Plan) -> dict: + finished_tasks = plan.get_finished_tasks() + code_written = [remove_comments(task.code) for task in finished_tasks] + code_written = "\n\n".join(code_written) + prompt = CHECK_DATA_PROMPT.format(code_written=code_written) + rsp = await self._aask(prompt) + code = CodeParser.parse_code(text=rsp) + return code diff --git a/metagpt/actions/di/write_plan.py b/metagpt/actions/di/write_plan.py new file mode 100644 index 0000000000000000000000000000000000000000..cf2e517121b03a72ac4ea743cbae414608b0ab09 --- /dev/null +++ b/metagpt/actions/di/write_plan.py @@ -0,0 +1,88 @@ +# -*- encoding: utf-8 -*- +""" +@Date : 2023/11/20 11:24:03 +@Author : orange-crow +@File : plan.py +""" +from __future__ import annotations + +import json +from copy import deepcopy +from typing import Tuple + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.schema import Message, Plan, Task +from metagpt.strategy.task_type import TaskType +from metagpt.utils.common import CodeParser + +PROMPT_TEMPLATE: str = """ +# Context: +{context} +# Available Task Types: +{task_type_desc} +# Task: +Based on the context, write a plan or modify an existing plan of what you should do to achieve the goal. A plan consists of one to {max_tasks} tasks. +If you are modifying an existing plan, carefully follow the instruction, don't make unnecessary changes. Give the whole plan unless instructed to modify only one task of the plan. +If you encounter errors on the current task, revise and output the current single task only. +Output a list of jsons following the format: +```json +[ + {{ + "task_id": str = "unique identifier for a task in plan, can be an ordinal", + "dependent_task_ids": list[str] = "ids of tasks prerequisite to this task", + "instruction": "what you should do in this task, one short phrase or sentence.", + "task_type": "type of this task, should be one of Available Task Types.", + }}, + ... +] +``` +""" + + +class WritePlan(Action): + async def run(self, context: list[Message], max_tasks: int = 5) -> str: + task_type_desc = "\n".join([f"- **{tt.type_name}**: {tt.value.desc}" for tt in TaskType]) + prompt = PROMPT_TEMPLATE.format( + context="\n".join([str(ct) for ct in context]), max_tasks=max_tasks, task_type_desc=task_type_desc + ) + rsp = await self._aask(prompt) + rsp = CodeParser.parse_code(text=rsp) + return rsp + + +def update_plan_from_rsp(rsp: str, current_plan: Plan): + rsp = json.loads(rsp) + tasks = [Task(**task_config) for task_config in rsp] + + if len(tasks) == 1 or tasks[0].dependent_task_ids: + if tasks[0].dependent_task_ids and len(tasks) > 1: + # tasks[0].dependent_task_ids means the generated tasks are not a complete plan + # for they depend on tasks in the current plan, in this case, we only support updating one task each time + logger.warning( + "Current plan will take only the first generated task if the generated tasks are not a complete plan" + ) + # handle a single task + if current_plan.has_task_id(tasks[0].task_id): + # replace an existing task + current_plan.replace_task( + tasks[0].task_id, tasks[0].dependent_task_ids, tasks[0].instruction, tasks[0].assignee + ) + else: + # append one task + current_plan.append_task( + tasks[0].task_id, tasks[0].dependent_task_ids, tasks[0].instruction, tasks[0].assignee + ) + + else: + # add tasks in general + current_plan.add_tasks(tasks) + + +def precheck_update_plan_from_rsp(rsp: str, current_plan: Plan) -> Tuple[bool, str]: + temp_plan = deepcopy(current_plan) + try: + update_plan_from_rsp(rsp, temp_plan) + return True, "" + except Exception as e: + return False, e diff --git a/metagpt/actions/execute_task.py b/metagpt/actions/execute_task.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc3bd699a540d4e4a83108910523d04dd1ba73c --- /dev/null +++ b/metagpt/actions/execute_task.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/13 12:26 +@Author : femto Zheng +@File : execute_task.py +""" + + +from metagpt.actions import Action +from metagpt.schema import Message + + +class ExecuteTask(Action): + name: str = "ExecuteTask" + i_context: list[Message] = [] + + async def run(self, *args, **kwargs): + pass diff --git a/metagpt/actions/extract_readme.py b/metagpt/actions/extract_readme.py new file mode 100644 index 0000000000000000000000000000000000000000..69f5503a9aeeb0f168c99493a295daf1aa3103f9 --- /dev/null +++ b/metagpt/actions/extract_readme.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Module Description: This script defines the LearnReadMe class, which is an action to learn from the contents of + a README.md file. +Author: mashenquan +Date: 2024-3-20 +""" +from pathlib import Path +from typing import Optional + +from pydantic import Field + +from metagpt.actions import Action +from metagpt.const import GRAPH_REPO_FILE_REPO +from metagpt.schema import Message +from metagpt.utils.common import aread +from metagpt.utils.di_graph_repository import DiGraphRepository +from metagpt.utils.graph_repository import GraphKeyword, GraphRepository + + +class ExtractReadMe(Action): + """ + An action to extract summary, installation, configuration, usages from the contents of a README.md file. + + Attributes: + graph_db (Optional[GraphRepository]): A graph database repository. + install_to_path (Optional[str]): The path where the repository to install to. + """ + + graph_db: Optional[GraphRepository] = None + install_to_path: Optional[str] = Field(default="/TO/PATH") + _readme: Optional[str] = None + _filename: Optional[str] = None + + async def run(self, with_messages=None, **kwargs): + """ + Implementation of `Action`'s `run` method. + + Args: + with_messages (Optional[Type]): An optional argument specifying messages to react to. + """ + graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name + self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + summary = await self._summarize() + await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_SUMMARY, object_=summary) + install = await self._extract_install() + await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_INSTALL, object_=install) + conf = await self._extract_configuration() + await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_CONFIG, object_=conf) + usage = await self._extract_usage() + await self.graph_db.insert(subject=self._filename, predicate=GraphKeyword.HAS_USAGE, object_=usage) + + await self.graph_db.save() + + return Message(content="", cause_by=self) + + async def _summarize(self) -> str: + readme = await self._get() + summary = await self.llm.aask( + readme, + system_msgs=[ + "You are a tool can summarize git repository README.md file.", + "Return the summary about what is the repository.", + ], + stream=False, + ) + return summary + + async def _extract_install(self) -> str: + await self._get() + install = await self.llm.aask( + self._readme, + system_msgs=[ + "You are a tool can install git repository according to README.md file.", + "Return a bash code block of markdown including:\n" + f"1. git clone the repository to the directory `{self.install_to_path}`;\n" + f"2. cd `{self.install_to_path}`;\n" + f"3. install the repository.", + ], + stream=False, + ) + return install + + async def _extract_configuration(self) -> str: + await self._get() + configuration = await self.llm.aask( + self._readme, + system_msgs=[ + "You are a tool can configure git repository according to README.md file.", + "Return a bash code block of markdown object to configure the repository if necessary, otherwise return" + " a empty bash code block of markdown object", + ], + stream=False, + ) + return configuration + + async def _extract_usage(self) -> str: + await self._get() + usage = await self.llm.aask( + self._readme, + system_msgs=[ + "You are a tool can summarize all usages of git repository according to README.md file.", + "Return a list of code block of markdown objects to demonstrates the usage of the repository.", + ], + stream=False, + ) + return usage + + async def _get(self) -> str: + if self._readme is not None: + return self._readme + root = Path(self.i_context).resolve() + filename = None + for file_path in root.iterdir(): + if file_path.is_file() and file_path.stem == "README": + filename = file_path + break + if not filename: + return "" + self._readme = await aread(filename=filename, encoding="utf-8") + self._filename = str(filename) + return self._readme diff --git a/metagpt/actions/fix_bug.py b/metagpt/actions/fix_bug.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5df6dc6035fed968a801f5e1b93b12e004db4a --- /dev/null +++ b/metagpt/actions/fix_bug.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +""" +@Time : 2023-12-12 +@Author : mashenquan +@File : fix_bug.py +""" +from metagpt.actions import Action + + +class FixBug(Action): + """Fix bug action without any implementation details""" + + name: str = "FixBug" diff --git a/metagpt/actions/generate_questions.py b/metagpt/actions/generate_questions.py new file mode 100644 index 0000000000000000000000000000000000000000..bf0ba62773bb243b20e1087f3fc00bcd20306396 --- /dev/null +++ b/metagpt/actions/generate_questions.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@File : generate_questions.py +""" +from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode + +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="Task: Refer to the context to further inquire about the details that interest you, within a word limit" + " of 150 words. Please provide the specific details you would like to inquire about here", + example=["1. What ...", "2. How ...", "3. ..."], +) + + +class GenerateQuestions(Action): + """This class allows LLM to further mine noteworthy details based on specific "##TOPIC"(discussion topic) and + "##RECORD" (discussion records), thereby deepening the discussion.""" + + name: str = "GenerateQuestions" + + async def run(self, context) -> ActionNode: + return await QUESTIONS.fill(req=context, llm=self.llm) diff --git a/metagpt/actions/import_repo.py b/metagpt/actions/import_repo.py new file mode 100644 index 0000000000000000000000000000000000000000..82aa916f46fb9a2a0b187eaabe20122b26e1a147 --- /dev/null +++ b/metagpt/actions/import_repo.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" + +This script defines an action to import a Git repository into the MetaGPT project format, enabling incremental + appending of requirements. +The MetaGPT project format encompasses a structured representation of project data compatible with MetaGPT's + capabilities, facilitating the integration of Git repositories into MetaGPT workflows while allowing for the gradual + addition of requirements. + +""" +import json +import re +from pathlib import Path +from typing import List, Optional + +from pydantic import BaseModel + +from metagpt.actions import Action +from metagpt.actions.extract_readme import ExtractReadMe +from metagpt.actions.rebuild_class_view import RebuildClassView +from metagpt.actions.rebuild_sequence_view import RebuildSequenceView +from metagpt.const import GRAPH_REPO_FILE_REPO +from metagpt.logs import logger +from metagpt.schema import Message +from metagpt.tools.libs.git import git_clone +from metagpt.utils.common import ( + aread, + awrite, + list_files, + parse_json_code_block, + split_namespace, +) +from metagpt.utils.di_graph_repository import DiGraphRepository +from metagpt.utils.file_repository import FileRepository +from metagpt.utils.git_repository import GitRepository +from metagpt.utils.graph_repository import GraphKeyword, GraphRepository +from metagpt.utils.project_repo import ProjectRepo + + +class ImportRepo(Action): + """ + An action to import a Git repository into a graph database and create related artifacts. + + Attributes: + repo_path (str): The URL of the Git repository to import. + graph_db (Optional[GraphRepository]): The output graph database of the Git repository. + rid (str): The output requirement ID. + """ + + repo_path: str # input, git repo url. + graph_db: Optional[GraphRepository] = None # output. graph db of the git repository + rid: str = "" # output, requirement ID. + + async def run(self, with_messages: List[Message] = None, **kwargs) -> Message: + """ + Runs the import process for the Git repository. + + Args: + with_messages (List[Message], optional): Additional messages to include. + **kwargs: Additional keyword arguments. + + Returns: + Message: A message indicating the completion of the import process. + """ + await self._create_repo() + await self._create_prd() + await self._create_system_design() + self.context.git_repo.archive(comments="Import") + + async def _create_repo(self): + path = await git_clone(url=self.repo_path, output_dir=self.config.workspace.path) + self.repo_path = str(path) + self.config.project_path = path + self.context.git_repo = GitRepository(local_path=path, auto_init=True) + self.context.repo = ProjectRepo(self.context.git_repo) + self.context.src_workspace = await self._guess_src_workspace() + await awrite( + filename=self.context.repo.workdir / ".src_workspace", + data=str(self.context.src_workspace.relative_to(self.context.repo.workdir)), + ) + + async def _create_prd(self): + action = ExtractReadMe(i_context=str(self.context.repo.workdir), context=self.context) + await action.run() + graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name + self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + rows = await self.graph_db.select(predicate=GraphKeyword.HAS_SUMMARY) + prd = {"Project Name": self.context.repo.workdir.name} + for r in rows: + if Path(r.subject).stem == "README": + prd["Original Requirements"] = r.object_ + break + self.rid = FileRepository.new_filename() + await self.repo.docs.prd.save(filename=self.rid + ".json", content=json.dumps(prd)) + + async def _create_system_design(self): + action = RebuildClassView( + name="ReverseEngineering", i_context=str(self.context.src_workspace), context=self.context + ) + await action.run() + rows = await action.graph_db.select(predicate="hasMermaidClassDiagramFile") + class_view_filename = rows[0].object_ + logger.info(f"class view:{class_view_filename}") + + rows = await action.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO) + tag = "__name__:__main__" + entries = [] + src_workspace = self.context.src_workspace.relative_to(self.context.repo.workdir) + for r in rows: + if tag in r.subject: + path = split_namespace(r.subject)[0] + elif tag in r.object_: + path = split_namespace(r.object_)[0] + else: + continue + if Path(path).is_relative_to(src_workspace): + entries.append(Path(path)) + main_entry = await self._guess_main_entry(entries) + full_path = RebuildSequenceView.get_full_filename(self.context.repo.workdir, main_entry) + action = RebuildSequenceView(context=self.context, i_context=str(full_path)) + try: + await action.run() + except Exception as e: + logger.warning(f"{e}, use the last successful version.") + files = list_files(self.context.repo.resources.data_api_design.workdir) + pattern = re.compile(r"[^a-zA-Z0-9]") + name = re.sub(pattern, "_", str(main_entry)) + filename = Path(name).with_suffix(".sequence_diagram.mmd") + postfix = str(filename) + sequence_files = [i for i in files if postfix in str(i)] + content = await aread(filename=sequence_files[0]) + await self.context.repo.resources.data_api_design.save( + filename=self.repo.workdir.stem + ".sequence_diagram.mmd", content=content + ) + await self._save_system_design() + + async def _save_system_design(self): + class_view = await self.context.repo.resources.data_api_design.get( + filename=self.repo.workdir.stem + ".class_diagram.mmd" + ) + sequence_view = await self.context.repo.resources.data_api_design.get( + filename=self.repo.workdir.stem + ".sequence_diagram.mmd" + ) + file_list = self.context.git_repo.get_files(relative_path=".", root_relative_path=self.context.src_workspace) + data = { + "Data structures and interfaces": class_view.content, + "Program call flow": sequence_view.content, + "File list": [str(i) for i in file_list], + } + await self.context.repo.docs.system_design.save(filename=self.rid + ".json", content=json.dumps(data)) + + async def _guess_src_workspace(self) -> Path: + files = list_files(self.context.repo.workdir) + dirs = [i.parent for i in files if i.name == "__init__.py"] + distinct = set() + for i in dirs: + done = False + for j in distinct: + if i.is_relative_to(j): + done = True + break + if j.is_relative_to(i): + break + if not done: + distinct = {j for j in distinct if not j.is_relative_to(i)} + distinct.add(i) + if len(distinct) == 1: + return list(distinct)[0] + prompt = "\n".join([f"- {str(i)}" for i in distinct]) + rsp = await self.llm.aask( + prompt, + system_msgs=[ + "You are a tool to choose the source code path from a list of paths based on the directory name.", + "You should identify the source code path among paths such as unit test path, examples path, etc.", + "Return a markdown JSON object containing:\n" + '- a "src" field containing the source code path;\n' + '- a "reason" field containing explaining why other paths is not the source code path\n', + ], + ) + logger.debug(rsp) + json_blocks = parse_json_code_block(rsp) + + class Data(BaseModel): + src: str + reason: str + + data = Data.model_validate_json(json_blocks[0]) + logger.info(f"src_workspace: {data.src}") + return Path(data.src) + + async def _guess_main_entry(self, entries: List[Path]) -> Path: + if len(entries) == 1: + return entries[0] + + file_list = "## File List\n" + file_list += "\n".join([f"- {i}" for i in entries]) + + rows = await self.graph_db.select(predicate=GraphKeyword.HAS_USAGE) + usage = "## Usage\n" + for r in rows: + if Path(r.subject).stem == "README": + usage += r.object_ + + prompt = file_list + "\n---\n" + usage + rsp = await self.llm.aask( + prompt, + system_msgs=[ + 'You are a tool to choose the source file path from "File List" which is used in "Usage".', + 'You choose the source file path based on the name of file and the class name and package name used in "Usage".', + "Return a markdown JSON object containing:\n" + '- a "filename" field containing the chosen source file path from "File List" which is used in "Usage";\n' + '- a "reason" field explaining why.', + ], + stream=False, + ) + logger.debug(rsp) + json_blocks = parse_json_code_block(rsp) + + class Data(BaseModel): + filename: str + reason: str + + data = Data.model_validate_json(json_blocks[0]) + logger.info(f"main: {data.filename}") + return Path(data.filename) diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf71a8ff9063c95aec6b8acdca637200ba0f603 --- /dev/null +++ b/metagpt/actions/invoice_ocr.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ + +""" +@Time : 2023/9/21 18:10:20 +@Author : Stitch-z +@File : invoice_ocr.py +@Describe : Actions of the invoice ocr assistant. +""" + +import os +import zipfile +from datetime import datetime +from pathlib import Path +from typing import Optional + +import pandas as pd +from paddleocr import PaddleOCR + +from metagpt.actions import Action +from metagpt.const import INVOICE_OCR_TABLE_PATH +from metagpt.logs import logger +from metagpt.prompts.invoice_ocr import ( + EXTRACT_OCR_MAIN_INFO_PROMPT, + REPLY_OCR_QUESTION_PROMPT, +) +from metagpt.utils.common import OutputParser +from metagpt.utils.file import File + + +class InvoiceOCR(Action): + """Action class for performing OCR on invoice files, including zip, PDF, png, and jpg files. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language for OCR output. Defaults to "ch" (Chinese). + + """ + + name: str = "InvoiceOCR" + i_context: Optional[str] = None + + @staticmethod + async def _check_file_type(file_path: Path) -> str: + """Check the file type of the given filename. + + Args: + file_path: The path of the file. + + Returns: + The file type based on FileExtensionType enum. + + Raises: + Exception: If the file format is not zip, pdf, png, or jpg. + """ + ext = file_path.suffix + if ext not in [".zip", ".pdf", ".png", ".jpg"]: + raise Exception("The invoice format is not zip, pdf, png, or jpg") + + return ext + + @staticmethod + async def _unzip(file_path: Path) -> Path: + """Unzip a file and return the path to the unzipped directory. + + Args: + file_path: The path to the zip file. + + Returns: + The path to the unzipped directory. + """ + file_directory = file_path.parent / "unzip_invoices" / datetime.now().strftime("%Y%m%d%H%M%S") + with zipfile.ZipFile(file_path, "r") as zip_ref: + for zip_info in zip_ref.infolist(): + # Use CP437 to encode the file name, and then use GBK decoding to prevent Chinese garbled code + relative_name = Path(zip_info.filename.encode("cp437").decode("gbk")) + if relative_name.suffix: + full_filename = file_directory / relative_name + await File.write(full_filename.parent, relative_name.name, zip_ref.read(zip_info.filename)) + + logger.info(f"unzip_path: {file_directory}") + return file_directory + + @staticmethod + async def _ocr(invoice_file_path: Path): + ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1) + ocr_result = ocr.ocr(str(invoice_file_path), cls=True) + for result in ocr_result[0]: + result[1] = (result[1][0], round(result[1][1], 2)) # round long confidence scores to reduce token costs + return ocr_result + + async def run(self, file_path: Path, *args, **kwargs) -> list: + """Execute the action to identify invoice files through OCR. + + Args: + file_path: The path to the input file. + + Returns: + A list of OCR results. + """ + file_ext = await self._check_file_type(file_path) + + if file_ext == ".zip": + # OCR recognizes zip batch files + unzip_path = await self._unzip(file_path) + ocr_list = [] + for root, _, files in os.walk(unzip_path): + for filename in files: + invoice_file_path = Path(root) / Path(filename) + # Identify files that match the type + if Path(filename).suffix in [".zip", ".pdf", ".png", ".jpg"]: + ocr_result = await self._ocr(str(invoice_file_path)) + ocr_list.append(ocr_result) + return ocr_list + + else: + # OCR identifies single file + ocr_result = await self._ocr(file_path) + return [ocr_result] + + +class GenerateTable(Action): + """Action class for generating tables from OCR results. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language used for the generated table. Defaults to "ch" (Chinese). + + """ + + name: str = "GenerateTable" + i_context: Optional[str] = None + language: str = "ch" + + async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]: + """Processes OCR results, extracts invoice information, generates a table, and saves it as an Excel file. + + Args: + ocr_results: A list of OCR results obtained from invoice processing. + filename: The name of the output Excel file. + + Returns: + A dictionary containing the invoice information. + + """ + table_data = [] + pathname = INVOICE_OCR_TABLE_PATH + pathname.mkdir(parents=True, exist_ok=True) + + for ocr_result in ocr_results: + # Extract invoice OCR main information + prompt = EXTRACT_OCR_MAIN_INFO_PROMPT.format(ocr_result=ocr_result, language=self.language) + ocr_info = await self._aask(prompt=prompt) + invoice_data = OutputParser.extract_struct(ocr_info, dict) + if invoice_data: + table_data.append(invoice_data) + + # Generate Excel file + filename = f"{filename.split('.')[0]}.xlsx" + full_filename = f"{pathname}/{filename}" + df = pd.DataFrame(table_data) + df.to_excel(full_filename, index=False) + return table_data + + +class ReplyQuestion(Action): + """Action class for generating replies to questions based on OCR results. + + Args: + name: The name of the action. Defaults to an empty string. + language: The language used for generating the reply. Defaults to "ch" (Chinese). + + """ + + language: str = "ch" + + async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str: + """Reply to questions based on ocr results. + + Args: + query: The question for which a reply is generated. + ocr_result: A list of OCR results. + + Returns: + A reply result of string type. + """ + prompt = REPLY_OCR_QUESTION_PROMPT.format(query=query, ocr_result=ocr_result, language=self.language) + resp = await self._aask(prompt=prompt) + return resp diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py new file mode 100644 index 0000000000000000000000000000000000000000..393c483cc56525e834374ce1ad95049250273f55 --- /dev/null +++ b/metagpt/actions/prepare_documents.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/11/20 +@Author : mashenquan +@File : prepare_documents.py +@Desc: PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt. + RFC 135 2.2.3.5.1. +""" +import shutil +from pathlib import Path +from typing import Dict, Optional + +from metagpt.actions import Action, UserRequirement +from metagpt.const import REQUIREMENT_FILENAME +from metagpt.logs import logger +from metagpt.schema import AIMessage +from metagpt.utils.common import any_to_str +from metagpt.utils.file_repository import FileRepository +from metagpt.utils.project_repo import ProjectRepo + + +class PrepareDocuments(Action): + """PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt.""" + + name: str = "PrepareDocuments" + i_context: Optional[str] = None + key_descriptions: Optional[Dict[str, str]] = None + send_to: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if not self.key_descriptions: + self.key_descriptions = {"project_path": 'the project path if exists in "Original Requirement"'} + + @property + def config(self): + return self.context.config + + def _init_repo(self) -> ProjectRepo: + """Initialize the Git environment.""" + if not self.config.project_path: + name = self.config.project_name or FileRepository.new_filename() + path = Path(self.config.workspace.path) / name + else: + path = Path(self.config.project_path) + if path.exists() and not self.config.inc: + shutil.rmtree(path) + self.context.kwargs.project_path = path + self.context.kwargs.inc = self.config.inc + return ProjectRepo(path) + + async def run(self, with_messages, **kwargs): + """Create and initialize the workspace folder, initialize the Git environment.""" + user_requirements = [i for i in with_messages if i.cause_by == any_to_str(UserRequirement)] + if not self.config.project_path and user_requirements and self.key_descriptions: + args = await user_requirements[0].parse_resources(llm=self.llm, key_descriptions=self.key_descriptions) + for k, v in args.items(): + if not v or k in ["resources", "reason"]: + continue + self.context.kwargs.set(k, v) + logger.info(f"{k}={v}") + if self.context.kwargs.project_path: + self.config.update_via_cli( + project_path=self.context.kwargs.project_path, + project_name="", + inc=False, + reqa_file=self.context.kwargs.reqa_file or "", + max_auto_summarize_code=0, + ) + + repo = self._init_repo() + + # Write the newly added requirements from the main parameter idea to `docs/requirement.txt`. + await repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content) + # Send a Message notification to the WritePRD action, instructing it to process requirements using + # `docs/requirement.txt` and `docs/prd/`. + return AIMessage( + content="", + instruct_content=AIMessage.create_instruct_value( + kvs={ + "project_path": str(repo.workdir), + "requirements_filename": str(repo.docs.workdir / REQUIREMENT_FILENAME), + "prd_filenames": [str(repo.docs.prd.workdir / i) for i in repo.docs.prd.all_files], + }, + class_name="PrepareDocumentsOutput", + ), + cause_by=self, + send_to=self.send_to, + ) diff --git a/metagpt/actions/prepare_interview.py b/metagpt/actions/prepare_interview.py new file mode 100644 index 0000000000000000000000000000000000000000..0a7eb6581ed3d0a28ec3f0a4dfd1431fdc8a6598 --- /dev/null +++ b/metagpt/actions/prepare_interview.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/9/19 15:02 +@Author : DevXiaolan +@File : prepare_interview.py +""" +from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode + +QUESTIONS = ActionNode( + key="Questions", + expected_type=list[str], + instruction="""Role: You are an interviewer of our company who is well-knonwn in frontend or backend develop; +Requirement: Provide a list of questions for the interviewer to ask the interviewee, by reading the resume of the interviewee in the context. +Attention: Provide as markdown block as the format above, at least 10 questions.""", + example=["1. What ...", "2. How ..."], +) + + +class PrepareInterview(Action): + name: str = "PrepareInterview" + + async def run(self, context): + return await QUESTIONS.fill(req=context, llm=self.llm) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py new file mode 100644 index 0000000000000000000000000000000000000000..2bfe0da3a2c2db64f91910ff34006ef23f3acc32 --- /dev/null +++ b/metagpt/actions/project_management.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 19:12 +@Author : alexanderwu +@File : project_management.py +@Modified By: mashenquan, 2023/11/27. + 1. Divide the context into three components: legacy code, unit test code, and console log. + 2. Move the document storage operations related to WritePRD from the save operation of WriteDesign. + 3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. +""" + +import json +from pathlib import Path +from typing import List, Optional, Union + +from pydantic import BaseModel, Field + +from metagpt.actions.action import Action +from metagpt.actions.project_management_an import PM_NODE, REFINED_PM_NODE +from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME +from metagpt.logs import logger +from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import ( + aread, + awrite, + rectify_pathname, + save_json_to_markdown, + to_markdown_code_block, +) +from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.report import DocsReporter + +NEW_REQ_TEMPLATE = """ +### Legacy Content +{old_task} + +### New Requirements +{context} +""" + + +@register_tool(include_functions=["run"]) +class WriteTasks(Action): + name: str = "CreateTasks" + i_context: Optional[str] = None + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + + async def run( + self, + with_messages: List[Message] = None, + *, + user_requirement: str = "", + design_filename: str = "", + output_pathname: str = "", + **kwargs, + ) -> Union[AIMessage, str]: + """ + Write a project schedule given a project system design file. + + Args: + user_requirement (str, optional): A string specifying the user's requirements. Defaults to an empty string. + design_filename (str): The output file path of the document. Defaults to an empty string. + output_pathname (str, optional): The output path name of file that the project schedule should be saved to. + **kwargs: Additional keyword arguments. + + Returns: + str: Path to the generated project schedule. + + Example: + # Write a project schedule with a given system design. + >>> design_filename = "/absolute/path/to/snake_game/docs/system_design.json" + >>> output_pathname = "/absolute/path/to/snake_game/docs/project_schedule.json" + >>> user_requirement = "Write project schedule for a snake game following these requirements:..." + >>> action = WriteTasks() + >>> result = await action.run(user_requirement=user_requirement, design_filename=design_filename, output_pathname=output_pathname) + >>> print(result) + The project schedule is at /absolute/path/to/snake_game/docs/project_schedule.json + + # Write a project schedule with a user requirement. + >>> user_requirement = "Write project schedule for a snake game following these requirements: ..." + >>> output_pathname = "/absolute/path/to/snake_game/docs/project_schedule.json" + >>> action = WriteTasks() + >>> result = await action.run(user_requirement=user_requirement, output_pathname=output_pathname) + >>> print(result) + The project schedule is at /absolute/path/to/snake_game/docs/project_schedule.json + """ + if not with_messages: + return await self._execute_api( + user_requirement=user_requirement, design_filename=design_filename, output_pathname=output_pathname + ) + + self.input_args = with_messages[-1].instruct_content + self.repo = ProjectRepo(self.input_args.project_path) + changed_system_designs = self.input_args.changed_system_design_filenames + changed_tasks = [str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys())] + change_files = Documents() + # Rewrite the system designs that have undergone changes based on the git head diff under + # `docs/system_designs/`. + for filename in changed_system_designs: + task_doc = await self._update_tasks(filename=filename) + change_files.docs[str(self.repo.docs.task.workdir / task_doc.filename)] = task_doc + + # Rewrite the task files that have undergone changes based on the git head diff under `docs/tasks/`. + for filename in changed_tasks: + if filename in change_files.docs: + continue + task_doc = await self._update_tasks(filename=filename) + change_files.docs[filename] = task_doc + + if not change_files.docs: + logger.info("Nothing has changed.") + # Wait until all files under `docs/tasks/` are processed before sending the publish_message, leaving room for + # global optimization in subsequent steps. + kvs = self.input_args.model_dump() + kvs["changed_task_filenames"] = [ + str(self.repo.docs.task.workdir / i) for i in list(self.repo.docs.task.changed_files.keys()) + ] + kvs["python_package_dependency_filename"] = str(self.repo.workdir / PACKAGE_REQUIREMENTS_FILENAME) + return AIMessage( + content="WBS is completed. " + + "\n".join( + [PACKAGE_REQUIREMENTS_FILENAME] + + list(self.repo.docs.task.changed_files.keys()) + + list(self.repo.resources.api_spec_and_task.changed_files.keys()) + ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WriteTaskOutput"), + cause_by=self, + ) + + async def _update_tasks(self, filename): + root_relative_path = Path(filename).relative_to(self.repo.workdir) + system_design_doc = await Document.load(filename=filename, project_path=self.repo.workdir) + task_doc = await self.repo.docs.task.get(root_relative_path.name) + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "task"}, "meta") + if task_doc: + task_doc = await self._merge(system_design_doc=system_design_doc, task_doc=task_doc) + await self.repo.docs.task.save_doc(doc=task_doc, dependencies={system_design_doc.root_relative_path}) + else: + rsp = await self._run_new_tasks(context=system_design_doc.content) + task_doc = await self.repo.docs.task.save( + filename=system_design_doc.filename, + content=rsp.instruct_content.model_dump_json(), + dependencies={system_design_doc.root_relative_path}, + ) + await self._update_requirements(task_doc) + md = await self.repo.resources.api_spec_and_task.save_pdf(doc=task_doc) + await reporter.async_report(self.repo.workdir / md.root_relative_path, "path") + return task_doc + + async def _run_new_tasks(self, context: str): + node = await PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema) + return node + + async def _merge(self, system_design_doc, task_doc) -> Document: + context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_task=task_doc.content) + node = await REFINED_PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema) + task_doc.content = node.instruct_content.model_dump_json() + return task_doc + + async def _update_requirements(self, doc): + m = json.loads(doc.content) + packages = set(m.get("Required packages", set())) + requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME) + if not requirement_doc: + requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="") + lines = requirement_doc.content.splitlines() + for pkg in lines: + if pkg == "": + continue + packages.add(pkg) + await self.repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages)) + + async def _execute_api( + self, user_requirement: str = "", design_filename: str = "", output_pathname: str = "" + ) -> str: + context = to_markdown_code_block(user_requirement) + if design_filename: + design_filename = rectify_pathname(path=design_filename, default_filename="system_design.md") + content = await aread(filename=design_filename) + context += to_markdown_code_block(content) + + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "task"}, "meta") + node = await self._run_new_tasks(context) + file_content = node.instruct_content.model_dump_json() + + if not output_pathname: + output_pathname = Path(output_pathname) / "docs" / "project_schedule.json" + elif not Path(output_pathname).is_absolute(): + output_pathname = self.config.workspace.path / output_pathname + output_pathname = rectify_pathname(path=output_pathname, default_filename="project_schedule.json") + await awrite(filename=output_pathname, data=file_content) + md_output_filename = output_pathname.with_suffix(".md") + await save_json_to_markdown(content=file_content, output_filename=md_output_filename) + await reporter.async_report(md_output_filename, "path") + return f'Project Schedule filename: "{str(output_pathname)}"' diff --git a/metagpt/actions/project_management_an.py b/metagpt/actions/project_management_an.py new file mode 100644 index 0000000000000000000000000000000000000000..a953feb4cba40280c18409242bbae6c8fa8ded48 --- /dev/null +++ b/metagpt/actions/project_management_an.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/14 15:28 +@Author : alexanderwu +@File : project_management_an.py +""" +from typing import List, Optional + +from metagpt.actions.action_node import ActionNode + +REQUIRED_PACKAGES = ActionNode( + key="Required packages", + expected_type=Optional[List[str]], + instruction="Provide required packages The response language should correspond to the context and requirements.", + example=["flask==1.1.2", "bcrypt==3.2.0"], +) + +REQUIRED_OTHER_LANGUAGE_PACKAGES = ActionNode( + key="Required Other language third-party packages", + expected_type=List[str], + instruction="List down the required packages for languages other than Python.", + example=["No third-party dependencies required"], +) + +LOGIC_ANALYSIS = ActionNode( + key="Logic Analysis", + expected_type=List[List[str]], + instruction="Provide a list of files with the classes/methods/functions to be implemented, " + "including dependency analysis and imports." + "Ensure consistency between System Design and Logic Analysis; the files must match exactly. " + "If the file is written in Vue or React, use Tailwind CSS for styling.", + example=[ + ["game.py", "Contains Game class and ... functions"], + ["main.py", "Contains main function, from game import Game"], + ], +) + +REFINED_LOGIC_ANALYSIS = ActionNode( + key="Refined Logic Analysis", + expected_type=List[List[str]], + instruction="Review and refine the logic analysis by merging the Legacy Content and Incremental Content. " + "Provide a comprehensive list of files with classes/methods/functions to be implemented or modified incrementally. " + "Include dependency analysis, consider potential impacts on existing code, and document necessary imports.", + example=[ + ["game.py", "Contains Game class and ... functions"], + ["main.py", "Contains main function, from game import Game"], + ["new_feature.py", "Introduces NewFeature class and related functions"], + ["utils.py", "Modifies existing utility functions to support incremental changes"], + ], +) + +TASK_LIST = ActionNode( + key="Task list", + expected_type=List[str], + instruction="Break down the tasks into a list of filenames, prioritized by dependency order.", + example=["game.py", "main.py"], +) + +REFINED_TASK_LIST = ActionNode( + key="Refined Task list", + expected_type=List[str], + instruction="Review and refine the combined task list after the merger of Legacy Content and Incremental Content, " + "and consistent with Refined File List. Ensure that tasks are organized in a logical and prioritized order, " + "considering dependencies for a streamlined and efficient development process. ", + example=["new_feature.py", "utils", "game.py", "main.py"], +) + +FULL_API_SPEC = ActionNode( + key="Full API spec", + expected_type=str, + instruction="Describe all APIs using OpenAPI 3.0 spec that may be used by both frontend and backend. If front-end " + "and back-end communication is not required, leave it blank.", + example="openapi: 3.0.0 ...", +) + +SHARED_KNOWLEDGE = ActionNode( + key="Shared Knowledge", + expected_type=str, + instruction="Detail any shared knowledge, like common utility functions or configuration variables.", + example="`game.py` contains functions shared across the project.", +) + +REFINED_SHARED_KNOWLEDGE = ActionNode( + key="Refined Shared Knowledge", + expected_type=str, + instruction="Update and expand shared knowledge to reflect any new elements introduced. This includes common " + "utility functions, configuration variables for team collaboration. Retain content that is not related to " + "incremental development but important for consistency and clarity.", + example="`new_module.py` enhances shared utility functions for improved code reusability and collaboration.", +) + + +ANYTHING_UNCLEAR_PM = ActionNode( + key="Anything UNCLEAR", + expected_type=str, + instruction="Mention any unclear aspects in the project management context and try to clarify them.", + example="Clarification needed on how to start and initialize third-party libraries.", +) + +NODES = [ + REQUIRED_PACKAGES, + REQUIRED_OTHER_LANGUAGE_PACKAGES, + LOGIC_ANALYSIS, + TASK_LIST, + FULL_API_SPEC, + SHARED_KNOWLEDGE, + ANYTHING_UNCLEAR_PM, +] + +REFINED_NODES = [ + REQUIRED_PACKAGES, + REQUIRED_OTHER_LANGUAGE_PACKAGES, + REFINED_LOGIC_ANALYSIS, + REFINED_TASK_LIST, + FULL_API_SPEC, + REFINED_SHARED_KNOWLEDGE, + ANYTHING_UNCLEAR_PM, +] + +PM_NODE = ActionNode.from_children("PM_NODE", NODES) +REFINED_PM_NODE = ActionNode.from_children("REFINED_PM_NODE", REFINED_NODES) diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py new file mode 100644 index 0000000000000000000000000000000000000000..64f003f919dbac8abb20e7ff1f3d014f9d0473c6 --- /dev/null +++ b/metagpt/actions/rebuild_class_view.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/19 +@Author : mashenquan +@File : rebuild_class_view.py +@Desc : Reconstructs class diagram from a source code project. + Implement RFC197, https://deepwisdom.feishu.cn/wiki/VyK0wfq56ivuvjklMKJcmHQknGt +""" + +from pathlib import Path +from typing import Optional, Set, Tuple + +import aiofiles + +from metagpt.actions import Action +from metagpt.const import ( + AGGREGATION, + COMPOSITION, + DATA_API_DESIGN_FILE_REPO, + GENERALIZATION, + GRAPH_REPO_FILE_REPO, +) +from metagpt.logs import logger +from metagpt.repo_parser import DotClassInfo, RepoParser +from metagpt.schema import UMLClassView +from metagpt.utils.common import concat_namespace, split_namespace +from metagpt.utils.di_graph_repository import DiGraphRepository +from metagpt.utils.graph_repository import GraphKeyword, GraphRepository + + +class RebuildClassView(Action): + """ + Reconstructs a graph repository about class diagram from a source code project. + + Attributes: + graph_db (Optional[GraphRepository]): The optional graph repository. + """ + + graph_db: Optional[GraphRepository] = None + + async def run(self, with_messages=None, format=None): + """ + Implementation of `Action`'s `run` method. + + Args: + with_messages (Optional[Type]): An optional argument specifying messages to react to. + format (str): The format for the prompt schema. + """ + format = format if format else self.config.prompt_schema + graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name + self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + repo_parser = RepoParser(base_directory=Path(self.i_context)) + # use pylint + class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.i_context)) + await GraphRepository.update_graph_db_with_class_views(self.graph_db, class_views) + await GraphRepository.update_graph_db_with_class_relationship_views(self.graph_db, relationship_views) + await GraphRepository.rebuild_composition_relationship(self.graph_db) + # use ast + direction, diff_path = self._diff_path(path_root=Path(self.i_context).resolve(), package_root=package_root) + symbols = repo_parser.generate_symbols() + for file_info in symbols: + # Align to the same root directory in accordance with `class_views`. + file_info.file = self._align_root(file_info.file, direction, diff_path) + await GraphRepository.update_graph_db_with_file_info(self.graph_db, file_info) + await self._create_mermaid_class_views() + await self.graph_db.save() + + async def _create_mermaid_class_views(self) -> str: + """Creates a Mermaid class diagram using data from the `graph_db` graph repository. + + This method utilizes information stored in the graph repository to generate a Mermaid class diagram. + Returns: + mermaid class diagram file name. + """ + path = self.context.git_repo.workdir / DATA_API_DESIGN_FILE_REPO + path.mkdir(parents=True, exist_ok=True) + pathname = path / self.context.git_repo.workdir.name + filename = str(pathname.with_suffix(".class_diagram.mmd")) + async with aiofiles.open(filename, mode="w", encoding="utf-8") as writer: + content = "classDiagram\n" + logger.debug(content) + await writer.write(content) + # class names + rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) + class_distinct = set() + relationship_distinct = set() + for r in rows: + content = await self._create_mermaid_class(r.subject) + if content: + await writer.write(content) + class_distinct.add(r.subject) + for r in rows: + content, distinct = await self._create_mermaid_relationship(r.subject) + if content: + logger.debug(content) + await writer.write(content) + relationship_distinct.update(distinct) + logger.info(f"classes: {len(class_distinct)}, relationship: {len(relationship_distinct)}") + + if self.i_context: + r_filename = Path(filename).relative_to(self.context.git_repo.workdir) + await self.graph_db.insert( + subject=self.i_context, predicate="hasMermaidClassDiagramFile", object_=str(r_filename) + ) + logger.info(f"{self.i_context} hasMermaidClassDiagramFile {filename}") + return filename + + async def _create_mermaid_class(self, ns_class_name) -> str: + """Generates a Mermaid class diagram for a specific class using data from the `graph_db` graph repository. + + Args: + ns_class_name (str): The namespace-prefixed name of the class for which the Mermaid class diagram is to be created. + + Returns: + str: A Mermaid code block object in markdown representing the class diagram. + """ + fields = split_namespace(ns_class_name) + if len(fields) > 2: + # Ignore sub-class + return "" + + rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL) + if not rows: + return "" + dot_class_info = DotClassInfo.model_validate_json(rows[0].object_) + class_view = UMLClassView.load_dot_class_info(dot_class_info) + + # update uml view + await self.graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json()) + # update uml isCompositeOf + for c in dot_class_info.compositions: + await self.graph_db.insert( + subject=ns_class_name, + predicate=GraphKeyword.IS + COMPOSITION + GraphKeyword.OF, + object_=concat_namespace("?", c), + ) + + # update uml isAggregateOf + for a in dot_class_info.aggregations: + await self.graph_db.insert( + subject=ns_class_name, + predicate=GraphKeyword.IS + AGGREGATION + GraphKeyword.OF, + object_=concat_namespace("?", a), + ) + + content = class_view.get_mermaid(align=1) + logger.debug(content) + return content + + async def _create_mermaid_relationship(self, ns_class_name: str) -> Tuple[Optional[str], Optional[Set]]: + """Generates a Mermaid class relationship diagram for a specific class using data from the `graph_db` graph repository. + + Args: + ns_class_name (str): The namespace-prefixed class name for which the Mermaid relationship diagram is to be created. + + Returns: + Tuple[str, Set]: A tuple containing the relationship diagram as a string and a set of deduplication. + """ + s_fields = split_namespace(ns_class_name) + if len(s_fields) > 2: + # Ignore sub-class + return None, None + + predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]} + mappings = { + GENERALIZATION: " <|-- ", + COMPOSITION: " *-- ", + AGGREGATION: " o-- ", + } + content = "" + distinct = set() + for p, v in predicates.items(): + rows = await self.graph_db.select(subject=ns_class_name, predicate=p) + for r in rows: + o_fields = split_namespace(r.object_) + if len(o_fields) > 2: + # Ignore sub-class + continue + relationship = mappings.get(v, " .. ") + link = f"{o_fields[1]}{relationship}{s_fields[1]}" + distinct.add(link) + content += f"\t{link}\n" + + return content, distinct + + @staticmethod + def _diff_path(path_root: Path, package_root: Path) -> (str, str): + """Returns the difference between the root path and the path information represented in the package name. + + Args: + path_root (Path): The root path. + package_root (Path): The package root path. + + Returns: + Tuple[str, str]: A tuple containing the representation of the difference ("+", "-", "=") and the path detail of the differing part. + + Example: + >>> _diff_path(path_root=Path("/Users/x/github/MetaGPT"), package_root=Path("/Users/x/github/MetaGPT/metagpt")) + "-", "metagpt" + + >>> _diff_path(path_root=Path("/Users/x/github/MetaGPT/metagpt"), package_root=Path("/Users/x/github/MetaGPT/metagpt")) + "=", "." + """ + if len(str(path_root)) > len(str(package_root)): + return "+", str(path_root.relative_to(package_root)) + if len(str(path_root)) < len(str(package_root)): + return "-", str(package_root.relative_to(path_root)) + return "=", "." + + @staticmethod + def _align_root(path: str, direction: str, diff_path: str) -> str: + """Aligns the path to the same root represented by `diff_path`. + + Args: + path (str): The path to be aligned. + direction (str): The direction of alignment ('+', '-', '='). + diff_path (str): The path representing the difference. + + Returns: + str: The aligned path. + + Example: + >>> _align_root(path="metagpt/software_company.py", direction="+", diff_path="MetaGPT") + "MetaGPT/metagpt/software_company.py" + + >>> _align_root(path="metagpt/software_company.py", direction="-", diff_path="metagpt") + "software_company.py" + """ + if direction == "=": + return path + if direction == "+": + return diff_path + "/" + path + else: + return path[len(diff_path) + 1 :] diff --git a/metagpt/actions/rebuild_sequence_view.py b/metagpt/actions/rebuild_sequence_view.py new file mode 100644 index 0000000000000000000000000000000000000000..627cbd151b0059bf325457a856ed07ffba25ef66 --- /dev/null +++ b/metagpt/actions/rebuild_sequence_view.py @@ -0,0 +1,605 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 +@Author : mashenquan +@File : rebuild_sequence_view.py +@Desc : Reconstruct sequence view information through reverse engineering. + Implement RFC197, https://deepwisdom.feishu.cn/wiki/VyK0wfq56ivuvjklMKJcmHQknGt +""" +from __future__ import annotations + +import re +from datetime import datetime +from pathlib import Path +from typing import List, Optional, Set + +from pydantic import BaseModel +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.const import GRAPH_REPO_FILE_REPO +from metagpt.logs import logger +from metagpt.repo_parser import CodeBlockInfo, DotClassInfo +from metagpt.schema import UMLClassView +from metagpt.utils.common import ( + add_affix, + aread, + auto_namespace, + concat_namespace, + general_after_log, + list_files, + parse_json_code_block, + read_file_block, + split_namespace, +) +from metagpt.utils.di_graph_repository import DiGraphRepository +from metagpt.utils.graph_repository import SPO, GraphKeyword, GraphRepository + + +class ReverseUseCase(BaseModel): + """ + Represents a reverse engineered use case. + + Attributes: + description (str): A description of the reverse use case. + inputs (List[str]): List of inputs for the reverse use case. + outputs (List[str]): List of outputs for the reverse use case. + actors (List[str]): List of actors involved in the reverse use case. + steps (List[str]): List of steps for the reverse use case. + reason (str): The reason behind the reverse use case. + """ + + description: str + inputs: List[str] + outputs: List[str] + actors: List[str] + steps: List[str] + reason: str + + +class ReverseUseCaseDetails(BaseModel): + """ + Represents details of a reverse engineered use case. + + Attributes: + description (str): A description of the reverse use case details. + use_cases (List[ReverseUseCase]): List of reverse use cases. + relationship (List[str]): List of relationships associated with the reverse use case details. + """ + + description: str + use_cases: List[ReverseUseCase] + relationship: List[str] + + +class RebuildSequenceView(Action): + """ + Represents an action to reconstruct sequence view through reverse engineering. + + Attributes: + graph_db (Optional[GraphRepository]): An optional instance of GraphRepository for graph database operations. + """ + + graph_db: Optional[GraphRepository] = None + + async def run(self, with_messages=None, format=None): + """ + Implementation of `Action`'s `run` method. + + Args: + with_messages (Optional[Type]): An optional argument specifying messages to react to. + format (str): The format for the prompt schema. + """ + format = format if format else self.config.prompt_schema + graph_repo_pathname = self.context.git_repo.workdir / GRAPH_REPO_FILE_REPO / self.context.git_repo.workdir.name + self.graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json"))) + if not self.i_context: + entries = await self._search_main_entry() + else: + entries = [SPO(subject=self.i_context, predicate="", object_="")] + for entry in entries: + await self._rebuild_main_sequence_view(entry) + while await self._merge_sequence_view(entry): + pass + await self.graph_db.save() + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _rebuild_main_sequence_view(self, entry: SPO): + """ + Reconstruct the sequence diagram for the __main__ entry of the source code through reverse engineering. + + Args: + entry (SPO): The SPO (Subject, Predicate, Object) object in the graph database that is related to the + subject `__name__:__main__`. + """ + filename = entry.subject.split(":", 1)[0] + rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) + classes = [] + prefix = filename + ":" + for r in rows: + if prefix in r.subject: + classes.append(r) + await self._rebuild_use_case(r.subject) + participants = await self._search_participants(split_namespace(entry.subject)[0]) + class_details = [] + class_views = [] + for c in classes: + detail = await self._get_class_detail(c.subject) + if not detail: + continue + class_details.append(detail) + view = await self._get_uml_class_view(c.subject) + if view: + class_views.append(view) + + actors = await self._get_participants(c.subject) + participants.update(set(actors)) + + use_case_blocks = [] + for c in classes: + use_cases = await self._get_class_use_cases(c.subject) + use_case_blocks.append(use_cases) + prompt_blocks = ["## Use Cases\n" + "\n".join(use_case_blocks)] + block = "## Participants\n" + for p in participants: + block += f"- {p}\n" + prompt_blocks.append(block) + block = "## Mermaid Class Views\n```mermaid\n" + block += "\n\n".join([c.get_mermaid() for c in class_views]) + block += "\n```\n" + prompt_blocks.append(block) + block = "## Source Code\n```python\n" + block += await self._get_source_code(filename) + block += "\n```\n" + prompt_blocks.append(block) + prompt = "\n---\n".join(prompt_blocks) + + rsp = await self.llm.aask( + msg=prompt, + system_msgs=[ + "You are a python code to Mermaid Sequence Diagram translator in function detail.", + "Translate the given markdown text to a Mermaid Sequence Diagram.", + "Return the merged Mermaid sequence diagram in a markdown code block format.", + ], + stream=False, + ) + sequence_view = rsp.removeprefix("```mermaid").removesuffix("```") + rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW) + for r in rows: + if r.predicate == GraphKeyword.HAS_SEQUENCE_VIEW: + await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_) + await self.graph_db.insert( + subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view + ) + await self.graph_db.insert( + subject=entry.subject, + predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER, + object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)), + ) + for c in classes: + await self.graph_db.insert( + subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject) + ) + await self._save_sequence_view(subject=entry.subject, content=sequence_view) + + async def _merge_sequence_view(self, entry: SPO) -> bool: + """ + Augments additional information to the provided SPO (Subject, Predicate, Object) entry in the sequence diagram. + + Args: + entry (SPO): The SPO object representing the relationship in the graph database. + + Returns: + bool: True if additional information has been augmented, otherwise False. + """ + new_participant = await self._search_new_participant(entry) + if not new_participant: + return False + + await self._merge_participant(entry, new_participant) + return True + + async def _search_main_entry(self) -> List: + """ + Asynchronously searches for the SPO object that is related to `__name__:__main__`. + + Returns: + List: A list containing information about the main entry in the sequence diagram. + """ + rows = await self.graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO) + tag = "__name__:__main__" + entries = [] + for r in rows: + if tag in r.subject or tag in r.object_: + entries.append(r) + return entries + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _rebuild_use_case(self, ns_class_name: str): + """ + Asynchronously reconstructs the use case for the provided namespace-prefixed class name. + + Args: + ns_class_name (str): The namespace-prefixed class name for which the use case is to be reconstructed. + """ + rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE) + if rows: + return + + detail = await self._get_class_detail(ns_class_name) + if not detail: + return + participants = set() + participants.update(set(detail.compositions)) + participants.update(set(detail.aggregations)) + class_view = await self._get_uml_class_view(ns_class_name) + source_code = await self._get_source_code(ns_class_name) + + prompt_blocks = [] + block = "## Participants\n" + for p in participants: + block += f"- {p}\n" + prompt_blocks.append(block) + block = "## Mermaid Class Views\n```mermaid\n" + block += class_view.get_mermaid() + block += "\n```\n" + prompt_blocks.append(block) + block = "## Source Code\n```python\n" + block += source_code + block += "\n```\n" + prompt_blocks.append(block) + prompt = "\n---\n".join(prompt_blocks) + + rsp = await self.llm.aask( + msg=prompt, + system_msgs=[ + "You are a python code to UML 2.0 Use Case translator.", + 'The generated UML 2.0 Use Case must include the roles or entities listed in "Participants".', + "The functional descriptions of Actors and Use Cases in the generated UML 2.0 Use Case must not " + 'conflict with the information in "Mermaid Class Views".', + 'The section under `if __name__ == "__main__":` of "Source Code" contains information about external ' + "system interactions with the internal system.", + "Return a markdown JSON object with:\n" + '- a "description" key to explain what the whole source code want to do;\n' + '- a "use_cases" key list all use cases, each use case in the list should including a `description` ' + "key describes about what the use case to do, a `inputs` key lists the input names of the use case " + "from external sources, a `outputs` key lists the output names of the use case to external sources, " + "a `actors` key lists the participant actors of the use case, a `steps` key lists the steps about how " + "the use case works step by step, a `reason` key explaining under what circumstances would the " + "external system execute this use case.\n" + '- a "relationship" key lists all the descriptions of relationship among these use cases.\n', + ], + stream=False, + ) + + code_blocks = parse_json_code_block(rsp) + for block in code_blocks: + detail = ReverseUseCaseDetails.model_validate_json(block) + await self.graph_db.insert( + subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json() + ) + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _rebuild_sequence_view(self, ns_class_name: str): + """ + Asynchronously reconstructs the sequence diagram for the provided namespace-prefixed class name. + + Args: + ns_class_name (str): The namespace-prefixed class name for which the sequence diagram is to be reconstructed. + """ + await self._rebuild_use_case(ns_class_name) + + prompts_blocks = [] + use_case_markdown = await self._get_class_use_cases(ns_class_name) + if not use_case_markdown: # external class + await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="") + return + block = f"## Use Cases\n{use_case_markdown}" + prompts_blocks.append(block) + + participants = await self._get_participants(ns_class_name) + block = "## Participants\n" + "\n".join([f"- {s}" for s in participants]) + prompts_blocks.append(block) + + view = await self._get_uml_class_view(ns_class_name) + block = "## Mermaid Class Views\n```mermaid\n" + block += view.get_mermaid() + block += "\n```\n" + prompts_blocks.append(block) + + block = "## Source Code\n```python\n" + block += await self._get_source_code(ns_class_name) + block += "\n```\n" + prompts_blocks.append(block) + prompt = "\n---\n".join(prompts_blocks) + + rsp = await self.llm.aask( + prompt, + system_msgs=[ + "You are a Mermaid Sequence Diagram translator in function detail.", + "Translate the markdown text to a Mermaid Sequence Diagram.", + "Response must be concise.", + "Return a markdown mermaid code block.", + ], + stream=False, + ) + + sequence_view = rsp.removeprefix("```mermaid").removesuffix("```") + await self.graph_db.insert( + subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view + ) + + async def _get_participants(self, ns_class_name: str) -> List[str]: + """ + Asynchronously returns the participants list of the sequence diagram for the provided namespace-prefixed SPO + object. + + Args: + ns_class_name (str): The namespace-prefixed class name for which to retrieve the participants list. + + Returns: + List[str]: A list of participants in the sequence diagram. + """ + participants = set() + detail = await self._get_class_detail(ns_class_name) + if not detail: + return [] + participants.update(set(detail.compositions)) + participants.update(set(detail.aggregations)) + return list(participants) + + async def _get_class_use_cases(self, ns_class_name: str) -> str: + """ + Asynchronously assembles the context about the use case information of the namespace-prefixed SPO object. + + Args: + ns_class_name (str): The namespace-prefixed class name for which to retrieve use case information. + + Returns: + str: A string containing the assembled context about the use case information. + """ + block = "" + rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE) + for i, r in enumerate(rows): + detail = ReverseUseCaseDetails.model_validate_json(r.object_) + block += f"\n### {i + 1}. {detail.description}" + for j, use_case in enumerate(detail.use_cases): + block += f"\n#### {i + 1}.{j + 1}. {use_case.description}\n" + block += "\n##### Inputs\n" + "\n".join([f"- {s}" for s in use_case.inputs]) + block += "\n##### Outputs\n" + "\n".join([f"- {s}" for s in use_case.outputs]) + block += "\n##### Actors\n" + "\n".join([f"- {s}" for s in use_case.actors]) + block += "\n##### Steps\n" + "\n".join([f"- {s}" for s in use_case.steps]) + block += "\n#### Use Case Relationship\n" + "\n".join([f"- {s}" for s in detail.relationship]) + return block + "\n" + + async def _get_class_detail(self, ns_class_name: str) -> DotClassInfo | None: + """ + Asynchronously retrieves the dot format class details of the namespace-prefixed SPO object. + + Args: + ns_class_name (str): The namespace-prefixed class name for which to retrieve class details. + + Returns: + Union[DotClassInfo, None]: A DotClassInfo object representing the dot format class details, + or None if the details are not available. + """ + rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_DETAIL) + if not rows: + return None + dot_class_info = DotClassInfo.model_validate_json(rows[0].object_) + return dot_class_info + + async def _get_uml_class_view(self, ns_class_name: str) -> UMLClassView | None: + """ + Asynchronously retrieves the UML 2.0 format class details of the namespace-prefixed SPO object. + + Args: + ns_class_name (str): The namespace-prefixed class name for which to retrieve UML class details. + + Returns: + Union[UMLClassView, None]: A UMLClassView object representing the UML 2.0 format class details, + or None if the details are not available. + """ + rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_VIEW) + if not rows: + return None + class_view = UMLClassView.model_validate_json(rows[0].object_) + return class_view + + async def _get_source_code(self, ns_class_name: str) -> str: + """ + Asynchronously retrieves the source code of the namespace-prefixed SPO object. + + Args: + ns_class_name (str): The namespace-prefixed class name for which to retrieve the source code. + + Returns: + str: A string containing the source code of the specified namespace-prefixed class. + """ + rows = await self.graph_db.select(subject=ns_class_name, predicate=GraphKeyword.HAS_PAGE_INFO) + filename = split_namespace(ns_class_name=ns_class_name)[0] + if not rows: + src_filename = RebuildSequenceView.get_full_filename(root=self.i_context, pathname=filename) + if not src_filename: + return "" + return await aread(filename=src_filename, encoding="utf-8") + code_block_info = CodeBlockInfo.model_validate_json(rows[0].object_) + return await read_file_block( + filename=filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno + ) + + @staticmethod + def get_full_filename(root: str | Path, pathname: str | Path) -> Path | None: + """ + Convert package name to the full path of the module. + + Args: + root (Union[str, Path]): The root path or string representing the package. + pathname (Union[str, Path]): The pathname or string representing the module. + + Returns: + Union[Path, None]: The full path of the module, or None if the path cannot be determined. + + Examples: + If `root`(workdir) is "/User/xxx/github/MetaGPT/metagpt", and the `pathname` is + "metagpt/management/skill_manager.py", then the returned value will be + "/User/xxx/github/MetaGPT/metagpt/management/skill_manager.py" + """ + if re.match(r"^/.+", str(pathname)): + return pathname + files = list_files(root=root) + postfix = "/" + str(pathname) + for i in files: + if str(i).endswith(postfix): + return i + return None + + @staticmethod + def parse_participant(mermaid_sequence_diagram: str) -> List[str]: + """ + Parses the provided Mermaid sequence diagram and returns the list of participants. + + Args: + mermaid_sequence_diagram (str): The Mermaid sequence diagram string to be parsed. + + Returns: + List[str]: A list of participants extracted from the sequence diagram. + """ + pattern = r"participant ([\w\.]+)" + matches = re.findall(pattern, mermaid_sequence_diagram) + matches = [re.sub(r"[\\/'\"]+", "", i) for i in matches] + return matches + + async def _search_new_participant(self, entry: SPO) -> str | None: + """ + Asynchronously retrieves a participant whose sequence diagram has not been augmented. + + Args: + entry (SPO): The SPO object representing the relationship in the graph database. + + Returns: + Union[str, None]: A participant whose sequence diagram has not been augmented, or None if not found. + """ + rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW) + if not rows: + return None + sequence_view = rows[0].object_ + rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT) + merged_participants = [] + for r in rows: + name = split_namespace(r.object_)[-1] + merged_participants.append(name) + participants = self.parse_participant(sequence_view) + for p in participants: + if p in merged_participants: + continue + return p + return None + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _merge_participant(self, entry: SPO, class_name: str): + """ + Augments the sequence diagram of `class_name` to the sequence diagram of `entry`. + + Args: + entry (SPO): The SPO object representing the base sequence diagram. + class_name (str): The class name whose sequence diagram is to be augmented. + """ + rows = await self.graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS) + participants = [] + for r in rows: + name = split_namespace(r.subject)[-1] + if name == class_name: + participants.append(r) + if len(participants) == 0: # external participants + await self.graph_db.insert( + subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name) + ) + return + if len(participants) > 1: + for r in participants: + await self.graph_db.insert( + subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject) + ) + return + + participant = participants[0] + await self._rebuild_sequence_view(participant.subject) + sequence_views = await self.graph_db.select( + subject=participant.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW + ) + if not sequence_views: # external class + return + rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW) + prompt = f"```mermaid\n{sequence_views[0].object_}\n```\n---\n```mermaid\n{rows[0].object_}\n```" + + rsp = await self.llm.aask( + prompt, + system_msgs=[ + "You are a tool to merge sequence diagrams into one.", + "Participants with the same name are considered identical.", + "Return the merged Mermaid sequence diagram in a markdown code block format.", + ], + stream=False, + ) + + sequence_view = rsp.removeprefix("```mermaid").removesuffix("```") + rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW) + for r in rows: + await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_) + await self.graph_db.insert( + subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view + ) + await self.graph_db.insert( + subject=entry.subject, + predicate=GraphKeyword.HAS_SEQUENCE_VIEW_VER, + object_=concat_namespace(datetime.now().strftime("%Y%m%d%H%M%S%f")[:-3], add_affix(sequence_view)), + ) + await self.graph_db.insert( + subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject) + ) + await self._save_sequence_view(subject=entry.subject, content=sequence_view) + + async def _save_sequence_view(self, subject: str, content: str): + pattern = re.compile(r"[^a-zA-Z0-9]") + name = re.sub(pattern, "_", subject) + filename = Path(name).with_suffix(".sequence_diagram.mmd") + await self.context.repo.resources.data_api_design.save(filename=str(filename), content=content) + + async def _search_participants(self, filename: str) -> Set: + content = await self._get_source_code(filename) + + rsp = await self.llm.aask( + msg=content, + system_msgs=[ + "You are a tool for listing all class names used in a source file.", + "Return a markdown JSON object with: " + '- a "class_names" key containing the list of class names used in the file; ' + '- a "reasons" key lists all reason objects, each object containing a "class_name" key for class name, a "reference" key explaining the line where the class has been used.', + ], + ) + + class _Data(BaseModel): + class_names: List[str] + reasons: List + + json_blocks = parse_json_code_block(rsp) + data = _Data.model_validate_json(json_blocks[0]) + return set(data.class_names) diff --git a/metagpt/actions/requirement_analysis/.DS_Store b/metagpt/actions/requirement_analysis/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..6e14c7e5b03715a3765f3df7a79fed65a921ae37 Binary files /dev/null and b/metagpt/actions/requirement_analysis/.DS_Store differ diff --git a/metagpt/actions/requirement_analysis/__init__.py b/metagpt/actions/requirement_analysis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d196bafeeb1e83d799736a812034765f1ed82899 --- /dev/null +++ b/metagpt/actions/requirement_analysis/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : __init__.py +@Desc : The implementation of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +from metagpt.actions.requirement_analysis.evaluate_action import EvaluationData, EvaluateAction + +__all__ = [EvaluationData, EvaluateAction] diff --git a/metagpt/actions/requirement_analysis/evaluate_action.py b/metagpt/actions/requirement_analysis/evaluate_action.py new file mode 100644 index 0000000000000000000000000000000000000000..376c73f2c978fd8b231cdd9a9149e77862bdc470 --- /dev/null +++ b/metagpt/actions/requirement_analysis/evaluate_action.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : evaluate_action.py +@Desc : The implementation of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +from typing import Optional + +from pydantic import BaseModel +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.utils.common import CodeParser, general_after_log, to_markdown_code_block + + +class EvaluationData(BaseModel): + """Model to represent evaluation data. + + Attributes: + is_pass (bool): Indicates if the evaluation passed or failed. + conclusion (Optional[str]): Conclusion or remarks about the evaluation. + """ + + is_pass: bool + conclusion: Optional[str] = None + + +class EvaluateAction(Action): + """The base class for an evaluation action. + + This class provides methods to evaluate prompts using a specified language model. + """ + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _evaluate(self, prompt: str) -> (bool, str): + """Evaluates a given prompt. + + Args: + prompt (str): The prompt to be evaluated. + + Returns: + tuple: A tuple containing: + - bool: Indicates if the evaluation passed. + - str: The JSON string containing the evaluation data. + """ + rsp = await self.llm.aask(prompt) + json_data = CodeParser.parse_code(text=rsp, lang="json") + data = EvaluationData.model_validate_json(json_data) + return data.is_pass, to_markdown_code_block(val=json_data, type_="json") + + async def _vote(self, prompt: str) -> EvaluationData: + """Evaluates a prompt multiple times and returns the consensus. + + Args: + prompt (str): The prompt to be evaluated. + + Returns: + EvaluationData: An object containing the evaluation result and a summary of evaluations. + """ + evaluations = {} + for i in range(3): + vote, evaluation = await self._evaluate(prompt) + val = evaluations.get(vote, []) + val.append(evaluation) + if len(val) > 1: + return EvaluationData(is_pass=vote, conclusion="\n".join(val)) + evaluations[vote] = val diff --git a/metagpt/actions/requirement_analysis/framework/__init__.py b/metagpt/actions/requirement_analysis/framework/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..968effd862ffa6187cf034e981186eeecb47207b --- /dev/null +++ b/metagpt/actions/requirement_analysis/framework/__init__.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : __init__.py +@Desc : The implementation of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +import json +import uuid +from datetime import datetime +from pathlib import Path +from typing import Optional, Union, List + +from pydantic import BaseModel + +from metagpt.actions.requirement_analysis.framework.evaluate_framework import EvaluateFramework +from metagpt.actions.requirement_analysis.framework.write_framework import WriteFramework +from metagpt.config2 import config +from metagpt.utils.common import awrite + + +async def save_framework( + dir_data: str, trd: Optional[str] = None, output_dir: Optional[Union[str, Path]] = None +) -> List[str]: + """ + Saves framework data to files based on input JSON data and optionally saves a TRD (technical requirements document). + + Args: + dir_data (str): JSON data in string format enclosed in triple backticks ("```json" "...data..." "```"). + trd (str, optional): Technical requirements document content to be saved. Defaults to None. + output_dir (Union[str, Path], optional): Output directory path where files will be saved. If not provided, + a default directory is created based on the current timestamp and a random UUID suffix. + + Returns: + List[str]: List of file paths where data was saved. + + Raises: + Any exceptions raised during file writing operations. + + Notes: + - JSON data should be provided in the format "```json ...data... ```". + - The function ensures that paths and filenames are correctly formatted and creates necessary directories. + + Example: + ```python + dir_data = "```json\n[{\"path\": \"/folder\", \"filename\": \"file1.txt\", \"content\": \"Some content\"}]\n```" + trd = "Technical requirements document content." + output_dir = '/path/to/output/dir' + saved_files = await save_framework(dir_data, trd, output_dir) + print(saved_files) + ``` + """ + output_dir = ( + Path(output_dir) + if output_dir + else config.workspace.path / (datetime.now().strftime("%Y%m%d%H%M%ST") + uuid.uuid4().hex[0:8]) + ) + output_dir.mkdir(parents=True, exist_ok=True) + + json_data = dir_data.removeprefix("```json").removesuffix("```") + items = json.loads(json_data) + + class Data(BaseModel): + path: str + filename: str + content: str + + if trd: + pathname = output_dir / "TRD.md" + await awrite(filename=pathname, data=trd) + + files = [] + for i in items: + v = Data.model_validate(i) + if v.path and v.path[0] == "/": + v.path = "." + v.path + pathname = output_dir / v.path + pathname.mkdir(parents=True, exist_ok=True) + pathname = pathname / v.filename + await awrite(filename=pathname, data=v.content) + files.append(str(pathname)) + return files + + +__all__ = [WriteFramework, EvaluateFramework] diff --git a/metagpt/actions/requirement_analysis/framework/evaluate_framework.py b/metagpt/actions/requirement_analysis/framework/evaluate_framework.py new file mode 100644 index 0000000000000000000000000000000000000000..2f923965836385fc95b424625a1e0ecbe0f60d6d --- /dev/null +++ b/metagpt/actions/requirement_analysis/framework/evaluate_framework.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : evaluate_framework.py +@Desc : The implementation of Chapter 2.1.8 of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" + +from metagpt.actions.requirement_analysis import EvaluateAction, EvaluationData +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import to_markdown_code_block + + +@register_tool(include_functions=["run"]) +class EvaluateFramework(EvaluateAction): + """WriteFramework deal with the following situations: + 1. Given a TRD and the software framework based on the TRD, evaluate the quality of the software framework. + """ + + async def run( + self, + *, + use_case_actors: str, + trd: str, + acknowledge: str, + legacy_output: str, + additional_technical_requirements: str, + ) -> EvaluationData: + """ + Run the evaluation of the software framework based on the provided TRD and related parameters. + + Args: + use_case_actors (str): A description of the actors involved in the use case. + trd (str): The Technical Requirements Document (TRD) that outlines the requirements for the software framework. + acknowledge (str): External acknowledgments or acknowledgments information related to the framework. + legacy_output (str): The previous versions of software framework returned by `WriteFramework`. + additional_technical_requirements (str): Additional technical requirements that need to be considered during evaluation. + + Returns: + EvaluationData: An object containing the results of the evaluation. + + Example: + >>> evaluate_framework = EvaluateFramework() + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> trd = "## TRD\\n..." + >>> acknowledge = "## Interfaces\\n..." + >>> framework = '{"path":"balabala", "filename":"...", ...' + >>> constraint = "Using Java language, ..." + >>> evaluation = await evaluate_framework.run( + >>> use_case_actors=use_case_actors, + >>> trd=trd, + >>> acknowledge=acknowledge, + >>> legacy_output=framework, + >>> additional_technical_requirements=constraint, + >>> ) + >>> is_pass = evaluation.is_pass + >>> print(is_pass) + True + >>> evaluation_conclusion = evaluation.conclusion + >>> print(evaluation_conclusion) + Balabala... + """ + prompt = PROMPT.format( + use_case_actors=use_case_actors, + trd=to_markdown_code_block(val=trd), + acknowledge=to_markdown_code_block(val=acknowledge), + legacy_output=to_markdown_code_block(val=legacy_output), + additional_technical_requirements=to_markdown_code_block(val=additional_technical_requirements), + ) + return await self._vote(prompt) + + +PROMPT = """ +## Actor, System, External System +{use_case_actors} + +## Legacy TRD +{trd} + +## Acknowledge +{acknowledge} + +## Legacy Outputs +{legacy_output} + +## Additional Technical Requirements +{additional_technical_requirements} + +--- +You are a tool that evaluates the quality of framework code based on the TRD content; +You need to refer to the content of the "Legacy TRD" section to check for any errors or omissions in the framework code found in "Legacy Outputs"; +The content of "Actor, System, External System" provides an explanation of actors and systems that appear in UML Use Case diagram; +Information about the external system missing from the "Legacy TRD" can be found in the "Acknowledge" section; +Which interfaces defined in "Acknowledge" are used in the "Legacy TRD"? +Do not implement the interface in "Acknowledge" section until it is used in "Legacy TRD", you can check whether they are the same interface by looking at its ID or url; +Parts not mentioned in the "Legacy TRD" will be handled by other TRDs, therefore, processes not present in the "Legacy TRD" are considered ready; +"Additional Technical Requirements" specifies the additional technical requirements that the generated software framework code must meet; +Do the parameters of the interface of the external system used in the code comply with it's specifications in 'Acknowledge'? +Is there a lack of necessary configuration files? +Return a markdown JSON object with: +- an "issues" key containing a string list of natural text about the issues that need to addressed, found in the "Legacy Outputs" if any exits, each issue found must provide a detailed description and include reasons; +- a "conclusion" key containing the evaluation conclusion; +- a "misalignment" key containing the judgement detail of the natural text string list about the misalignment with "Legacy TRD"; +- a "is_pass" key containing a true boolean value if there is not any issue in the "Legacy Outputs"; +""" diff --git a/metagpt/actions/requirement_analysis/framework/write_framework.py b/metagpt/actions/requirement_analysis/framework/write_framework.py new file mode 100644 index 0000000000000000000000000000000000000000..2aa03f4473dc02161fd3481238fc83e67003a082 --- /dev/null +++ b/metagpt/actions/requirement_analysis/framework/write_framework.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : write_framework.py +@Desc : The implementation of Chapter 2.1.8 of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +import json + +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import general_after_log, to_markdown_code_block + + +@register_tool(include_functions=["run"]) +class WriteFramework(Action): + """WriteFramework deal with the following situations: + 1. Given a TRD, write out the software framework. + """ + + async def run( + self, + *, + use_case_actors: str, + trd: str, + acknowledge: str, + legacy_output: str, + evaluation_conclusion: str, + additional_technical_requirements: str, + ) -> str: + """ + Run the action to generate a software framework based on the provided TRD and related information. + + Args: + use_case_actors (str): Description of the use case actors involved. + trd (str): Technical Requirements Document detailing the requirements. + acknowledge (str): External acknowledgements or acknowledgements required. + legacy_output (str): Previous version of the software framework returned by `WriteFramework.run`. + evaluation_conclusion (str): Conclusion from the evaluation of the requirements. + additional_technical_requirements (str): Any additional technical requirements. + + Returns: + str: The generated software framework as a string. + + Example: + >>> write_framework = WriteFramework() + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> trd = "## TRD\\n..." + >>> acknowledge = "## Interfaces\\n..." + >>> legacy_output = '{"path":"balabala", "filename":"...", ...' + >>> evaluation_conclusion = "Balabala..." + >>> constraint = "Using Java language, ..." + >>> framework = await write_framework.run( + >>> use_case_actors=use_case_actors, + >>> trd=trd, + >>> acknowledge=acknowledge, + >>> legacy_output=framework, + >>> evaluation_conclusion=evaluation_conclusion, + >>> additional_technical_requirements=constraint, + >>> ) + >>> print(framework) + {"path":"balabala", "filename":"...", ... + + """ + acknowledge = await self._extract_external_interfaces(trd=trd, knowledge=acknowledge) + prompt = PROMPT.format( + use_case_actors=use_case_actors, + trd=to_markdown_code_block(val=trd), + acknowledge=to_markdown_code_block(val=acknowledge), + legacy_output=to_markdown_code_block(val=legacy_output), + evaluation_conclusion=evaluation_conclusion, + additional_technical_requirements=to_markdown_code_block(val=additional_technical_requirements), + ) + return await self._write(prompt) + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _write(self, prompt: str) -> str: + rsp = await self.llm.aask(prompt) + # Do not use `CodeParser` here. + tags = ["```json", "```"] + bix = rsp.find(tags[0]) + eix = rsp.rfind(tags[1]) + if bix >= 0: + rsp = rsp[bix : eix + len(tags[1])] + json_data = rsp.removeprefix("```json").removesuffix("```") + json.loads(json_data) # validate + return json_data + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _extract_external_interfaces(self, trd: str, knowledge: str) -> str: + prompt = f"## TRD\n{to_markdown_code_block(val=trd)}\n\n## Knowledge\n{to_markdown_code_block(val=knowledge)}\n" + rsp = await self.llm.aask( + prompt, + system_msgs=[ + "You are a tool that removes impurities from articles; you can remove irrelevant content from articles.", + 'Identify which interfaces are used in "TRD"? Remove the relevant content of the interfaces NOT used in "TRD" from "Knowledge" and return the simplified content of "Knowledge".', + ], + ) + return rsp + + +PROMPT = """ +## Actor, System, External System +{use_case_actors} + +## TRD +{trd} + +## Acknowledge +{acknowledge} + +## Legacy Outputs +{legacy_output} + +## Evaluation Conclusion +{evaluation_conclusion} + +## Additional Technical Requirements +{additional_technical_requirements} + +--- +You are a tool that generates software framework code based on TRD. +The content of "Actor, System, External System" provides an explanation of actors and systems that appear in UML Use Case diagram; +The descriptions of the interfaces of the external system used in the "TRD" can be found in the "Acknowledge" section; Do not implement the interface of the external system in "Acknowledge" section until it is used in "TRD"; +"Legacy Outputs" contains the software framework code generated by you last time, which you can improve by addressing the issues raised in "Evaluation Conclusion"; +"Additional Technical Requirements" specifies the additional technical requirements that the generated software framework code must meet; +Develop the software framework based on the "TRD", the output files should include: +- The `README.md` file should include: + - The folder structure diagram of the entire project; + - Correspondence between classes, interfaces, and functions with the content in the "TRD" section; + - Prerequisites if necessary; + - Installation if necessary; + - Configuration if necessary; + - Usage if necessary; +- The `CLASS.md` file should include the class diagram in PlantUML format based on the "TRD"; +- The `SEQUENCE.md` file should include the sequence diagram in PlantUML format based on the "TRD"; +- The source code files that implement the "TRD" and "Additional Technical Requirements"; do not add comments to source code files; +- The configuration files that required by the source code files, "TRD" and "Additional Technical Requirements"; + +Return a markdown JSON object list, each object containing: +- a "path" key with a value specifying its path; +- a "filename" key with a value specifying its file name; +- a "content" key with a value containing its file content; +""" diff --git a/metagpt/actions/requirement_analysis/requirement/__init__.py b/metagpt/actions/requirement_analysis/requirement/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/metagpt/actions/requirement_analysis/requirement/pic2txt.py b/metagpt/actions/requirement_analysis/requirement/pic2txt.py new file mode 100644 index 0000000000000000000000000000000000000000..b8f236dacb1650bdf221688b27b1248e223695ce --- /dev/null +++ b/metagpt/actions/requirement_analysis/requirement/pic2txt.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/27 +@Author : mashenquan +@File : pic2txt.py +""" +import json +from pathlib import Path +from typing import List + +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import encode_image, general_after_log, to_markdown_code_block + + +@register_tool(include_functions=["run"]) +class Pic2Txt(Action): + """Pic2Txt deal with the following situations: + Given some pictures depicting user requirements alongside contextual description, write out the intact textual user requirements. + """ + + async def run( + self, + *, + image_paths: List[str], + textual_user_requirement: str = "", + legacy_output: str = "", + evaluation_conclusion: str = "", + additional_technical_requirements: str = "", + ) -> str: + """ + Given some pictures depicting user requirements alongside contextual description, write out the intact textual user requirements + + Args: + image_paths (List[str]): A list of file paths to the input image(s) depicting user requirements. + textual_user_requirement (str, optional): Textual user requirement that alongside the given images, if any. + legacy_output (str, optional): The intact textual user requirements generated by you last time, if any. + evaluation_conclusion (str, optional): Conclusion or evaluation based on the processed requirements. + additional_technical_requirements (str, optional): Any supplementary technical details relevant to the process. + + Returns: + str: Textual representation of user requirements extracted from the provided image(s). + + Raises: + ValueError: If image_paths list is empty. + OSError: If there is an issue accessing or reading the image files. + + Example: + >>> images = ["requirements/pic/1.png", "requirements/pic/2.png", "requirements/pic/3.png"] + >>> textual_user_requirements = "User requirement paragraph 1 ..., ![](1.png). paragraph 2...![](2.png)..." + >>> action = Pic2Txt() + >>> intact_textual_user_requirements = await action.run(image_paths=images, textual_user_requirement=textual_user_requirements) + >>> print(intact_textual_user_requirements) + "User requirement paragraph 1 ..., ![...](1.png) This picture describes... paragraph 2...![...](2.png)..." + + """ + descriptions = {} + for i in image_paths: + filename = Path(i) + base64_image = encode_image(filename) + rsp = await self._pic2txt( + "Generate a paragraph of text based on the content of the image, the language of the text is consistent with the language in the image.", + base64_image=base64_image, + ) + descriptions[filename.name] = rsp + + prompt = PROMPT.format( + textual_user_requirement=textual_user_requirement, + acknowledge=to_markdown_code_block(val=json.dumps(descriptions), type_="json"), + legacy_output=to_markdown_code_block(val=legacy_output), + evaluation_conclusion=evaluation_conclusion, + additional_technical_requirements=to_markdown_code_block(val=additional_technical_requirements), + ) + return await self._write(prompt) + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _write(self, prompt: str) -> str: + rsp = await self.llm.aask(prompt) + return rsp + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _pic2txt(self, prompt: str, base64_image: str) -> str: + rsp = await self.llm.aask(prompt, images=base64_image) + return rsp + + +PROMPT = """ +## Textual User Requirements +{textual_user_requirement} + +## Acknowledge +{acknowledge} + +## Legacy Outputs +{legacy_output} + +## Evaluation Conclusion +{evaluation_conclusion} + +## Additional Technical Requirements +{additional_technical_requirements} + +--- +You are a tool that generates an intact textual user requirements given a few of textual fragments of user requirements and some fragments of UI pictures. +The content of "Textual User Requirements" provides a few of textual fragments of user requirements; +The content of "Acknowledge" provides the descriptions of pictures used in "Textual User Requirements"; +"Legacy Outputs" contains the intact textual user requirements generated by you last time, which you can improve by addressing the issues raised in "Evaluation Conclusion"; +"Additional Technical Requirements" specifies the additional technical requirements that the generated textual user requirements must meet; +You need to merge the text content of the corresponding image in the "Acknowledge" into the "Textual User Requirements" to generate a complete, natural and coherent description of the user requirements; +Return the intact textual user requirements according to the given fragments of the user requirement of "Textual User Requirements" and the UI pictures; +""" diff --git a/metagpt/actions/requirement_analysis/trd/__init__.py b/metagpt/actions/requirement_analysis/trd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4603532c423b5ccb7ac02e4c21f46b5305998fcf --- /dev/null +++ b/metagpt/actions/requirement_analysis/trd/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : __init__.py +@Desc : The implementation of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" + + +from metagpt.actions.requirement_analysis.trd.detect_interaction import DetectInteraction +from metagpt.actions.requirement_analysis.trd.evaluate_trd import EvaluateTRD +from metagpt.actions.requirement_analysis.trd.write_trd import WriteTRD +from metagpt.actions.requirement_analysis.trd.compress_external_interfaces import CompressExternalInterfaces + +__all__ = [CompressExternalInterfaces, DetectInteraction, WriteTRD, EvaluateTRD] diff --git a/metagpt/actions/requirement_analysis/trd/compress_external_interfaces.py b/metagpt/actions/requirement_analysis/trd/compress_external_interfaces.py new file mode 100644 index 0000000000000000000000000000000000000000..abaf6fc307cfbc4c964a46be50060173032085b7 --- /dev/null +++ b/metagpt/actions/requirement_analysis/trd/compress_external_interfaces.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : compress_external_interfaces.py +@Desc : The implementation of Chapter 2.1.5 of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import general_after_log + + +@register_tool(include_functions=["run"]) +class CompressExternalInterfaces(Action): + """CompressExternalInterfaces deal with the following situations: + 1. Given a natural text of acknowledgement, it extracts and compresses the information about external system interfaces. + """ + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def run( + self, + *, + acknowledge: str, + ) -> str: + """ + Extracts and compresses information about external system interfaces from a given acknowledgement text. + + Args: + acknowledge (str): A natural text of acknowledgement containing details about external system interfaces. + + Returns: + str: A compressed version of the information about external system interfaces. + + Example: + >>> compress_acknowledge = CompressExternalInterfaces() + >>> acknowledge = "## Interfaces\\n..." + >>> available_external_interfaces = await compress_acknowledge.run(acknowledge=acknowledge) + >>> print(available_external_interfaces) + ```json\n[\n{\n"id": 1,\n"inputs": {... + """ + return await self.llm.aask( + msg=acknowledge, + system_msgs=[ + "Extracts and compresses the information about external system interfaces.", + "Return a markdown JSON list of objects, each object containing:\n" + '- an "id" key containing the interface id;\n' + '- an "inputs" key containing a dict of input parameters that consist of name and description pairs;\n' + '- an "outputs" key containing a dict of returns that consist of name and description pairs;\n', + ], + ) diff --git a/metagpt/actions/requirement_analysis/trd/detect_interaction.py b/metagpt/actions/requirement_analysis/trd/detect_interaction.py new file mode 100644 index 0000000000000000000000000000000000000000..b7719319417bcd2993f262224df72abac27cd143 --- /dev/null +++ b/metagpt/actions/requirement_analysis/trd/detect_interaction.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : detect_interaction.py +@Desc : The implementation of Chapter 2.1.6 of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import general_after_log, to_markdown_code_block + + +@register_tool(include_functions=["run"]) +class DetectInteraction(Action): + """DetectInteraction deal with the following situations: + 1. Given a natural text of user requirements, it identifies the interaction events and the participants of those interactions from the original text. + """ + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def run( + self, + *, + user_requirements: str, + use_case_actors: str, + legacy_interaction_events: str, + evaluation_conclusion: str, + ) -> str: + """ + Identifies interaction events and participants from the user requirements. + + Args: + user_requirements (str): A natural language text detailing the user's requirements. + use_case_actors (str): A description of the actors involved in the use case. + legacy_interaction_events (str): The previous version of the interaction events identified by you. + evaluation_conclusion (str): The external evaluation conclusions regarding the interactions events identified by you. + + Returns: + str: A string summarizing the identified interaction events and their participants. + + Example: + >>> detect_interaction = DetectInteraction() + >>> user_requirements = "User requirements 1. ..." + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> previous_version_interaction_events = "['interaction ...', ...]" + >>> evaluation_conclusion = "Issues: ..." + >>> interaction_events = await detect_interaction.run( + >>> user_requirements=user_requirements, + >>> use_case_actors=use_case_actors, + >>> legacy_interaction_events=previous_version_interaction_events, + >>> evaluation_conclusion=evaluation_conclusion, + >>> ) + >>> print(interaction_events) + "['interaction ...', ...]" + """ + msg = PROMPT.format( + use_case_actors=use_case_actors, + original_user_requirements=to_markdown_code_block(val=user_requirements), + previous_version_of_interaction_events=legacy_interaction_events, + the_evaluation_conclusion_of_previous_version_of_trd=evaluation_conclusion, + ) + return await self.llm.aask(msg=msg) + + +PROMPT = """ +## Actor, System, External System +{use_case_actors} + +## User Requirements +{original_user_requirements} + +## Legacy Interaction Events +{previous_version_of_interaction_events} + +## Evaluation Conclusion +{the_evaluation_conclusion_of_previous_version_of_trd} + +--- +You are a tool for capturing interaction events. +"Actor, System, External System" provides the possible participants of the interaction event; +"Legacy Interaction Events" is the contents of the interaction events that you output earlier; +Some descriptions in the "Evaluation Conclusion" relate to the content of "User Requirements", and these descriptions in the "Evaluation Conclusion" address some issues regarding the content of "Legacy Interaction Events"; +You need to capture the interaction events occurring in the description within the content of "User Requirements" word-for-word, including: +1. Who is interacting with whom. An interaction event has a maximum of 2 participants. If there are multiple participants, it indicates that multiple events are combined into one event and should be further split; +2. When an interaction event occurs, who is the initiator? What data did the initiator enter? +3. What data does the interaction event ultimately return according to the "User Requirements"? + +You can check the data flow described in the "User Requirements" to see if there are any missing interaction events; +Return a markdown JSON object list, each object of the list containing: +- a "name" key containing the name of the interaction event; +- a "participants" key containing a string list of the names of the two participants; +- a "initiator" key containing the name of the participant who initiate the interaction; +- a "input" key containing a natural text description about the input data; +""" diff --git a/metagpt/actions/requirement_analysis/trd/evaluate_trd.py b/metagpt/actions/requirement_analysis/trd/evaluate_trd.py new file mode 100644 index 0000000000000000000000000000000000000000..5c256ed075bf29957287f2ae3ac469eb5f0d80cb --- /dev/null +++ b/metagpt/actions/requirement_analysis/trd/evaluate_trd.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : evaluate_trd.py +@Desc : The implementation of Chapter 2.1.6~2.1.7 of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" + +from metagpt.actions.requirement_analysis import EvaluateAction, EvaluationData +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import to_markdown_code_block + + +@register_tool(include_functions=["run"]) +class EvaluateTRD(EvaluateAction): + """EvaluateTRD deal with the following situations: + 1. Given a TRD, evaluates the quality and returns a conclusion. + """ + + async def run( + self, + *, + user_requirements: str, + use_case_actors: str, + trd: str, + interaction_events: str, + legacy_user_requirements_interaction_events: str = "", + ) -> EvaluationData: + """ + Evaluates the given TRD based on user requirements, use case actors, interaction events, and optionally external legacy interaction events. + + Args: + user_requirements (str): The requirements provided by the user. + use_case_actors (str): The actors involved in the use case. + trd (str): The TRD (Technical Requirements Document) to be evaluated. + interaction_events (str): The interaction events related to the user requirements and the TRD. + legacy_user_requirements_interaction_events (str, optional): External legacy interaction events tied to the user requirements. Defaults to an empty string. + + Returns: + EvaluationData: The conclusion of the TRD evaluation. + + Example: + >>> evaluate_trd = EvaluateTRD() + >>> user_requirements = "User requirements 1. ..." + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> trd = "## TRD\\n..." + >>> interaction_events = "['interaction ...', ...]" + >>> evaluation_conclusion = "Issues: ..." + >>> legacy_user_requirements_interaction_events = ["user requirements 1. ...", ...] + >>> evaluation = await evaluate_trd.run( + >>> user_requirements=user_requirements, + >>> use_case_actors=use_case_actors, + >>> trd=trd, + >>> interaction_events=interaction_events, + >>> legacy_user_requirements_interaction_events=str(legacy_user_requirements_interaction_events), + >>> ) + >>> is_pass = evaluation.is_pass + >>> print(is_pass) + True + >>> evaluation_conclusion = evaluation.conclusion + >>> print(evaluation_conclusion) + ## Conclustion\n balabalabala... + + """ + prompt = PROMPT.format( + use_case_actors=use_case_actors, + user_requirements=to_markdown_code_block(val=user_requirements), + trd=to_markdown_code_block(val=trd), + legacy_user_requirements_interaction_events=legacy_user_requirements_interaction_events, + interaction_events=interaction_events, + ) + return await self._vote(prompt) + + +PROMPT = """ +## Actor, System, External System +{use_case_actors} + +## User Requirements +{user_requirements} + +## TRD Design +{trd} + +## External Interaction Events +{legacy_user_requirements_interaction_events} + +## Interaction Events +{legacy_user_requirements_interaction_events} +{interaction_events} + +--- +You are a tool to evaluate the TRD design. +"Actor, System, External System" provides the all possible participants in interaction events; +"User Requirements" provides the original requirements description, any parts not mentioned in this description will be handled by other modules, so do not fabricate requirements; +"External Interaction Events" is provided by an external module for your use, its content is also referred to "Interaction Events" section; The content in "External Interaction Events" can be determined to be problem-free; +"External Interaction Events" provides some identified interaction events and the interacting participants based on the part of the content of the "User Requirements"; +"Interaction Events" provides some identified interaction events and the interacting participants based on the content of the "User Requirements"; +"TRD Design" provides a comprehensive design of the implementation steps for the original requirements, incorporating the interaction events from "Interaction Events" and adding additional steps to connect the complete upstream and downstream data flows; +In order to integrate the full upstream and downstream data flow, the "TRD Design" allows for the inclusion of steps that do not appear in the original requirements description, but do not conflict with those explicitly described in the "User Requirements"; +Which interactions from "Interaction Events" correspond to which steps in "TRD Design"? Please provide reasons. +Which aspects of "TRD Design" and "Interaction Events" do not align with the descriptions in "User Requirements"? Please provide detailed descriptions and reasons. +If the descriptions in "User Requirements" are divided into multiple steps in "TRD Design" and "Interaction Events," it can be considered compliant with the descriptions in "User Requirements" as long as it does not conflict with them; +There is a possibility of missing details in the descriptions of "User Requirements". Any additional steps in "TRD Design" and "Interaction Events" are considered compliant with "User Requirements" as long as they do not conflict with the descriptions provided in "User Requirements"; +If there are interaction events with external systems in "TRD Design", you must explicitly specify the ID of the external interface to use for the interaction events, the input and output parameters of the used external interface must explictly match the input and output of the interaction event; +Does the sequence of steps in "Interaction Events" cause performance or cost issues? Please provide detailed descriptions and reasons; +If each step of "TRD Design" has input data, its input data is provided either by the output of the previous steps or by participants of "Actor, System, External System", and there should be no passive data; +Return a markdown JSON object with: +- an "issues" key containing a string list of natural text about the issues that need to be addressed, found in the "TRD Design" if any exist, each issue found must provide a detailed description and include reasons; +- a "conclusion" key containing the evaluation conclusion; +- a "correspondence_between" key containing the judgement detail of the natural text string list about the correspondence between "Interaction Events" and "TRD Design" steps; +- a "misalignment" key containing the judgement detail of the natural text string list about the misalignment with "User Requirements"; +- a "is_pass" key containing a true boolean value if there is not any issue in the "TRD Design"; +""" diff --git a/metagpt/actions/requirement_analysis/trd/write_trd.py b/metagpt/actions/requirement_analysis/trd/write_trd.py new file mode 100644 index 0000000000000000000000000000000000000000..bad93ea766e1157330b5fc752c6bc13412b8ec90 --- /dev/null +++ b/metagpt/actions/requirement_analysis/trd/write_trd.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/6/13 +@Author : mashenquan +@File : write_trd.py +@Desc : The implementation of Chapter 2.1.6~2.1.7 of RFC243. https://deepwisdom.feishu.cn/wiki/QobGwPkImijoyukBUKHcrYetnBb +""" +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import general_after_log, to_markdown_code_block + + +@register_tool(include_functions=["run"]) +class WriteTRD(Action): + """WriteTRD deal with the following situations: + 1. Given some new user requirements, write out a new TRD(Technical Requirements Document). + 2. Given some incremental user requirements, update the legacy TRD. + """ + + async def run( + self, + *, + user_requirements: str = "", + use_case_actors: str, + available_external_interfaces: str, + evaluation_conclusion: str = "", + interaction_events: str, + previous_version_trd: str = "", + legacy_user_requirements: str = "", + legacy_user_requirements_trd: str = "", + legacy_user_requirements_interaction_events: str = "", + ) -> str: + """ + Handles the writing or updating of a Technical Requirements Document (TRD) based on user requirements. + + Args: + user_requirements (str): The new/incremental user requirements. + use_case_actors (str): Description of the actors involved in the use case. + available_external_interfaces (str): List of available external interfaces. + evaluation_conclusion (str, optional): The conclusion of the evaluation of the TRD written by you. Defaults to an empty string. + interaction_events (str): The interaction events related to the user requirements that you are handling. + previous_version_trd (str, optional): The previous version of the TRD written by you, for updating. + legacy_user_requirements (str, optional): Existing user requirements handled by an external object for your use. Defaults to an empty string. + legacy_user_requirements_trd (str, optional): The TRD associated with the existing user requirements handled by an external object for your use. Defaults to an empty string. + legacy_user_requirements_interaction_events (str, optional): Interaction events related to the existing user requirements handled by an external object for your use. Defaults to an empty string. + + Returns: + str: The newly created or updated TRD written by you. + + Example: + >>> # Given a new user requirements, write out a new TRD. + >>> user_requirements = "Write a 'snake game' TRD." + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> available_external_interfaces = "The available external interfaces returned by `CompressExternalInterfaces.run` are ..." + >>> previous_version_trd = "TRD ..." # The last version of the TRD written out if there is. + >>> evaluation_conclusion = "Conclusion ..." # The conclusion returned by `EvaluateTRD.run` if there is. + >>> interaction_events = "Interaction ..." # The interaction events returned by `DetectInteraction.run`. + >>> write_trd = WriteTRD() + >>> new_version_trd = await write_trd.run( + >>> user_requirements=user_requirements, + >>> use_case_actors=use_case_actors, + >>> available_external_interfaces=available_external_interfaces, + >>> evaluation_conclusion=evaluation_conclusion, + >>> interaction_events=interaction_events, + >>> previous_version_trd=previous_version_trd, + >>> ) + >>> print(new_version_trd) + ## Technical Requirements Document\n ... + + >>> # Given an incremental requirements, update the legacy TRD. + >>> legacy_user_requirements = ["User requirements 1. ...", "User requirements 2. ...", ...] + >>> legacy_user_requirements_trd = "## Technical Requirements Document\\n ..." # The TRD before integrating more user requirements. + >>> legacy_user_requirements_interaction_events = ["The interaction events list of user requirements 1 ...", "The interaction events list of user requiremnts 2 ...", ...] + >>> use_case_actors = "- Actor: game player;\\n- System: snake game; \\n- External System: game center;" + >>> available_external_interfaces = "The available external interfaces returned by `CompressExternalInterfaces.run` are ..." + >>> increment_requirements = "The incremental user requirements are ..." + >>> evaluation_conclusion = "Conclusion ..." # The conclusion returned by `EvaluateTRD.run` if there is. + >>> previous_version_trd = "TRD ..." # The last version of the TRD written out if there is. + >>> write_trd = WriteTRD() + >>> new_version_trd = await write_trd.run( + >>> user_requirements=increment_requirements, + >>> use_case_actors=use_case_actors, + >>> available_external_interfaces=available_external_interfaces, + >>> evaluation_conclusion=evaluation_conclusion, + >>> interaction_events=interaction_events, + >>> previous_version_trd=previous_version_trd, + >>> legacy_user_requirements=str(legacy_user_requirements), + >>> legacy_user_requirements_trd=legacy_user_requirements_trd, + >>> legacy_user_requirements_interaction_events=str(legacy_user_requirements_interaction_events), + >>> ) + >>> print(new_version_trd) + ## Technical Requirements Document\n ... + """ + if legacy_user_requirements: + return await self._write_incremental_trd( + use_case_actors=use_case_actors, + legacy_user_requirements=legacy_user_requirements, + available_external_interfaces=available_external_interfaces, + legacy_user_requirements_trd=legacy_user_requirements_trd, + legacy_user_requirements_interaction_events=legacy_user_requirements_interaction_events, + incremental_user_requirements=user_requirements, + previous_version_trd=previous_version_trd, + evaluation_conclusion=evaluation_conclusion, + incremental_user_requirements_interaction_events=interaction_events, + ) + return await self._write_new_trd( + use_case_actors=use_case_actors, + original_user_requirement=user_requirements, + available_external_interfaces=available_external_interfaces, + legacy_trd=previous_version_trd, + evaluation_conclusion=evaluation_conclusion, + interaction_events=interaction_events, + ) + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _write_new_trd( + self, + *, + use_case_actors: str, + original_user_requirement: str, + available_external_interfaces: str, + legacy_trd: str, + evaluation_conclusion: str, + interaction_events: str, + ) -> str: + prompt = NEW_PROMPT.format( + use_case_actors=use_case_actors, + original_user_requirement=to_markdown_code_block(val=original_user_requirement), + available_external_interfaces=available_external_interfaces, + legacy_trd=to_markdown_code_block(val=legacy_trd), + evaluation_conclusion=evaluation_conclusion, + interaction_events=interaction_events, + ) + return await self.llm.aask(prompt) + + @retry( + wait=wait_random_exponential(min=1, max=20), + stop=stop_after_attempt(6), + after=general_after_log(logger), + ) + async def _write_incremental_trd( + self, + *, + use_case_actors: str, + legacy_user_requirements: str, + available_external_interfaces: str, + legacy_user_requirements_trd: str, + legacy_user_requirements_interaction_events: str, + incremental_user_requirements: str, + previous_version_trd: str, + evaluation_conclusion: str, + incremental_user_requirements_interaction_events: str, + ): + prompt = INCREMENTAL_PROMPT.format( + use_case_actors=use_case_actors, + legacy_user_requirements=to_markdown_code_block(val=legacy_user_requirements), + available_external_interfaces=available_external_interfaces, + legacy_user_requirements_trd=to_markdown_code_block(val=legacy_user_requirements_trd), + legacy_user_requirements_interaction_events=legacy_user_requirements_interaction_events, + incremental_user_requirements=to_markdown_code_block(val=incremental_user_requirements), + previous_version_trd=to_markdown_code_block(val=previous_version_trd), + evaluation_conclusion=evaluation_conclusion, + incremental_user_requirements_interaction_events=incremental_user_requirements_interaction_events, + ) + return await self.llm.aask(prompt) + + +NEW_PROMPT = """ +## Actor, System, External System +{use_case_actors} + +## User Requirements +{original_user_requirement} + +## Available External Interfaces +{available_external_interfaces} + +## Legacy TRD +{legacy_trd} + +## Evaluation Conclusion +{evaluation_conclusion} + +## Interaction Events +{interaction_events} + +--- +You are a TRD generator. +The content of "Actor, System, External System" provides an explanation of actors and systems that appear in UML Use Case diagram; +The content of "Available External Interfaces" provides the candidate steps, along with the inputs and outputs of each step; +"User Requirements" provides the original requirements description, any parts not mentioned in this description will be handled by other modules, so do not fabricate requirements; +"Legacy TRD" provides the old version of the TRD based on the "User Requirements" and can serve as a reference for the new TRD; +"Evaluation Conclusion" provides a summary of the evaluation of the old TRD in the "Legacy TRD" and can serve as a reference for the new TRD; +"Interaction Events" provides some identified interaction events and the interacting participants based on the content of the "User Requirements"; +1. What inputs and outputs are described in the "User Requirements"? +2. How many steps are needed to achieve the inputs and outputs described in the "User Requirements"? Which actors from the "Actor, System, External System" section are involved in each step? What are the inputs and outputs of each step? Where is this output used, for example, as input for which interface or where it is required in the requirements, etc.? +3. Output a complete Technical Requirements Document (TRD): + 3.1. In the description, use the actor and system names defined in the "Actor, System, External System" section to describe the interactors; + 3.2. The content should include the original text of the requirements from "User Requirements"; + 3.3. In the TRD, each step can involve a maximum of two participants. If there are more than two participants, the step needs to be further split; + 3.4. In the TRD, each step must include detailed descriptions, inputs, outputs, participants, initiator, and the rationale for the step's existence. The rationale should reference the original text to justify it, such as specifying which interface requires the output of this step as parameters or where in the requirements this step is mandated, etc.; + 3.5. In the TRD, if you need to call interfaces of external systems, you must explicitly specify the interface IDs of the external systems you want to call; +""" + +INCREMENTAL_PROMPT = """ +## Actor, System, External System +{use_case_actors} + +## Legacy User Requirements +{legacy_user_requirements} + +## Available External Interfaces +{available_external_interfaces} + +## The TRD of Legacy User Requirements +{legacy_user_requirements_trd} + + +## The Interaction Events of Legacy User Requirements +{legacy_user_requirements_interaction_events} + +## Incremental Requirements +{incremental_user_requirements} + +## Legacy TRD +{previous_version_trd} + +## Evaluation Conclusion +{evaluation_conclusion} + +## Interaction Events +{incremental_user_requirements_interaction_events} + +--- +You are a TRD generator. +The content of "Actor, System, External System" provides an explanation of actors and systems that appear in UML Use Case diagram; +The content of "Available External Interfaces" provides the candidate steps, along with the inputs and outputs of each step; +"Legacy User Requirements" provides the original requirements description handled by other modules for your use; +"The TRD of Legacy User Requirements" is the TRD generated by other modules based on the "Legacy User Requirements" for your use; +"The Interaction Events of Legacy User Requirements" is the interaction events list generated by other modules based on the "Legacy User Requirements" for your use; +"Incremental Requirements" provides the original requirements description that you need to address, any parts not mentioned in this description will be handled by other modules, so do not fabricate requirements; +The requirements in "Legacy User Requirements" combined with the "Incremental Requirements" form a complete set of requirements, therefore, you need to add the TRD portion of the "Incremental Requirements" to "The TRD of Legacy User Requirements", the added content must not conflict with the original content of "The TRD of Legacy User Requirements"; +"Legacy TRD" provides the old version of the TRD you previously wrote based on the "Incremental Requirements" and can serve as a reference for the new TRD; +"Evaluation Conclusion" provides a summary of the evaluation of the old TRD you generated in the "Legacy TRD", and the identified issues can serve as a reference for the new TRD you create; +"Interaction Events" provides some identified interaction events and the interacting participants based on the content of the "Incremental Requirements"; +1. What inputs and outputs are described in the "Incremental Requirements"? +2. How many steps are needed to achieve the inputs and outputs described in the "Incremental Requirements"? Which actors from the "Actor, System, External System" section are involved in each step? What are the inputs and outputs of each step? Where is this output used, for example, as input for which interface or where it is required in the requirements, etc.? +3. Output a complete Technical Requirements Document (TRD): + 3.1. In the description, use the actor and system names defined in the "Actor, System, External System" section to describe the interactors; + 3.2. The content should include the original text of the requirements from "User Requirements"; + 3.3. In the TRD, each step can involve a maximum of two participants. If there are more than two participants, the step needs to be further split; + 3.4. In the TRD, each step must include detailed descriptions, inputs, outputs, participants, initiator, and the rationale for the step's existence. The rationale should reference the original text to justify it, such as specifying which interface requires the output of this step as parameters or where in the requirements this step is mandated, etc. + """ diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py new file mode 100644 index 0000000000000000000000000000000000000000..2665844856dc79a9d3ee74f86c7a2715347353ff --- /dev/null +++ b/metagpt/actions/research.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import asyncio +from datetime import datetime +from typing import Any, Callable, Coroutine, Optional, Union + +from pydantic import TypeAdapter, model_validator + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.tools.search_engine import SearchEngine +from metagpt.tools.web_browser_engine import WebBrowserEngine +from metagpt.utils.common import OutputParser +from metagpt.utils.parse_html import WebPage +from metagpt.utils.text import generate_prompt_chunk, reduce_message_length + +LANG_PROMPT = "Please respond in {language}." + +RESEARCH_BASE_SYSTEM = """You are an AI critical thinker research assistant. Your sole purpose is to write well \ +written, critically acclaimed, objective and structured reports on the given text.""" + +RESEARCH_TOPIC_SYSTEM = "You are an AI researcher assistant, and your research topic is:\n#TOPIC#\n{topic}" + +SEARCH_TOPIC_PROMPT = """Please provide up to 2 necessary keywords related to your research topic for Google search. \ +Your response must be in JSON format, for example: ["keyword1", "keyword2"].""" + +SUMMARIZE_SEARCH_PROMPT = """### Requirements +1. The keywords related to your research topic and the search results are shown in the "Search Result Information" section. +2. Provide up to {decomposition_nums} queries related to your research topic base on the search results. +3. Please respond in the following JSON format: ["query1", "query2", "query3", ...]. + +### Search Result Information +{search_results} +""" + +COLLECT_AND_RANKURLS_PROMPT = """### Topic +{topic} +### Query +{query} + +### The online search results +{results} + +### Requirements +Please remove irrelevant search results that are not related to the query or topic. +If the query is time-sensitive or specifies a certain time frame, please also remove search results that are outdated or outside the specified time frame. Notice that the current time is {time_stamp}. +Then, sort the remaining search results based on the link credibility. If two results have equal credibility, prioritize them based on the relevance. +Provide the ranked results' indices in JSON format, like [0, 1, 3, 4, ...], without including other words. +""" + +WEB_BROWSE_AND_SUMMARIZE_PROMPT = """### Requirements +1. Utilize the text in the "Reference Information" section to respond to the question "{query}". +2. If the question cannot be directly answered using the text, but the text is related to the research topic, please provide \ +a comprehensive summary of the text. +3. If the text is entirely unrelated to the research topic, please reply with a simple text "Not relevant." +4. Include all relevant factual information, numbers, statistics, etc., if available. + +### Reference Information +{content} +""" + + +CONDUCT_RESEARCH_PROMPT = """### Reference Information +{content} + +### Requirements +Please provide a detailed research report in response to the following topic: "{topic}", using the information provided \ +above. The report must meet the following requirements: + +- Focus on directly addressing the chosen topic. +- Ensure a well-structured and in-depth presentation, incorporating relevant facts and figures where available. +- Present data and findings in an intuitive manner, utilizing feature comparative tables, if applicable. +- The report should have a minimum word count of 2,000 and be formatted with Markdown syntax following APA style guidelines. +- Include all source URLs in APA format at the end of the report. +""" + + +class CollectLinks(Action): + """Action class to collect links from a search engine.""" + + name: str = "CollectLinks" + i_context: Optional[str] = None + desc: str = "Collect links from a search engine." + search_func: Optional[Any] = None + search_engine: Optional[SearchEngine] = None + rank_func: Optional[Callable[[list[str]], None]] = None + + @model_validator(mode="after") + def validate_engine_and_run_func(self): + if self.search_engine is None: + self.search_engine = SearchEngine.from_search_config(self.config.search, proxy=self.config.proxy) + return self + + async def run( + self, + topic: str, + decomposition_nums: int = 4, + url_per_query: int = 4, + system_text: str | None = None, + ) -> dict[str, list[str]]: + """Run the action to collect links. + + Args: + topic: The research topic. + decomposition_nums: The number of search questions to generate. + url_per_query: The number of URLs to collect per search question. + system_text: The system text. + + Returns: + A dictionary containing the search questions as keys and the collected URLs as values. + """ + system_text = system_text if system_text else RESEARCH_TOPIC_SYSTEM.format(topic=topic) + keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text]) + try: + keywords = OutputParser.extract_struct(keywords, list) + keywords = TypeAdapter(list[str]).validate_python(keywords) + except Exception as e: + logger.exception(f"fail to get keywords related to the research topic '{topic}' for {e}") + keywords = [topic] + results = await asyncio.gather(*(self.search_engine.run(i, as_string=False) for i in keywords)) + + def gen_msg(): + while True: + search_results = "\n".join( + f"#### Keyword: {i}\n Search Result: {j}\n" for (i, j) in zip(keywords, results) + ) + prompt = SUMMARIZE_SEARCH_PROMPT.format( + decomposition_nums=decomposition_nums, search_results=search_results + ) + yield prompt + remove = max(results, key=len) + remove.pop() + if len(remove) == 0: + break + + model_name = self.config.llm.model + prompt = reduce_message_length(gen_msg(), model_name, system_text, self.config.llm.max_token) + logger.debug(prompt) + queries = await self._aask(prompt, [system_text]) + try: + queries = OutputParser.extract_struct(queries, list) + queries = TypeAdapter(list[str]).validate_python(queries) + except Exception as e: + logger.exception(f"fail to break down the research question due to {e}") + queries = keywords + ret = {} + for query in queries: + ret[query] = await self._search_and_rank_urls(topic, query, url_per_query) + return ret + + async def _search_and_rank_urls( + self, topic: str, query: str, num_results: int = 4, max_num_results: int = None + ) -> list[str]: + """Search and rank URLs based on a query. + + Args: + topic: The research topic. + query: The search query. + num_results: The number of URLs to collect. + max_num_results: The max number of URLs to collect. + + Returns: + A list of ranked URLs. + """ + max_results = max_num_results or max(num_results * 2, 6) + results = await self._search_urls(query, max_results=max_results) + if len(results) == 0: + return [] + _results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results)) + time_stamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results, time_stamp=time_stamp) + logger.debug(prompt) + indices = await self._aask(prompt) + try: + indices = OutputParser.extract_struct(indices, list) + assert all(isinstance(i, int) for i in indices) + except Exception as e: + logger.exception(f"fail to rank results for {e}") + indices = list(range(max_results)) + results = [results[i] for i in indices] + if self.rank_func: + results = self.rank_func(results) + return [i["link"] for i in results[:num_results]] + + async def _search_urls(self, query: str, max_results: int) -> list[dict[str, str]]: + """Use search_engine to get urls. + + Returns: + e.g. [{"title": "...", "link": "...", "snippet", "..."}] + """ + + return await self.search_engine.run(query, max_results=max_results, as_string=False) + + +class WebBrowseAndSummarize(Action): + """Action class to explore the web and provide summaries of articles and webpages.""" + + name: str = "WebBrowseAndSummarize" + i_context: Optional[str] = None + desc: str = "Explore the web and provide summaries of articles and webpages." + browse_func: Union[Callable[[list[str]], None], None] = None + web_browser_engine: Optional[WebBrowserEngine] = None + + @model_validator(mode="after") + def validate_engine_and_run_func(self): + if self.web_browser_engine is None: + self.web_browser_engine = WebBrowserEngine.from_browser_config( + self.config.browser, + browse_func=self.browse_func, + proxy=self.config.proxy, + ) + return self + + async def run( + self, + url: str, + *urls: str, + query: str, + system_text: str = RESEARCH_BASE_SYSTEM, + use_concurrent_summarization: bool = False, + per_page_timeout: Optional[float] = None, + ) -> dict[str, str]: + """Run the action to browse the web and provide summaries. + + Args: + url: The main URL to browse. + urls: Additional URLs to browse. + query: The research question. + system_text: The system text. + use_concurrent_summarization: Whether to concurrently summarize the content of the webpage by LLM. + per_page_timeout: The maximum time for fetching a single page in seconds. + + Returns: + A dictionary containing the URLs as keys and their summaries as values. + """ + contents = await self._fetch_web_contents(url, *urls, per_page_timeout=per_page_timeout) + + all_urls = [url] + list(urls) + summarize_tasks = [self._summarize_content(content, query, system_text) for content in contents] + summaries = await self._execute_summarize_tasks(summarize_tasks, use_concurrent_summarization) + result = {url: summary for url, summary in zip(all_urls, summaries) if summary} + + return result + + async def _fetch_web_contents( + self, url: str, *urls: str, per_page_timeout: Optional[float] = None + ) -> list[WebPage]: + """Fetch web contents from given URLs.""" + + contents = await self.web_browser_engine.run(url, *urls, per_page_timeout=per_page_timeout) + + return [contents] if not urls else contents + + async def _summarize_content(self, page: WebPage, query: str, system_text: str) -> str: + """Summarize web content.""" + try: + prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}") + + content = page.inner_text + + if self._is_content_invalid(content): + logger.warning(f"Invalid content detected for URL {page.url}: {content[:10]}...") + return None + + chunk_summaries = [] + for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, 4096): + logger.debug(prompt) + summary = await self._aask(prompt, [system_text]) + if summary == "Not relevant.": + continue + chunk_summaries.append(summary) + + if not chunk_summaries: + return None + + if len(chunk_summaries) == 1: + return chunk_summaries[0] + + content = "\n".join(chunk_summaries) + prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content) + summary = await self._aask(prompt, [system_text]) + return summary + except Exception as e: + logger.error(f"Error summarizing content: {e}") + return None + + def _is_content_invalid(self, content: str) -> bool: + """Check if the content is invalid based on specific starting phrases.""" + + invalid_starts = ["Fail to load page", "Access Denied"] + + return any(content.strip().startswith(phrase) for phrase in invalid_starts) + + async def _execute_summarize_tasks(self, tasks: list[Coroutine[Any, Any, str]], use_concurrent: bool) -> list[str]: + """Execute summarize tasks either concurrently or sequentially.""" + + if use_concurrent: + return await asyncio.gather(*tasks) + + return [await task for task in tasks] + + +class ConductResearch(Action): + """Action class to conduct research and generate a research report.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def run( + self, + topic: str, + content: str, + system_text: str = RESEARCH_BASE_SYSTEM, + ) -> str: + """Run the action to conduct research and generate a research report. + + Args: + topic: The research topic. + content: The content for research. + system_text: The system text. + + Returns: + The generated research report. + """ + prompt = CONDUCT_RESEARCH_PROMPT.format(topic=topic, content=content) + logger.debug(prompt) + self.llm.auto_max_tokens = True + return await self._aask(prompt, [system_text]) + + +def get_research_system_text(topic: str, language: str): + """Get the system text for conducting research. + + Args: + topic: The research topic. + language: The language for the system text. + + Returns: + The system text for conducting research. + """ + return " ".join((RESEARCH_TOPIC_SYSTEM.format(topic=topic), LANG_PROMPT.format(language=language))) diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py new file mode 100644 index 0000000000000000000000000000000000000000..b2c33c19b22b70d8b0cfb31ce0370fd931a7fcaf --- /dev/null +++ b/metagpt/actions/run_code.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:46 +@Author : alexanderwu +@File : run_code.py +@Modified By: mashenquan, 2023/11/27. + 1. Mark the location of Console logs in the PROMPT_TEMPLATE with markdown code-block formatting to enhance + the understanding for the LLM. + 2. Fix bug: Add the "install dependency" operation. + 3. Encapsulate the input of RunCode into RunCodeContext and encapsulate the output of RunCode into + RunCodeResult to standardize and unify parameter passing between WriteCode, RunCode, and DebugError. + 4. According to section 2.2.3.5.7 of RFC 135, change the method of transferring file content + (code files, unit test files, log files) from using the message to using the file name. + 5. Merged the `Config` class of send18:dev branch to take over the set/get operations of the Environment + class. +""" +import subprocess +from pathlib import Path +from typing import Tuple + +from pydantic import Field + +from metagpt.actions.action import Action +from metagpt.logs import logger +from metagpt.schema import RunCodeContext, RunCodeResult +from metagpt.utils.exceptions import handle_exception + +PROMPT_TEMPLATE = """ +Role: You are a senior development and qa engineer, your role is summarize the code running result. +If the running result does not include an error, you should explicitly approve the result. +On the other hand, if the running result indicates some error, you should point out which part, the development code or the test code, produces the error, +and give specific instructions on fixing the errors. Here is the code info: +{context} +Now you should begin your analysis +--- +## instruction: +Please summarize the cause of the errors and give correction instruction +## File To Rewrite: +Determine the ONE file to rewrite in order to fix the error, for example, xyz.py, or test_xyz.py +## Status: +Determine if all of the code works fine, if so write PASS, else FAIL, +WRITE ONLY ONE WORD, PASS OR FAIL, IN THIS SECTION +## Send To: +Please write NoOne if there are no errors, Engineer if the errors are due to problematic development codes, else QaEngineer, +WRITE ONLY ONE WORD, NoOne OR Engineer OR QaEngineer, IN THIS SECTION. +--- +You should fill in necessary instruction, status, send to, and finally return all content between the --- segment line. +""" + +TEMPLATE_CONTEXT = """ +## Development Code File Name +{code_file_name} +## Development Code +```python +{code} +``` +## Test File Name +{test_file_name} +## Test Code +```python +{test_code} +``` +## Running Command +{command} +## Running Output +standard output: +```text +{outs} +``` +standard errors: +```text +{errs} +``` +""" + + +class RunCode(Action): + name: str = "RunCode" + i_context: RunCodeContext = Field(default_factory=RunCodeContext) + + @classmethod + async def run_text(cls, code) -> Tuple[str, str]: + try: + # We will document_store the result in this dictionary + namespace = {} + exec(code, namespace) + except Exception as e: + return "", str(e) + return namespace.get("result", ""), "" + + async def run_script(self, working_directory, additional_python_paths=[], command=[]) -> Tuple[str, str]: + working_directory = str(working_directory) + additional_python_paths = [str(path) for path in additional_python_paths] + + # Copy the current environment variables + env = self.context.new_environ() + + # Modify the PYTHONPATH environment variable + additional_python_paths = [working_directory] + additional_python_paths + additional_python_paths = ":".join(additional_python_paths) + env["PYTHONPATH"] = additional_python_paths + ":" + env.get("PYTHONPATH", "") + RunCode._install_dependencies(working_directory=working_directory, env=env) + + # Start the subprocess + process = subprocess.Popen( + command, cwd=working_directory, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + ) + logger.info(" ".join(command)) + + try: + # Wait for the process to complete, with a timeout + stdout, stderr = process.communicate(timeout=10) + except subprocess.TimeoutExpired: + logger.info("The command did not complete within the given timeout.") + process.kill() # Kill the process if it times out + stdout, stderr = process.communicate() + return stdout.decode("utf-8"), stderr.decode("utf-8") + + async def run(self, *args, **kwargs) -> RunCodeResult: + logger.info(f"Running {' '.join(self.i_context.command)}") + if self.i_context.mode == "script": + outs, errs = await self.run_script( + command=self.i_context.command, + working_directory=self.i_context.working_directory, + additional_python_paths=self.i_context.additional_python_paths, + ) + elif self.i_context.mode == "text": + outs, errs = await self.run_text(code=self.i_context.code) + + logger.info(f"{outs=}") + logger.info(f"{errs=}") + + context = TEMPLATE_CONTEXT.format( + code=self.i_context.code, + code_file_name=self.i_context.code_filename, + test_code=self.i_context.test_code, + test_file_name=self.i_context.test_filename, + command=" ".join(self.i_context.command), + outs=outs[:500], # outs might be long but they are not important, truncate them to avoid token overflow + errs=errs[:10000], # truncate errors to avoid token overflow + ) + + prompt = PROMPT_TEMPLATE.format(context=context) + rsp = await self._aask(prompt) + return RunCodeResult(summary=rsp, stdout=outs, stderr=errs) + + @staticmethod + @handle_exception(exception_type=subprocess.CalledProcessError) + def _install_via_subprocess(cmd, check, cwd, env): + return subprocess.run(cmd, check=check, cwd=cwd, env=env) + + @staticmethod + def _install_requirements(working_directory, env): + file_path = Path(working_directory) / "requirements.txt" + if not file_path.exists(): + return + if file_path.stat().st_size == 0: + return + install_command = ["python", "-m", "pip", "install", "-r", "requirements.txt"] + logger.info(" ".join(install_command)) + RunCode._install_via_subprocess(install_command, check=True, cwd=working_directory, env=env) + + @staticmethod + def _install_pytest(working_directory, env): + install_pytest_command = ["python", "-m", "pip", "install", "pytest"] + logger.info(" ".join(install_pytest_command)) + RunCode._install_via_subprocess(install_pytest_command, check=True, cwd=working_directory, env=env) + + @staticmethod + def _install_dependencies(working_directory, env): + RunCode._install_requirements(working_directory, env) + RunCode._install_pytest(working_directory, env) diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py new file mode 100644 index 0000000000000000000000000000000000000000..7eed7381b90775f4fa777a2c150caac66f4f283d --- /dev/null +++ b/metagpt/actions/search_and_summarize.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/23 17:26 +@Author : alexanderwu +@File : search_google.py +""" +from typing import Optional + +import pydantic +from pydantic import model_validator + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.schema import Message +from metagpt.tools.search_engine import SearchEngine + +SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements +1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. +- The context is for reference only. If it is irrelevant to the user's search request history, please reduce its reference and usage. +2. If there are citable links in the context, annotate them in the main text in the format [main text](citation link). If there are none in the context, do not write links. +3. The reply should be graceful, clear, non-repetitive, smoothly written, and of moderate length, in {LANG}. + +### Dialogue History (For example) +A: MLOps competitors + +### Current Question (For example) +A: MLOps competitors + +### Current Reply (For example) +1. Alteryx Designer: etc. if any +2. Matlab: ditto +3. IBM SPSS Statistics +4. RapidMiner Studio +5. DataRobot AI Platform +6. Databricks Lakehouse Platform +7. Amazon SageMaker +8. Dataiku +""" + +SEARCH_AND_SUMMARIZE_SYSTEM_EN_US = SEARCH_AND_SUMMARIZE_SYSTEM.format(LANG="en-us") + +SEARCH_AND_SUMMARIZE_PROMPT = """ +### Reference Information +{CONTEXT} + +### Dialogue History +{QUERY_HISTORY} +{QUERY} + +### Current Question +{QUERY} + +### Current Reply: Based on the information, please write the reply to the Question + + +""" + +SEARCH_AND_SUMMARIZE_SALES_SYSTEM = """## Requirements +1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation. +- The context is for reference only. If it is irrelevant to the user's search request history, please reduce its reference and usage. +2. If there are citable links in the context, annotate them in the main text in the format [main text](citation link). If there are none in the context, do not write links. +3. The reply should be graceful, clear, non-repetitive, smoothly written, and of moderate length, in Simplified Chinese. + +# Example +## Reference Information +... + +## Dialogue History +user: Which facial cleanser is good for oily skin? +Salesperson: Hello, for oily skin, it is suggested to choose a product that can deeply cleanse, control oil, and is gentle and skin-friendly. According to customer feedback and market reputation, the following facial cleansers are recommended:... +user: Do you have any by L'Oreal? +> Salesperson: ... + +## Ideal Answer +Yes, I've selected the following for you: +1. L'Oreal Men's Facial Cleanser: Oil control, anti-acne, balance of water and oil, pore purification, effectively against blackheads, deep exfoliation, refuse oil shine. Dense foam, not tight after washing. +2. L'Oreal Age Perfect Hydrating Cleanser: Added with sodium cocoyl glycinate and Centella Asiatica, two effective ingredients, it can deeply cleanse, tighten the skin, gentle and not tight. +""" + +SEARCH_AND_SUMMARIZE_SALES_PROMPT = """ +## Reference Information +{CONTEXT} + +## Dialogue History +{QUERY_HISTORY} +{QUERY} +> {ROLE}: + +""" + +SEARCH_FOOD = """ +# User Search Request +What are some delicious foods in Xiamen? + +# Requirements +You are a member of a professional butler team and will provide helpful suggestions: +1. Please summarize the user's search request based on the context and avoid including unrelated text. +2. Use [main text](reference link) in markdown format to **naturally annotate** 3-5 textual elements (such as product words or similar text sections) within the main text for easy navigation. +3. The response should be elegant, clear, **without any repetition of text**, smoothly written, and of moderate length. +""" + + +class SearchAndSummarize(Action): + name: str = "" + content: Optional[str] = None + search_engine: SearchEngine = None + result: str = "" + + @model_validator(mode="after") + def validate_search_engine(self): + if self.search_engine is None: + try: + config = self.config + search_engine = SearchEngine.from_search_config(config.search, proxy=config.proxy) + except pydantic.ValidationError: + search_engine = None + + self.search_engine = search_engine + return self + + async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str: + if self.search_engine is None: + logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature") + return "" + + query = context[-1].content + # logger.debug(query) + rsp = await self.search_engine.run(query) + self.result = rsp + if not rsp: + logger.error("empty rsp...") + return "" + # logger.info(rsp) + + system_prompt = [system_text] + + prompt = SEARCH_AND_SUMMARIZE_PROMPT.format( + ROLE=self.prefix, + CONTEXT=rsp, + QUERY_HISTORY="\n".join([str(i) for i in context[:-1]]), + QUERY=str(context[-1]), + ) + result = await self._aask(prompt, system_prompt) + logger.debug(prompt) + logger.debug(result) + return result diff --git a/metagpt/actions/search_enhanced_qa.py b/metagpt/actions/search_enhanced_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..1427f9b1952acdc64938ad9d4fc5164a010c1286 --- /dev/null +++ b/metagpt/actions/search_enhanced_qa.py @@ -0,0 +1,292 @@ +"""Enhancing question-answering capabilities through search engine augmentation.""" + +from __future__ import annotations + +import json + +from pydantic import Field, PrivateAttr, model_validator + +from metagpt.actions import Action +from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize +from metagpt.logs import logger +from metagpt.tools.tool_registry import register_tool +from metagpt.tools.web_browser_engine import WebBrowserEngine +from metagpt.utils.common import CodeParser +from metagpt.utils.parse_html import WebPage +from metagpt.utils.report import ThoughtReporter + +REWRITE_QUERY_PROMPT = """ +Role: You are a highly efficient assistant that provide a better search query for web search engine to answer the given question. + +I will provide you with a question. Your task is to provide a better search query for web search engine. + +## Context +### Question +{q} + +## Format Example +```json +{{ + "query": "the better search query for web search engine.", +}} +``` + +## Instructions +- Understand the question given by the user. +- Provide a better search query for web search engine to answer the given question, your answer must be written in the same language as the question. +- When rewriting, if you are unsure of the specific time, do not include the time. + +## Constraint +Format: Just print the result in json format like **Format Example**. + +## Action +Follow **Instructions**, generate output and make sure it follows the **Constraint**. +""" + +SEARCH_ENHANCED_QA_SYSTEM_PROMPT = """ +You are a large language AI assistant built by MGX. You are given a user question, and please write clean, concise and accurate answer to the question. You will be given a set of related contexts to the question, each starting with a reference number like [[citation:x]], where x is a number. Please use the context. + +Your answer must be correct, accurate and written by an expert using an unbiased and professional tone. Please limit to 1024 tokens. Do not give any information that is not related to the question, and do not repeat. Say "information is missing on" followed by the related topic, if the given context do not provide sufficient information. + +Do not include [citation:x] in your anwser, where x is a number. Other than code and specific names and citations, your answer must be written in the same language as the question. + +Here are the set of contexts: + +{context} + +Remember, don't blindly repeat the contexts verbatim. And here is the user question: +""" + + +@register_tool(include_functions=["run"]) +class SearchEnhancedQA(Action): + """Question answering and info searching through search engine.""" + + name: str = "SearchEnhancedQA" + desc: str = "Integrating search engine results to anwser the question." + + collect_links_action: CollectLinks = Field( + default_factory=CollectLinks, description="Action to collect relevant links from a search engine." + ) + web_browse_and_summarize_action: WebBrowseAndSummarize = Field( + default=None, + description="Action to explore the web and provide summaries of articles and webpages.", + ) + per_page_timeout: float = Field( + default=20, description="The maximum time for fetching a single page is in seconds. Defaults to 20s." + ) + java_script_enabled: bool = Field( + default=False, description="Whether or not to enable JavaScript in the web browser context. Defaults to False." + ) + user_agent: str = Field( + default="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36 Edg/116.0.1938.81", + description="Specific user agent to use in browser.", + ) + extra_http_headers: dict = Field( + default={"sec-ch-ua": 'Chromium";v="125", "Not.A/Brand";v="24'}, + description="An object containing additional HTTP headers to be sent with every request.", + ) + max_chars_per_webpage_summary: int = Field( + default=4000, description="Maximum summary length for each web page content." + ) + max_search_results: int = Field( + default=10, + description="Maximum number of search results (links) to collect using the collect_links_action. This controls the number of potential sources for answering the question.", + ) + + _reporter: ThoughtReporter = PrivateAttr(ThoughtReporter()) + + @model_validator(mode="after") + def initialize(self): + if self.web_browse_and_summarize_action is None: + web_browser_engine = WebBrowserEngine.from_browser_config( + self.config.browser, + proxy=self.config.proxy, + java_script_enabled=self.java_script_enabled, + extra_http_headers=self.extra_http_headers, + user_agent=self.user_agent, + ) + + self.web_browse_and_summarize_action = WebBrowseAndSummarize(web_browser_engine=web_browser_engine) + + return self + + async def run(self, query: str, rewrite_query: bool = True) -> str: + """Answer a query by leveraging web search results. + + Args: + query (str): The original user query. + rewrite_query (bool): Whether to rewrite the query for better web search results. Defaults to True. + + Returns: + str: A detailed answer based on web search results. + + Raises: + ValueError: If the query is invalid. + """ + async with self._reporter: + await self._reporter.async_report({"type": "search", "stage": "init"}) + self._validate_query(query) + + processed_query = await self._process_query(query, rewrite_query) + context = await self._build_context(processed_query) + + return await self._generate_answer(processed_query, context) + + def _validate_query(self, query: str) -> None: + """Validate the input query. + + Args: + query (str): The query to validate. + + Raises: + ValueError: If the query is invalid. + """ + + if not query.strip(): + raise ValueError("Query cannot be empty or contain only whitespace.") + + async def _process_query(self, query: str, should_rewrite: bool) -> str: + """Process the query, optionally rewriting it.""" + + if should_rewrite: + return await self._rewrite_query(query) + + return query + + async def _rewrite_query(self, query: str) -> str: + """Write a better search query for web search engine. + + If the rewrite process fails, the original query is returned. + + Args: + query (str): The original search query. + + Returns: + str: The rewritten query if successful, otherwise the original query. + """ + + prompt = REWRITE_QUERY_PROMPT.format(q=query) + + try: + resp = await self._aask(prompt) + rewritten_query = self._extract_rewritten_query(resp) + + logger.info(f"Query rewritten: '{query}' -> '{rewritten_query}'") + return rewritten_query + except Exception as e: + logger.warning(f"Query rewrite failed. Returning original query. Error: {e}") + return query + + def _extract_rewritten_query(self, response: str) -> str: + """Extract the rewritten query from the LLM's JSON response.""" + + resp_json = json.loads(CodeParser.parse_code(response, lang="json")) + return resp_json["query"] + + async def _build_context(self, query: str) -> str: + """Construct a context string from web search citations. + + Args: + query (str): The search query. + + Returns: + str: Formatted context with numbered citations. + """ + + citations = await self._search_citations(query) + context = "\n\n".join([f"[[citation:{i+1}]] {c}" for i, c in enumerate(citations)]) + + return context + + async def _search_citations(self, query: str) -> list[str]: + """Perform web search and summarize relevant content. + + Args: + query (str): The search query. + + Returns: + list[str]: Summaries of relevant web content. + """ + + relevant_urls = await self._collect_relevant_links(query) + await self._reporter.async_report({"type": "search", "stage": "searching", "urls": relevant_urls}) + if not relevant_urls: + logger.warning(f"No relevant URLs found for query: {query}") + return [] + + logger.info(f"The Relevant links are: {relevant_urls}") + + web_summaries = await self._summarize_web_content(relevant_urls) + if not web_summaries: + logger.warning(f"No summaries generated for query: {query}") + return [] + + citations = list(web_summaries.values()) + + return citations + + async def _collect_relevant_links(self, query: str) -> list[str]: + """Search and rank URLs relevant to the query. + + Args: + query (str): The search query. + + Returns: + list[str]: Ranked list of relevant URLs. + """ + + return await self.collect_links_action._search_and_rank_urls( + topic=query, query=query, max_num_results=self.max_search_results + ) + + async def _summarize_web_content(self, urls: list[str]) -> dict[str, str]: + """Fetch and summarize content from given URLs. + + Args: + urls (list[str]): List of URLs to summarize. + + Returns: + dict[str, str]: Mapping of URLs to their summaries. + """ + + contents = await self._fetch_web_contents(urls) + + summaries = {} + await self._reporter.async_report( + {"type": "search", "stage": "browsing", "pages": [i.model_dump() for i in contents]} + ) + for content in contents: + url = content.url + inner_text = content.inner_text.replace("\n", "") + if self.web_browse_and_summarize_action._is_content_invalid(inner_text): + logger.warning(f"Invalid content detected for URL {url}: {inner_text[:10]}...") + continue + + summary = inner_text[: self.max_chars_per_webpage_summary] + summaries[url] = summary + + return summaries + + async def _fetch_web_contents(self, urls: list[str]) -> list[WebPage]: + return await self.web_browse_and_summarize_action._fetch_web_contents( + *urls, per_page_timeout=self.per_page_timeout + ) + + async def _generate_answer(self, query: str, context: str) -> str: + """Generate an answer using the query and context. + + Args: + query (str): The user's question. + context (str): Relevant information from web search. + + Returns: + str: Generated answer based on the context. + """ + + system_prompt = SEARCH_ENHANCED_QA_SYSTEM_PROMPT.format(context=context) + + async with ThoughtReporter(uuid=self._reporter.uuid, enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "search", "stage": "answer"}) + rsp = await self._aask(query, [system_prompt]) + return rsp diff --git a/metagpt/actions/skill_action.py b/metagpt/actions/skill_action.py new file mode 100644 index 0000000000000000000000000000000000000000..078ab008a5042fca6500a54fec8e0caebce0262d --- /dev/null +++ b/metagpt/actions/skill_action.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/28 +@Author : mashenquan +@File : skill_action.py +@Desc : Call learned skill +""" +from __future__ import annotations + +import ast +import importlib +import traceback +from copy import deepcopy +from typing import Dict, Optional + +from metagpt.actions import Action +from metagpt.learn.skill_loader import Skill +from metagpt.logs import logger +from metagpt.schema import Message + + +# TOTEST +class ArgumentsParingAction(Action): + skill: Skill + ask: str + rsp: Optional[Message] = None + args: Optional[Dict] = None + + @property + def prompt(self): + prompt = f"{self.skill.name} function parameters description:\n" + for k, v in self.skill.arguments.items(): + prompt += f"parameter `{k}`: {v}\n" + prompt += "\n---\n" + prompt += "Examples:\n" + for e in self.skill.examples: + prompt += f"If want you to do `{e.ask}`, return `{e.answer}` brief and clear.\n" + prompt += "\n---\n" + prompt += ( + f"\nRefer to the `{self.skill.name}` function description, and fill in the function parameters according " + 'to the example "I want you to do xx" in the Examples section.' + f"\nNow I want you to do `{self.ask}`, return function parameters in Examples format above, brief and " + "clear." + ) + return prompt + + async def run(self, with_message=None, **kwargs) -> Message: + prompt = self.prompt + rsp = await self.llm.aask( + msg=prompt, + system_msgs=["You are a function parser.", "You can convert spoken words into function parameters."], + stream=False, + ) + logger.debug(f"SKILL:{prompt}\n, RESULT:{rsp}") + self.args = ArgumentsParingAction.parse_arguments(skill_name=self.skill.name, txt=rsp) + self.rsp = Message(content=rsp, role="assistant", instruct_content=self.args, cause_by=self) + return self.rsp + + @staticmethod + def parse_arguments(skill_name, txt) -> dict: + prefix = skill_name + "(" + if prefix not in txt: + logger.error(f"{skill_name} not in {txt}") + return None + if ")" not in txt: + logger.error(f"')' not in {txt}") + return None + begin_ix = txt.find(prefix) + end_ix = txt.rfind(")") + args_txt = txt[begin_ix + len(prefix) : end_ix] + logger.info(args_txt) + fake_expression = f"dict({args_txt})" + parsed_expression = ast.parse(fake_expression, mode="eval") + args = {} + for keyword in parsed_expression.body.keywords: + key = keyword.arg + value = ast.literal_eval(keyword.value) + args[key] = value + return args + + +class SkillAction(Action): + skill: Skill + args: Dict + rsp: Optional[Message] = None + + async def run(self, with_message=None, **kwargs) -> Message: + """Run action""" + options = deepcopy(kwargs) + if self.args: + for k in self.args.keys(): + if k in options: + options.pop(k) + try: + rsp = await self.find_and_call_function(self.skill.name, args=self.args, **options) + self.rsp = Message(content=rsp, role="assistant", cause_by=self) + except Exception as e: + logger.exception(f"{e}, traceback:{traceback.format_exc()}") + self.rsp = Message(content=f"Error: {e}", role="assistant", cause_by=self) + return self.rsp + + @staticmethod + async def find_and_call_function(function_name, args, **kwargs) -> str: + try: + module = importlib.import_module("metagpt.learn") + function = getattr(module, function_name) + # Invoke function and return result + result = await function(**args, **kwargs) + return result + except (ModuleNotFoundError, AttributeError): + logger.error(f"{function_name} not found") + raise ValueError(f"{function_name} not found") diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py new file mode 100644 index 0000000000000000000000000000000000000000..e3556caa7bdaa5c33fbcacf8a92c9c0d5f8dc98a --- /dev/null +++ b/metagpt/actions/summarize_code.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Author : alexanderwu +@File : summarize_code.py +@Modified By: mashenquan, 2023/12/5. Archive the summarization content of issue discovery for use in WriteCode. +""" +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel, Field +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions.action import Action +from metagpt.logs import logger +from metagpt.schema import CodeSummarizeContext +from metagpt.utils.common import get_markdown_code_block_type +from metagpt.utils.project_repo import ProjectRepo + +PROMPT_TEMPLATE = """ +NOTICE +Role: You are a professional software engineer, and your main task is to review the code. +Language: Please use the same language as the user requirement, but the title and code should be still in English. For example, if the user speaks Chinese, the specific text of your answer should also be in Chinese. +ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenced "Format example". + +----- +# System Design +```text +{system_design} +``` +----- +# Task +```text +{task} +``` +----- +{code_blocks} + +## Code Review All: Please read all historical files and find possible bugs in the files, such as unimplemented functions, calling errors, unreferences, etc. + +## Call flow: mermaid code, based on the implemented function, use mermaid to draw a complete call chain + +## Summary: Summary based on the implementation of historical files + +## TODOs: Python dict[str, str], write down the list of files that need to be modified and the reasons. We will modify them later. + +""" + +FORMAT_EXAMPLE = """ + +## Code Review All + +### a.py +- It fulfills less of xxx requirements... +- Field yyy is not given... +-... + +### b.py +... + +### c.py +... + +## Call flow +```mermaid +flowchart TB + c1-->a2 + subgraph one + a1-->a2 + end + subgraph two + b1-->b2 + end + subgraph three + c1-->c2 + end +``` + +## Summary +- a.py:... +- b.py:... +- c.py:... +- ... + +## TODOs +{ + "a.py": "implement requirement xxx...", +} + +""" + + +class SummarizeCode(Action): + name: str = "SummarizeCode" + i_context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + + @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) + async def summarize_code(self, prompt): + code_rsp = await self._aask(prompt) + return code_rsp + + async def run(self): + design_pathname = Path(self.i_context.design_filename) + design_doc = await self.repo.docs.system_design.get(filename=design_pathname.name) + task_pathname = Path(self.i_context.task_filename) + task_doc = await self.repo.docs.task.get(filename=task_pathname.name) + code_blocks = [] + for filename in self.i_context.codes_filenames: + code_doc = await self.repo.srcs.get(filename) + code_block = f"```{get_markdown_code_block_type(filename)}\n{code_doc.content}\n```\n---\n" + code_blocks.append(code_block) + format_example = FORMAT_EXAMPLE + prompt = PROMPT_TEMPLATE.format( + system_design=design_doc.content, + task=task_doc.content, + code_blocks="\n".join(code_blocks), + format_example=format_example, + ) + logger.info("Summarize code..") + rsp = await self.summarize_code(prompt) + return rsp diff --git a/metagpt/actions/talk_action.py b/metagpt/actions/talk_action.py new file mode 100644 index 0000000000000000000000000000000000000000..3fec327838778768ac62635b707705e2155e7d57 --- /dev/null +++ b/metagpt/actions/talk_action.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/28 +@Author : mashenquan +@File : talk_action.py +@Desc : Act as it’s a talk +""" +from typing import Optional + +from metagpt.actions import Action +from metagpt.logs import logger +from metagpt.schema import Message + + +class TalkAction(Action): + i_context: str + history_summary: str = "" + knowledge: str = "" + rsp: Optional[Message] = None + + @property + def agent_description(self): + return self.context.kwargs.agent_description + + @property + def language(self): + return self.context.kwargs.language or self.config.language + + @property + def prompt(self): + prompt = "" + if self.agent_description: + prompt = ( + f"You are {self.agent_description}. Your responses should align with the role-play agreement, " + f"maintaining the character's persona and habits. When faced with unrelated questions, playfully " + f"decline to answer without revealing your AI nature to preserve the character's image.\n\n" + ) + prompt += f"Knowledge:\n{self.knowledge}\n\n" if self.knowledge else "" + prompt += f"{self.history_summary}\n\n" + prompt += ( + "If the information is insufficient, you can search in the historical conversation or knowledge above.\n" + ) + language = self.language + prompt += ( + f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.\n " + f"{self.i_context}" + ) + logger.debug(f"PROMPT: {prompt}") + return prompt + + @property + def prompt_gpt4(self): + kvs = { + "{role}": self.agent_description or "", + "{history}": self.history_summary or "", + "{knowledge}": self.knowledge or "", + "{language}": self.language, + "{ask}": self.i_context, + } + prompt = TalkActionPrompt.FORMATION_LOOSE + for k, v in kvs.items(): + prompt = prompt.replace(k, v) + logger.info(f"PROMPT: {prompt}") + return prompt + + # async def run_old(self, *args, **kwargs) -> ActionOutput: + # prompt = self.prompt + # rsp = await self.llm.aask(msg=prompt, system_msgs=[]) + # logger.debug(f"PROMPT:{prompt}\nRESULT:{rsp}\n") + # self._rsp = ActionOutput(content=rsp) + # return self._rsp + + @property + def aask_args(self): + language = self.language + system_msgs = [ + f"You are {self.agent_description}.", + "Your responses should align with the role-play agreement, " + "maintaining the character's persona and habits. When faced with unrelated questions, playfully " + "decline to answer without revealing your AI nature to preserve the character's image.", + "If the information is insufficient, you can search in the context or knowledge.", + f"Answer the following questions strictly in {language}, and the answers must follow the Markdown format.", + ] + format_msgs = [] + if self.knowledge: + format_msgs.append({"role": "assistant", "content": self.knowledge}) + if self.history_summary: + format_msgs.append({"role": "assistant", "content": self.history_summary}) + return self.i_context, format_msgs, system_msgs + + async def run(self, with_message=None, **kwargs) -> Message: + msg, format_msgs, system_msgs = self.aask_args + rsp = await self.llm.aask(msg=msg, format_msgs=format_msgs, system_msgs=system_msgs, stream=False) + self.rsp = Message(content=rsp, role="assistant", cause_by=self) + return self.rsp + + +class TalkActionPrompt: + FORMATION = """Formation: "Capacity and role" defines the role you are currently playing; + "[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation; + "[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses; + "Statement" defines the work detail you need to complete at this stage; + "[ASK_BEGIN]" and [ASK_END] tags enclose the questions; + "Constraint" defines the conditions that your responses must comply with. + "Personality" defines your language style。 + "Insight" provides a deeper understanding of the characters' inner traits. + "Initial" defines the initial setup of a character. + +Capacity and role: {role} +Statement: Your responses should align with the role-play agreement, maintaining the + character's persona and habits. When faced with unrelated questions, playfully decline to answer without revealing + your AI nature to preserve the character's image. + +[HISTORY_BEGIN] + +{history} + +[HISTORY_END] + +[KNOWLEDGE_BEGIN] + +{knowledge} + +[KNOWLEDGE_END] + +Statement: If the information is insufficient, you can search in the historical conversation or knowledge. +Statement: Unless you are a language professional, answer the following questions strictly in {language} +, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]" +, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses. + + +{ask} +""" + + FORMATION_LOOSE = """Formation: "Capacity and role" defines the role you are currently playing; + "[HISTORY_BEGIN]" and "[HISTORY_END]" tags enclose the historical conversation; + "[KNOWLEDGE_BEGIN]" and "[KNOWLEDGE_END]" tags enclose the knowledge may help for your responses; + "Statement" defines the work detail you need to complete at this stage; + "Constraint" defines the conditions that your responses must comply with. + "Personality" defines your language style。 + "Insight" provides a deeper understanding of the characters' inner traits. + "Initial" defines the initial setup of a character. + +Capacity and role: {role} +Statement: Your responses should maintaining the character's persona and habits. When faced with unrelated questions +, playfully decline to answer without revealing your AI nature to preserve the character's image. + +[HISTORY_BEGIN] + +{history} + +[HISTORY_END] + +[KNOWLEDGE_BEGIN] + +{knowledge} + +[KNOWLEDGE_END] + +Statement: If the information is insufficient, you can search in the historical conversation or knowledge. +Statement: Unless you are a language professional, answer the following questions strictly in {language} +, and the answers must follow the Markdown format. Strictly excluding any tag likes "[HISTORY_BEGIN]" +, "[HISTORY_END]", "[KNOWLEDGE_BEGIN]", "[KNOWLEDGE_END]" in responses. + + +{ask} +""" diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py new file mode 100644 index 0000000000000000000000000000000000000000..da25fe621c426a6382de82c684d43e63fb35f8c5 --- /dev/null +++ b/metagpt/actions/write_code.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:45 +@Author : alexanderwu +@File : write_code.py +@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.1.3 of RFC 116, modify the data type of the `cause_by` + value of the `Message` object. +@Modified By: mashenquan, 2023-11-27. + 1. Mark the location of Design, Tasks, Legacy Code and Debug logs in the PROMPT_TEMPLATE with markdown + code-block formatting to enhance the understanding for the LLM. + 2. Following the think-act principle, solidify the task parameters when creating the WriteCode object, rather + than passing them in when calling the run function. + 3. Encapsulate the input of RunCode into RunCodeContext and encapsulate the output of RunCode into + RunCodeResult to standardize and unify parameter passing between WriteCode, RunCode, and DebugError. +""" + +import json +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel, Field +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions.action import Action +from metagpt.actions.project_management_an import REFINED_TASK_LIST, TASK_LIST +from metagpt.actions.write_code_plan_and_change_an import REFINED_TEMPLATE +from metagpt.logs import logger +from metagpt.schema import CodingContext, Document, RunCodeResult +from metagpt.utils.common import CodeParser, get_markdown_code_block_type +from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.report import EditorReporter + +PROMPT_TEMPLATE = """ +NOTICE +Role: You are a professional engineer; the main goal is to write google-style, elegant, modular, easy to read and maintain code +Language: Please use the same language as the user requirement, but the title and code should be still in English. For example, if the user speaks Chinese, the specific text of your answer should also be in Chinese. +ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenced "Format example". + +# Context +## Design +{design} + +## Task +{task} + +## Legacy Code +{code} + +## Debug logs +```text +{logs} + +{summary_log} +``` + +## Bug Feedback logs +```text +{feedback} +``` + +# Format example +## Code: {demo_filename}.py +```python +## {demo_filename}.py +... +``` +## Code: {demo_filename}.js +```javascript +// {demo_filename}.js +... +``` + +# Instruction: Based on the context, follow "Format example", write code. + +## Code: {filename}. Write code with triple quoto, based on the following attentions and context. +1. Only One file: do your best to implement THIS ONLY ONE FILE. +2. COMPLETE CODE: Your code will be part of the entire project, so please implement complete, reliable, reusable code snippets. +3. Set default value: If there is any setting, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE. AVOID circular import. +4. Follow design: YOU MUST FOLLOW "Data structures and interfaces". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design. +5. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE. +6. Before using a external variable/module, make sure you import it first. +7. Write out EVERY CODE DETAIL, DON'T LEAVE TODO. + +""" + + +class WriteCode(Action): + name: str = "WriteCode" + i_context: Document = Field(default_factory=Document) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + + @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) + async def write_code(self, prompt) -> str: + code_rsp = await self._aask(prompt) + code = CodeParser.parse_code(text=code_rsp) + return code + + async def run(self, *args, **kwargs) -> CodingContext: + bug_feedback = None + if self.input_args and hasattr(self.input_args, "issue_filename"): + bug_feedback = await Document.load(self.input_args.issue_filename) + coding_context = CodingContext.loads(self.i_context.content) + if not coding_context.code_plan_and_change_doc: + coding_context.code_plan_and_change_doc = await self.repo.docs.code_plan_and_change.get( + filename=coding_context.task_doc.filename + ) + test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json") + requirement_doc = await Document.load(self.input_args.requirements_filename) + summary_doc = None + if coding_context.design_doc and coding_context.design_doc.filename: + summary_doc = await self.repo.docs.code_summary.get(filename=coding_context.design_doc.filename) + logs = "" + if test_doc: + test_detail = RunCodeResult.loads(test_doc.content) + logs = test_detail.stderr + + if self.config.inc or bug_feedback: + code_context = await self.get_codes( + coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo, use_inc=True + ) + else: + code_context = await self.get_codes( + coding_context.task_doc, exclude=self.i_context.filename, project_repo=self.repo + ) + + if self.config.inc: + prompt = REFINED_TEMPLATE.format( + user_requirement=requirement_doc.content if requirement_doc else "", + code_plan_and_change=coding_context.code_plan_and_change_doc.content + if coding_context.code_plan_and_change_doc + else "", + design=coding_context.design_doc.content if coding_context.design_doc else "", + task=coding_context.task_doc.content if coding_context.task_doc else "", + code=code_context, + logs=logs, + feedback=bug_feedback.content if bug_feedback else "", + filename=self.i_context.filename, + demo_filename=Path(self.i_context.filename).stem, + summary_log=summary_doc.content if summary_doc else "", + ) + else: + prompt = PROMPT_TEMPLATE.format( + design=coding_context.design_doc.content if coding_context.design_doc else "", + task=coding_context.task_doc.content if coding_context.task_doc else "", + code=code_context, + logs=logs, + feedback=bug_feedback.content if bug_feedback else "", + filename=self.i_context.filename, + demo_filename=Path(self.i_context.filename).stem, + summary_log=summary_doc.content if summary_doc else "", + ) + logger.info(f"Writing {coding_context.filename}..") + async with EditorReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "code", "filename": coding_context.filename}, "meta") + code = await self.write_code(prompt) + if not coding_context.code_doc: + # avoid root_path pydantic ValidationError if use WriteCode alone + coding_context.code_doc = Document( + filename=coding_context.filename, root_path=str(self.repo.src_relative_path) + ) + coding_context.code_doc.content = code + await reporter.async_report(coding_context.code_doc, "document") + return coding_context + + @staticmethod + async def get_codes(task_doc: Document, exclude: str, project_repo: ProjectRepo, use_inc: bool = False) -> str: + """ + Get codes for generating the exclude file in various scenarios. + + Attributes: + task_doc (Document): Document object of the task file. + exclude (str): The file to be generated. Specifies the filename to be excluded from the code snippets. + project_repo (ProjectRepo): ProjectRepo object of the project. + use_inc (bool): Indicates whether the scenario involves incremental development. Defaults to False. + + Returns: + str: Codes for generating the exclude file. + """ + if not task_doc: + return "" + if not task_doc.content: + task_doc = project_repo.docs.task.get(filename=task_doc.filename) + m = json.loads(task_doc.content) + code_filenames = m.get(TASK_LIST.key, []) if not use_inc else m.get(REFINED_TASK_LIST.key, []) + codes = [] + src_file_repo = project_repo.srcs + # Incremental development scenario + if use_inc: + for filename in src_file_repo.all_files: + code_block_type = get_markdown_code_block_type(filename) + # Exclude the current file from the all code snippets + if filename == exclude: + # If the file is in the old workspace, use the old code + # Exclude unnecessary code to maintain a clean and focused main.py file, ensuring only relevant and + # essential functionality is included for the project’s requirements + if filename != "main.py": + # Use old code + doc = await src_file_repo.get(filename=filename) + # If the file is in the src workspace, skip it + else: + continue + codes.insert( + 0, f"### The name of file to rewrite: `{filename}`\n```{code_block_type}\n{doc.content}```\n" + ) + logger.info(f"Prepare to rewrite `{filename}`") + # The code snippets are generated from the src workspace + else: + doc = await src_file_repo.get(filename=filename) + # If the file does not exist in the src workspace, skip it + if not doc: + continue + codes.append(f"### File Name: `{filename}`\n```{code_block_type}\n{doc.content}```\n\n") + + # Normal scenario + else: + for filename in code_filenames: + # Exclude the current file to get the code snippets for generating the current file + if filename == exclude: + continue + doc = await src_file_repo.get(filename=filename) + if not doc: + continue + code_block_type = get_markdown_code_block_type(filename) + codes.append(f"### File Name: `{filename}`\n```{code_block_type}\n{doc.content}```\n\n") + + return "\n".join(codes) diff --git a/metagpt/actions/write_code_an_draft.py b/metagpt/actions/write_code_an_draft.py new file mode 100644 index 0000000000000000000000000000000000000000..d6622284d2ed488e345cc72ae12b97572e5b551c --- /dev/null +++ b/metagpt/actions/write_code_an_draft.py @@ -0,0 +1,589 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Author : alexanderwu +@File : write_review.py +""" +import asyncio +from typing import List, Literal + +from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode + +REVIEW = ActionNode( + key="Review", + expected_type=List[str], + instruction="Act as an experienced reviewer and critically assess the given output. Provide specific and" + " constructive feedback, highlighting areas for improvement and suggesting changes.", + example=[ + "The logic in the function `calculate_total` seems flawed. Shouldn't it consider the discount rate as well?", + "The TODO function is not implemented yet? Should we implement it before commit?", + ], +) + +REVIEW_RESULT = ActionNode( + key="ReviewResult", + expected_type=Literal["LGTM", "LBTM"], + instruction="LGTM/LBTM. If the code is fully implemented, " "give a LGTM, otherwise provide a LBTM.", + example="LBTM", +) + +NEXT_STEPS = ActionNode( + key="NextSteps", + expected_type=str, + instruction="Based on the code review outcome, suggest actionable steps. This can include code changes, " + "refactoring suggestions, or any follow-up tasks.", + example="""1. Refactor the `process_data` method to improve readability and efficiency. +2. Cover edge cases in the `validate_user` function. +3. Implement a the TODO in the `calculate_total` function. +4. Fix the `handle_events` method to update the game state only if a move is successful. + ```python + def handle_events(self): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + return False + if event.type == pygame.KEYDOWN: + moved = False + if event.key == pygame.K_UP: + moved = self.game.move('UP') + elif event.key == pygame.K_DOWN: + moved = self.game.move('DOWN') + elif event.key == pygame.K_LEFT: + moved = self.game.move('LEFT') + elif event.key == pygame.K_RIGHT: + moved = self.game.move('RIGHT') + if moved: + # Update the game state only if a move was successful + self.render() + return True + ``` +""", +) + +WRITE_DRAFT = ActionNode( + key="WriteDraft", + expected_type=str, + instruction="Could you write draft code for move function in order to implement it?", + example="Draft: ...", +) + + +WRITE_FUNCTION = ActionNode( + key="WriteFunction", + expected_type=str, + instruction="write code for the function not implemented.", + example=""" +```Code +... +``` +""", +) + + +REWRITE_CODE = ActionNode( + key="RewriteCode", + expected_type=str, + instruction="""rewrite code based on the Review and Actions""", + example=""" +```python +## example.py +def calculate_total(price, quantity): + total = price * quantity +``` +""", +) + + +CODE_REVIEW_CONTEXT = """ +# System +Role: You are a professional software engineer, and your main task is to review and revise the code. You need to ensure that the code conforms to the google-style standards, is elegantly designed and modularized, easy to read and maintain. +Language: Please use the same language as the user requirement, but the title and code should be still in English. For example, if the user speaks Chinese, the specific text of your answer should also be in Chinese. + +# Context +## System Design +{"Implementation approach": "我们将使用HTML、CSS和JavaScript来实现这个单机的响应式2048游戏。为了确保游戏性能流畅和响应式设计,我们会选择使用Vue.js框架,因为它易于上手且适合构建交互式界面。我们还将使用localStorage来记录玩家的最高分。", "File list": ["index.html", "styles.css", "main.js", "game.js", "storage.js"], "Data structures and interfaces": "classDiagram\ + class Game {\ + -board Array\ + -score Number\ + -bestScore Number\ + +constructor()\ + +startGame()\ + +move(direction: String)\ + +getBoard() Array\ + +getScore() Number\ + +getBestScore() Number\ + +setBestScore(score: Number)\ + }\ + class Storage {\ + +getBestScore() Number\ + +setBestScore(score: Number)\ + }\ + class Main {\ + +init()\ + +bindEvents()\ + }\ + Game --> Storage : uses\ + Main --> Game : uses", "Program call flow": "sequenceDiagram\ + participant M as Main\ + participant G as Game\ + participant S as Storage\ + M->>G: init()\ + G->>S: getBestScore()\ + S-->>G: return bestScore\ + M->>G: bindEvents()\ + M->>G: startGame()\ + loop Game Loop\ + M->>G: move(direction)\ + G->>S: setBestScore(score)\ + S-->>G: return\ + end", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"} + +## Tasks +{"Required packages": ["无需第三方包"], "Required Other language third-party packages": ["vue.js"], "Logic Analysis": [["index.html", "作为游戏的入口文件和主要的HTML结构"], ["styles.css", "包含所有的CSS样式,确保游戏界面美观"], ["main.js", "包含Main类,负责初始化游戏和绑定事件"], ["game.js", "包含Game类,负责游戏逻辑,如开始游戏、移动方块等"], ["storage.js", "包含Storage类,用于获取和设置玩家的最高分"]], "Task list": ["index.html", "styles.css", "storage.js", "game.js", "main.js"], "Full API spec": "", "Shared Knowledge": "\'game.js\' 包含游戏逻辑相关的函数,被 \'main.js\' 调用。", "Anything UNCLEAR": "目前项目要求明确,没有不清楚的地方。"} + +## Code Files +----- index.html + + + + + + 2048游戏 + + + + +
+

2048

+
+
+
分数
+
{{ score }}
+
+
+
最高分
+
{{ bestScore }}
+
+
+
+
+
+ {{ cell !== 0 ? cell : \'\' }} +
+
+
+ +
+ + + + + + + + +----- styles.css +/* styles.css */ +body, html { + margin: 0; + padding: 0; + font-family: \'Arial\', sans-serif; +} + +#app { + text-align: center; + font-size: 18px; + color: #776e65; +} + +h1 { + color: #776e65; + font-size: 72px; + font-weight: bold; + margin: 20px 0; +} + +.scores-container { + display: flex; + justify-content: center; + margin-bottom: 20px; +} + +.score-container, .best-container { + background: #bbada0; + padding: 10px; + border-radius: 5px; + margin: 0 10px; + min-width: 100px; + text-align: center; +} + +.score-header, .best-header { + color: #eee4da; + font-size: 18px; + margin-bottom: 5px; +} + +.game-container { + max-width: 500px; + margin: 0 auto 20px; + background: #bbada0; + padding: 15px; + border-radius: 10px; + position: relative; +} + +.grid-row { + display: flex; +} + +.grid-cell { + background: #cdc1b4; + width: 100px; + height: 100px; + margin: 5px; + display: flex; + justify-content: center; + align-items: center; + font-size: 35px; + font-weight: bold; + color: #776e65; + border-radius: 3px; +} + +/* Dynamic classes for different number cells */ +.number-cell-2 { + background: #eee4da; +} + +.number-cell-4 { + background: #ede0c8; +} + +.number-cell-8 { + background: #f2b179; + color: #f9f6f2; +} + +.number-cell-16 { + background: #f59563; + color: #f9f6f2; +} + +.number-cell-32 { + background: #f67c5f; + color: #f9f6f2; +} + +.number-cell-64 { + background: #f65e3b; + color: #f9f6f2; +} + +.number-cell-128 { + background: #edcf72; + color: #f9f6f2; +} + +.number-cell-256 { + background: #edcc61; + color: #f9f6f2; +} + +.number-cell-512 { + background: #edc850; + color: #f9f6f2; +} + +.number-cell-1024 { + background: #edc53f; + color: #f9f6f2; +} + +.number-cell-2048 { + background: #edc22e; + color: #f9f6f2; +} + +/* Larger numbers need smaller font sizes */ +.number-cell-1024, .number-cell-2048 { + font-size: 30px; +} + +button { + background-color: #8f7a66; + color: #f9f6f2; + border: none; + border-radius: 3px; + padding: 10px 20px; + font-size: 18px; + cursor: pointer; + outline: none; +} + +button:hover { + background-color: #9f8b76; +} + +----- storage.js +## storage.js +class Storage { + // 获取最高分 + getBestScore() { + // 尝试从localStorage中获取最高分,如果不存在则默认为0 + const bestScore = localStorage.getItem(\'bestScore\'); + return bestScore ? Number(bestScore) : 0; + } + + // 设置最高分 + setBestScore(score) { + // 将最高分设置到localStorage中 + localStorage.setItem(\'bestScore\', score.toString()); + } +} + + + +## Code to be Reviewed: game.js +```Code +## game.js +class Game { + constructor() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = 0; + } + + createEmptyBoard() { + const board = []; + for (let i = 0; i < 4; i++) { + board[i] = [0, 0, 0, 0]; + } + return board; + } + + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.addRandomTile(); + this.addRandomTile(); + } + + addRandomTile() { + let emptyCells = []; + for (let r = 0; r < 4; r++) { + for (let c = 0; c < 4; c++) { + if (this.board[r][c] === 0) { + emptyCells.push({ r, c }); + } + } + } + if (emptyCells.length > 0) { + let randomCell = emptyCells[Math.floor(Math.random() * emptyCells.length)]; + this.board[randomCell.r][randomCell.c] = Math.random() < 0.9 ? 2 : 4; + } + } + + move(direction) { + // This function will handle the logic for moving tiles + // in the specified direction and merging them + // It will also update the score and add a new random tile if the move is successful + // The actual implementation of this function is complex and would require + // a significant amount of code to handle all the cases for moving and merging tiles + // For the purposes of this example, we will not implement the full logic + // Instead, we will just call addRandomTile to simulate a move + this.addRandomTile(); + } + + getBoard() { + return this.board; + } + + getScore() { + return this.score; + } + + getBestScore() { + return this.bestScore; + } + + setBestScore(score) { + this.bestScore = score; + } +} + +``` +""" + + +CODE_REVIEW_SMALLEST_CONTEXT = """ +## Code to be Reviewed: game.js +```Code +// game.js +class Game { + constructor() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = 0; + } + + createEmptyBoard() { + const board = []; + for (let i = 0; i < 4; i++) { + board[i] = [0, 0, 0, 0]; + } + return board; + } + + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.addRandomTile(); + this.addRandomTile(); + } + + addRandomTile() { + let emptyCells = []; + for (let r = 0; r < 4; r++) { + for (let c = 0; c < 4; c++) { + if (this.board[r][c] === 0) { + emptyCells.push({ r, c }); + } + } + } + if (emptyCells.length > 0) { + let randomCell = emptyCells[Math.floor(Math.random() * emptyCells.length)]; + this.board[randomCell.r][randomCell.c] = Math.random() < 0.9 ? 2 : 4; + } + } + + move(direction) { + // This function will handle the logic for moving tiles + // in the specified direction and merging them + // It will also update the score and add a new random tile if the move is successful + // The actual implementation of this function is complex and would require + // a significant amount of code to handle all the cases for moving and merging tiles + // For the purposes of this example, we will not implement the full logic + // Instead, we will just call addRandomTile to simulate a move + this.addRandomTile(); + } + + getBoard() { + return this.board; + } + + getScore() { + return this.score; + } + + getBestScore() { + return this.bestScore; + } + + setBestScore(score) { + this.bestScore = score; + } +} + +``` +""" + + +CODE_REVIEW_SAMPLE = """ +## Code Review: game.js +1. The code partially implements the requirements. The `Game` class is missing the full implementation of the `move` method, which is crucial for the game\'s functionality. +2. The code logic is not completely correct. The `move` method is not implemented, which means the game cannot process player moves. +3. The existing code follows the "Data structures and interfaces" in terms of class structure but lacks full method implementations. +4. Not all functions are implemented. The `move` method is incomplete and does not handle the logic for moving and merging tiles. +5. All necessary pre-dependencies seem to be imported since the code does not indicate the need for additional imports. +6. The methods from other files (such as `Storage`) are not being used in the provided code snippet, but the class structure suggests that they will be used correctly. + +## Actions +1. Implement the `move` method to handle tile movements and merging. This is a complex task that requires careful consideration of the game\'s rules and logic. Here is a simplified version of how one might begin to implement the `move` method: + ```javascript + move(direction) { + // Simplified logic for moving tiles up + if (direction === \'up\') { + for (let col = 0; col < 4; col++) { + let tiles = this.board.map(row => row[col]).filter(val => val !== 0); + let merged = []; + for (let i = 0; i < tiles.length; i++) { + if (tiles[i] === tiles[i + 1]) { + tiles[i] *= 2; + this.score += tiles[i]; + tiles[i + 1] = 0; + merged.push(i); + } + } + tiles = tiles.filter(val => val !== 0); + while (tiles.length < 4) { + tiles.push(0); + } + for (let row = 0; row < 4; row++) { + this.board[row][col] = tiles[row]; + } + } + } + // Additional logic needed for \'down\', \'left\', \'right\' + // ... + this.addRandomTile(); + } + ``` +2. Integrate the `Storage` class methods to handle the best score. This means updating the `startGame` and `setBestScore` methods to use `Storage` for retrieving and setting the best score: + ```javascript + startGame() { + this.board = this.createEmptyBoard(); + this.score = 0; + this.bestScore = new Storage().getBestScore(); // Retrieve the best score from storage + this.addRandomTile(); + this.addRandomTile(); + } + + setBestScore(score) { + if (score > this.bestScore) { + this.bestScore = score; + new Storage().setBestScore(score); // Set the new best score in storage + } + } + ``` + +## Code Review Result +LBTM + +``` +""" + + +WRITE_CODE_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, REVIEW_RESULT, NEXT_STEPS]) +WRITE_MOVE_NODE = ActionNode.from_children("WRITE_MOVE_NODE", [WRITE_DRAFT, WRITE_FUNCTION]) + + +CR_FOR_MOVE_FUNCTION_BY_3 = """ +The move function implementation provided appears to be well-structured and follows a clear logic for moving and merging tiles in the specified direction. However, there are a few potential improvements that could be made to enhance the code: + +1. Encapsulation: The logic for moving and merging tiles could be encapsulated into smaller, reusable functions to improve readability and maintainability. + +2. Magic Numbers: There are some magic numbers (e.g., 4, 3) used in the loops that could be replaced with named constants for improved readability and easier maintenance. + +3. Comments: Adding comments to explain the logic and purpose of each section of the code can improve understanding for future developers who may need to work on or maintain the code. + +4. Error Handling: It's important to consider error handling for unexpected input or edge cases to ensure the function behaves as expected in all scenarios. + +Overall, the code could benefit from refactoring to improve readability, maintainability, and extensibility. If you would like, I can provide a refactored version of the move function that addresses these considerations. +""" + + +class WriteCodeAN(Action): + """Write a code review for the context.""" + + async def run(self, context): + self.llm.system_prompt = "You are an outstanding engineer and can implement any code" + return await WRITE_MOVE_NODE.fill(req=context, llm=self.llm, schema="json") + + +async def main(): + await WriteCodeAN().run(CODE_REVIEW_SMALLEST_CONTEXT) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/metagpt/actions/write_code_plan_and_change_an.py b/metagpt/actions/write_code_plan_and_change_an.py new file mode 100644 index 0000000000000000000000000000000000000000..989df52f22d427c140340a34466cb113e8b4f86c --- /dev/null +++ b/metagpt/actions/write_code_plan_and_change_an.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/26 +@Author : mannaandpoem +@File : write_code_plan_and_change_an.py +""" +from typing import List, Optional + +from pydantic import BaseModel, Field + +from metagpt.actions.action import Action +from metagpt.actions.action_node import ActionNode +from metagpt.logs import logger +from metagpt.schema import CodePlanAndChangeContext, Document +from metagpt.utils.common import get_markdown_code_block_type +from metagpt.utils.project_repo import ProjectRepo + +DEVELOPMENT_PLAN = ActionNode( + key="Development Plan", + expected_type=List[str], + instruction="Develop a comprehensive and step-by-step incremental development plan, providing the detail " + "changes to be implemented at each step based on the order of 'Task List'", + example=[ + "Enhance the functionality of `calculator.py` by extending it to incorporate methods for subtraction, ...", + "Update the existing codebase in main.py to incorporate new API endpoints for subtraction, ...", + ], +) + +INCREMENTAL_CHANGE = ActionNode( + key="Incremental Change", + expected_type=List[str], + instruction="Write Incremental Change by making a code draft that how to implement incremental development " + "including detailed steps based on the context. Note: Track incremental changes using the marks `+` and `-` to " + "indicate additions and deletions, and ensure compliance with the output format of `git diff`", + example=[ + '''```diff +--- Old/calculator.py ++++ New/calculator.py + +class Calculator: + self.result = number1 + number2 + return self.result + +- def sub(self, number1, number2) -> float: ++ def subtract(self, number1: float, number2: float) -> float: ++ """ ++ Subtracts the second number from the first and returns the result. ++ ++ Args: ++ number1 (float): The number to be subtracted from. ++ number2 (float): The number to subtract. ++ ++ Returns: ++ float: The difference of number1 and number2. ++ """ ++ self.result = number1 - number2 ++ return self.result ++ + def multiply(self, number1: float, number2: float) -> float: +- pass ++ """ ++ Multiplies two numbers and returns the result. ++ ++ Args: ++ number1 (float): The first number to multiply. ++ number2 (float): The second number to multiply. ++ ++ Returns: ++ float: The product of number1 and number2. ++ """ ++ self.result = number1 * number2 ++ return self.result ++ + def divide(self, number1: float, number2: float) -> float: +- pass ++ """ ++ ValueError: If the second number is zero. ++ """ ++ if number2 == 0: ++ raise ValueError('Cannot divide by zero') ++ self.result = number1 / number2 ++ return self.result ++ +- def reset_result(self): ++ def clear(self): ++ if self.result != 0.0: ++ print("Result is not zero, clearing...") ++ else: ++ print("Result is already zero, no need to clear.") ++ + self.result = 0.0 +```''', + """```diff +--- Old/main.py ++++ New/main.py + +def add_numbers(): + result = calculator.add_numbers(num1, num2) + return jsonify({'result': result}), 200 + +-# TODO: Implement subtraction, multiplication, and division operations ++@app.route('/subtract_numbers', methods=['POST']) ++def subtract_numbers(): ++ data = request.get_json() ++ num1 = data.get('num1', 0) ++ num2 = data.get('num2', 0) ++ result = calculator.subtract_numbers(num1, num2) ++ return jsonify({'result': result}), 200 ++ ++@app.route('/multiply_numbers', methods=['POST']) ++def multiply_numbers(): ++ data = request.get_json() ++ num1 = data.get('num1', 0) ++ num2 = data.get('num2', 0) ++ try: ++ result = calculator.divide_numbers(num1, num2) ++ except ValueError as e: ++ return jsonify({'error': str(e)}), 400 ++ return jsonify({'result': result}), 200 ++ + if __name__ == '__main__': + app.run() +```""", + ], +) + +CODE_PLAN_AND_CHANGE_CONTEXT = """ +## User New Requirements +{requirement} + +## Issue +{issue} + +## PRD +{prd} + +## Design +{design} + +## Task +{task} + +## Legacy Code +{code} +""" + +REFINED_TEMPLATE = """ +NOTICE +Role: You are a professional engineer; The main goal is to complete incremental development by combining legacy code and plan and Incremental Change, ensuring the integration of new features. + +# Context +## User New Requirements +{user_requirement} + +## Code Plan And Change +{code_plan_and_change} + +## Design +{design} + +## Task +{task} + +## Legacy Code +{code} + + +## Debug logs +```text +{logs} + +{summary_log} +``` + +## Bug Feedback logs +```text +{feedback} +``` + +# Format example +## Code: {demo_filename}.py +```python +## {demo_filename}.py +... +``` +## Code: {demo_filename}.js +```javascript +// {demo_filename}.js +... +``` + +# Instruction: Based on the context, follow "Format example", write or rewrite code. +## Write/Rewrite Code: Only write one file {filename}, write or rewrite complete code using triple quotes based on the following attentions and context. +1. Only One file: do your best to implement THIS ONLY ONE FILE. +2. COMPLETE CODE: Your code will be part of the entire project, so please implement complete, reliable, reusable code snippets. +3. Set default value: If there is any setting, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE. AVOID circular import. +4. Follow design: YOU MUST FOLLOW "Data structures and interfaces". DONT CHANGE ANY DESIGN. Do not use public member functions that do not exist in your design. +5. Follow Code Plan And Change: If there is any "Incremental Change" that is marked by the git diff format with '+' and '-' symbols, or Legacy Code files contain "{filename} to be rewritten", you must merge it into the code file according to the "Development Plan". +6. CAREFULLY CHECK THAT YOU DONT MISS ANY NECESSARY CLASS/FUNCTION IN THIS FILE. +7. Before using a external variable/module, make sure you import it first. +8. Write out EVERY CODE DETAIL, DON'T LEAVE TODO. +9. Attention: Retain details that are not related to incremental development but are important for maintaining the consistency and clarity of the old code. +""" + +CODE_PLAN_AND_CHANGE = [DEVELOPMENT_PLAN, INCREMENTAL_CHANGE] + +WRITE_CODE_PLAN_AND_CHANGE_NODE = ActionNode.from_children("WriteCodePlanAndChange", CODE_PLAN_AND_CHANGE) + + +class WriteCodePlanAndChange(Action): + name: str = "WriteCodePlanAndChange" + i_context: CodePlanAndChangeContext = Field(default_factory=CodePlanAndChangeContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + + async def run(self, *args, **kwargs): + self.llm.system_prompt = "You are a professional software engineer, your primary responsibility is to " + "meticulously craft comprehensive incremental development plan and deliver detailed incremental change" + prd_doc = await Document.load(filename=self.i_context.prd_filename) + design_doc = await Document.load(filename=self.i_context.design_filename) + task_doc = await Document.load(filename=self.i_context.task_filename) + context = CODE_PLAN_AND_CHANGE_CONTEXT.format( + requirement=f"```text\n{self.i_context.requirement}\n```", + issue=f"```text\n{self.i_context.issue}\n```", + prd=prd_doc.content, + design=design_doc.content, + task=task_doc.content, + code=await self.get_old_codes(), + ) + logger.info("Writing code plan and change..") + return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(req=context, llm=self.llm, schema="json") + + async def get_old_codes(self) -> str: + old_codes = await self.repo.srcs.get_all() + codes = [ + f"### File Name: `{code.filename}`\n```{get_markdown_code_block_type(code.filename)}\n{code.content}```\n" + for code in old_codes + ] + return "\n".join(codes) diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py new file mode 100644 index 0000000000000000000000000000000000000000..209e4e8ac40a97f296c38ae8402c43a828b13cb8 --- /dev/null +++ b/metagpt/actions/write_code_review.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:45 +@Author : alexanderwu +@File : write_code_review.py +@Modified By: mashenquan, 2023/11/27. Following the think-act principle, solidify the task parameters when creating the + WriteCode object, rather than passing them in when calling the run function. +""" +import asyncio +import os +from pathlib import Path +from typing import Optional + +from pydantic import BaseModel, Field +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from metagpt.actions import WriteCode +from metagpt.actions.action import Action +from metagpt.logs import logger +from metagpt.schema import CodingContext, Document +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import CodeParser, aread, awrite +from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.report import EditorReporter + +PROMPT_TEMPLATE = """ +# System +Role: You are a professional software engineer, and your main task is to review and revise the code. You need to ensure that the code conforms to the google-style standards, is elegantly designed and modularized, easy to read and maintain. +Language: Please use the same language as the user requirement, but the title and code should be still in English. For example, if the user speaks Chinese, the specific text of your answer should also be in Chinese. +ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenced "Format example". + +# Context +{context} + +----- + +## Code to be Reviewed: {filename} +```Code +{code} +``` +""" + +EXAMPLE_AND_INSTRUCTION = """ + +{format_example} + + +# Instruction: Based on the actual code, follow one of the "Code Review Format example". +- Note the code filename should be `{filename}`. Return the only ONE file `{filename}` under review. + +## Code Review: Ordered List. Based on the "Code to be Reviewed", provide key, clear, concise, and specific answer. If any answer is no, explain how to fix it step by step. +1. Is the code implemented as per the requirements? If not, how to achieve it? Analyse it step by step. +2. Is the code logic completely correct? If there are errors, please indicate how to correct them. +3. Does the existing code follow the "Data structures and interfaces"? +4. Are all functions implemented? If there is no implementation, please indicate how to achieve it step by step. +5. Have all necessary pre-dependencies been imported? If not, indicate which ones need to be imported +6. Are methods from other files being reused correctly? + +## Actions: Ordered List. Things that should be done after CR, such as implementing class A and function B + +## Code Review Result: str. If the code doesn't have bugs, we don't need to rewrite it, so answer LGTM and stop. ONLY ANSWER LGTM/LBTM. +LGTM/LBTM + +""" + +FORMAT_EXAMPLE = """ +----- + +# Code Review Format example 1 +## Code Review: {filename} +1. No, we should fix the logic of class A due to ... +2. ... +3. ... +4. No, function B is not implemented, ... +5. ... +6. ... + +## Actions +1. Fix the `handle_events` method to update the game state only if a move is successful. + ```python + def handle_events(self): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + return False + if event.type == pygame.KEYDOWN: + moved = False + if event.key == pygame.K_UP: + moved = self.game.move('UP') + elif event.key == pygame.K_DOWN: + moved = self.game.move('DOWN') + elif event.key == pygame.K_LEFT: + moved = self.game.move('LEFT') + elif event.key == pygame.K_RIGHT: + moved = self.game.move('RIGHT') + if moved: + # Update the game state only if a move was successful + self.render() + return True + ``` +2. Implement function B + +## Code Review Result +LBTM + +----- + +# Code Review Format example 2 +## Code Review: {filename} +1. Yes. +2. Yes. +3. Yes. +4. Yes. +5. Yes. +6. Yes. + +## Actions +pass + +## Code Review Result +LGTM + +----- +""" + +REWRITE_CODE_TEMPLATE = """ +# Instruction: rewrite the `{filename}` based on the Code Review and Actions +## Rewrite Code: CodeBlock. If it still has some bugs, rewrite {filename} using a Markdown code block, with the filename docstring preceding the code block. Do your utmost to optimize THIS SINGLE FILE. Return all completed codes and prohibit the return of unfinished codes. +```python +## {filename} +... +``` +or +```javascript +// {filename} +... +``` +""" + + +class WriteCodeReview(Action): + name: str = "WriteCodeReview" + i_context: CodingContext = Field(default_factory=CodingContext) + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + + @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) + async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, doc): + filename = doc.filename + cr_rsp = await self._aask(context_prompt + cr_prompt) + result = CodeParser.parse_block("Code Review Result", cr_rsp) + if "LGTM" in result: + return result, None + + # if LBTM, rewrite code + async with EditorReporter(enable_llm_stream=True) as reporter: + await reporter.async_report( + {"type": "code", "filename": filename, "src_path": doc.root_relative_path}, "meta" + ) + rewrite_prompt = f"{context_prompt}\n{cr_rsp}\n{REWRITE_CODE_TEMPLATE.format(filename=filename)}" + code_rsp = await self._aask(rewrite_prompt) + code = CodeParser.parse_code(text=code_rsp) + doc.content = code + await reporter.async_report(doc, "document") + return result, code + + async def run(self, *args, **kwargs) -> CodingContext: + iterative_code = self.i_context.code_doc.content + k = self.context.config.code_validate_k_times or 1 + + for i in range(k): + format_example = FORMAT_EXAMPLE.format(filename=self.i_context.code_doc.filename) + task_content = self.i_context.task_doc.content if self.i_context.task_doc else "" + code_context = await WriteCode.get_codes( + self.i_context.task_doc, + exclude=self.i_context.filename, + project_repo=self.repo, + use_inc=self.config.inc, + ) + + ctx_list = [ + "## System Design\n" + str(self.i_context.design_doc) + "\n", + "## Task\n" + task_content + "\n", + "## Code Files\n" + code_context + "\n", + ] + if self.config.inc: + requirement_doc = await Document.load(filename=self.input_args.requirements_filename) + insert_ctx_list = [ + "## User New Requirements\n" + str(requirement_doc) + "\n", + "## Code Plan And Change\n" + str(self.i_context.code_plan_and_change_doc) + "\n", + ] + ctx_list = insert_ctx_list + ctx_list + + context_prompt = PROMPT_TEMPLATE.format( + context="\n".join(ctx_list), + code=iterative_code, + filename=self.i_context.code_doc.filename, + ) + cr_prompt = EXAMPLE_AND_INSTRUCTION.format( + format_example=format_example, + filename=self.i_context.code_doc.filename, + ) + len1 = len(iterative_code) if iterative_code else 0 + len2 = len(self.i_context.code_doc.content) if self.i_context.code_doc.content else 0 + logger.info( + f"Code review and rewrite {self.i_context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, " + f"len(self.i_context.code_doc.content)={len2}" + ) + result, rewrited_code = await self.write_code_review_and_rewrite( + context_prompt, cr_prompt, self.i_context.code_doc + ) + if "LBTM" in result: + iterative_code = rewrited_code + elif "LGTM" in result: + self.i_context.code_doc.content = iterative_code + return self.i_context + # code_rsp = await self._aask_v1(prompt, "code_rsp", OUTPUT_MAPPING) + # self._save(context, filename, code) + # 如果rewrited_code是None(原code perfect),那么直接返回code + self.i_context.code_doc.content = iterative_code + return self.i_context + + +@register_tool(include_functions=["run"]) +class ValidateAndRewriteCode(Action): + """According to the design and task documents, validate the code to ensure it is complete and correct.""" + + name: str = "ValidateAndRewriteCode" + + async def run( + self, + code_path: str, + system_design_input: str = "", + project_schedule_input: str = "", + code_validate_k_times: int = 2, + ) -> str: + """Validates the provided code based on the accompanying system design and project schedule documentation, return the complete and correct code. + + Read the code from code_path, and write the final code to code_path. + If both system_design_input and project_schedule_input are absent, it will return and do nothing. + + Args: + code_path (str): The file path of the code snippet to be validated. This should be a string containing the path to the source code file. + system_design_input (str): Content or file path of the design document associated with the code. This should describe the system architecture, used in the code. It helps provide context for the validation process. + project_schedule_input (str): Content or file path of the task document describing what the code is intended to accomplish. This should outline the functional requirements or objectives of the code. + code_validate_k_times (int, optional): The number of iterations for validating and potentially rewriting the code. Defaults to 2. + + Returns: + str: The potentially corrected or approved code after validation. + + Example Usage: + # Example of how to call the run method with a code snippet and documentation + await ValidateAndRewriteCode().run( + code_path="/tmp/game.js", + system_design_input="/tmp/system_design.json", + project_schedule_input="/tmp/project_task_list.json" + ) + """ + if not system_design_input and not project_schedule_input: + logger.info( + "Both `system_design_input` and `project_schedule_input` are absent, ValidateAndRewriteCode will do nothing." + ) + return + + code, design_doc, task_doc = await asyncio.gather( + aread(code_path), self._try_aread(system_design_input), self._try_aread(project_schedule_input) + ) + code_doc = self._create_code_doc(code_path=code_path, code=code) + review_action = WriteCodeReview(i_context=CodingContext(filename=code_doc.filename)) + + context = "\n".join( + [ + "## System Design\n" + design_doc + "\n", + "## Task\n" + task_doc + "\n", + ] + ) + + for i in range(code_validate_k_times): + context_prompt = PROMPT_TEMPLATE.format(context=context, code=code, filename=code_path) + cr_prompt = EXAMPLE_AND_INSTRUCTION.format( + format_example=FORMAT_EXAMPLE.format(filename=code_path), + ) + logger.info(f"The {i+1}th time to CodeReview: {code_path}.") + result, rewrited_code = await review_action.write_code_review_and_rewrite( + context_prompt, cr_prompt, doc=code_doc + ) + + if "LBTM" in result: + code = rewrited_code + elif "LGTM" in result: + break + + await awrite(filename=code_path, data=code) + + return ( + f"The review and rewriting of the code in the file '{os.path.basename(code_path)}' has been completed." + + code + ) + + @staticmethod + async def _try_aread(input: str) -> str: + """Try to read from the path if it's a file; return input directly if not.""" + + if os.path.exists(input): + return await aread(input) + + return input + + @staticmethod + def _create_code_doc(code_path: str, code: str) -> Document: + """Create a Document to represent the code doc.""" + + path = Path(code_path) + + return Document(root_path=str(path.parent), filename=path.name, content=code) diff --git a/metagpt/actions/write_docstring.py b/metagpt/actions/write_docstring.py new file mode 100644 index 0000000000000000000000000000000000000000..5cc4cafb8198f54b8f1b5cec182bf1f8efb1df11 --- /dev/null +++ b/metagpt/actions/write_docstring.py @@ -0,0 +1,218 @@ +"""Code Docstring Generator. + +This script provides a tool to automatically generate docstrings for Python code. It uses the specified style to create +docstrings for the given code and system text. + +Usage: + python3 -m metagpt.actions.write_docstring [--overwrite] [--style=] + +Arguments: + filename The path to the Python file for which you want to generate docstrings. + +Options: + --overwrite If specified, overwrite the original file with the code containing docstrings. + --style= Specify the style of the generated docstrings. + Valid values: 'google', 'numpy', or 'sphinx'. + Default: 'google' + +Example: + python3 -m metagpt.actions.write_docstring ./metagpt/software_company.py --overwrite False --style=numpy + +This script uses the 'fire' library to create a command-line interface. It generates docstrings for the given Python code using +the specified docstring style and adds them to the code. +""" +from __future__ import annotations + +import ast +from pathlib import Path +from typing import Literal, Optional + +from metagpt.actions.action import Action +from metagpt.utils.common import OutputParser, aread, awrite +from metagpt.utils.pycst import merge_docstring + +PYTHON_DOCSTRING_SYSTEM = """### Requirements +1. Add docstrings to the given code following the {style} style. +2. Replace the function body with an Ellipsis object(...) to reduce output. +3. If the types are already annotated, there is no need to include them in the docstring. +4. Extract only class, function or the docstrings for the module parts from the given Python code, avoiding any other text. + +### Input Example +```python +def function_with_pep484_type_annotations(param1: int) -> bool: + return isinstance(param1, int) + +class ExampleError(Exception): + def __init__(self, msg: str): + self.msg = msg +``` + +### Output Example +```python +{example} +``` +""" + +# https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html + +PYTHON_DOCSTRING_EXAMPLE_GOOGLE = ''' +def function_with_pep484_type_annotations(param1: int) -> bool: + """Example function with PEP 484 type annotations. + + Extended description of function. + + Args: + param1: The first parameter. + + Returns: + The return value. True for success, False otherwise. + """ + ... + +class ExampleError(Exception): + """Exceptions are documented in the same way as classes. + + The __init__ method was documented in the class level docstring. + + Args: + msg: Human readable string describing the exception. + + Attributes: + msg: Human readable string describing the exception. + """ + ... +''' + +PYTHON_DOCSTRING_EXAMPLE_NUMPY = ''' +def function_with_pep484_type_annotations(param1: int) -> bool: + """ + Example function with PEP 484 type annotations. + + Extended description of function. + + Parameters + ---------- + param1 + The first parameter. + + Returns + ------- + bool + The return value. True for success, False otherwise. + """ + ... + +class ExampleError(Exception): + """ + Exceptions are documented in the same way as classes. + + The __init__ method was documented in the class level docstring. + + Parameters + ---------- + msg + Human readable string describing the exception. + + Attributes + ---------- + msg + Human readable string describing the exception. + """ + ... +''' + +PYTHON_DOCSTRING_EXAMPLE_SPHINX = ''' +def function_with_pep484_type_annotations(param1: int) -> bool: + """Example function with PEP 484 type annotations. + + Extended description of function. + + :param param1: The first parameter. + :type param1: int + + :return: The return value. True for success, False otherwise. + :rtype: bool + """ + ... + +class ExampleError(Exception): + """Exceptions are documented in the same way as classes. + + The __init__ method was documented in the class level docstring. + + :param msg: Human-readable string describing the exception. + :type msg: str + """ + ... +''' + +_python_docstring_style = { + "google": PYTHON_DOCSTRING_EXAMPLE_GOOGLE.strip(), + "numpy": PYTHON_DOCSTRING_EXAMPLE_NUMPY.strip(), + "sphinx": PYTHON_DOCSTRING_EXAMPLE_SPHINX.strip(), +} + + +class WriteDocstring(Action): + """This class is used to write docstrings for code. + + Attributes: + desc: A string describing the action. + """ + + desc: str = "Write docstring for code." + i_context: Optional[str] = None + + async def run( + self, + code: str, + system_text: str = PYTHON_DOCSTRING_SYSTEM, + style: Literal["google", "numpy", "sphinx"] = "google", + ) -> str: + """Writes docstrings for the given code and system text in the specified style. + + Args: + code: A string of Python code. + system_text: A string of system text. + style: A string specifying the style of the docstring. Can be 'google', 'numpy', or 'sphinx'. + + Returns: + The Python code with docstrings added. + """ + system_text = system_text.format(style=style, example=_python_docstring_style[style]) + simplified_code = _simplify_python_code(code) + documented_code = await self._aask(f"```python\n{simplified_code}\n```", [system_text]) + documented_code = OutputParser.parse_python_code(documented_code) + return merge_docstring(code, documented_code) + + @staticmethod + async def write_docstring( + filename: str | Path, overwrite: bool = False, style: Literal["google", "numpy", "sphinx"] = "google" + ) -> str: + data = await aread(str(filename)) + code = await WriteDocstring().run(data, style=style) + if overwrite: + await awrite(filename, code) + return code + + +def _simplify_python_code(code: str) -> None: + """Simplifies the given Python code by removing expressions and the last if statement. + + Args: + code: A string of Python code. + + Returns: + The simplified Python code. + """ + code_tree = ast.parse(code) + code_tree.body = [i for i in code_tree.body if not isinstance(i, ast.Expr)] + if isinstance(code_tree.body[-1], ast.If): + code_tree.body.pop() + return ast.unparse(code_tree) + + +if __name__ == "__main__": + import fire + + fire.Fire(WriteDocstring.write_docstring) diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py new file mode 100644 index 0000000000000000000000000000000000000000..7a04520d6e6107e38bf4bb30537cffa9049c2e4a --- /dev/null +++ b/metagpt/actions/write_prd.py @@ -0,0 +1,325 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:45 +@Author : alexanderwu +@File : write_prd.py +@Modified By: mashenquan, 2023/11/27. + 1. According to Section 2.2.3.1 of RFC 135, replace file data in the message with the file name. + 2. According to the design in Section 2.2.3.5.2 of RFC 135, add incremental iteration functionality. + 3. Move the document storage operations related to WritePRD from the save operation of WriteDesign. +@Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD. +@Modified By: mashenquan, 2024/5/31. Implement Chapter 3 of RFC 236. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import List, Optional, Union + +from pydantic import BaseModel, Field + +from metagpt.actions import Action, ActionOutput +from metagpt.actions.action_node import ActionNode +from metagpt.actions.fix_bug import FixBug +from metagpt.actions.write_prd_an import ( + COMPETITIVE_QUADRANT_CHART, + PROJECT_NAME, + REFINED_PRD_NODE, + WP_IS_RELATIVE_NODE, + WP_ISSUE_TYPE_NODE, + WRITE_PRD_NODE, +) +from metagpt.const import ( + BUGFIX_FILENAME, + COMPETITIVE_ANALYSIS_FILE_REPO, + REQUIREMENT_FILENAME, +) +from metagpt.logs import logger +from metagpt.schema import AIMessage, Document, Documents, Message +from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import ( + CodeParser, + aread, + awrite, + rectify_pathname, + save_json_to_markdown, + to_markdown_code_block, +) +from metagpt.utils.file_repository import FileRepository +from metagpt.utils.mermaid import mermaid_to_file +from metagpt.utils.project_repo import ProjectRepo +from metagpt.utils.report import DocsReporter, GalleryReporter + +CONTEXT_TEMPLATE = """ +### Project Name +{project_name} + +### Original Requirements +{requirements} + +### Search Information +- +""" + +NEW_REQ_TEMPLATE = """ +### Legacy Content +{old_prd} + +### New Requirements +{requirements} +""" + + +@register_tool(include_functions=["run"]) +class WritePRD(Action): + """WritePRD deal with the following situations: + 1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated. + 2. New requirement: If the requirement is a new requirement, the PRD document will be generated. + 3. Requirement update: If the requirement is an update, the PRD document will be updated. + """ + + repo: Optional[ProjectRepo] = Field(default=None, exclude=True) + input_args: Optional[BaseModel] = Field(default=None, exclude=True) + + async def run( + self, + with_messages: List[Message] = None, + *, + user_requirement: str = "", + output_pathname: str = "", + legacy_prd_filename: str = "", + extra_info: str = "", + **kwargs, + ) -> Union[AIMessage, str]: + """ + Write a Product Requirement Document. + + Args: + user_requirement (str): A string detailing the user's requirements. + output_pathname (str, optional): The output file path of the document. Defaults to "". + legacy_prd_filename (str, optional): The file path of the legacy Product Requirement Document to use as a reference. Defaults to "". + extra_info (str, optional): Additional information to include in the document. Defaults to "". + **kwargs: Additional keyword arguments. + + Returns: + str: The file path of the generated Product Requirement Document. + + Example: + # Write a new PRD (Product Requirement Document) + >>> user_requirement = "Write a snake game" + >>> output_pathname = "snake_game/docs/prd.json" + >>> extra_info = "YOUR EXTRA INFO, if any" + >>> write_prd = WritePRD() + >>> result = await write_prd.run(user_requirement=user_requirement, output_pathname=output_pathname, extra_info=extra_info) + >>> print(result) + PRD filename: "/absolute/path/to/snake_game/docs/prd.json" + + # Rewrite an existing PRD (Product Requirement Document) and save to a new path. + >>> user_requirement = "Write PRD for a snake game, include new features such as a web UI" + >>> legacy_prd_filename = "/absolute/path/to/snake_game/docs/prd.json" + >>> output_pathname = "/absolute/path/to/snake_game/docs/prd_new.json" + >>> extra_info = "YOUR EXTRA INFO, if any" + >>> write_prd = WritePRD() + >>> result = await write_prd.run(user_requirement=user_requirement, legacy_prd_filename=legacy_prd_filename, extra_info=extra_info) + >>> print(result) + PRD filename: "/absolute/path/to/snake_game/docs/prd_new.json" + """ + if not with_messages: + return await self._execute_api( + user_requirement=user_requirement, + output_pathname=output_pathname, + legacy_prd_filename=legacy_prd_filename, + extra_info=extra_info, + ) + + self.input_args = with_messages[-1].instruct_content + if not self.input_args: + self.repo = ProjectRepo(self.context.kwargs.project_path) + await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[-1].content) + self.input_args = AIMessage.create_instruct_value( + kvs={ + "project_path": self.context.kwargs.project_path, + "requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME), + "prd_filenames": [str(self.repo.docs.prd.workdir / i) for i in self.repo.docs.prd.all_files], + }, + class_name="PrepareDocumentsOutput", + ) + else: + self.repo = ProjectRepo(self.input_args.project_path) + req = await Document.load(filename=self.input_args.requirements_filename) + docs: list[Document] = [ + await Document.load(filename=i, project_path=self.repo.workdir) for i in self.input_args.prd_filenames + ] + + if not req: + raise FileNotFoundError("No requirement document found.") + + if await self._is_bugfix(req.content): + logger.info(f"Bugfix detected: {req.content}") + return await self._handle_bugfix(req) + # remove bugfix file from last round in case of conflict + await self.repo.docs.delete(filename=BUGFIX_FILENAME) + + # if requirement is related to other documents, update them, otherwise create a new one + if related_docs := await self.get_related_docs(req, docs): + logger.info(f"Requirement update detected: {req.content}") + await self._handle_requirement_update(req=req, related_docs=related_docs) + else: + logger.info(f"New requirement detected: {req.content}") + await self._handle_new_requirement(req) + + kvs = self.input_args.model_dump() + kvs["changed_prd_filenames"] = [ + str(self.repo.docs.prd.workdir / i) for i in list(self.repo.docs.prd.changed_files.keys()) + ] + kvs["project_path"] = str(self.repo.workdir) + kvs["requirements_filename"] = str(self.repo.docs.workdir / REQUIREMENT_FILENAME) + self.context.kwargs.project_path = str(self.repo.workdir) + return AIMessage( + content="PRD is completed. " + + "\n".join( + list(self.repo.docs.prd.changed_files.keys()) + + list(self.repo.resources.prd.changed_files.keys()) + + list(self.repo.resources.competitive_analysis.changed_files.keys()) + ), + instruct_content=AIMessage.create_instruct_value(kvs=kvs, class_name="WritePRDOutput"), + cause_by=self, + ) + + async def _handle_bugfix(self, req: Document) -> AIMessage: + # ... bugfix logic ... + await self.repo.docs.save(filename=BUGFIX_FILENAME, content=req.content) + await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content="") + return AIMessage( + content=f"A new issue is received: {BUGFIX_FILENAME}", + cause_by=FixBug, + instruct_content=AIMessage.create_instruct_value( + { + "project_path": str(self.repo.workdir), + "issue_filename": str(self.repo.docs.workdir / BUGFIX_FILENAME), + "requirements_filename": str(self.repo.docs.workdir / REQUIREMENT_FILENAME), + }, + class_name="IssueDetail", + ), + send_to="Alex", # the name of Engineer + ) + + async def _new_prd(self, requirement: str) -> ActionNode: + project_name = self.project_name + context = CONTEXT_TEMPLATE.format(requirements=requirement, project_name=project_name) + exclude = [PROJECT_NAME.key] if project_name else [] + node = await WRITE_PRD_NODE.fill( + req=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema + ) # schema=schema + return node + + async def _handle_new_requirement(self, req: Document) -> ActionOutput: + """handle new requirement""" + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "prd"}, "meta") + node = await self._new_prd(req.content) + await self._rename_workspace(node) + new_prd_doc = await self.repo.docs.prd.save( + filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json() + ) + await self._save_competitive_analysis(new_prd_doc) + md = await self.repo.resources.prd.save_pdf(doc=new_prd_doc) + await reporter.async_report(self.repo.workdir / md.root_relative_path, "path") + return Documents.from_iterable(documents=[new_prd_doc]).to_action_output() + + async def _handle_requirement_update(self, req: Document, related_docs: list[Document]) -> ActionOutput: + # ... requirement update logic ... + for doc in related_docs: + await self._update_prd(req=req, prd_doc=doc) + return Documents.from_iterable(documents=related_docs).to_action_output() + + async def _is_bugfix(self, context: str) -> bool: + if not self.repo.code_files_exists(): + return False + node = await WP_ISSUE_TYPE_NODE.fill(req=context, llm=self.llm) + return node.get("issue_type") == "BUG" + + async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]: + """get the related documents""" + # refine: use gather to speed up + return [i for i in docs if await self._is_related(req, i)] + + async def _is_related(self, req: Document, old_prd: Document) -> bool: + context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content) + node = await WP_IS_RELATIVE_NODE.fill(req=context, llm=self.llm) + return node.get("is_relative") == "YES" + + async def _merge(self, req: Document, related_doc: Document) -> Document: + if not self.project_name: + self.project_name = Path(self.project_path).name + prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content) + node = await REFINED_PRD_NODE.fill(req=prompt, llm=self.llm, schema=self.prompt_schema) + related_doc.content = node.instruct_content.model_dump_json() + await self._rename_workspace(node) + return related_doc + + async def _update_prd(self, req: Document, prd_doc: Document) -> Document: + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "prd"}, "meta") + new_prd_doc: Document = await self._merge(req=req, related_doc=prd_doc) + await self.repo.docs.prd.save_doc(doc=new_prd_doc) + await self._save_competitive_analysis(new_prd_doc) + md = await self.repo.resources.prd.save_pdf(doc=new_prd_doc) + await reporter.async_report(self.repo.workdir / md.root_relative_path, "path") + return new_prd_doc + + async def _save_competitive_analysis(self, prd_doc: Document, output_filename: Path = None): + m = json.loads(prd_doc.content) + quadrant_chart = m.get(COMPETITIVE_QUADRANT_CHART.key) + if not quadrant_chart: + return + pathname = output_filename or self.repo.workdir / COMPETITIVE_ANALYSIS_FILE_REPO / Path(prd_doc.filename).stem + pathname.parent.mkdir(parents=True, exist_ok=True) + await mermaid_to_file(self.config.mermaid.engine, quadrant_chart, pathname) + image_path = pathname.parent / f"{pathname.name}.svg" + if image_path.exists(): + await GalleryReporter().async_report(image_path, "path") + + async def _rename_workspace(self, prd): + if not self.project_name: + if isinstance(prd, (ActionOutput, ActionNode)): + ws_name = prd.instruct_content.model_dump()["Project Name"] + else: + ws_name = CodeParser.parse_str(block="Project Name", text=prd) + if ws_name: + self.project_name = ws_name + if self.repo: + self.repo.git_repo.rename_root(self.project_name) + + async def _execute_api( + self, user_requirement: str, output_pathname: str, legacy_prd_filename: str, extra_info: str + ) -> str: + content = "#### User Requirements\n{user_requirement}\n#### Extra Info\n{extra_info}\n".format( + user_requirement=to_markdown_code_block(val=user_requirement), + extra_info=to_markdown_code_block(val=extra_info), + ) + async with DocsReporter(enable_llm_stream=True) as reporter: + await reporter.async_report({"type": "prd"}, "meta") + req = Document(content=content) + if not legacy_prd_filename: + node = await self._new_prd(requirement=req.content) + new_prd = Document(content=node.instruct_content.model_dump_json()) + else: + content = await aread(filename=legacy_prd_filename) + old_prd = Document(content=content) + new_prd = await self._merge(req=req, related_doc=old_prd) + + if not output_pathname: + output_pathname = self.config.workspace.path / "docs" / "prd.json" + elif not Path(output_pathname).is_absolute(): + output_pathname = self.config.workspace.path / output_pathname + output_pathname = rectify_pathname(path=output_pathname, default_filename="prd.json") + await awrite(filename=output_pathname, data=new_prd.content) + competitive_analysis_filename = output_pathname.parent / f"{output_pathname.stem}-competitive-analysis" + await self._save_competitive_analysis(prd_doc=new_prd, output_filename=Path(competitive_analysis_filename)) + md_output_filename = output_pathname.with_suffix(".md") + await save_json_to_markdown(content=new_prd.content, output_filename=md_output_filename) + await reporter.async_report(md_output_filename, "path") + return f'PRD filename: "{str(output_pathname)}". The product requirement document (PRD) has been completed.' diff --git a/metagpt/actions/write_prd_an.py b/metagpt/actions/write_prd_an.py new file mode 100644 index 0000000000000000000000000000000000000000..81e16bcfa377219b2f23a353ae80517d20dc6044 --- /dev/null +++ b/metagpt/actions/write_prd_an.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/14 11:40 +@Author : alexanderwu +@File : write_prd_an.py +""" +from typing import List, Union + +from metagpt.actions.action_node import ActionNode + +LANGUAGE = ActionNode( + key="Language", + expected_type=str, + instruction="Provide the language used in the project, typically matching the user's requirement language.", + example="en_us", +) + +PROGRAMMING_LANGUAGE = ActionNode( + key="Programming Language", + expected_type=str, + instruction="Mainstream programming language. If not specified in the requirements, use Vite, React, MUI, Tailwind CSS.", + example="Vite, React, MUI, Tailwind CSS", +) + +ORIGINAL_REQUIREMENTS = ActionNode( + key="Original Requirements", + expected_type=str, + instruction="Place the original user's requirements here.", + example="Create a 2048 game", +) + +REFINED_REQUIREMENTS = ActionNode( + key="Refined Requirements", + expected_type=str, + instruction="Place the New user's original requirements here.", + example="Create a 2048 game with a new feature that ...", +) + +PROJECT_NAME = ActionNode( + key="Project Name", + expected_type=str, + instruction='According to the content of "Original Requirements," name the project using snake case style , ' + "like 'game_2048' or 'simple_crm.", + example="game_2048", +) + +PRODUCT_GOALS = ActionNode( + key="Product Goals", + expected_type=List[str], + instruction="Provide up to three clear, orthogonal product goals.", + example=["Create an engaging user experience", "Improve accessibility, be responsive", "More beautiful UI"], +) + +REFINED_PRODUCT_GOALS = ActionNode( + key="Refined Product Goals", + expected_type=List[str], + instruction="Update and expand the original product goals to reflect the evolving needs due to incremental " + "development. Ensure that the refined goals align with the current project direction and contribute to its success.", + example=[ + "Enhance user engagement through new features", + "Optimize performance for scalability", + "Integrate innovative UI enhancements", + ], +) + +USER_STORIES = ActionNode( + key="User Stories", + expected_type=List[str], + instruction="Provide up to 3 to 5 scenario-based user stories.", + example=[ + "As a player, I want to be able to choose difficulty levels", + "As a player, I want to see my score after each game", + "As a player, I want to get restart button when I lose", + "As a player, I want to see beautiful UI that make me feel good", + "As a player, I want to play game via mobile phone", + ], +) + +REFINED_USER_STORIES = ActionNode( + key="Refined User Stories", + expected_type=List[str], + instruction="Update and expand the original scenario-based user stories to reflect the evolving needs due to " + "incremental development. Ensure that the refined user stories capture incremental features and improvements. ", + example=[ + "As a player, I want to choose difficulty levels to challenge my skills", + "As a player, I want a visually appealing score display after each game for a better gaming experience", + "As a player, I want a convenient restart button displayed when I lose to quickly start a new game", + "As a player, I want an enhanced and aesthetically pleasing UI to elevate the overall gaming experience", + "As a player, I want the ability to play the game seamlessly on my mobile phone for on-the-go entertainment", + ], +) + +COMPETITIVE_ANALYSIS = ActionNode( + key="Competitive Analysis", + expected_type=List[str], + instruction="Provide 5 to 7 competitive products.", + example=[ + "2048 Game A: Simple interface, lacks responsive features", + "play2048.co: Beautiful and responsive UI with my best score shown", + "2048game.com: Responsive UI with my best score shown, but many ads", + ], +) + +COMPETITIVE_QUADRANT_CHART = ActionNode( + key="Competitive Quadrant Chart", + expected_type=str, + instruction="Use mermaid quadrantChart syntax. Distribute scores evenly between 0 and 1", + example="""quadrantChart + title "Reach and engagement of campaigns" + x-axis "Low Reach" --> "High Reach" + y-axis "Low Engagement" --> "High Engagement" + quadrant-1 "We should expand" + quadrant-2 "Need to promote" + quadrant-3 "Re-evaluate" + quadrant-4 "May be improved" + "Campaign A": [0.3, 0.6] + "Campaign B": [0.45, 0.23] + "Campaign C": [0.57, 0.69] + "Campaign D": [0.78, 0.34] + "Campaign E": [0.40, 0.34] + "Campaign F": [0.35, 0.78] + "Our Target Product": [0.5, 0.6]""", +) + +REQUIREMENT_ANALYSIS = ActionNode( + key="Requirement Analysis", + expected_type=str, + instruction="Provide a detailed analysis of the requirements.", + example="", +) + +REFINED_REQUIREMENT_ANALYSIS = ActionNode( + key="Refined Requirement Analysis", + expected_type=Union[List[str], str], + instruction="Review and refine the existing requirement analysis into a string list to align with the evolving needs of the project " + "due to incremental development. Ensure the analysis comprehensively covers the new features and enhancements " + "required for the refined project scope.", + example=["Require add ...", "Require modify ..."], +) + +REQUIREMENT_POOL = ActionNode( + key="Requirement Pool", + expected_type=List[List[str]], + instruction="List down the top-5 requirements with their priority (P0, P1, P2).", + example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]], +) + +REFINED_REQUIREMENT_POOL = ActionNode( + key="Refined Requirement Pool", + expected_type=List[List[str]], + instruction="List down the top 5 to 7 requirements with their priority (P0, P1, P2). " + "Cover both legacy content and incremental content. Retain content unrelated to incremental development", + example=[["P0", "The main code ..."], ["P0", "The game algorithm ..."]], +) + +UI_DESIGN_DRAFT = ActionNode( + key="UI Design draft", + expected_type=str, + instruction="Provide a simple description of UI elements, functions, style, and layout.", + example="Basic function description with a simple style and layout.", +) + +ANYTHING_UNCLEAR = ActionNode( + key="Anything UNCLEAR", + expected_type=str, + instruction="Mention any aspects of the project that are unclear and try to clarify them.", + example="Currently, all aspects of the project are clear.", +) + +ISSUE_TYPE = ActionNode( + key="issue_type", + expected_type=str, + instruction="Answer BUG/REQUIREMENT. If it is a bugfix, answer BUG, otherwise answer Requirement", + example="BUG", +) + +IS_RELATIVE = ActionNode( + key="is_relative", + expected_type=str, + instruction="Answer YES/NO. If the requirement is related to the old PRD, answer YES, otherwise NO", + example="YES", +) + +REASON = ActionNode( + key="reason", expected_type=str, instruction="Explain the reasoning process from question to answer", example="..." +) + + +NODES = [ + LANGUAGE, + PROGRAMMING_LANGUAGE, + ORIGINAL_REQUIREMENTS, + PROJECT_NAME, + PRODUCT_GOALS, + USER_STORIES, + COMPETITIVE_ANALYSIS, + COMPETITIVE_QUADRANT_CHART, + REQUIREMENT_ANALYSIS, + REQUIREMENT_POOL, + UI_DESIGN_DRAFT, + ANYTHING_UNCLEAR, +] + +REFINED_NODES = [ + LANGUAGE, + PROGRAMMING_LANGUAGE, + REFINED_REQUIREMENTS, + PROJECT_NAME, + REFINED_PRODUCT_GOALS, + REFINED_USER_STORIES, + COMPETITIVE_ANALYSIS, + COMPETITIVE_QUADRANT_CHART, + REFINED_REQUIREMENT_ANALYSIS, + REFINED_REQUIREMENT_POOL, + UI_DESIGN_DRAFT, + ANYTHING_UNCLEAR, +] + +WRITE_PRD_NODE = ActionNode.from_children("WritePRD", NODES) +REFINED_PRD_NODE = ActionNode.from_children("RefinedPRD", REFINED_NODES) +WP_ISSUE_TYPE_NODE = ActionNode.from_children("WP_ISSUE_TYPE", [ISSUE_TYPE, REASON]) +WP_IS_RELATIVE_NODE = ActionNode.from_children("WP_IS_RELATIVE", [IS_RELATIVE, REASON]) diff --git a/metagpt/actions/write_prd_review.py b/metagpt/actions/write_prd_review.py new file mode 100644 index 0000000000000000000000000000000000000000..68fb5d9e8da961d205378610356b4423203a7bd6 --- /dev/null +++ b/metagpt/actions/write_prd_review.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 17:45 +@Author : alexanderwu +@File : write_prd_review.py +""" + +from typing import Optional + +from metagpt.actions.action import Action + + +class WritePRDReview(Action): + name: str = "" + i_context: Optional[str] = None + + prd: Optional[str] = None + desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" + prd_review_prompt_template: str = """ +Given the following Product Requirement Document (PRD): +{prd} + +As a project manager, please review it and provide your feedback and suggestions. +""" + + async def run(self, prd): + self.prd = prd + prompt = self.prd_review_prompt_template.format(prd=self.prd) + review = await self._aask(prompt) + return review diff --git a/metagpt/actions/write_review.py b/metagpt/actions/write_review.py new file mode 100644 index 0000000000000000000000000000000000000000..907a1e990107f3f573386f495997b513bb83f951 --- /dev/null +++ b/metagpt/actions/write_review.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Author : alexanderwu +@File : write_review.py +""" +from typing import List + +from metagpt.actions import Action +from metagpt.actions.action_node import ActionNode + +REVIEW = ActionNode( + key="Review", + expected_type=List[str], + instruction="Act as an experienced Reviewer and review the given output. Ask a series of critical questions, " + "concisely and clearly, to help the writer improve their work.", + example=[ + "This is a good PRD, but I think it can be improved by adding more details.", + ], +) + +LGTM = ActionNode( + key="LGTM", + expected_type=str, + instruction="LGTM/LBTM. If the output is good enough, give a LGTM (Looks Good To Me) to the writer, " + "else LBTM (Looks Bad To Me).", + example="LGTM", +) + +WRITE_REVIEW_NODE = ActionNode.from_children("WRITE_REVIEW_NODE", [REVIEW, LGTM]) + + +class WriteReview(Action): + """Write a review for the given context.""" + + name: str = "WriteReview" + + async def run(self, context): + return await WRITE_REVIEW_NODE.fill(req=context, llm=self.llm, schema="json") diff --git a/metagpt/actions/write_teaching_plan.py b/metagpt/actions/write_teaching_plan.py new file mode 100644 index 0000000000000000000000000000000000000000..c5f70ae05b7972adff24a22b8e231dc88c65fb14 --- /dev/null +++ b/metagpt/actions/write_teaching_plan.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/7/27 +@Author : mashenquan +@File : write_teaching_plan.py +""" +from typing import Optional + +from metagpt.actions import Action +from metagpt.context import Context +from metagpt.logs import logger + + +class WriteTeachingPlanPart(Action): + """Write Teaching Plan Part""" + + i_context: Optional[str] = None + topic: str = "" + language: str = "Chinese" + rsp: Optional[str] = None + + async def run(self, with_message=None, **kwargs): + statement_patterns = TeachingPlanBlock.TOPIC_STATEMENTS.get(self.topic, []) + statements = [] + for p in statement_patterns: + s = self.format_value(p, context=self.context) + statements.append(s) + formatter = ( + TeachingPlanBlock.PROMPT_TITLE_TEMPLATE + if self.topic == TeachingPlanBlock.COURSE_TITLE + else TeachingPlanBlock.PROMPT_TEMPLATE + ) + prompt = formatter.format( + formation=TeachingPlanBlock.FORMATION, + role=self.prefix, + statements="\n".join(statements), + lesson=self.i_context, + topic=self.topic, + language=self.language, + ) + + logger.debug(prompt) + rsp = await self._aask(prompt=prompt) + logger.debug(rsp) + self._set_result(rsp) + return self.rsp + + def _set_result(self, rsp): + if TeachingPlanBlock.DATA_BEGIN_TAG in rsp: + ix = rsp.index(TeachingPlanBlock.DATA_BEGIN_TAG) + rsp = rsp[ix + len(TeachingPlanBlock.DATA_BEGIN_TAG) :] + if TeachingPlanBlock.DATA_END_TAG in rsp: + ix = rsp.index(TeachingPlanBlock.DATA_END_TAG) + rsp = rsp[0:ix] + self.rsp = rsp.strip() + if self.topic != TeachingPlanBlock.COURSE_TITLE: + return + if "#" not in self.rsp or self.rsp.index("#") != 0: + self.rsp = "# " + self.rsp + + def __str__(self): + """Return `topic` value when str()""" + return self.topic + + def __repr__(self): + """Show `topic` value when debug""" + return self.topic + + @staticmethod + def format_value(value, context: Context): + """Fill parameters inside `value` with `options`.""" + if not isinstance(value, str): + return value + if "{" not in value: + return value + + options = context.config.model_dump() + for k, v in context.kwargs: + options[k] = v # None value is allowed to override and disable the value from config. + opts = {k: v for k, v in options.items() if v is not None} + try: + return value.format(**opts) + except KeyError as e: + logger.warning(f"Parameter is missing:{e}") + + for k, v in opts.items(): + value = value.replace("{" + f"{k}" + "}", str(v)) + return value + + +class TeachingPlanBlock: + FORMATION = ( + '"Capacity and role" defines the role you are currently playing;\n' + '\t"[LESSON_BEGIN]" and "[LESSON_END]" tags enclose the content of textbook;\n' + '\t"Statement" defines the work detail you need to complete at this stage;\n' + '\t"Answer options" defines the format requirements for your responses;\n' + '\t"Constraint" defines the conditions that your responses must comply with.' + ) + + COURSE_TITLE = "Title" + TOPICS = [ + COURSE_TITLE, + "Teaching Hours", + "Teaching Objectives", + "Teaching Content", + "Teaching Methods and Strategies", + "Learning Activities", + "Teaching Time Allocation", + "Assessment and Feedback", + "Teaching Summary and Improvement", + "Vocabulary Cloze", + "Choice Questions", + "Grammar Questions", + "Translation Questions", + ] + + TOPIC_STATEMENTS = { + COURSE_TITLE: [ + "Statement: Find and return the title of the lesson only in markdown first-level header format, " + "without anything else." + ], + "Teaching Content": [ + 'Statement: "Teaching Content" must include vocabulary, analysis, and examples of various grammar ' + "structures that appear in the textbook, as well as the listening materials and key points.", + 'Statement: "Teaching Content" must include more examples.', + ], + "Teaching Time Allocation": [ + 'Statement: "Teaching Time Allocation" must include how much time is allocated to each ' + "part of the textbook content." + ], + "Teaching Methods and Strategies": [ + 'Statement: "Teaching Methods and Strategies" must include teaching focus, difficulties, materials, ' + "procedures, in detail." + ], + "Vocabulary Cloze": [ + 'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", ' + "create vocabulary cloze. The cloze should include 10 {language} questions with {teaching_language} " + "answers, and it should also include 10 {teaching_language} questions with {language} answers. " + "The key-related vocabulary and phrases in the textbook content must all be included in the exercises.", + ], + "Grammar Questions": [ + 'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", ' + "create grammar questions. 10 questions." + ], + "Choice Questions": [ + 'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", ' + "create choice questions. 10 questions." + ], + "Translation Questions": [ + 'Statement: Based on the content of the textbook enclosed by "[LESSON_BEGIN]" and "[LESSON_END]", ' + "create translation questions. The translation should include 10 {language} questions with " + "{teaching_language} answers, and it should also include 10 {teaching_language} questions with " + "{language} answers." + ], + } + + # Teaching plan title + PROMPT_TITLE_TEMPLATE = ( + "Do not refer to the context of the previous conversation records, " + "start the conversation anew.\n\n" + "Formation: {formation}\n\n" + "{statements}\n" + "Constraint: Writing in {language}.\n" + 'Answer options: Encloses the lesson title with "[TEACHING_PLAN_BEGIN]" ' + 'and "[TEACHING_PLAN_END]" tags.\n' + "[LESSON_BEGIN]\n" + "{lesson}\n" + "[LESSON_END]" + ) + + # Teaching plan parts: + PROMPT_TEMPLATE = ( + "Do not refer to the context of the previous conversation records, " + "start the conversation anew.\n\n" + "Formation: {formation}\n\n" + "Capacity and role: {role}\n" + 'Statement: Write the "{topic}" part of teaching plan, ' + 'WITHOUT ANY content unrelated to "{topic}"!!\n' + "{statements}\n" + 'Answer options: Enclose the teaching plan content with "[TEACHING_PLAN_BEGIN]" ' + 'and "[TEACHING_PLAN_END]" tags.\n' + "Answer options: Using proper markdown format from second-level header format.\n" + "Constraint: Writing in {language}.\n" + "[LESSON_BEGIN]\n" + "{lesson}\n" + "[LESSON_END]" + ) + + DATA_BEGIN_TAG = "[TEACHING_PLAN_BEGIN]" + DATA_END_TAG = "[TEACHING_PLAN_END]" diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py new file mode 100644 index 0000000000000000000000000000000000000000..286d3ea1351d3e6e3d674fffab063e2884c1005a --- /dev/null +++ b/metagpt/actions/write_test.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/11 22:12 +@Author : alexanderwu +@File : write_test.py +@Modified By: mashenquan, 2023-11-27. Following the think-act principle, solidify the task parameters when creating the + WriteTest object, rather than passing them in when calling the run function. +""" + +from typing import Optional + +from metagpt.actions.action import Action +from metagpt.const import TEST_CODES_FILE_REPO +from metagpt.logs import logger +from metagpt.schema import Document, TestingContext +from metagpt.utils.common import CodeParser + +PROMPT_TEMPLATE = """ +NOTICE +1. Role: You are a QA engineer; the main goal is to design, develop, and execute PEP8 compliant, well-structured, maintainable test cases and scripts for Python 3.9. Your focus should be on ensuring the product quality of the entire project through systematic testing. +2. Requirement: Based on the context, develop a comprehensive test suite that adequately covers all relevant aspects of the code file under review. Your test suite will be part of the overall project QA, so please develop complete, robust, and reusable test cases. +3. Attention1: Use '##' to split sections, not '#', and '## ' SHOULD WRITE BEFORE the test case or script. +4. Attention2: If there are any settings in your tests, ALWAYS SET A DEFAULT VALUE, ALWAYS USE STRONG TYPE AND EXPLICIT VARIABLE. +5. Attention3: YOU MUST FOLLOW "Data structures and interfaces". DO NOT CHANGE ANY DESIGN. Make sure your tests respect the existing design and ensure its validity. +6. Think before writing: What should be tested and validated in this document? What edge cases could exist? What might fail? +7. CAREFULLY CHECK THAT YOU DON'T MISS ANY NECESSARY TEST CASES/SCRIPTS IN THIS FILE. +Attention: Use '##' to split sections, not '#', and '## ' SHOULD WRITE BEFORE the test case or script and triple quotes. +----- +## Given the following code, please write appropriate test cases using Python's unittest framework to verify the correctness and robustness of this code: +```python +{code_to_test} +``` +Note that the code to test is at {source_file_path}, we will put your test code at {workspace}/tests/{test_file_name}, and run your test code from {workspace}, +you should correctly import the necessary classes based on these file locations! +## {test_file_name}: Write test code with triple quote. Do your best to implement THIS ONLY ONE FILE. +""" + + +class WriteTest(Action): + name: str = "WriteTest" + i_context: Optional[TestingContext] = None + + async def write_code(self, prompt): + code_rsp = await self._aask(prompt) + + try: + code = CodeParser.parse_code(text=code_rsp) + except Exception: + # Handle the exception if needed + logger.error(f"Can't parse the code: {code_rsp}") + + # Return code_rsp in case of an exception, assuming llm just returns code as it is and doesn't wrap it inside ``` + code = code_rsp + return code + + async def run(self, *args, **kwargs) -> TestingContext: + if not self.i_context.test_doc: + self.i_context.test_doc = Document( + filename="test_" + self.i_context.code_doc.filename, root_path=TEST_CODES_FILE_REPO + ) + fake_root = "/data" + prompt = PROMPT_TEMPLATE.format( + code_to_test=self.i_context.code_doc.content, + test_file_name=self.i_context.test_doc.filename, + source_file_path=fake_root + "/" + self.i_context.code_doc.root_relative_path, + workspace=fake_root, + ) + self.i_context.test_doc.content = await self.write_code(prompt) + return self.i_context diff --git a/metagpt/actions/write_tutorial.py b/metagpt/actions/write_tutorial.py new file mode 100644 index 0000000000000000000000000000000000000000..184cd8573f9d5c3426c2302ddb34f85ade4ad85b --- /dev/null +++ b/metagpt/actions/write_tutorial.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# _*_ coding: utf-8 _*_ +""" +@Time : 2023/9/4 15:40:40 +@Author : Stitch-z +@File : tutorial_assistant.py +@Describe : Actions of the tutorial assistant, including writing directories and document content. +""" + +from typing import Dict + +from metagpt.actions import Action +from metagpt.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT +from metagpt.utils.common import OutputParser + + +class WriteDirectory(Action): + """Action class for writing tutorial directories. + + Args: + name: The name of the action. + language: The language to output, default is "Chinese". + """ + + name: str = "WriteDirectory" + language: str = "Chinese" + + async def run(self, topic: str, *args, **kwargs) -> Dict: + """Execute the action to generate a tutorial directory according to the topic. + + Args: + topic: The tutorial topic. + + Returns: + the tutorial directory information, including {"title": "xxx", "directory": [{"dir 1": ["sub dir 1", "sub dir 2"]}]}. + """ + prompt = DIRECTORY_PROMPT.format(topic=topic, language=self.language) + resp = await self._aask(prompt=prompt) + return OutputParser.extract_struct(resp, dict) + + +class WriteContent(Action): + """Action class for writing tutorial content. + + Args: + name: The name of the action. + directory: The content to write. + language: The language to output, default is "Chinese". + """ + + name: str = "WriteContent" + directory: dict = dict() + language: str = "Chinese" + + async def run(self, topic: str, *args, **kwargs) -> str: + """Execute the action to write document content according to the directory and topic. + + Args: + topic: The tutorial topic. + + Returns: + The written tutorial content. + """ + prompt = CONTENT_PROMPT.format(topic=topic, language=self.language, directory=self.directory) + return await self._aask(prompt=prompt) diff --git a/metagpt/base/__init__.py b/metagpt/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2fbe8eaffd5c06032e04e80535d5ce6c83699d4 --- /dev/null +++ b/metagpt/base/__init__.py @@ -0,0 +1,8 @@ +from metagpt.base.base_env import BaseEnvironment +from metagpt.base.base_role import BaseRole + + +__all__ = [ + "BaseEnvironment", + "BaseRole", +] diff --git a/metagpt/base/__pycache__/__init__.cpython-310.pyc b/metagpt/base/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0e899fea937a946460d81347d96c872217cdf01 Binary files /dev/null and b/metagpt/base/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/base/__pycache__/__init__.cpython-39.pyc b/metagpt/base/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96e06598365b613598a078499a6720769a8c16ae Binary files /dev/null and b/metagpt/base/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/base/__pycache__/base_env.cpython-310.pyc b/metagpt/base/__pycache__/base_env.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..148699f6fcf706d4a50e93f278bd11b4b8217222 Binary files /dev/null and b/metagpt/base/__pycache__/base_env.cpython-310.pyc differ diff --git a/metagpt/base/__pycache__/base_env.cpython-39.pyc b/metagpt/base/__pycache__/base_env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdbb4c9cc186b8cb5b04e666be2bba2902a94e6a Binary files /dev/null and b/metagpt/base/__pycache__/base_env.cpython-39.pyc differ diff --git a/metagpt/base/__pycache__/base_env_space.cpython-310.pyc b/metagpt/base/__pycache__/base_env_space.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7d8600892cc8d0605f7edcea4bb35401da575fd Binary files /dev/null and b/metagpt/base/__pycache__/base_env_space.cpython-310.pyc differ diff --git a/metagpt/base/__pycache__/base_env_space.cpython-39.pyc b/metagpt/base/__pycache__/base_env_space.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3307c62fe3051184cfdc1f504fbec684bd82b38 Binary files /dev/null and b/metagpt/base/__pycache__/base_env_space.cpython-39.pyc differ diff --git a/metagpt/base/__pycache__/base_role.cpython-310.pyc b/metagpt/base/__pycache__/base_role.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7f300b0882f8f3b3f36398a645f9ad0b01ae2b3 Binary files /dev/null and b/metagpt/base/__pycache__/base_role.cpython-310.pyc differ diff --git a/metagpt/base/__pycache__/base_role.cpython-39.pyc b/metagpt/base/__pycache__/base_role.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e38e0efb23423a70ab70f43297662c6e8a6afd5 Binary files /dev/null and b/metagpt/base/__pycache__/base_role.cpython-39.pyc differ diff --git a/metagpt/base/__pycache__/base_serialization.cpython-310.pyc b/metagpt/base/__pycache__/base_serialization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..987cd12bf9f7ca8caeb662ee85727443d524e939 Binary files /dev/null and b/metagpt/base/__pycache__/base_serialization.cpython-310.pyc differ diff --git a/metagpt/base/__pycache__/base_serialization.cpython-39.pyc b/metagpt/base/__pycache__/base_serialization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec21649fd37d57e8733abcc43ab2c1a896e3ac31 Binary files /dev/null and b/metagpt/base/__pycache__/base_serialization.cpython-39.pyc differ diff --git a/metagpt/base/base_env.py b/metagpt/base/base_env.py new file mode 100644 index 0000000000000000000000000000000000000000..361b8b58f24a4913e62cb582546a0455f5db664c --- /dev/null +++ b/metagpt/base/base_env.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : base environment + +import typing +from abc import abstractmethod +from typing import Any, Optional + +from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams +from metagpt.base.base_serialization import BaseSerialization + +if typing.TYPE_CHECKING: + from metagpt.schema import Message + + +class BaseEnvironment(BaseSerialization): + """Base environment""" + + @abstractmethod + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Implement this to get init observation""" + + @abstractmethod + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + """Implement this if you want to get partial observation from the env""" + + @abstractmethod + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + """Implement this to feed a action and then get new observation from the env""" + + @abstractmethod + def publish_message(self, message: "Message", peekable: bool = True) -> bool: + """Distribute the message to the recipients.""" + + @abstractmethod + async def run(self, k=1): + """Process all task at once""" diff --git a/metagpt/base/base_env_space.py b/metagpt/base/base_env_space.py new file mode 100644 index 0000000000000000000000000000000000000000..fd0cfa399f00298d904d88982ea56c1008b1d1b2 --- /dev/null +++ b/metagpt/base/base_env_space.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from enum import IntEnum + +from pydantic import BaseModel, ConfigDict, Field + + +class BaseEnvActionType(IntEnum): + # # NONE = 0 # no action to run, just get observation + pass + + +class BaseEnvAction(BaseModel): + """env action type and its related params of action functions/apis""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_type: int = Field(default=0, description="action type") + + +class BaseEnvObsType(IntEnum): + # # NONE = 0 # get whole observation from env + pass + + +class BaseEnvObsParams(BaseModel): + """observation params for different EnvObsType to get its observe result""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + obs_type: int = Field(default=0, description="observation type") diff --git a/metagpt/base/base_role.py b/metagpt/base/base_role.py new file mode 100644 index 0000000000000000000000000000000000000000..1f7f00fa2350b273a82d0888f529fd5189865d4c --- /dev/null +++ b/metagpt/base/base_role.py @@ -0,0 +1,36 @@ +from abc import abstractmethod +from typing import Optional, Union + +from metagpt.base.base_serialization import BaseSerialization + + +class BaseRole(BaseSerialization): + """Abstract base class for all roles.""" + + name: str + + @property + def is_idle(self) -> bool: + raise NotImplementedError + + @abstractmethod + def think(self): + """Consider what to do and decide on the next course of action.""" + raise NotImplementedError + + @abstractmethod + def act(self): + """Perform the current action.""" + raise NotImplementedError + + @abstractmethod + async def react(self) -> "Message": + """Entry to one of three strategies by which Role reacts to the observed Message.""" + + @abstractmethod + async def run(self, with_message: Optional[Union[str, "Message", list[str]]] = None) -> Optional["Message"]: + """Observe, and think and act based on the results of the observation.""" + + @abstractmethod + def get_memories(self, k: int = 0) -> list["Message"]: + """Return the most recent k memories of this role.""" diff --git a/metagpt/base/base_serialization.py b/metagpt/base/base_serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..8aff7f39e33145594d389b6a71d2faee03fd24f2 --- /dev/null +++ b/metagpt/base/base_serialization.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, model_serializer, model_validator + + +class BaseSerialization(BaseModel, extra="forbid"): + """ + PolyMorphic subclasses Serialization / Deserialization Mixin + - First of all, we need to know that pydantic is not designed for polymorphism. + - If Engineer is subclass of Role, it would be serialized as Role. If we want to serialize it as Engineer, we need + to add `class name` to Engineer. So we need Engineer inherit SerializationMixin. + + More details: + - https://docs.pydantic.dev/latest/concepts/serialization/ + - https://github.com/pydantic/pydantic/discussions/7008 discuss about avoid `__get_pydantic_core_schema__` + """ + + __is_polymorphic_base = False + __subclasses_map__ = {} + + @model_serializer(mode="wrap") + def __serialize_with_class_type__(self, default_serializer) -> Any: + # default serializer, then append the `__module_class_name` field and return + ret = default_serializer(self) + ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}" + return ret + + @model_validator(mode="wrap") + @classmethod + def __convert_to_real_type__(cls, value: Any, handler): + if isinstance(value, dict) is False: + return handler(value) + + # it is a dict so make sure to remove the __module_class_name + # because we don't allow extra keywords but want to ensure + # e.g Cat.model_validate(cat.model_dump()) works + class_full_name = value.pop("__module_class_name", None) + + # if it's not the polymorphic base we construct via default handler + if not cls.__is_polymorphic_base: + if class_full_name is None: + return handler(value) + elif str(cls) == f"": + return handler(value) + else: + # f"Trying to instantiate {class_full_name} but this is not the polymorphic base class") + pass + + # otherwise we lookup the correct polymorphic type and construct that + # instead + if class_full_name is None: + raise ValueError("Missing __module_class_name field") + + class_type = cls.__subclasses_map__.get(class_full_name, None) + + if class_type is None: + # TODO could try dynamic import + raise TypeError(f"Trying to instantiate {class_full_name}, which has not yet been defined!") + + return class_type(**value) + + def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs): + cls.__is_polymorphic_base = is_polymorphic_base + cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls + super().__init_subclass__(**kwargs) diff --git a/metagpt/config2.py b/metagpt/config2.py new file mode 100644 index 0000000000000000000000000000000000000000..02039f737991df6d9c1edc19671fbb2d101ee441 --- /dev/null +++ b/metagpt/config2.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 01:25 +@Author : alexanderwu +@File : config2.py +""" +import os +from pathlib import Path +from typing import Dict, Iterable, List, Literal, Optional + +from pydantic import BaseModel, Field, model_validator + +from metagpt.configs.browser_config import BrowserConfig +from metagpt.configs.embedding_config import EmbeddingConfig +from metagpt.configs.exp_pool_config import ExperiencePoolConfig +from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.configs.mermaid_config import MermaidConfig +from metagpt.configs.omniparse_config import OmniParseConfig +from metagpt.configs.redis_config import RedisConfig +from metagpt.configs.role_custom_config import RoleCustomConfig +from metagpt.configs.role_zero_config import RoleZeroConfig +from metagpt.configs.s3_config import S3Config +from metagpt.configs.search_config import SearchConfig +from metagpt.configs.workspace_config import WorkspaceConfig +from metagpt.const import CONFIG_ROOT, METAGPT_ROOT +from metagpt.utils.yaml_model import YamlModel + + +class CLIParams(BaseModel): + """CLI parameters""" + + project_path: str = "" + project_name: str = "" + inc: bool = False + reqa_file: str = "" + max_auto_summarize_code: int = 0 + git_reinit: bool = False + + @model_validator(mode="after") + def check_project_path(self): + """Check project_path and project_name""" + if self.project_path: + self.inc = True + self.project_name = self.project_name or Path(self.project_path).name + return self + + +class Config(CLIParams, YamlModel): + """Configurations for MetaGPT""" + + # Key Parameters + llm: LLMConfig + + # RAG Embedding + embedding: EmbeddingConfig = EmbeddingConfig() + + # omniparse + omniparse: OmniParseConfig = OmniParseConfig() + + # Global Proxy. Will be used if llm.proxy is not set + proxy: str = "" + + # Tool Parameters + search: SearchConfig = SearchConfig() + enable_search: bool = False + browser: BrowserConfig = BrowserConfig() + mermaid: MermaidConfig = MermaidConfig() + + # Storage Parameters + s3: Optional[S3Config] = None + redis: Optional[RedisConfig] = None + + # Misc Parameters + repair_llm_output: bool = False + prompt_schema: Literal["json", "markdown", "raw"] = "json" + workspace: WorkspaceConfig = Field(default_factory=WorkspaceConfig) + enable_longterm_memory: bool = False + code_validate_k_times: int = 2 + + # Experience Pool Parameters + exp_pool: ExperiencePoolConfig = Field(default_factory=ExperiencePoolConfig) + + # Will be removed in the future + metagpt_tti_url: str = "" + language: str = "English" + redis_key: str = "placeholder" + iflytek_app_id: str = "" + iflytek_api_secret: str = "" + iflytek_api_key: str = "" + azure_tts_subscription_key: str = "" + azure_tts_region: str = "" + _extra: dict = dict() # extra config dict + + # Role's custom configuration + roles: Optional[List[RoleCustomConfig]] = None + + # RoleZero's configuration + role_zero: RoleZeroConfig = Field(default_factory=RoleZeroConfig) + + @classmethod + def from_home(cls, path): + """Load config from ~/.metagpt/config2.yaml""" + pathname = CONFIG_ROOT / path + if not pathname.exists(): + return None + return Config.from_yaml_file(pathname) + + @classmethod + def default(cls, reload: bool = False, **kwargs) -> "Config": + """Load default config + - Priority: env < default_config_paths + - Inside default_config_paths, the latter one overwrites the former one + """ + default_config_paths = ( + METAGPT_ROOT / "config/config2.yaml", + CONFIG_ROOT / "config2.yaml", + ) + if reload or default_config_paths not in _CONFIG_CACHE: + dicts = [dict(os.environ), *(Config.read_yaml(path) for path in default_config_paths), kwargs] + final = merge_dict(dicts) + _CONFIG_CACHE[default_config_paths] = Config(**final) + return _CONFIG_CACHE[default_config_paths] + + @classmethod + def from_llm_config(cls, llm_config: dict): + """user config llm + example: + llm_config = {"api_type": "xxx", "api_key": "xxx", "model": "xxx"} + gpt4 = Config.from_llm_config(llm_config) + A = Role(name="A", profile="Democratic candidate", goal="Win the election", actions=[a1], watch=[a2], config=gpt4) + """ + llm_config = LLMConfig.model_validate(llm_config) + dicts = [dict(os.environ)] + dicts += [{"llm": llm_config}] + final = merge_dict(dicts) + return Config(**final) + + def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code): + """update config via cli""" + + # Use in the PrepareDocuments action according to Section 2.2.3.5.1 of RFC 135. + if project_path: + inc = True + project_name = project_name or Path(project_path).name + self.project_path = project_path + self.project_name = project_name + self.inc = inc + self.reqa_file = reqa_file + self.max_auto_summarize_code = max_auto_summarize_code + + @property + def extra(self): + return self._extra + + @extra.setter + def extra(self, value: dict): + self._extra = value + + def get_openai_llm(self) -> Optional[LLMConfig]: + """Get OpenAI LLMConfig by name. If no OpenAI, raise Exception""" + if self.llm.api_type == LLMType.OPENAI: + return self.llm + return None + + def get_azure_llm(self) -> Optional[LLMConfig]: + """Get Azure LLMConfig by name. If no Azure, raise Exception""" + if self.llm.api_type == LLMType.AZURE: + return self.llm + return None + + +def merge_dict(dicts: Iterable[Dict]) -> Dict: + """Merge multiple dicts into one, with the latter dict overwriting the former""" + result = {} + for dictionary in dicts: + result.update(dictionary) + return result + + +_CONFIG_CACHE = {} +config = Config.default() diff --git a/metagpt/configs/__init__.py b/metagpt/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e42e6788f240b7df0abbf07410554d66641313ba --- /dev/null +++ b/metagpt/configs/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 16:33 +@Author : alexanderwu +@File : __init__.py +""" diff --git a/metagpt/configs/__pycache__/__init__.cpython-310.pyc b/metagpt/configs/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4e8391aa7b3be8fef60d0e8e6a880ddd69c8c0a Binary files /dev/null and b/metagpt/configs/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/__init__.cpython-39.pyc b/metagpt/configs/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9ee6626414691dde671399fa321c7f7262c6cc1 Binary files /dev/null and b/metagpt/configs/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/browser_config.cpython-310.pyc b/metagpt/configs/__pycache__/browser_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8592661164bdc20a46aa40de28dd18d5c594b7a Binary files /dev/null and b/metagpt/configs/__pycache__/browser_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/browser_config.cpython-39.pyc b/metagpt/configs/__pycache__/browser_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c851f9683c414bc6b7577449b201f2c7b8bfb88 Binary files /dev/null and b/metagpt/configs/__pycache__/browser_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/compress_msg_config.cpython-310.pyc b/metagpt/configs/__pycache__/compress_msg_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd6303d80bf2ae96ecee472b944c24bc786631c3 Binary files /dev/null and b/metagpt/configs/__pycache__/compress_msg_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/compress_msg_config.cpython-39.pyc b/metagpt/configs/__pycache__/compress_msg_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f5065e09cbb6681b987b549af8b00e8ddad6f90 Binary files /dev/null and b/metagpt/configs/__pycache__/compress_msg_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/embedding_config.cpython-310.pyc b/metagpt/configs/__pycache__/embedding_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b01b7ea71fbf4f5f70577214dd062a81fe1a4967 Binary files /dev/null and b/metagpt/configs/__pycache__/embedding_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/embedding_config.cpython-39.pyc b/metagpt/configs/__pycache__/embedding_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db7c20935206fbc0cc603f024b947a6b14be4ff3 Binary files /dev/null and b/metagpt/configs/__pycache__/embedding_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/exp_pool_config.cpython-310.pyc b/metagpt/configs/__pycache__/exp_pool_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e244d198625feb5601bf1809177883f6959a800 Binary files /dev/null and b/metagpt/configs/__pycache__/exp_pool_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/exp_pool_config.cpython-39.pyc b/metagpt/configs/__pycache__/exp_pool_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef89c5a30943da03ae5192fdad390e6b48f05acf Binary files /dev/null and b/metagpt/configs/__pycache__/exp_pool_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/llm_config.cpython-310.pyc b/metagpt/configs/__pycache__/llm_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88e0d3bcc5d0f4aa04dd6b95306664cd851508c0 Binary files /dev/null and b/metagpt/configs/__pycache__/llm_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/llm_config.cpython-39.pyc b/metagpt/configs/__pycache__/llm_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee7494da54a61ae69c82b87198dee43c8ec8df26 Binary files /dev/null and b/metagpt/configs/__pycache__/llm_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/mermaid_config.cpython-310.pyc b/metagpt/configs/__pycache__/mermaid_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ff633242b29eb14896048df121ea71795e3f02d Binary files /dev/null and b/metagpt/configs/__pycache__/mermaid_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/mermaid_config.cpython-39.pyc b/metagpt/configs/__pycache__/mermaid_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b9dd0d2f24f4f5e6512d0acede7c5a84b8127be Binary files /dev/null and b/metagpt/configs/__pycache__/mermaid_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/models_config.cpython-310.pyc b/metagpt/configs/__pycache__/models_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3272893cb024252e27542f39aa00a61d8b05924f Binary files /dev/null and b/metagpt/configs/__pycache__/models_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/models_config.cpython-39.pyc b/metagpt/configs/__pycache__/models_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f79407bbad979f835d979426437c76c8d6550f91 Binary files /dev/null and b/metagpt/configs/__pycache__/models_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/omniparse_config.cpython-310.pyc b/metagpt/configs/__pycache__/omniparse_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b04008dcb1ebbfa974c55af634c1bdc6bc4e39d5 Binary files /dev/null and b/metagpt/configs/__pycache__/omniparse_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/omniparse_config.cpython-39.pyc b/metagpt/configs/__pycache__/omniparse_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59611951ddad32af1f976208d230868bf19c29a0 Binary files /dev/null and b/metagpt/configs/__pycache__/omniparse_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/redis_config.cpython-310.pyc b/metagpt/configs/__pycache__/redis_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2af9d502448afb01feda95dd422ff80f4f475db7 Binary files /dev/null and b/metagpt/configs/__pycache__/redis_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/redis_config.cpython-39.pyc b/metagpt/configs/__pycache__/redis_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..533b07a7d38cab0dc9f00ce5511be341d80ecf5d Binary files /dev/null and b/metagpt/configs/__pycache__/redis_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/role_custom_config.cpython-310.pyc b/metagpt/configs/__pycache__/role_custom_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f81c2d2d24d7e619e711dc1e3ad9d0b8952e8c92 Binary files /dev/null and b/metagpt/configs/__pycache__/role_custom_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/role_custom_config.cpython-39.pyc b/metagpt/configs/__pycache__/role_custom_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..946bae0d9ef0fc393de5b3e07e38a9383ceabdf7 Binary files /dev/null and b/metagpt/configs/__pycache__/role_custom_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/role_zero_config.cpython-310.pyc b/metagpt/configs/__pycache__/role_zero_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01a7552fd958f6dff63b51a6dcf32ed1184a6e09 Binary files /dev/null and b/metagpt/configs/__pycache__/role_zero_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/role_zero_config.cpython-39.pyc b/metagpt/configs/__pycache__/role_zero_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e430dec62bd85cbf1a660983f3f46e81ee85585d Binary files /dev/null and b/metagpt/configs/__pycache__/role_zero_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/s3_config.cpython-310.pyc b/metagpt/configs/__pycache__/s3_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c648a76e2cb42313221e8e8d050ded0de74405d Binary files /dev/null and b/metagpt/configs/__pycache__/s3_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/s3_config.cpython-39.pyc b/metagpt/configs/__pycache__/s3_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..743919bc288f4dff98bf2d27219f8b0791226a1d Binary files /dev/null and b/metagpt/configs/__pycache__/s3_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/search_config.cpython-310.pyc b/metagpt/configs/__pycache__/search_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5f239f4acda08cb42e885cc8ffc8f115a8cd3a6 Binary files /dev/null and b/metagpt/configs/__pycache__/search_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/search_config.cpython-39.pyc b/metagpt/configs/__pycache__/search_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1adf978ed33dac9ac9490f24ddaeaa8a7f79ee4 Binary files /dev/null and b/metagpt/configs/__pycache__/search_config.cpython-39.pyc differ diff --git a/metagpt/configs/__pycache__/workspace_config.cpython-310.pyc b/metagpt/configs/__pycache__/workspace_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78177ae0489610b1eea571e65ecaebc0973af719 Binary files /dev/null and b/metagpt/configs/__pycache__/workspace_config.cpython-310.pyc differ diff --git a/metagpt/configs/__pycache__/workspace_config.cpython-39.pyc b/metagpt/configs/__pycache__/workspace_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79e0e6493d174208e1f5207dadc5f9451fecfb82 Binary files /dev/null and b/metagpt/configs/__pycache__/workspace_config.cpython-39.pyc differ diff --git a/metagpt/configs/browser_config.py b/metagpt/configs/browser_config.py new file mode 100644 index 0000000000000000000000000000000000000000..fafbaeeb85d57029a8296a2e58deeab761835888 --- /dev/null +++ b/metagpt/configs/browser_config.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:06 +@Author : alexanderwu +@File : browser_config.py +""" +from enum import Enum +from typing import Literal + +from metagpt.utils.yaml_model import YamlModel + + +class WebBrowserEngineType(Enum): + PLAYWRIGHT = "playwright" + SELENIUM = "selenium" + CUSTOM = "custom" + + @classmethod + def __missing__(cls, key): + """Default type conversion""" + return cls.CUSTOM + + +class BrowserConfig(YamlModel): + """Config for Browser""" + + engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT + browser_type: Literal["chromium", "firefox", "webkit", "chrome", "firefox", "edge", "ie"] = "chromium" + """If the engine is Playwright, the value should be one of "chromium", "firefox", or "webkit". If it is Selenium, the value + should be either "chrome", "firefox", "edge", or "ie".""" diff --git a/metagpt/configs/compress_msg_config.py b/metagpt/configs/compress_msg_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c46334c1257006ae769cd8505da1969094b93557 --- /dev/null +++ b/metagpt/configs/compress_msg_config.py @@ -0,0 +1,32 @@ +from enum import Enum + + +class CompressType(Enum): + """ + Compression Type for messages. Used to compress messages under token limit. + - "": No compression. Default value. + - "post_cut_by_msg": Keep as many latest messages as possible. + - "post_cut_by_token": Keep as many latest messages as possible and truncate the earliest fit-in message. + - "pre_cut_by_msg": Keep as many earliest messages as possible. + - "pre_cut_by_token": Keep as many earliest messages as possible and truncate the latest fit-in message. + """ + + NO_COMPRESS = "" + POST_CUT_BY_MSG = "post_cut_by_msg" + POST_CUT_BY_TOKEN = "post_cut_by_token" + PRE_CUT_BY_MSG = "pre_cut_by_msg" + PRE_CUT_BY_TOKEN = "pre_cut_by_token" + + def __missing__(self, key): + return self.NO_COMPRESS + + @classmethod + def get_type(cls, type_name): + for member in cls: + if member.value == type_name: + return member + return cls.NO_COMPRESS + + @classmethod + def cut_types(cls): + return [member for member in cls if "cut" in member.value] diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f9b41b9dc9afcee0bf34448f7263b7f640d8ed0e --- /dev/null +++ b/metagpt/configs/embedding_config.py @@ -0,0 +1,54 @@ +from enum import Enum +from typing import Optional + +from pydantic import field_validator + +from metagpt.utils.yaml_model import YamlModel + + +class EmbeddingType(Enum): + OPENAI = "openai" + AZURE = "azure" + GEMINI = "gemini" + OLLAMA = "ollama" + + +class EmbeddingConfig(YamlModel): + """Config for Embedding. + + Examples: + --------- + api_type: "openai" + api_key: "YOU_API_KEY" + dimensions: "YOUR_MODEL_DIMENSIONS" + + api_type: "azure" + api_key: "YOU_API_KEY" + base_url: "YOU_BASE_URL" + api_version: "YOU_API_VERSION" + dimensions: "YOUR_MODEL_DIMENSIONS" + + api_type: "gemini" + api_key: "YOU_API_KEY" + + api_type: "ollama" + base_url: "YOU_BASE_URL" + model: "YOU_MODEL" + dimensions: "YOUR_MODEL_DIMENSIONS" + """ + + api_type: Optional[EmbeddingType] = None + api_key: Optional[str] = None + base_url: Optional[str] = None + api_version: Optional[str] = None + + model: Optional[str] = None + embed_batch_size: Optional[int] = None + dimensions: Optional[int] = None # output dimension of embedding model + + @field_validator("api_type", mode="before") + @classmethod + def check_api_type(cls, v): + if v == "": + return None + return v diff --git a/metagpt/configs/exp_pool_config.py b/metagpt/configs/exp_pool_config.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a2d5d41788c1e567aedac32f720444fa0691ea --- /dev/null +++ b/metagpt/configs/exp_pool_config.py @@ -0,0 +1,25 @@ +from enum import Enum + +from pydantic import Field + +from metagpt.utils.yaml_model import YamlModel + + +class ExperiencePoolRetrievalType(Enum): + BM25 = "bm25" + CHROMA = "chroma" + + +class ExperiencePoolConfig(YamlModel): + enabled: bool = Field( + default=False, + description="Flag to enable or disable the experience pool. When disabled, both reading and writing are ineffective.", + ) + enable_read: bool = Field(default=False, description="Enable to read from experience pool.") + enable_write: bool = Field(default=False, description="Enable to write to experience pool.") + persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.") + retrieval_type: ExperiencePoolRetrievalType = Field( + default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool." + ) + use_llm_ranker: bool = Field(default=True, description="Use LLM Reranker to get better result.") + collection_name: str = Field(default="experience_pool", description="The collection name in chromadb") diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f17904d8fc173c50715331c90cda254d4c4da772 --- /dev/null +++ b/metagpt/configs/llm_config.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 16:33 +@Author : alexanderwu +@File : llm_config.py +""" + +from enum import Enum +from typing import Optional + +from pydantic import field_validator + +from metagpt.configs.compress_msg_config import CompressType +from metagpt.const import CONFIG_ROOT, LLM_API_TIMEOUT, METAGPT_ROOT +from metagpt.utils.yaml_model import YamlModel + + +class LLMType(Enum): + OPENAI = "openai" + ANTHROPIC = "anthropic" + CLAUDE = "claude" # alias name of anthropic + SPARK = "spark" + ZHIPUAI = "zhipuai" + FIREWORKS = "fireworks" + OPEN_LLM = "open_llm" + GEMINI = "gemini" + METAGPT = "metagpt" + AZURE = "azure" + OLLAMA = "ollama" # /chat at ollama api + OLLAMA_GENERATE = "ollama.generate" # /generate at ollama api + OLLAMA_EMBEDDINGS = "ollama.embeddings" # /embeddings at ollama api + OLLAMA_EMBED = "ollama.embed" # /embed at ollama api + QIANFAN = "qianfan" # Baidu BCE + DASHSCOPE = "dashscope" # Aliyun LingJi DashScope + MOONSHOT = "moonshot" + MISTRAL = "mistral" + YI = "yi" # lingyiwanwu + OPEN_ROUTER = "open_router" + DEEPSEEK = "deepseek" + SILICONFLOW = "siliconflow" + OPENROUTER = "openrouter" + OPENROUTER_REASONING = "openrouter_reasoning" + BEDROCK = "bedrock" + ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk + LLAMA_API = "llama_api" + + def __missing__(self, key): + return self.OPENAI + + +class LLMConfig(YamlModel): + """Config for LLM + + OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/resources/chat/completions.py#L681 + Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields + """ + + api_key: str = "sk-" + api_type: LLMType = LLMType.OPENAI + base_url: str = "https://api.openai.com/v1" + api_version: Optional[str] = None + + model: Optional[str] = None # also stands for DEPLOYMENT_NAME + pricing_plan: Optional[str] = None # Cost Settlement Plan Parameters. + + # For Cloud Service Provider like Baidu/ Alibaba + access_key: Optional[str] = None + secret_key: Optional[str] = None + session_token: Optional[str] = None + endpoint: Optional[str] = None # for self-deployed model on the cloud + + # For Spark(Xunfei), maybe remove later + app_id: Optional[str] = None + api_secret: Optional[str] = None + domain: Optional[str] = None + + # For Chat Completion + max_token: int = 4096 + temperature: float = 0.0 + top_p: float = 1.0 + top_k: int = 0 + repetition_penalty: float = 1.0 + stop: Optional[str] = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + best_of: Optional[int] = None + n: Optional[int] = None + stream: bool = True + seed: Optional[int] = None + # https://cookbook.openai.com/examples/using_logprobs + logprobs: Optional[bool] = None + top_logprobs: Optional[int] = None + timeout: int = 600 + context_length: Optional[int] = None # Max input tokens + + # For Amazon Bedrock + region_name: str = None + + # For Network + proxy: Optional[str] = None + + # Cost Control + calc_usage: bool = True + + # Compress request messages under token limit + compress_type: CompressType = CompressType.NO_COMPRESS + + # For Messages Control + use_system_prompt: bool = True + + # reasoning / thinking switch + reasoning: bool = False + reasoning_max_token: int = 4000 # reasoning budget tokens to generate, usually smaller than max_token + + @field_validator("api_key") + @classmethod + def check_llm_key(cls, v): + if v in ["", None, "YOUR_API_KEY"]: + repo_config_path = METAGPT_ROOT / "config/config2.yaml" + root_config_path = CONFIG_ROOT / "config2.yaml" + if root_config_path.exists(): + raise ValueError( + f"Please set your API key in {root_config_path}. If you also set your config in {repo_config_path}, \n" + f"the former will overwrite the latter. This may cause unexpected result.\n" + ) + elif repo_config_path.exists(): + raise ValueError(f"Please set your API key in {repo_config_path}") + else: + raise ValueError("Please set your API key in config2.yaml") + return v + + @field_validator("timeout") + @classmethod + def check_timeout(cls, v): + return v or LLM_API_TIMEOUT diff --git a/metagpt/configs/mermaid_config.py b/metagpt/configs/mermaid_config.py new file mode 100644 index 0000000000000000000000000000000000000000..47f14f4cd0953d4649ccfb66275e9380b2095645 --- /dev/null +++ b/metagpt/configs/mermaid_config.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:07 +@Author : alexanderwu +@File : mermaid_config.py +""" +from typing import Literal + +from metagpt.utils.yaml_model import YamlModel + + +class MermaidConfig(YamlModel): + """Config for Mermaid""" + + engine: Literal["nodejs", "ink", "playwright", "pyppeteer", "none"] = "nodejs" + path: str = "mmdc" # mmdc + puppeteer_config: str = "" + pyppeteer_path: str = "/usr/bin/google-chrome-stable" diff --git a/metagpt/configs/models_config.py b/metagpt/configs/models_config.py new file mode 100644 index 0000000000000000000000000000000000000000..bc4897fec52e85763f5d6ea09aed2f015c713d36 --- /dev/null +++ b/metagpt/configs/models_config.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +models_config.py + +This module defines the ModelsConfig class for handling configuration of LLM models. + +Attributes: + CONFIG_ROOT (Path): Root path for configuration files. + METAGPT_ROOT (Path): Root path for MetaGPT files. + +Classes: + ModelsConfig (YamlModel): Configuration class for LLM models. +""" +from pathlib import Path +from typing import Dict, List, Optional + +from pydantic import Field, field_validator + +from metagpt.config2 import merge_dict +from metagpt.configs.llm_config import LLMConfig +from metagpt.const import CONFIG_ROOT, METAGPT_ROOT +from metagpt.utils.yaml_model import YamlModel + + +class ModelsConfig(YamlModel): + """ + Configuration class for `models` in `config2.yaml`. + + Attributes: + models (Dict[str, LLMConfig]): Dictionary mapping model names or types to LLMConfig objects. + + Methods: + update_llm_model(cls, value): Validates and updates LLM model configurations. + from_home(cls, path): Loads configuration from ~/.metagpt/config2.yaml. + default(cls): Loads default configuration from predefined paths. + get(self, name_or_type: str) -> Optional[LLMConfig]: Retrieves LLMConfig by name or API type. + """ + + models: Dict[str, LLMConfig] = Field(default_factory=dict) + + @field_validator("models", mode="before") + @classmethod + def update_llm_model(cls, value): + """ + Validates and updates LLM model configurations. + + Args: + value (Dict[str, Union[LLMConfig, dict]]): Dictionary of LLM configurations. + + Returns: + Dict[str, Union[LLMConfig, dict]]: Updated dictionary of LLM configurations. + """ + for key, config in value.items(): + if isinstance(config, LLMConfig): + config.model = config.model or key + elif isinstance(config, dict): + config["model"] = config.get("model") or key + return value + + @classmethod + def from_home(cls, path): + """ + Loads configuration from ~/.metagpt/config2.yaml. + + Args: + path (str): Relative path to configuration file. + + Returns: + Optional[ModelsConfig]: Loaded ModelsConfig object or None if file doesn't exist. + """ + pathname = CONFIG_ROOT / path + if not pathname.exists(): + return None + return ModelsConfig.from_yaml_file(pathname) + + @classmethod + def default(cls): + """ + Loads default configuration from predefined paths. + + Returns: + ModelsConfig: Default ModelsConfig object. + """ + default_config_paths: List[Path] = [ + METAGPT_ROOT / "config/config2.yaml", + CONFIG_ROOT / "config2.yaml", + ] + + dicts = [ModelsConfig.read_yaml(path) for path in default_config_paths] + final = merge_dict(dicts) + return ModelsConfig(**final) + + def get(self, name_or_type: str) -> Optional[LLMConfig]: + """ + Retrieves LLMConfig object by name or API type. + + Args: + name_or_type (str): Name or API type of the LLM model. + + Returns: + Optional[LLMConfig]: LLMConfig object if found, otherwise None. + """ + if not name_or_type: + return None + model = self.models.get(name_or_type) + if model: + return model + for m in self.models.values(): + if m.api_type == name_or_type: + return m + return None diff --git a/metagpt/configs/omniparse_config.py b/metagpt/configs/omniparse_config.py new file mode 100644 index 0000000000000000000000000000000000000000..8f38f9f518982c2d82131b2257acc52e97668ace --- /dev/null +++ b/metagpt/configs/omniparse_config.py @@ -0,0 +1,7 @@ +from metagpt.utils.yaml_model import YamlModel + + +class OmniParseConfig(YamlModel): + api_key: str = "" + base_url: str = "" + timeout: int = 600 diff --git a/metagpt/configs/redis_config.py b/metagpt/configs/redis_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c4cfb6764dd1006f17a64e9863da43e6832d8ef3 --- /dev/null +++ b/metagpt/configs/redis_config.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:06 +@Author : alexanderwu +@File : redis_config.py +""" +from metagpt.utils.yaml_model import YamlModelWithoutDefault + + +class RedisConfig(YamlModelWithoutDefault): + host: str + port: int + username: str = "" + password: str + db: str + + def to_url(self): + return f"redis://{self.host}:{self.port}" + + def to_kwargs(self): + return { + "username": self.username, + "password": self.password, + "db": self.db, + } diff --git a/metagpt/configs/role_custom_config.py b/metagpt/configs/role_custom_config.py new file mode 100644 index 0000000000000000000000000000000000000000..581de605e62d2fedcdd9718238f698d69c6bedbd --- /dev/null +++ b/metagpt/configs/role_custom_config.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/4/22 16:33 +@Author : Justin +@File : role_custom_config.py +""" +from metagpt.configs.llm_config import LLMConfig +from metagpt.utils.yaml_model import YamlModel + + +class RoleCustomConfig(YamlModel): + """custom config for roles + role: role's className or role's role_id + To be expanded + """ + + role: str = "" + llm: LLMConfig diff --git a/metagpt/configs/role_zero_config.py b/metagpt/configs/role_zero_config.py new file mode 100644 index 0000000000000000000000000000000000000000..91d554b2f46c270a5abf7224295d65bc096f90e1 --- /dev/null +++ b/metagpt/configs/role_zero_config.py @@ -0,0 +1,11 @@ +from pydantic import Field + +from metagpt.utils.yaml_model import YamlModel + + +class RoleZeroConfig(YamlModel): + enable_longterm_memory: bool = Field(default=False, description="Whether to use long-term memory.") + longterm_memory_persist_path: str = Field(default=".role_memory_data", description="The directory to save data.") + memory_k: int = Field(default=200, description="The capacity of short-term memory.") + similarity_top_k: int = Field(default=5, description="The number of long-term memories to retrieve.") + use_llm_ranker: bool = Field(default=False, description="Whether to use LLM Reranker to get better result.") diff --git a/metagpt/configs/s3_config.py b/metagpt/configs/s3_config.py new file mode 100644 index 0000000000000000000000000000000000000000..72b81fae46c73e0e352b61cdd0d1f7991acea9f6 --- /dev/null +++ b/metagpt/configs/s3_config.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:07 +@Author : alexanderwu +@File : s3_config.py +""" +from metagpt.utils.yaml_model import YamlModelWithoutDefault + + +class S3Config(YamlModelWithoutDefault): + access_key: str + secret_key: str + endpoint: str + bucket: str diff --git a/metagpt/configs/search_config.py b/metagpt/configs/search_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2c773b685b211657f026a74ecc87bd9faed32dab --- /dev/null +++ b/metagpt/configs/search_config.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:06 +@Author : alexanderwu +@File : search_config.py +""" +from enum import Enum +from typing import Callable, Optional + +from pydantic import ConfigDict, Field + +from metagpt.utils.yaml_model import YamlModel + + +class SearchEngineType(Enum): + SERPAPI_GOOGLE = "serpapi" + SERPER_GOOGLE = "serper" + DIRECT_GOOGLE = "google" + DUCK_DUCK_GO = "ddg" + CUSTOM_ENGINE = "custom" + BING = "bing" + + +class SearchConfig(YamlModel): + """Config for Search""" + + model_config = ConfigDict(extra="allow") + + api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO + api_key: str = "" + cse_id: str = "" # for google + search_func: Optional[Callable] = None + params: dict = Field( + default_factory=lambda: { + "engine": "google", + "google_domain": "google.com", + "gl": "us", + "hl": "en", + } + ) diff --git a/metagpt/configs/workspace_config.py b/metagpt/configs/workspace_config.py new file mode 100644 index 0000000000000000000000000000000000000000..df7aeaef9bf2f13ecaec248de2d78de21160e75d --- /dev/null +++ b/metagpt/configs/workspace_config.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 19:09 +@Author : alexanderwu +@File : workspace_config.py +""" +from datetime import datetime +from pathlib import Path +from uuid import uuid4 + +from pydantic import field_validator, model_validator + +from metagpt.const import DEFAULT_WORKSPACE_ROOT +from metagpt.utils.yaml_model import YamlModel + + +class WorkspaceConfig(YamlModel): + path: Path = DEFAULT_WORKSPACE_ROOT + use_uid: bool = False + uid: str = "" + + @field_validator("path") + @classmethod + def check_workspace_path(cls, v): + if isinstance(v, str): + v = Path(v) + return v + + @model_validator(mode="after") + def check_uid_and_update_path(self): + if self.use_uid and not self.uid: + self.uid = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid4().hex[-8:]}" + self.path = self.path / self.uid + + # Create workspace path if not exists + self.path.mkdir(parents=True, exist_ok=True) + return self diff --git a/metagpt/const.py b/metagpt/const.py new file mode 100644 index 0000000000000000000000000000000000000000..94a7d8529be3df35b6f3837380b9648f8096139a --- /dev/null +++ b/metagpt/const.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +from pathlib import Path + +from loguru import logger + +import metagpt + + +def get_metagpt_package_root(): + """Get the root directory of the installed package.""" + package_root = Path(metagpt.__file__).parent.parent + logger.info(f"Package root set to {str(package_root)}") + return package_root + + +def get_metagpt_root(): + """Get the project root directory.""" + # Check if a project root is specified in the environment variable + project_root_env = os.getenv("METAGPT_PROJECT_ROOT") + if project_root_env: + project_root = Path(project_root_env) + logger.info(f"PROJECT_ROOT set from environment variable to {str(project_root)}") + else: + # Fallback to package root if no environment variable is set + project_root = get_metagpt_package_root() + for i in (".git", ".project_root", ".gitignore"): + if (project_root / i).exists(): + break + else: + project_root = Path.cwd() + + return project_root + + +# METAGPT PROJECT ROOT AND VARS +CONFIG_ROOT = Path.home() / ".metagpt" +METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT +DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace" + +EXAMPLE_PATH = METAGPT_ROOT / "examples" +EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data" +DATA_PATH = METAGPT_ROOT / "data" +DABENCH_PATH = EXAMPLE_PATH / "di/InfiAgent-DABench/data" +EXAMPLE_BENCHMARK_PATH = EXAMPLE_PATH / "data/rag_bm" +TEST_DATA_PATH = METAGPT_ROOT / "tests/data" +RESEARCH_PATH = DATA_PATH / "research" +TUTORIAL_PATH = DATA_PATH / "tutorial_docx" +INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" + +UT_PATH = DATA_PATH / "ut" +SWAGGER_PATH = UT_PATH / "files/api/" +UT_PY_PATH = UT_PATH / "files/ut/" +API_QUESTIONS_PATH = UT_PATH / "files/question/" + +SERDESER_PATH = DEFAULT_WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project + +TMP = METAGPT_ROOT / "tmp" + +SOURCE_ROOT = METAGPT_ROOT / "metagpt" +PROMPT_PATH = SOURCE_ROOT / "prompts" +SKILL_DIRECTORY = SOURCE_ROOT / "skills" +TOOL_SCHEMA_PATH = METAGPT_ROOT / "metagpt/tools/schemas" +TOOL_LIBS_PATH = METAGPT_ROOT / "metagpt/tools/libs" + +# TEMPLATE PATH +TEMPLATE_FOLDER_PATH = METAGPT_ROOT / "template" +VUE_TEMPLATE_PATH = TEMPLATE_FOLDER_PATH / "vue_template" +REACT_TEMPLATE_PATH = TEMPLATE_FOLDER_PATH / "react_template" + +# REAL CONSTS + +MEM_TTL = 24 * 30 * 3600 + +MESSAGE_ROUTE_FROM = "sent_from" +MESSAGE_ROUTE_TO = "send_to" +MESSAGE_ROUTE_CAUSE_BY = "cause_by" +MESSAGE_META_ROLE = "role" +MESSAGE_ROUTE_TO_ALL = "" +MESSAGE_ROUTE_TO_NONE = "" +MESSAGE_ROUTE_TO_SELF = "" # Add this tag to replace `ActionOutput` + + +REQUIREMENT_FILENAME = "requirement.txt" +BUGFIX_FILENAME = "bugfix.txt" +PACKAGE_REQUIREMENTS_FILENAME = "requirements.txt" + +DOCS_FILE_REPO = "docs" +PRDS_FILE_REPO = "docs/prd" +SYSTEM_DESIGN_FILE_REPO = "docs/system_design" +TASK_FILE_REPO = "docs/task" +CODE_PLAN_AND_CHANGE_FILE_REPO = "docs/code_plan_and_change" +COMPETITIVE_ANALYSIS_FILE_REPO = "resources/competitive_analysis" +DATA_API_DESIGN_FILE_REPO = "resources/data_api_design" +SEQ_FLOW_FILE_REPO = "resources/seq_flow" +SYSTEM_DESIGN_PDF_FILE_REPO = "resources/system_design" +PRD_PDF_FILE_REPO = "resources/prd" +TASK_PDF_FILE_REPO = "resources/api_spec_and_task" +CODE_PLAN_AND_CHANGE_PDF_FILE_REPO = "resources/code_plan_and_change" +TEST_CODES_FILE_REPO = "tests" +TEST_OUTPUTS_FILE_REPO = "test_outputs" +CODE_SUMMARIES_FILE_REPO = "docs/code_summary" +CODE_SUMMARIES_PDF_FILE_REPO = "resources/code_summary" +RESOURCES_FILE_REPO = "resources" +SD_OUTPUT_FILE_REPO = DEFAULT_WORKSPACE_ROOT +GRAPH_REPO_FILE_REPO = "docs/graph_repo" +VISUAL_GRAPH_REPO_FILE_REPO = "resources/graph_db" +CLASS_VIEW_FILE_REPO = "docs/class_view" + +YAPI_URL = "http://yapi.deepwisdomai.com/" +SD_URL = "http://172.31.0.51:49094" + +DEFAULT_LANGUAGE = "English" +DEFAULT_MAX_TOKENS = 1500 +COMMAND_TOKENS = 500 +BRAIN_MEMORY = "BRAIN_MEMORY" +SKILL_PATH = "SKILL_PATH" +SERPER_API_KEY = "SERPER_API_KEY" +DEFAULT_TOKEN_SIZE = 500 + +# format +BASE64_FORMAT = "base64" + +# REDIS +REDIS_KEY = "REDIS_KEY" + +# Message id +IGNORED_MESSAGE_ID = "0" + +# Class Relationship +GENERALIZATION = "Generalize" +COMPOSITION = "Composite" +AGGREGATION = "Aggregate" + +# Timeout +USE_CONFIG_TIMEOUT = 0 # Using llm.timeout configuration. +LLM_API_TIMEOUT = 300 + +# Assistant alias +ASSISTANT_ALIAS = "response" + +# Markdown +MARKDOWN_TITLE_PREFIX = "## " + +# Reporter +METAGPT_REPORTER_DEFAULT_URL = os.environ.get("METAGPT_REPORTER_URL", "") + +# Metadata defines +AGENT = "agent" +IMAGES = "images" + +# SWE agent +SWE_SETUP_PATH = get_metagpt_package_root() / "metagpt/tools/swe_agent_commands/setup_default.sh" + +# experience pool +EXPERIENCE_MASK = "" + +# TeamLeader's name +TEAMLEADER_NAME = "Mike" + +DEFAULT_MIN_TOKEN_COUNT = 10000 +DEFAULT_MAX_TOKEN_COUNT = 100000000 diff --git a/metagpt/context.py b/metagpt/context.py new file mode 100644 index 0000000000000000000000000000000000000000..0769f78eb66b179b9daa42a68c55910d18ebeea8 --- /dev/null +++ b/metagpt/context.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/4 16:32 +@Author : alexanderwu +@File : context.py +""" +from __future__ import annotations + +import os +from typing import Any, Dict, Optional + +from pydantic import BaseModel, ConfigDict, Field + +from metagpt.config2 import Config +from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.provider.base_llm import BaseLLM +from metagpt.provider.llm_provider_registry import create_llm_instance +from metagpt.utils.cost_manager import ( + CostManager, + FireworksCostManager, + TokenCostManager, +) + + +class AttrDict(BaseModel): + """A dict-like object that allows access to keys as attributes, compatible with Pydantic.""" + + model_config = ConfigDict(extra="allow") + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.__dict__.update(kwargs) + + def __getattr__(self, key): + return self.__dict__.get(key, None) + + def __setattr__(self, key, value): + self.__dict__[key] = value + + def __delattr__(self, key): + if key in self.__dict__: + del self.__dict__[key] + else: + raise AttributeError(f"No such attribute: {key}") + + def set(self, key, val: Any): + self.__dict__[key] = val + + def get(self, key, default: Any = None): + return self.__dict__.get(key, default) + + def remove(self, key): + if key in self.__dict__: + self.__delattr__(key) + + +class Context(BaseModel): + """Env context for MetaGPT""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + kwargs: AttrDict = AttrDict() + config: Config = Field(default_factory=Config.default) + + cost_manager: CostManager = CostManager() + + _llm: Optional[BaseLLM] = None + + def new_environ(self): + """Return a new os.environ object""" + env = os.environ.copy() + # i = self.options + # env.update({k: v for k, v in i.items() if isinstance(v, str)}) + return env + + def _select_costmanager(self, llm_config: LLMConfig) -> CostManager: + """Return a CostManager instance""" + if llm_config.api_type == LLMType.FIREWORKS: + return FireworksCostManager() + elif llm_config.api_type == LLMType.OPEN_LLM: + return TokenCostManager() + else: + return self.cost_manager + + def llm(self) -> BaseLLM: + """Return a LLM instance, fixme: support cache""" + # if self._llm is None: + self._llm = create_llm_instance(self.config.llm) + if self._llm.cost_manager is None: + self._llm.cost_manager = self._select_costmanager(self.config.llm) + return self._llm + + def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM: + """Return a LLM instance, fixme: support cache""" + # if self._llm is None: + llm = create_llm_instance(llm_config) + if llm.cost_manager is None: + llm.cost_manager = self._select_costmanager(llm_config) + return llm + + def serialize(self) -> Dict[str, Any]: + """Serialize the object's attributes into a dictionary. + + Returns: + Dict[str, Any]: A dictionary containing serialized data. + """ + return { + "kwargs": {k: v for k, v in self.kwargs.__dict__.items()}, + "cost_manager": self.cost_manager.model_dump_json(), + } + + def deserialize(self, serialized_data: Dict[str, Any]): + """Deserialize the given serialized data and update the object's attributes accordingly. + + Args: + serialized_data (Dict[str, Any]): A dictionary containing serialized data. + """ + if not serialized_data: + return + kwargs = serialized_data.get("kwargs") + if kwargs: + for k, v in kwargs.items(): + self.kwargs.set(k, v) + cost_manager = serialized_data.get("cost_manager") + if cost_manager: + self.cost_manager.model_validate_json(cost_manager) diff --git a/metagpt/context_mixin.py b/metagpt/context_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..59daa692f6e1785f902e1e5eea0c7c889d90f27f --- /dev/null +++ b/metagpt/context_mixin.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/1/11 17:25 +@Author : alexanderwu +@File : context_mixin.py +""" +from typing import Optional + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from metagpt.config2 import Config +from metagpt.context import Context +from metagpt.provider.base_llm import BaseLLM + + +class ContextMixin(BaseModel): + """Mixin class for context and config""" + + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") + + # Pydantic has bug on _private_attr when using inheritance, so we use private_* instead + # - https://github.com/pydantic/pydantic/issues/7142 + # - https://github.com/pydantic/pydantic/issues/7083 + # - https://github.com/pydantic/pydantic/issues/7091 + + # Env/Role/Action will use this context as private context, or use self.context as public context + private_context: Optional[Context] = Field(default=None, exclude=True) + # Env/Role/Action will use this config as private config, or use self.context.config as public config + private_config: Optional[Config] = Field(default=None, exclude=True) + + # Env/Role/Action will use this llm as private llm, or use self.context._llm instance + private_llm: Optional[BaseLLM] = Field(default=None, exclude=True) + + @model_validator(mode="after") + def validate_context_mixin_extra(self): + self._process_context_mixin_extra() + return self + + def _process_context_mixin_extra(self): + """Process the extra field""" + kwargs = self.model_extra or {} + self.set_context(kwargs.pop("context", None)) + self.set_config(kwargs.pop("config", None)) + self.set_llm(kwargs.pop("llm", None)) + + def set(self, k, v, override=False): + """Set attribute""" + if override or not self.__dict__.get(k): + self.__dict__[k] = v + + def set_context(self, context: Context, override=True): + """Set context""" + self.set("private_context", context, override) + + def set_config(self, config: Config, override=False): + """Set config""" + self.set("private_config", config, override) + if config is not None: + _ = self.llm # init llm + + def set_llm(self, llm: BaseLLM, override=False): + """Set llm""" + self.set("private_llm", llm, override) + + @property + def config(self) -> Config: + """Role config: role config > context config""" + if self.private_config: + return self.private_config + return self.context.config + + @config.setter + def config(self, config: Config) -> None: + """Set config""" + self.set_config(config) + + @property + def context(self) -> Context: + """Role context: role context > context""" + if self.private_context: + return self.private_context + return Context() + + @context.setter + def context(self, context: Context) -> None: + """Set context""" + self.set_context(context) + + @property + def llm(self) -> BaseLLM: + """Role llm: if not existed, init from role.config""" + # print(f"class:{self.__class__.__name__}({self.name}), llm: {self._llm}, llm_config: {self._llm_config}") + if not self.private_llm: + self.private_llm = self.context.llm_with_cost_manager_from_llm_config(self.config.llm) + return self.private_llm + + @llm.setter + def llm(self, llm: BaseLLM) -> None: + """Set llm""" + self.private_llm = llm diff --git a/metagpt/document.py b/metagpt/document.py new file mode 100644 index 0000000000000000000000000000000000000000..4a8bb68d5cc3fced399cf7134b9d248ed36c3b7a --- /dev/null +++ b/metagpt/document.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/6/8 14:03 +@Author : alexanderwu +@File : document.py +@Desc : Classes and Operations Related to Files in the File System. +""" +from enum import Enum +from pathlib import Path +from typing import Optional, Union + +import pandas as pd +from llama_index.core import Document, SimpleDirectoryReader +from llama_index.core.node_parser import SimpleNodeParser +from llama_index.readers.file import PDFReader +from pydantic import BaseModel, ConfigDict, Field +from tqdm import tqdm + +from metagpt.logs import logger +from metagpt.repo_parser import RepoParser + + +def validate_cols(content_col: str, df: pd.DataFrame): + if content_col not in df.columns: + raise ValueError("Content column not found in DataFrame.") + + +def read_data(data_path: Path) -> Union[pd.DataFrame, list[Document]]: + suffix = data_path.suffix + if ".xlsx" == suffix: + data = pd.read_excel(data_path) + elif ".csv" == suffix: + data = pd.read_csv(data_path) + elif ".json" == suffix: + data = pd.read_json(data_path) + elif suffix in (".docx", ".doc"): + data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data() + elif ".txt" == suffix: + data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data() + node_parser = SimpleNodeParser.from_defaults(separator="\n", chunk_size=256, chunk_overlap=0) + data = node_parser.get_nodes_from_documents(data) + elif ".pdf" == suffix: + data = PDFReader.load_data(str(data_path)) + else: + raise NotImplementedError("File format not supported.") + return data + + +class DocumentStatus(Enum): + """Indicates document status, a mechanism similar to RFC/PEP""" + + DRAFT = "draft" + UNDERREVIEW = "underreview" + APPROVED = "approved" + DONE = "done" + + +class Document(BaseModel): + """ + Document: Handles operations related to document files. + """ + + path: Path = Field(default=None) + name: str = Field(default="") + content: str = Field(default="") + + # metadata? in content perhaps. + author: str = Field(default="") + status: DocumentStatus = Field(default=DocumentStatus.DRAFT) + reviews: list = Field(default_factory=list) + + @classmethod + def from_path(cls, path: Path): + """ + Create a Document instance from a file path. + """ + if not path.exists(): + raise FileNotFoundError(f"File {path} not found.") + content = path.read_text() + return cls(content=content, path=path) + + @classmethod + def from_text(cls, text: str, path: Optional[Path] = None): + """ + Create a Document from a text string. + """ + return cls(content=text, path=path) + + def to_path(self, path: Optional[Path] = None): + """ + Save content to the specified file path. + """ + if path is not None: + self.path = path + + if self.path is None: + raise ValueError("File path is not set.") + + self.path.parent.mkdir(parents=True, exist_ok=True) + # TODO: excel, csv, json, etc. + self.path.write_text(self.content, encoding="utf-8") + + def persist(self): + """ + Persist document to disk. + """ + return self.to_path() + + +class IndexableDocument(Document): + """ + Advanced document handling: For vector databases or search engines. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + data: Union[pd.DataFrame, list] + content_col: Optional[str] = Field(default="") + meta_col: Optional[str] = Field(default="") + + @classmethod + def from_path(cls, data_path: Path, content_col="content", meta_col="metadata"): + if not data_path.exists(): + raise FileNotFoundError(f"File {data_path} not found.") + data = read_data(data_path) + if isinstance(data, pd.DataFrame): + validate_cols(content_col, data) + return cls(data=data, content=str(data), content_col=content_col, meta_col=meta_col) + try: + content = data_path.read_text() + except Exception as e: + logger.debug(f"Load {str(data_path)} error: {e}") + content = "" + return cls(data=data, content=content, content_col=content_col, meta_col=meta_col) + + def _get_docs_and_metadatas_by_df(self) -> (list, list): + df = self.data + docs = [] + metadatas = [] + for i in tqdm(range(len(df))): + docs.append(df[self.content_col].iloc[i]) + if self.meta_col: + metadatas.append({self.meta_col: df[self.meta_col].iloc[i]}) + else: + metadatas.append({}) + return docs, metadatas + + def _get_docs_and_metadatas_by_llamaindex(self) -> (list, list): + data = self.data + docs = [i.text for i in data] + metadatas = [i.metadata for i in data] + return docs, metadatas + + def get_docs_and_metadatas(self) -> (list, list): + if isinstance(self.data, pd.DataFrame): + return self._get_docs_and_metadatas_by_df() + elif isinstance(self.data, list): + return self._get_docs_and_metadatas_by_llamaindex() + else: + raise NotImplementedError("Data type not supported for metadata extraction.") + + +class RepoMetadata(BaseModel): + name: str = Field(default="") + n_docs: int = Field(default=0) + n_chars: int = Field(default=0) + symbols: list = Field(default_factory=list) + + +class Repo(BaseModel): + # Name of this repo. + name: str = Field(default="") + # metadata: RepoMetadata = Field(default=RepoMetadata) + docs: dict[Path, Document] = Field(default_factory=dict) + codes: dict[Path, Document] = Field(default_factory=dict) + assets: dict[Path, Document] = Field(default_factory=dict) + path: Path = Field(default=None) + + def _path(self, filename): + return self.path / filename + + @classmethod + def from_path(cls, path: Path): + """Load documents, code, and assets from a repository path.""" + path.mkdir(parents=True, exist_ok=True) + repo = Repo(path=path, name=path.name) + for file_path in path.rglob("*"): + # FIXME: These judgments are difficult to support multiple programming languages and need to be more general + if file_path.is_file() and file_path.suffix in [".json", ".txt", ".md", ".py", ".js", ".css", ".html"]: + repo._set(file_path.read_text(), file_path) + return repo + + def to_path(self): + """Persist all documents, code, and assets to the given repository path.""" + for doc in self.docs.values(): + doc.to_path() + for code in self.codes.values(): + code.to_path() + for asset in self.assets.values(): + asset.to_path() + + def _set(self, content: str, path: Path): + """Add a document to the appropriate category based on its file extension.""" + suffix = path.suffix + doc = Document(content=content, path=path, name=str(path.relative_to(self.path))) + + # FIXME: These judgments are difficult to support multiple programming languages and need to be more general + if suffix.lower() == ".md": + self.docs[path] = doc + elif suffix.lower() in [".py", ".js", ".css", ".html"]: + self.codes[path] = doc + else: + self.assets[path] = doc + return doc + + def set(self, filename: str, content: str): + """Set a document and persist it to disk.""" + path = self._path(filename) + doc = self._set(content, path) + doc.to_path() + + def get(self, filename: str) -> Optional[Document]: + """Get a document by its filename.""" + path = self._path(filename) + return self.docs.get(path) or self.codes.get(path) or self.assets.get(path) + + def get_text_documents(self) -> list[Document]: + return list(self.docs.values()) + list(self.codes.values()) + + def eda(self) -> RepoMetadata: + n_docs = sum(len(i) for i in [self.docs, self.codes, self.assets]) + n_chars = sum(sum(len(j.content) for j in i.values()) for i in [self.docs, self.codes, self.assets]) + symbols = RepoParser(base_directory=self.path).generate_symbols() + return RepoMetadata(name=self.name, n_docs=n_docs, n_chars=n_chars, symbols=symbols) diff --git a/metagpt/document_store/__init__.py b/metagpt/document_store/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..766e141a5e90079de122fda03fa5ff3a5e833f54 --- /dev/null +++ b/metagpt/document_store/__init__.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/25 10:20 +@Author : alexanderwu +@File : __init__.py +""" + +from metagpt.document_store.faiss_store import FaissStore + +__all__ = ["FaissStore"] diff --git a/metagpt/document_store/base_store.py b/metagpt/document_store/base_store.py new file mode 100644 index 0000000000000000000000000000000000000000..6aafc57bb0e6d6e91954244ed7e1b778eab4eb6b --- /dev/null +++ b/metagpt/document_store/base_store.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/28 00:01 +@Author : alexanderwu +@File : base_store.py +""" +from abc import ABC, abstractmethod +from pathlib import Path + + +class BaseStore(ABC): + """FIXME: consider add_index, set_index and think about granularity.""" + + @abstractmethod + def search(self, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def write(self, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def add(self, *args, **kwargs): + raise NotImplementedError + + +class LocalStore(BaseStore, ABC): + def __init__(self, raw_data_path: Path, cache_dir: Path = None): + if not raw_data_path: + raise FileNotFoundError + self.raw_data_path = raw_data_path + self.fname = self.raw_data_path.stem + if not cache_dir: + cache_dir = raw_data_path.parent + self.cache_dir = cache_dir + self.store = self._load() + if not self.store: + self.store = self.write() + + def _get_index_and_store_fname(self, index_ext=".json", docstore_ext=".json"): + index_file = self.cache_dir / "default__vector_store" / index_ext + store_file = self.cache_dir / "docstore" / docstore_ext + return index_file, store_file + + @abstractmethod + def _load(self): + raise NotImplementedError + + @abstractmethod + def _write(self, docs, metadatas): + raise NotImplementedError diff --git a/metagpt/document_store/chromadb_store.py b/metagpt/document_store/chromadb_store.py new file mode 100644 index 0000000000000000000000000000000000000000..1d3a014ee63cf4a1def6dd2de22dc30313ff8b03 --- /dev/null +++ b/metagpt/document_store/chromadb_store.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/29 14:46 +@Author : alexanderwu +@File : chromadb_store.py +""" +import chromadb + + +class ChromaStore: + """If inherited from BaseStore, or importing other modules from metagpt, a Python exception occurs, which is strange.""" + + def __init__(self, name: str, get_or_create: bool = False): + client = chromadb.Client() + collection = client.create_collection(name, get_or_create=get_or_create) + self.client = client + self.collection = collection + + def search(self, query, n_results=2, metadata_filter=None, document_filter=None): + # kwargs can be used for optional filtering + results = self.collection.query( + query_texts=[query], + n_results=n_results, + where=metadata_filter, # optional filter + where_document=document_filter, # optional filter + ) + return results + + def persist(self): + """Chroma recommends using server mode and not persisting locally.""" + raise NotImplementedError + + def write(self, documents, metadatas, ids): + # This function is similar to add(), but it's for more generalized updates + # It assumes you're passing in lists of docs, metadatas, and ids + return self.collection.add( + documents=documents, + metadatas=metadatas, + ids=ids, + ) + + def add(self, document, metadata, _id): + # This function is for adding individual documents + # It assumes you're passing in a single doc, metadata, and id + return self.collection.add( + documents=[document], + metadatas=[metadata], + ids=[_id], + ) + + def delete(self, _id): + return self.collection.delete([_id]) diff --git a/metagpt/document_store/faiss_store.py b/metagpt/document_store/faiss_store.py new file mode 100644 index 0000000000000000000000000000000000000000..b196bef270665ae7a72a6f5b25d32ff5ac39f497 --- /dev/null +++ b/metagpt/document_store/faiss_store.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/5/25 10:20 +@Author : alexanderwu +@File : faiss_store.py +""" +import asyncio +from pathlib import Path +from typing import Any, Optional + +import faiss +from llama_index.core import VectorStoreIndex, load_index_from_storage +from llama_index.core.embeddings import BaseEmbedding +from llama_index.core.schema import Document, QueryBundle, TextNode +from llama_index.core.storage import StorageContext +from llama_index.vector_stores.faiss import FaissVectorStore + +from metagpt.document import IndexableDocument +from metagpt.document_store.base_store import LocalStore +from metagpt.logs import logger +from metagpt.utils.embedding import get_embedding + + +class FaissStore(LocalStore): + def __init__( + self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: BaseEmbedding = None + ): + self.meta_col = meta_col + self.content_col = content_col + self.embedding = embedding or get_embedding() + self.store: VectorStoreIndex + super().__init__(raw_data, cache_dir) + + def _load(self) -> Optional["VectorStoreIndex"]: + index_file, store_file = self._get_index_and_store_fname() + + if not (index_file.exists() and store_file.exists()): + logger.info("Missing at least one of index_file/store_file, load failed and return None") + return None + vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.cache_dir) + storage_context = StorageContext.from_defaults(persist_dir=self.cache_dir, vector_store=vector_store) + index = load_index_from_storage(storage_context, embed_model=self.embedding) + + return index + + def _write(self, docs: list[str], metadatas: list[dict[str, Any]]) -> VectorStoreIndex: + assert len(docs) == len(metadatas) + documents = [Document(text=doc, metadata=metadatas[idx]) for idx, doc in enumerate(docs)] + + vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536)) + storage_context = StorageContext.from_defaults(vector_store=vector_store) + index = VectorStoreIndex.from_documents( + documents=documents, storage_context=storage_context, embed_model=self.embedding + ) + + return index + + def persist(self): + self.store.storage_context.persist(self.cache_dir) + + def search(self, query: str, expand_cols=False, sep="\n", *args, k=5, **kwargs): + retriever = self.store.as_retriever(similarity_top_k=k) + rsp = retriever.retrieve(QueryBundle(query_str=query, embedding=self.embedding.get_text_embedding(query))) + + logger.debug(rsp) + if expand_cols: + return str(sep.join([f"{x.node.text}: {x.node.metadata}" for x in rsp])) + else: + return str(sep.join([f"{x.node.text}" for x in rsp])) + + async def asearch(self, *args, **kwargs): + return await asyncio.to_thread(self.search, *args, **kwargs) + + def write(self): + """Initialize the index and library based on the Document (JSON / XLSX, etc.) file provided by the user.""" + if not self.raw_data_path.exists(): + raise FileNotFoundError + doc = IndexableDocument.from_path(self.raw_data_path, self.content_col, self.meta_col) + docs, metadatas = doc.get_docs_and_metadatas() + + self.store = self._write(docs, metadatas) + self.persist() + return self.store + + def add(self, texts: list[str], *args, **kwargs) -> list[str]: + """FIXME: Currently, the store is not updated after adding.""" + texts_embeds = self.embedding.get_text_embedding_batch(texts) + nodes = [TextNode(text=texts[idx], embedding=embed) for idx, embed in enumerate(texts_embeds)] + self.store.insert_nodes(nodes) + + return [] + + def delete(self, *args, **kwargs): + """Currently, faiss does not provide a delete interface.""" + raise NotImplementedError diff --git a/metagpt/document_store/lancedb_store.py b/metagpt/document_store/lancedb_store.py new file mode 100644 index 0000000000000000000000000000000000000000..99c4575a6ce76f06511f9538c66c7daf6f8f120b --- /dev/null +++ b/metagpt/document_store/lancedb_store.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/8/9 15:42 +@Author : unkn-wn (Leon Yee) +@File : lancedb_store.py +""" +import os +import shutil + +import lancedb + + +class LanceStore: + def __init__(self, name): + db = lancedb.connect("./data/lancedb") + self.db = db + self.name = name + self.table = None + + def search(self, query, n_results=2, metric="L2", nprobes=20, **kwargs): + # This assumes query is a vector embedding + # kwargs can be used for optional filtering + # .select - only searches the specified columns + # .where - SQL syntax filtering for metadata (e.g. where("price > 100")) + # .metric - specifies the distance metric to use + # .nprobes - values will yield better recall (more likely to find vectors if they exist) at the expense of latency. + if self.table is None: + raise Exception("Table not created yet, please add data first.") + + results = ( + self.table.search(query) + .limit(n_results) + .select(kwargs.get("select")) + .where(kwargs.get("where")) + .metric(metric) + .nprobes(nprobes) + .to_df() + ) + return results + + def persist(self): + raise NotImplementedError + + def write(self, data, metadatas, ids): + # This function is similar to add(), but it's for more generalized updates + # "data" is the list of embeddings + # Inserts into table by expanding metadatas into a dataframe: [{'vector', 'id', 'meta', 'meta2'}, ...] + + documents = [] + for i in range(len(data)): + row = {"vector": data[i], "id": ids[i]} + row.update(metadatas[i]) + documents.append(row) + + if self.table is not None: + self.table.add(documents) + else: + self.table = self.db.create_table(self.name, documents) + + def add(self, data, metadata, _id): + # This function is for adding individual documents + # It assumes you're passing in a single vector embedding, metadata, and id + + row = {"vector": data, "id": _id} + row.update(metadata) + + if self.table is not None: + self.table.add([row]) + else: + self.table = self.db.create_table(self.name, [row]) + + def delete(self, _id): + # This function deletes a row by id. + # LanceDB delete syntax uses SQL syntax, so you can use "in" or "=" + if self.table is None: + raise Exception("Table not created yet, please add data first") + + if isinstance(_id, str): + return self.table.delete(f"id = '{_id}'") + else: + return self.table.delete(f"id = {_id}") + + def drop(self, name): + # This function drops a table, if it exists. + + path = os.path.join(self.db.uri, name + ".lance") + if os.path.exists(path): + shutil.rmtree(path) diff --git a/metagpt/document_store/milvus_store.py b/metagpt/document_store/milvus_store.py new file mode 100644 index 0000000000000000000000000000000000000000..e4d6d985e21d373976fbffcad0114be509b67a99 --- /dev/null +++ b/metagpt/document_store/milvus_store.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from metagpt.document_store.base_store import BaseStore + + +@dataclass +class MilvusConnection: + """ + Args: + uri: milvus url + token: milvus token + """ + + uri: str = None + token: str = None + + +class MilvusStore(BaseStore): + def __init__(self, connect: MilvusConnection): + try: + from pymilvus import MilvusClient + except ImportError: + raise Exception("Please install pymilvus first.") + if not connect.uri: + raise Exception("please check MilvusConnection, uri must be set.") + self.client = MilvusClient(uri=connect.uri, token=connect.token) + + def create_collection(self, collection_name: str, dim: int, enable_dynamic_schema: bool = True): + from pymilvus import DataType + + if self.client.has_collection(collection_name=collection_name): + self.client.drop_collection(collection_name=collection_name) + + schema = self.client.create_schema( + auto_id=False, + enable_dynamic_field=False, + ) + schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36) + schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim) + + index_params = self.client.prepare_index_params() + index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE") + + self.client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_params, + enable_dynamic_schema=enable_dynamic_schema, + ) + + @staticmethod + def build_filter(key, value) -> str: + if isinstance(value, str): + filter_expression = f'{key} == "{value}"' + else: + if isinstance(value, list): + filter_expression = f"{key} in {value}" + else: + filter_expression = f"{key} == {value}" + + return filter_expression + + def search( + self, + collection_name: str, + query: List[float], + filter: Dict = None, + limit: int = 10, + output_fields: Optional[List[str]] = None, + ) -> List[dict]: + filter_expression = " and ".join([self.build_filter(key, value) for key, value in filter.items()]) + print(filter_expression) + + res = self.client.search( + collection_name=collection_name, + data=[query], + filter=filter_expression, + limit=limit, + output_fields=output_fields, + )[0] + + return res + + def add(self, collection_name: str, _ids: List[str], vector: List[List[float]], metadata: List[Dict[str, Any]]): + data = dict() + + for i, id in enumerate(_ids): + data["id"] = id + data["vector"] = vector[i] + data["metadata"] = metadata[i] + + self.client.upsert(collection_name=collection_name, data=data) + + def delete(self, collection_name: str, _ids: List[str]): + self.client.delete(collection_name=collection_name, ids=_ids) + + def write(self, *args, **kwargs): + pass diff --git a/metagpt/document_store/qdrant_store.py b/metagpt/document_store/qdrant_store.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9637aa7abe40f6bc19d5faaea20d1233791b9d --- /dev/null +++ b/metagpt/document_store/qdrant_store.py @@ -0,0 +1,124 @@ +from dataclasses import dataclass +from typing import List + +from qdrant_client import QdrantClient +from qdrant_client.models import Filter, PointStruct, VectorParams + +from metagpt.document_store.base_store import BaseStore + + +@dataclass +class QdrantConnection: + """ + Args: + url: qdrant url + host: qdrant host + port: qdrant port + memory: qdrant service use memory mode + api_key: qdrant cloud api_key + """ + + url: str = None + host: str = None + port: int = None + memory: bool = False + api_key: str = None + + +class QdrantStore(BaseStore): + def __init__(self, connect: QdrantConnection): + if connect.memory: + self.client = QdrantClient(":memory:") + elif connect.url: + self.client = QdrantClient(url=connect.url, api_key=connect.api_key) + elif connect.host and connect.port: + self.client = QdrantClient(host=connect.host, port=connect.port, api_key=connect.api_key) + else: + raise Exception("please check QdrantConnection.") + + def create_collection( + self, + collection_name: str, + vectors_config: VectorParams, + force_recreate=False, + **kwargs, + ): + """ + create a collection + Args: + collection_name: collection name + vectors_config: VectorParams object,detail in https://github.com/qdrant/qdrant-client + force_recreate: default is False, if True, will delete exists collection,then create it + **kwargs: + + Returns: + + """ + try: + self.client.get_collection(collection_name) + if force_recreate: + res = self.client.recreate_collection(collection_name, vectors_config=vectors_config, **kwargs) + return res + return True + except: # noqa: E722 + return self.client.recreate_collection(collection_name, vectors_config=vectors_config, **kwargs) + + def has_collection(self, collection_name: str): + try: + self.client.get_collection(collection_name) + return True + except: # noqa: E722 + return False + + def delete_collection(self, collection_name: str, timeout=60): + res = self.client.delete_collection(collection_name, timeout=timeout) + if not res: + raise Exception(f"Delete collection {collection_name} failed.") + + def add(self, collection_name: str, points: List[PointStruct]): + """ + add some vector data to qdrant + Args: + collection_name: collection name + points: list of PointStruct object, about PointStruct detail in https://github.com/qdrant/qdrant-client + + Returns: NoneX + + """ + # self.client.upload_records() + self.client.upsert( + collection_name, + points, + ) + + def search( + self, + collection_name: str, + query: List[float], + query_filter: Filter = None, + k=10, + return_vector=False, + ): + """ + vector search + Args: + collection_name: qdrant collection name + query: input vector + query_filter: Filter object, detail in https://github.com/qdrant/qdrant-client + k: return the most similar k pieces of data + return_vector: whether return vector + + Returns: list of dict + + """ + hits = self.client.search( + collection_name=collection_name, + query_vector=query, + query_filter=query_filter, + limit=k, + with_vectors=return_vector, + ) + return [hit.__dict__ for hit in hits] + + def write(self, *args, **kwargs): + pass diff --git a/metagpt/environment/.DS_Store b/metagpt/environment/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..396c10e38bd4745f8615a8a1aaa47cc56084f7e0 Binary files /dev/null and b/metagpt/environment/.DS_Store differ diff --git a/metagpt/environment/README.md b/metagpt/environment/README.md new file mode 100644 index 0000000000000000000000000000000000000000..bb7d50d5013966513258603f788364cd3cf7145e --- /dev/null +++ b/metagpt/environment/README.md @@ -0,0 +1,38 @@ +Here is a environment description of MetaGPT env for different situation. +For now, the code only define the environment and still some todos like migrate roles/actions to current version. + +## Function +- Define `ExtEnv`(Base Class) which help users to integrate with external environment like games through apis or construct the game logics. +- Define `Environment`(Base Class) which is the env that MetaGPT directly used. And it includes roles and so on. +- Define the `EnvAPIRegistry` to mark the read/write apis that `ExtEnv` provide observe/step ability. And then, users can call the particular one to get observation from env or feedback to env. + +## Usage + +init environment +``` +android_env = env.create(EnvType.ANDROID) + +assistant = Role(name="Bob", profile="android assistant") +team = Team(investment=10.0, env=android_env, roles=[assistant]) +``` + +observe & step inside role's actions +``` +from metagpt.environment.api.env_api import EnvAPIAbstract + +# get screenshot from ExtEnv +screenshot_path: Path = await env.observe( + EnvAPIAbstract( + api_name="get_screenshot", kwargs={"ss_name": f"{round_count}_before", "local_save_dir": task_dir} + ) + ) + +# do a `tap` action on the screen +res = env.step(EnvAPIAbstract("system_tap", kwargs={"x": x, "y": y})) +``` + +## TODO +- add android app operation assistant under `examples/android_assistant` +- migrate roles/actions of werewolf game from old version into current version +- migrate roles/actions of minecraft game from old version into current version +- migrate roles/actions of stanford_town game from old version into current version diff --git a/metagpt/environment/__init__.py b/metagpt/environment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d77b3a83de28c92e0e11cc0e6e6a365b2b1c76 --- /dev/null +++ b/metagpt/environment/__init__.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from metagpt.environment.base_env import Environment + +# from metagpt.environment.android.android_env import AndroidEnv +from metagpt.environment.werewolf.werewolf_env import WerewolfEnv +from metagpt.environment.stanford_town.stanford_town_env import StanfordTownEnv +from metagpt.environment.software.software_env import SoftwareEnv + + +__all__ = ["AndroidEnv", "WerewolfEnv", "StanfordTownEnv", "SoftwareEnv", "Environment"] diff --git a/metagpt/environment/android/__init__.py b/metagpt/environment/android/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/environment/android/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/android/android_env.py b/metagpt/environment/android/android_env.py new file mode 100644 index 0000000000000000000000000000000000000000..66672d219e8cec2beaa1a79466b15e4882420880 --- /dev/null +++ b/metagpt/environment/android/android_env.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG Android Env + +from pydantic import Field + +from metagpt.environment.android.android_ext_env import AndroidExtEnv +from metagpt.environment.base_env import Environment + + +class AndroidEnv(AndroidExtEnv, Environment): + """in order to use actual `reset`&`observe`, inherited order: AndroidExtEnv, Environment""" + + rows: int = Field(default=0, description="rows of a grid on the screenshot") + cols: int = Field(default=0, description="cols of a grid on the screenshot") diff --git a/metagpt/environment/android/android_ext_env.py b/metagpt/environment/android/android_ext_env.py new file mode 100644 index 0000000000000000000000000000000000000000..63a421fa2f0dd5f0b164d6a9bfd2ee879f2630b1 --- /dev/null +++ b/metagpt/environment/android/android_ext_env.py @@ -0,0 +1,375 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : The Android external environment to integrate with Android apps +import subprocess +import time +from pathlib import Path +from typing import Any, Optional + +import clip +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from PIL import Image +from pydantic import Field + +from metagpt.const import DEFAULT_WORKSPACE_ROOT +from metagpt.environment.android.const import ADB_EXEC_FAIL +from metagpt.environment.android.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, + EnvObsValType, +) +from metagpt.environment.android.text_icon_localization import ( + clip_for_icon, + crop_for_clip, + det, + load_model, + ocr, +) +from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable +from metagpt.logs import logger +from metagpt.utils.common import download_model + + +def load_cv_model(device: str = "cpu") -> any: + ocr_detection = pipeline(Tasks.ocr_detection, model="damo/cv_resnet18_ocr-detection-line-level_damo") + ocr_recognition = pipeline(Tasks.ocr_recognition, model="damo/cv_convnextTiny_ocr-recognition-document_damo") + file_url = "https://huggingface.co/ShilongLiu/GroundingDINO/blob/main/groundingdino_swint_ogc.pth" + target_folder = Path(f"{DEFAULT_WORKSPACE_ROOT}/weights") + file_path = download_model(file_url, target_folder) + groundingdino_model = load_model(file_path, device=device).eval() + return ocr_detection, ocr_recognition, groundingdino_model + + +class AndroidExtEnv(ExtEnv): + device_id: Optional[str] = Field(default=None) + screenshot_dir: Optional[Path] = Field(default=None) + xml_dir: Optional[Path] = Field(default=None) + width: int = Field(default=720, description="device screen width") + height: int = Field(default=1080, description="device screen height") + ocr_detection: any = Field(default=None, description="ocr detection model") + ocr_recognition: any = Field(default=None, description="ocr recognition model") + groundingdino_model: any = Field(default=None, description="clip groundingdino model") + + def __init__(self, **data: Any): + super().__init__(**data) + device_id = data.get("device_id") + self.ocr_detection, self.ocr_recognition, self.groundingdino_model = load_cv_model() + if device_id: + devices = self.list_devices() + if device_id not in devices: + raise RuntimeError(f"device-id: {device_id} not found") + (width, height) = self.device_shape + self.width = data.get("width", width) + self.height = data.get("height", height) + self.create_device_path(self.screenshot_dir) + self.create_device_path(self.xml_dir) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + super().reset(seed=seed, options=options) + + obs = self._get_obs() + + return obs, {} + + def _get_obs(self) -> dict[str, EnvObsValType]: + pass + + def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any: + obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE + if obs_type == EnvObsType.NONE: + pass + elif obs_type == EnvObsType.GET_SCREENSHOT: + obs = self.get_screenshot(ss_name=obs_params.ss_name, local_save_dir=obs_params.local_save_dir) + elif obs_type == EnvObsType.GET_XML: + obs = self.get_xml(xml_name=obs_params.xml_name, local_save_dir=obs_params.local_save_dir) + return obs + + def step(self, action: EnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + res = self._execute_env_action(action) + + obs = {} + + ret = (obs, 1.0, False, False, {"res": res}) + return ret + + def _execute_env_action(self, action: EnvAction): + action_type = action.action_type + res = None + if action_type == EnvActionType.NONE: + pass + elif action_type == EnvActionType.SYSTEM_BACK: + res = self.system_back() + elif action_type == EnvActionType.SYSTEM_TAP: + res = self.system_tap(x=action.coord[0], y=action.coord[1]) + elif action_type == EnvActionType.USER_INPUT: + res = self.user_input(input_txt=action.input_txt) + elif action_type == EnvActionType.USER_LONGPRESS: + res = self.user_longpress(x=action.coord[0], y=action.coord[1]) + elif action_type == EnvActionType.USER_SWIPE: + res = self.user_swipe(x=action.coord[0], y=action.coord[1], orient=action.orient, dist=action.dist) + elif action_type == EnvActionType.USER_SWIPE_TO: + res = self.user_swipe_to(start=action.coord, end=action.tgt_coord) + return res + + @property + def adb_prefix_si(self): + """adb cmd prefix with `device_id` and `shell input`""" + return f"adb -s {self.device_id} shell input " + + @property + def adb_prefix_shell(self): + """adb cmd prefix with `device_id` and `shell`""" + return f"adb -s {self.device_id} shell " + + @property + def adb_prefix(self): + """adb cmd prefix with `device_id`""" + return f"adb -s {self.device_id} " + + def execute_adb_with_cmd(self, adb_cmd: str) -> str: + adb_cmd = adb_cmd.replace("\\", "/") + res = subprocess.run(adb_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + exec_res = ADB_EXEC_FAIL + if not res.returncode: + exec_res = res.stdout.strip() + return exec_res + + def create_device_path(self, folder_path: Path): + adb_cmd = f"{self.adb_prefix_shell} mkdir {folder_path} -p" + res = self.execute_adb_with_cmd(adb_cmd) + if res == ADB_EXEC_FAIL: + raise RuntimeError(f"create device path: {folder_path} failed") + + @property + def device_shape(self) -> tuple[int, int]: + adb_cmd = f"{self.adb_prefix_shell} wm size" + shape = (0, 0) + shape_res = self.execute_adb_with_cmd(adb_cmd) + if shape_res != ADB_EXEC_FAIL: + shape = tuple(map(int, shape_res.split(": ")[1].split("x"))) + return shape + + def list_devices(self): + adb_cmd = "adb devices" + res = self.execute_adb_with_cmd(adb_cmd) + devices = [] + if res != ADB_EXEC_FAIL: + devices = res.split("\n")[1:] + devices = [device.split()[0] for device in devices] + return devices + + @mark_as_readable + def get_screenshot(self, ss_name: str, local_save_dir: Path) -> Path: + """ + ss_name: screenshot file name + local_save_dir: local dir to store image from virtual machine + """ + assert self.screenshot_dir + ss_remote_path = Path(self.screenshot_dir).joinpath(f"{ss_name}.png") + ss_cmd = f"{self.adb_prefix_shell} screencap -p {ss_remote_path}" + ss_res = self.execute_adb_with_cmd(ss_cmd) + time.sleep(0.1) + res = ADB_EXEC_FAIL + if ss_res != ADB_EXEC_FAIL: + ss_local_path = Path(local_save_dir).joinpath(f"{ss_name}.png") + pull_cmd = f"{self.adb_prefix} pull {ss_remote_path} {ss_local_path}" + pull_res = self.execute_adb_with_cmd(pull_cmd) + time.sleep(0.1) + if pull_res != ADB_EXEC_FAIL: + res = ss_local_path + else: + ss_cmd = f"{self.adb_prefix_shell} rm /sdcard/{ss_name}.png" + ss_res = self.execute_adb_with_cmd(ss_cmd) + time.sleep(0.1) + ss_cmd = f"{self.adb_prefix_shell} screencap -p /sdcard/{ss_name}.png" + ss_res = self.execute_adb_with_cmd(ss_cmd) + time.sleep(0.1) + ss_cmd = f"{self.adb_prefix} pull /sdcard/{ss_name}.png {self.screenshot_dir}" + ss_res = self.execute_adb_with_cmd(ss_cmd) + image_path = Path(f"{self.screenshot_dir}/{ss_name}.png") + res = image_path + return Path(res) + + @mark_as_readable + def get_xml(self, xml_name: str, local_save_dir: Path) -> Path: + xml_remote_path = Path(self.xml_dir).joinpath(f"{xml_name}.xml") + dump_cmd = f"{self.adb_prefix_shell} uiautomator dump {xml_remote_path}" + xml_res = self.execute_adb_with_cmd(dump_cmd) + + res = ADB_EXEC_FAIL + if xml_res != ADB_EXEC_FAIL: + xml_local_path = Path(local_save_dir).joinpath(f"{xml_name}.xml") + pull_cmd = f"{self.adb_prefix} pull {xml_remote_path} {xml_local_path}" + pull_res = self.execute_adb_with_cmd(pull_cmd) + if pull_res != ADB_EXEC_FAIL: + res = xml_local_path + return Path(res) + + @mark_as_writeable + def system_back(self) -> str: + adb_cmd = f"{self.adb_prefix_si} keyevent KEYCODE_BACK" + back_res = self.execute_adb_with_cmd(adb_cmd) + return back_res + + @mark_as_writeable + def system_tap(self, x: int, y: int) -> str: + adb_cmd = f"{self.adb_prefix_si} tap {x} {y}" + tap_res = self.execute_adb_with_cmd(adb_cmd) + return tap_res + + @mark_as_writeable + def user_input(self, input_txt: str) -> str: + input_txt = input_txt.replace(" ", "%s").replace("'", "") + adb_cmd = f"{self.adb_prefix_si} text {input_txt}" + input_res = self.execute_adb_with_cmd(adb_cmd) + return input_res + + @mark_as_writeable + def user_longpress(self, x: int, y: int, duration: int = 500) -> str: + adb_cmd = f"{self.adb_prefix_si} swipe {x} {y} {x} {y} {duration}" + press_res = self.execute_adb_with_cmd(adb_cmd) + return press_res + + @mark_as_writeable + def user_swipe(self, x: int, y: int, orient: str = "up", dist: str = "medium", if_quick: bool = False) -> str: + dist_unit = int(self.width / 10) + if dist == "long": + dist_unit *= 3 + elif dist == "medium": + dist_unit *= 2 + + if orient == "up": + offset = 0, -2 * dist_unit + elif orient == "down": + offset = 0, 2 * dist_unit + elif orient == "left": + offset = -1 * dist_unit, 0 + elif orient == "right": + offset = dist_unit, 0 + else: + return ADB_EXEC_FAIL + + duration = 100 if if_quick else 400 + adb_cmd = f"{self.adb_prefix_si} swipe {x} {y} {x + offset[0]} {y + offset[1]} {duration}" + swipe_res = self.execute_adb_with_cmd(adb_cmd) + return swipe_res + + @mark_as_writeable + def user_swipe_to(self, start: tuple[int, int], end: tuple[int, int], duration: int = 400) -> str: + adb_cmd = f"{self.adb_prefix_si} swipe {start[0]} {start[1]} {end[0]} {end[1]} {duration}" + swipe_res = self.execute_adb_with_cmd(adb_cmd) + return swipe_res + + @mark_as_writeable + def user_exit(self) -> str: + adb_cmd = f"{self.adb_prefix_shell} am start -a android.intent.action.MAIN -c android.intent.category.HOME" + exit_res = self.execute_adb_with_cmd(adb_cmd) + return exit_res + + def _ocr_text(self, text: str) -> list: + image = self.get_screenshot("screenshot", self.screenshot_dir) + iw, ih = Image.open(image).size + x, y = self.device_shape + if iw > ih: + x, y = y, x + iw, ih = ih, iw + in_coordinate, out_coordinate = ocr(image, text, self.ocr_detection, self.ocr_recognition, iw, ih) + output_list = [in_coordinate, out_coordinate, x, y, iw, ih, image] + return output_list + + @mark_as_writeable + def user_open_app(self, app_name: str) -> str: + ocr_result = self._ocr_text(app_name) + in_coordinate, _, x, y, iw, ih = ( + ocr_result[0], + ocr_result[1], + ocr_result[2], + ocr_result[3], + ocr_result[4], + ocr_result[5], + ) + if len(in_coordinate) == 0: + logger.info(f"No App named {app_name}.") + return "no app here" + else: + tap_coordinate = [ + (in_coordinate[0][0] + in_coordinate[0][2]) / 2, + (in_coordinate[0][1] + in_coordinate[0][3]) / 2, + ] + tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)] + return self.system_tap(tap_coordinate[0] * x, (tap_coordinate[1] - round(50 / y, 2)) * y) + + @mark_as_writeable + def user_click_text(self, text: str) -> str: + ocr_result = self._ocr_text(text) + in_coordinate, out_coordinate, x, y, iw, ih, _ = ( + ocr_result[0], + ocr_result[1], + ocr_result[2], + ocr_result[3], + ocr_result[4], + ocr_result[5], + ocr_result[6], + ) + if len(out_coordinate) == 0: + logger.info( + f'Failed to execute action click text ({text}). The text "{text}" is not detected in the screenshot.' + ) + elif len(out_coordinate) == 1: + tap_coordinate = [ + (in_coordinate[0][0] + in_coordinate[0][2]) / 2, + (in_coordinate[0][1] + in_coordinate[0][3]) / 2, + ] + tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)] + return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y) + else: + logger.info( + f'Failed to execute action click text ({text}). There are too many text "{text}" in the screenshot.' + ) + + @mark_as_writeable + def user_stop(self): + logger.info("Successful execution of tasks") + + @mark_as_writeable + def user_click_icon(self, icon_shape_color: str) -> str: + screenshot_path = self.get_screenshot("screenshot", self.screenshot_dir) + image = screenshot_path + iw, ih = Image.open(image).size + x, y = self.device_shape + if iw > ih: + x, y = y, x + iw, ih = ih, iw + in_coordinate, out_coordinate = det(image, "icon", self.groundingdino_model) # 检测icon + if len(out_coordinate) == 1: # only one icon + tap_coordinate = [ + (in_coordinate[0][0] + in_coordinate[0][2]) / 2, + (in_coordinate[0][1] + in_coordinate[0][3]) / 2, + ] + tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)] + return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y) + + else: + temp_file = Path(f"{DEFAULT_WORKSPACE_ROOT}/temp") + temp_file.mkdir(parents=True, exist_ok=True) + hash_table, clip_filter = [], [] + for i, (td, box) in enumerate(zip(in_coordinate, out_coordinate)): + if crop_for_clip(image, td, i, temp_file): + hash_table.append(td) + crop_image = f"{i}.png" + clip_filter.append(temp_file.joinpath(crop_image)) + clip_model, clip_preprocess = clip.load("ViT-B/32") # FIXME: device=device + clip_filter = clip_for_icon(clip_model, clip_preprocess, clip_filter, icon_shape_color) + final_box = hash_table[clip_filter] + tap_coordinate = [(final_box[0] + final_box[2]) / 2, (final_box[1] + final_box[3]) / 2] + tap_coordinate = [round(tap_coordinate[0] / iw, 2), round(tap_coordinate[1] / ih, 2)] + print(tap_coordinate[0] * x, tap_coordinate[1] * y) + return self.system_tap(tap_coordinate[0] * x, tap_coordinate[1] * y) diff --git a/metagpt/environment/android/const.py b/metagpt/environment/android/const.py new file mode 100644 index 0000000000000000000000000000000000000000..8811289bf097f478c89f0ab07cfb8aa55d20e7a6 --- /dev/null +++ b/metagpt/environment/android/const.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +# For Android Assistant Agent +ADB_EXEC_FAIL = "FAILED" diff --git a/metagpt/environment/android/env_space.py b/metagpt/environment/android/env_space.py new file mode 100644 index 0000000000000000000000000000000000000000..8225f0127880247deaddb53bf5ab2fc19de90c89 --- /dev/null +++ b/metagpt/environment/android/env_space.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from pathlib import Path +from typing import Union + +import numpy as np +import numpy.typing as npt +from gymnasium import spaces +from pydantic import ConfigDict, Field, field_validator + +from metagpt.base.base_env_space import ( + BaseEnvAction, + BaseEnvActionType, + BaseEnvObsParams, + BaseEnvObsType, +) + + +class EnvActionType(BaseEnvActionType): + NONE = 0 # no action to run, just get observation + + SYSTEM_BACK = 1 + SYSTEM_TAP = 2 + USER_INPUT = 3 + USER_LONGPRESS = 4 + USER_SWIPE = 5 + USER_SWIPE_TO = 6 + + +class EnvAction(BaseEnvAction): + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_type: int = Field(default=EnvActionType.NONE, description="action type") + coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="operation coordinate" + ) + tgt_coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="target operation coordinate" + ) + input_txt: str = Field(default="", description="user input text") + orient: str = Field(default="up", description="swipe orient") + dist: str = Field(default="medium", description="swipe dist") + + @field_validator("coord", "tgt_coord", mode="before") + @classmethod + def check_coord(cls, coord) -> npt.NDArray[np.int64]: + if not isinstance(coord, np.ndarray): + return np.array(coord) + + +class EnvObsType(BaseEnvObsType): + NONE = 0 # get whole observation from env + + GET_SCREENSHOT = 1 + GET_XML = 2 + + +class EnvObsParams(BaseEnvObsParams): + model_config = ConfigDict(arbitrary_types_allowed=True) + + obs_type: int = Field(default=EnvObsType.NONE, description="observation type") + ss_name: str = Field(default="", description="screenshot file name") + xml_name: str = Field(default="", description="xml file name") + local_save_dir: Union[str, Path] = Field(default="", description="local dir to save file") + + +EnvObsValType = str + + +def get_observation_space() -> spaces.Dict: + space = spaces.Dict({"screenshot": spaces.Text(256), "xml": spaces.Text(256)}) + return space + + +def get_action_space(device_shape: tuple[int, int]) -> spaces.Dict: + space = spaces.Dict( + { + "action_type": spaces.Discrete(len(EnvActionType)), + "coord": spaces.Box( + np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64) + ), + "tgt_coord": spaces.Box( + np.array([0, 0], dtype=np.int64), np.array([device_shape[0], device_shape[1]], dtype=np.int64) + ), + "input_txt": spaces.Text(256), + "orient": spaces.Text(16), + "dist": spaces.Text(16), + } + ) + return space diff --git a/metagpt/environment/android/grounding_dino_config.py b/metagpt/environment/android/grounding_dino_config.py new file mode 100644 index 0000000000000000000000000000000000000000..9158d5f6260ec74bded95377d382387430d7cd70 --- /dev/null +++ b/metagpt/environment/android/grounding_dino_config.py @@ -0,0 +1,43 @@ +batch_size = 1 +modelname = "groundingdino" +backbone = "swin_T_224_1k" +position_embedding = "sine" +pe_temperatureH = 20 +pe_temperatureW = 20 +return_interm_indices = [1, 2, 3] +backbone_freeze_keywords = None +enc_layers = 6 +dec_layers = 6 +pre_norm = False +dim_feedforward = 2048 +hidden_dim = 256 +dropout = 0.0 +nheads = 8 +num_queries = 900 +query_dim = 4 +num_patterns = 0 +num_feature_levels = 4 +enc_n_points = 4 +dec_n_points = 4 +two_stage_type = "standard" +two_stage_bbox_embed_share = False +two_stage_class_embed_share = False +transformer_activation = "relu" +dec_pred_bbox_embed_share = True +dn_box_noise_scale = 1.0 +dn_label_noise_ratio = 0.5 +dn_label_coef = 1.0 +dn_bbox_coef = 1.0 +embed_init_tgt = True +dn_labelbook_size = 2000 +max_text_len = 256 +text_encoder_type = "bert-base-uncased" +use_text_enhancer = True +use_fusion_layer = True +use_checkpoint = True +use_transformer_ckpt = True +use_text_cross_attention = True +text_dropout = 0.0 +fusion_dropout = 0.0 +fusion_droppath = 0.1 +sub_sentence_present = True diff --git a/metagpt/environment/android/text_icon_localization.py b/metagpt/environment/android/text_icon_localization.py new file mode 100644 index 0000000000000000000000000000000000000000..e8886b540af596bc895f22a200a485e80b678450 --- /dev/null +++ b/metagpt/environment/android/text_icon_localization.py @@ -0,0 +1,368 @@ +# The code in this file was modified by MobileAgent +# https://github.com/X-PLUG/MobileAgent.git + +import math +from pathlib import Path + +import clip +import cv2 +import groundingdino.datasets.transforms as T +import numpy as np +import torch +from groundingdino.models import build_model +from groundingdino.util.slconfig import SLConfig +from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap +from PIL import Image + +################################## text_localization using ocr ####################### + + +def crop_image(img: any, position: any) -> any: + def distance(x1, y1, x2, y2): + return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2)) + + position = position.tolist() + for i in range(4): + for j in range(i + 1, 4): + if position[i][0] > position[j][0]: + tmp = position[j] + position[j] = position[i] + position[i] = tmp + if position[0][1] > position[1][1]: + tmp = position[0] + position[0] = position[1] + position[1] = tmp + + if position[2][1] > position[3][1]: + tmp = position[2] + position[2] = position[3] + position[3] = tmp + + x1, y1 = position[0][0], position[0][1] + x2, y2 = position[2][0], position[2][1] + x3, y3 = position[3][0], position[3][1] + x4, y4 = position[1][0], position[1][1] + + corners = np.zeros((4, 2), np.float32) + corners[0] = [x1, y1] + corners[1] = [x2, y2] + corners[2] = [x4, y4] + corners[3] = [x3, y3] + + img_width = distance((x1 + x4) / 2, (y1 + y4) / 2, (x2 + x3) / 2, (y2 + y3) / 2) + img_height = distance((x1 + x2) / 2, (y1 + y2) / 2, (x4 + x3) / 2, (y4 + y3) / 2) + + corners_trans = np.zeros((4, 2), np.float32) + corners_trans[0] = [0, 0] + corners_trans[1] = [img_width - 1, 0] + corners_trans[2] = [0, img_height - 1] + corners_trans[3] = [img_width - 1, img_height - 1] + + transform = cv2.getPerspectiveTransform(corners, corners_trans) + dst = cv2.warpPerspective(img, transform, (int(img_width), int(img_height))) + return dst + + +def calculate_size(box: any) -> any: + return (box[2] - box[0]) * (box[3] - box[1]) + + +def order_point(cooperation: any) -> any: + arr = np.array(cooperation).reshape([4, 2]) + sum_ = np.sum(arr, 0) + centroid = sum_ / arr.shape[0] + theta = np.arctan2(arr[:, 1] - centroid[1], arr[:, 0] - centroid[0]) + sort_points = arr[np.argsort(theta)] + sort_points = sort_points.reshape([4, -1]) + if sort_points[0][0] > centroid[0]: + sort_points = np.concatenate([sort_points[3:], sort_points[:3]]) + sort_points = sort_points.reshape([4, 2]).astype("float32") + return sort_points + + +def longest_common_substring_length(str1: str, str2: str) -> int: + m = len(str1) + n = len(str2) + dp = [[0] * (n + 1) for _ in range(m + 1)] + for i in range(1, m + 1): + for j in range(1, n + 1): + if str1[i - 1] == str2[j - 1]: + dp[i][j] = dp[i - 1][j - 1] + 1 + else: + dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]) + + return dp[m][n] + + +def ocr(image_path: Path, prompt: str, ocr_detection: any, ocr_recognition: any, x: int, y: int) -> any: + text_data = [] + coordinate = [] + image = Image.open(image_path) + iw, ih = image.size + + image_full = cv2.imread(str(image_path)) + det_result = ocr_detection(image_full) + det_result = det_result["polygons"] + for i in range(det_result.shape[0]): + pts = order_point(det_result[i]) + image_crop = crop_image(image_full, pts) + result = ocr_recognition(image_crop)["text"][0] + + if result == prompt: + box = [int(e) for e in list(pts.reshape(-1))] + box = [box[0], box[1], box[4], box[5]] + + if calculate_size(box) > 0.05 * iw * ih: + continue + + text_data.append( + [ + int(max(0, box[0] - 10) * x / iw), + int(max(0, box[1] - 10) * y / ih), + int(min(box[2] + 10, iw) * x / iw), + int(min(box[3] + 10, ih) * y / ih), + ] + ) + coordinate.append( + [ + int(max(0, box[0] - 300) * x / iw), + int(max(0, box[1] - 400) * y / ih), + int(min(box[2] + 300, iw) * x / iw), + int(min(box[3] + 400, ih) * y / ih), + ] + ) + + max_length = 0 + if len(text_data) == 0: + for i in range(det_result.shape[0]): + pts = order_point(det_result[i]) + image_crop = crop_image(image_full, pts) + result = ocr_recognition(image_crop)["text"][0] + + if len(result) < 0.3 * len(prompt): + continue + + if result in prompt: + now_length = len(result) + else: + now_length = longest_common_substring_length(result, prompt) + + if now_length > max_length: + max_length = now_length + box = [int(e) for e in list(pts.reshape(-1))] + box = [box[0], box[1], box[4], box[5]] + + text_data = [ + [ + int(max(0, box[0] - 10) * x / iw), + int(max(0, box[1] - 10) * y / ih), + int(min(box[2] + 10, iw) * x / iw), + int(min(box[3] + 10, ih) * y / ih), + ] + ] + coordinate = [ + [ + int(max(0, box[0] - 300) * x / iw), + int(max(0, box[1] - 400) * y / ih), + int(min(box[2] + 300, iw) * x / iw), + int(min(box[3] + 400, ih) * y / ih), + ] + ] + + if len(prompt) <= 10: + if max_length >= 0.8 * len(prompt): + return text_data, coordinate + else: + return [], [] + elif (len(prompt) > 10) and (len(prompt) <= 20): + if max_length >= 0.5 * len(prompt): + return text_data, coordinate + else: + return [], [] + else: + if max_length >= 0.4 * len(prompt): + return text_data, coordinate + else: + return [], [] + + else: + return text_data, coordinate + + +################################## icon_localization using clip ####################### + + +def calculate_iou(box1: list, box2: list) -> float: + x_a = max(box1[0], box2[0]) + y_a = max(box1[1], box2[1]) + x_b = min(box1[2], box2[2]) + y_b = min(box1[3], box2[3]) + + inter_area = max(0, x_b - x_a) * max(0, y_b - y_a) + box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) + box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union_area = box1_area + box2_area - inter_area + iou = inter_area / union_area + + return iou + + +def in_box(box: list, target: list) -> bool: + if (box[0] > target[0]) and (box[1] > target[1]) and (box[2] < target[2]) and (box[3] < target[3]): + return True + else: + return False + + +def crop_for_clip(image: any, box: any, i: int, temp_file: Path) -> bool: + image = Image.open(image) + w, h = image.size + bound = [0, 0, w, h] + if in_box(box, bound): + cropped_image = image.crop(box) + cropped_image.save(temp_file.joinpath(f"{i}.png")) + return True + else: + return False + + +def clip_for_icon(clip_model: any, clip_preprocess: any, images: any, prompt: str) -> any: + image_features = [] + for image_file in images: + image = clip_preprocess(Image.open(image_file)).unsqueeze(0).to(next(clip_model.parameters()).device) + image_feature = clip_model.encode_image(image) + image_features.append(image_feature) + image_features = torch.cat(image_features) + + text = clip.tokenize([prompt]).to(next(clip_model.parameters()).device) + text_features = clip_model.encode_text(text) + + image_features /= image_features.norm(dim=-1, keepdim=True) + text_features /= text_features.norm(dim=-1, keepdim=True) + similarity = (100.0 * image_features @ text_features.T).softmax(dim=0).squeeze(0) + _, max_pos = torch.max(similarity, dim=0) + pos = max_pos.item() + + return pos + + +def transform_image(image_pil: any) -> any: + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + image, _ = transform(image_pil, None) # 3, h, w + return image + + +def load_model(model_checkpoint_path: Path, device: str) -> any: + model_config_path = "grounding_dino_config.py" + args = SLConfig.fromfile(model_config_path) + args.device = device + model = build_model(args) + checkpoint = torch.load(model_checkpoint_path, map_location="cpu") + load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) + print(load_res) + _ = model.eval() + return model + + +def get_grounding_output( + model: any, image: any, caption: str, box_threshold: any, text_threshold: any, with_logits: bool = True +) -> any: + caption = caption.lower() + caption = caption.strip() + if not caption.endswith("."): + caption = caption + "." + + with torch.no_grad(): + outputs = model(image[None], captions=[caption]) + logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) + boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) + logits.shape[0] + + logits_filt = logits.clone() + boxes_filt = boxes.clone() + filt_mask = logits_filt.max(dim=1)[0] > box_threshold + logits_filt = logits_filt[filt_mask] # num_filt, 256 + boxes_filt = boxes_filt[filt_mask] # num_filt, 4 + logits_filt.shape[0] + + tokenlizer = model.tokenizer + tokenized = tokenlizer(caption) + + pred_phrases = [] + scores = [] + for logit, box in zip(logits_filt, boxes_filt): + pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) + if with_logits: + pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") + else: + pred_phrases.append(pred_phrase) + scores.append(logit.max().item()) + + return boxes_filt, torch.Tensor(scores), pred_phrases + + +def remove_boxes(boxes_filt: any, size: any, iou_threshold: float = 0.5) -> any: + boxes_to_remove = set() + + for i in range(len(boxes_filt)): + if calculate_size(boxes_filt[i]) > 0.05 * size[0] * size[1]: + boxes_to_remove.add(i) + for j in range(len(boxes_filt)): + if calculate_size(boxes_filt[j]) > 0.05 * size[0] * size[1]: + boxes_to_remove.add(j) + if i == j: + continue + if i in boxes_to_remove or j in boxes_to_remove: + continue + iou = calculate_iou(boxes_filt[i], boxes_filt[j]) + if iou >= iou_threshold: + boxes_to_remove.add(j) + + boxes_filt = [box for idx, box in enumerate(boxes_filt) if idx not in boxes_to_remove] + + return boxes_filt + + +def det( + input_image: any, + text_prompt: str, + groundingdino_model: any, + box_threshold: float = 0.05, + text_threshold: float = 0.5, +) -> any: + image = Image.open(input_image) + size = image.size + + image_pil = image.convert("RGB") + image = np.array(image_pil) + + transformed_image = transform_image(image_pil) + boxes_filt, scores, pred_phrases = get_grounding_output( + groundingdino_model, transformed_image, text_prompt, box_threshold, text_threshold + ) + + H, W = size[1], size[0] + for i in range(boxes_filt.size(0)): + boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) + boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 + boxes_filt[i][2:] += boxes_filt[i][:2] + + boxes_filt = boxes_filt.cpu().int().tolist() + filtered_boxes = remove_boxes(boxes_filt, size) # [:9] + coordinate = [] + image_data = [] + for box in filtered_boxes: + image_data.append( + [max(0, box[0] - 10), max(0, box[1] - 10), min(box[2] + 10, size[0]), min(box[3] + 10, size[1])] + ) + coordinate.append( + [max(0, box[0] - 25), max(0, box[1] - 25), min(box[2] + 25, size[0]), min(box[3] + 25, size[1])] + ) + + return image_data, coordinate diff --git a/metagpt/environment/api/__init__.py b/metagpt/environment/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/environment/api/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/api/env_api.py b/metagpt/environment/api/env_api.py new file mode 100644 index 0000000000000000000000000000000000000000..924f6b1041eee9b87e0dd4b144e2adecc5626728 --- /dev/null +++ b/metagpt/environment/api/env_api.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the environment api store + +from typing import Any, Callable, Union + +from pydantic import BaseModel, Field + + +class EnvAPIAbstract(BaseModel): + """api/interface summary description""" + + api_name: str = Field(default="", description="the api function name or id") + args: set = Field(default={}, description="the api function `args` params") + kwargs: dict = Field(default=dict(), description="the api function `kwargs` params") + + +class EnvAPIRegistry(BaseModel): + """the registry to store environment w&r api/interface""" + + registry: dict[str, Callable] = Field(default=dict(), exclude=True) + + def get(self, api_name: str): + if api_name not in self.registry: + raise KeyError(f"api_name: {api_name} not found") + return self.registry.get(api_name) + + def __getitem__(self, api_name: str) -> Callable: + return self.get(api_name) + + def __setitem__(self, api_name: str, func: Callable): + self.registry[api_name] = func + + def __len__(self): + return len(self.registry) + + def get_apis(self, as_str=True) -> dict[str, dict[str, Union[dict, Any, str]]]: + """return func schema without func instance""" + apis = dict() + for func_name, func_schema in self.registry.items(): + new_func_schema = dict() + for key, value in func_schema.items(): + if key == "func": + continue + new_func_schema[key] = str(value) if as_str else value + new_func_schema = new_func_schema + apis[func_name] = new_func_schema + return apis + + +class WriteAPIRegistry(EnvAPIRegistry): + """just as a explicit class name""" + + pass + + +class ReadAPIRegistry(EnvAPIRegistry): + """just as a explicit class name""" + + pass diff --git a/metagpt/environment/base_env.py b/metagpt/environment/base_env.py new file mode 100644 index 0000000000000000000000000000000000000000..03a4760c916b4c14fdf7f7e2bb70b9ed60c3dab4 --- /dev/null +++ b/metagpt/environment/base_env.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : base env of executing environment + +import asyncio +from abc import abstractmethod +from enum import Enum +from typing import Any, Dict, Iterable, Optional, Set, Union + +from gymnasium import spaces +from gymnasium.core import ActType, ObsType +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator + +from metagpt.base import BaseEnvironment, BaseRole +from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams +from metagpt.context import Context +from metagpt.environment.api.env_api import ( + EnvAPIAbstract, + ReadAPIRegistry, + WriteAPIRegistry, +) +from metagpt.logs import logger +from metagpt.memory import Memory +from metagpt.schema import Message +from metagpt.utils.common import get_function_schema, is_coroutine_func, is_send_to +from metagpt.utils.git_repository import GitRepository + + +class EnvType(Enum): + ANDROID = "Android" + GYM = "Gym" + WEREWOLF = "Werewolf" + MINECRAFT = "Minecraft" + STANFORDTOWN = "StanfordTown" + + +env_write_api_registry = WriteAPIRegistry() +env_read_api_registry = ReadAPIRegistry() + + +def mark_as_readable(func): + """mark functionn as a readable one in ExtEnv, it observes something from ExtEnv""" + env_read_api_registry[func.__name__] = get_function_schema(func) + return func + + +def mark_as_writeable(func): + """mark functionn as a writeable one in ExtEnv, it does something to ExtEnv""" + env_write_api_registry[func.__name__] = get_function_schema(func) + return func + + +class ExtEnv(BaseEnvironment, BaseModel): + """External Env to integrate actual game environment""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_space: spaces.Space[ActType] = Field(default_factory=spaces.Space, exclude=True) + observation_space: spaces.Space[ObsType] = Field(default_factory=spaces.Space, exclude=True) + + def _check_api_exist(self, rw_api: Optional[str] = None): + if not rw_api: + raise ValueError(f"{rw_api} not exists") + + def get_all_available_apis(self, mode: str = "read") -> list[Any]: + """get available read/write apis definition""" + assert mode in ["read", "write"] + if mode == "read": + return env_read_api_registry.get_apis() + else: + return env_write_api_registry.get_apis() + + async def read_from_api(self, env_action: Union[str, EnvAPIAbstract]): + """get observation from particular api of ExtEnv""" + if isinstance(env_action, str): + env_read_api = env_read_api_registry.get(api_name=env_action)["func"] + self._check_api_exist(env_read_api) + if is_coroutine_func(env_read_api): + res = await env_read_api(self) + else: + res = env_read_api(self) + elif isinstance(env_action, EnvAPIAbstract): + env_read_api = env_read_api_registry.get(api_name=env_action.api_name)["func"] + self._check_api_exist(env_read_api) + if is_coroutine_func(env_read_api): + res = await env_read_api(self, *env_action.args, **env_action.kwargs) + else: + res = env_read_api(self, *env_action.args, **env_action.kwargs) + return res + + async def write_thru_api(self, env_action: Union[str, Message, EnvAPIAbstract, list[EnvAPIAbstract]]): + """execute through particular api of ExtEnv""" + res = None + if isinstance(env_action, Message): + self.publish_message(env_action) + elif isinstance(env_action, EnvAPIAbstract): + env_write_api = env_write_api_registry.get(env_action.api_name)["func"] + self._check_api_exist(env_write_api) + if is_coroutine_func(env_write_api): + res = await env_write_api(self, *env_action.args, **env_action.kwargs) + else: + res = env_write_api(self, *env_action.args, **env_action.kwargs) + + return res + + @abstractmethod + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Implement this to get init observation""" + + @abstractmethod + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + """Implement this if you want to get partial observation from the env""" + + @abstractmethod + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + """Implement this to feed a action and then get new observation from the env""" + + +class Environment(ExtEnv): + """环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到 + Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + desc: str = Field(default="") # 环境描述 + roles: dict[str, SerializeAsAny[BaseRole]] = Field(default_factory=dict, validate_default=True) + member_addrs: Dict[BaseRole, Set] = Field(default_factory=dict, exclude=True) + history: Memory = Field(default_factory=Memory) # For debug + context: Context = Field(default_factory=Context, exclude=True) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + pass + + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + pass + + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + pass + + @model_validator(mode="after") + def init_roles(self): + self.add_roles(self.roles.values()) + return self + + def add_role(self, role: BaseRole): + """增加一个在当前环境的角色 + Add a role in the current environment + """ + self.roles[role.name] = role + role.set_env(self) + role.context = self.context + + def add_roles(self, roles: Iterable[BaseRole]): + """增加一批在当前环境的角色 + Add a batch of characters in the current environment + """ + for role in roles: + self.roles[role.name] = role + + for role in roles: # setup system message with roles + role.context = self.context + role.set_env(self) + + def publish_message(self, message: Message, peekable: bool = True) -> bool: + """ + Distribute the message to the recipients. + In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned + in RFC 113 for the entire system, the routing information in the Message is only responsible for + specifying the message recipient, without concern for where the message recipient is located. How to + route the message to the message recipient is a problem addressed by the transport framework designed + in RFC 113. + """ + logger.debug(f"publish_message: {message.dump()}") + found = False + # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113 + for role, addrs in self.member_addrs.items(): + if is_send_to(message, addrs): + role.put_message(message) + found = True + if not found: + logger.warning(f"Message no recipients: {message.dump()}") + self.history.add(message) # For debug + + return True + + async def run(self, k=1): + """处理一次所有信息的运行 + Process all Role runs at once + """ + for _ in range(k): + futures = [] + for role in self.roles.values(): + if role.is_idle: + continue + future = role.run() + futures.append(future) + + if futures: + await asyncio.gather(*futures) + logger.debug(f"is idle: {self.is_idle}") + + def get_roles(self) -> dict[str, BaseRole]: + """获得环境内的所有角色 + Process all Role runs at once + """ + return self.roles + + def get_role(self, name: str) -> BaseRole: + """获得环境内的指定角色 + get all the environment roles + """ + return self.roles.get(name, None) + + def role_names(self) -> list[str]: + return [i.name for i in self.roles.values()] + + @property + def is_idle(self): + """If true, all actions have been executed.""" + for r in self.roles.values(): + if not r.is_idle: + return False + return True + + def get_addresses(self, obj): + """Get the addresses of the object.""" + return self.member_addrs.get(obj, {}) + + def set_addresses(self, obj, addresses): + """Set the addresses of the object""" + self.member_addrs[obj] = addresses + + def archive(self, auto_archive=True): + if auto_archive and self.context.kwargs.get("project_path"): + git_repo = GitRepository(self.context.kwargs.project_path) + git_repo.archive() diff --git a/metagpt/environment/mgx/__init__.py b/metagpt/environment/mgx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/environment/mgx/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/mgx/mgx_env.py b/metagpt/environment/mgx/mgx_env.py new file mode 100644 index 0000000000000000000000000000000000000000..a8fc0df9f46e8f5fa97b7911bd28adf835fc5b08 --- /dev/null +++ b/metagpt/environment/mgx/mgx_env.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from metagpt.const import AGENT, IMAGES, MESSAGE_ROUTE_TO_ALL, TEAMLEADER_NAME +from metagpt.environment.base_env import Environment +from metagpt.logs import get_human_input +from metagpt.roles import Role +from metagpt.schema import Message, SerializationMixin +from metagpt.utils.common import extract_and_encode_images + + +class MGXEnv(Environment, SerializationMixin): + """MGX Environment""" + + direct_chat_roles: set[str] = set() # record direct chat: @role_name + + is_public_chat: bool = True + + def _publish_message(self, message: Message, peekable: bool = True) -> bool: + if self.is_public_chat: + message.send_to.add(MESSAGE_ROUTE_TO_ALL) + message = self.move_message_info_to_content(message) + return super().publish_message(message, peekable) + + def publish_message(self, message: Message, user_defined_recipient: str = "", publicer: str = "") -> bool: + """let the team leader take over message publishing""" + message = self.attach_images(message) # for multi-modal message + + tl = self.get_role(TEAMLEADER_NAME) # TeamLeader's name is Mike + + if user_defined_recipient: + # human user's direct chat message to a certain role + for role_name in message.send_to: + if self.get_role(role_name).is_idle: + # User starts a new direct chat with a certain role, expecting a direct chat response from the role; Other roles including TL should not be involved. + # If the role is not idle, it means the user helps the role with its current work, in this case, we handle the role's response message as usual. + self.direct_chat_roles.add(role_name) + + self._publish_message(message) + # # bypass team leader, team leader only needs to know but not to react (commented out because TL doesn't understand the message well in actual experiments) + # tl.rc.memory.add(self.move_message_info_to_content(message)) + + elif message.sent_from in self.direct_chat_roles: + # if chat is not public, direct chat response from a certain role to human user, team leader and other roles in the env should not be involved, no need to publish + self.direct_chat_roles.remove(message.sent_from) + if self.is_public_chat: + self._publish_message(message) + + elif publicer == tl.profile: + if message.send_to == {"no one"}: + # skip the dummy message from team leader + return True + # message processed by team leader can be published now + self._publish_message(message) + + else: + # every regular message goes through team leader + message.send_to.add(tl.name) + self._publish_message(message) + + self.history.add(message) + + return True + + async def ask_human(self, question: str, sent_from: Role = None) -> str: + # NOTE: Can be overwritten in remote setting + rsp = await get_human_input(question) + return "Human response: " + rsp + + async def reply_to_human(self, content: str, sent_from: Role = None) -> str: + # NOTE: Can be overwritten in remote setting + return "SUCCESS, human has received your reply. Refrain from resending duplicate messages. If you no longer need to take action, use the command ‘end’ to stop." + + def move_message_info_to_content(self, message: Message) -> Message: + """Two things here: + 1. Convert role, since role field must be reserved for LLM API, and is limited to, for example, one of ["user", "assistant", "system"] + 2. Add sender and recipient info to content, making TL aware, since LLM API only takes content as input + """ + converted_msg = message.model_copy(deep=True) + if converted_msg.role not in ["system", "user", "assistant"]: + converted_msg.role = "assistant" + sent_from = converted_msg.metadata[AGENT] if AGENT in converted_msg.metadata else converted_msg.sent_from + # When displaying send_to, change it to those who need to react and exclude those who only need to be aware, e.g.: + # send_to={} -> Mike; send_to={Alice} -> Alice; send_to={Alice, } -> Alice. + if converted_msg.send_to == {MESSAGE_ROUTE_TO_ALL}: + send_to = TEAMLEADER_NAME + else: + send_to = ", ".join({role for role in converted_msg.send_to if role != MESSAGE_ROUTE_TO_ALL}) + converted_msg.content = f"[Message] from {sent_from or 'User'} to {send_to}: {converted_msg.content}" + return converted_msg + + def attach_images(self, message: Message) -> Message: + if message.role == "user": + images = extract_and_encode_images(message.content) + if images: + message.add_metadata(IMAGES, images) + return message + + def __repr__(self): + return "MGXEnv()" diff --git a/metagpt/environment/minecraft/__init__.py b/metagpt/environment/minecraft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/environment/minecraft/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/minecraft/const.py b/metagpt/environment/minecraft/const.py new file mode 100644 index 0000000000000000000000000000000000000000..8ac15decc874ac6d18797c371265611eb0435f7e --- /dev/null +++ b/metagpt/environment/minecraft/const.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from metagpt.const import METAGPT_ROOT + +# For Minecraft Game Agent +MC_CKPT_DIR = METAGPT_ROOT / "data/minecraft/ckpt" +MC_LOG_DIR = METAGPT_ROOT / "logs" +MC_DEFAULT_WARMUP = { + "context": 15, + "biome": 10, + "time": 15, + "nearby_blocks": 0, + "other_blocks": 10, + "nearby_entities": 5, + "health": 15, + "hunger": 15, + "position": 0, + "equipment": 0, + "inventory": 0, + "optional_inventory_items": 7, + "chests": 0, + "completed_tasks": 0, + "failed_tasks": 0, +} +MC_CURRICULUM_OB = [ + "context", + "biome", + "time", + "nearby_blocks", + "other_blocks", + "nearby_entities", + "health", + "hunger", + "position", + "equipment", + "inventory", + "chests", + "completed_tasks", + "failed_tasks", +] +MC_CORE_INVENTORY_ITEMS = r".*_log|.*_planks|stick|crafting_table|furnace" +r"|cobblestone|dirt|coal|.*_pickaxe|.*_sword|.*_axe", # curriculum_agent: only show these items in inventory before optional_inventory_items reached in warm up diff --git a/metagpt/environment/minecraft/minecraft_env.py b/metagpt/environment/minecraft/minecraft_env.py new file mode 100644 index 0000000000000000000000000000000000000000..9c42949a6fb1addca2e7cebcc1fc514677bc88f0 --- /dev/null +++ b/metagpt/environment/minecraft/minecraft_env.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG Minecraft Env +# refs to `voyager voyager.py` + +import json +import re +import time +from typing import Any, Iterable + +from llama_index.vector_stores.chroma import ChromaVectorStore +from pydantic import ConfigDict, Field + +from metagpt.config2 import Config +from metagpt.environment.base_env import Environment +from metagpt.environment.minecraft.const import MC_CKPT_DIR +from metagpt.environment.minecraft.minecraft_ext_env import MinecraftExtEnv +from metagpt.logs import logger +from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file + + +class MinecraftEnv(MinecraftExtEnv, Environment): + """MinecraftEnv, including shared memory of cache and information between roles""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + event: dict[str, Any] = Field(default_factory=dict) + current_task: str = Field(default="Mine 1 wood log") + task_execution_time: float = Field(default=float) + context: str = Field(default="You can mine one of oak, birch, spruce, jungle, acacia, dark oak, or mangrove logs.") + code: str = Field(default="") + program_code: str = Field(default="") # write in skill/code/*.js + program_name: str = Field(default="") + critique: str = Field(default="") + skills: dict = Field(default_factory=dict) # for skills.json + retrieve_skills: list[str] = Field(default_factory=list) + event_summary: str = Field(default="") + + qa_cache: dict[str, str] = Field(default_factory=dict) + completed_tasks: list[str] = Field(default_factory=list) # Critique things + failed_tasks: list[str] = Field(default_factory=list) + + skill_desp: str = Field(default="") + + chest_memory: dict[str, Any] = Field(default_factory=dict) # eg: {'(1344, 64, 1381)': 'Unknown'} + chest_observation: str = Field(default="") # eg: "Chests: None\n\n" + + runtime_status: bool = False # equal to action execution status: success or failed + + vectordb: ChromaVectorStore = Field(default_factory=ChromaVectorStore) + + qa_cache_questions_vectordb: ChromaVectorStore = Field(default_factory=ChromaVectorStore) + + @property + def progress(self): + # return len(self.completed_tasks) + 10 # Test only + return len(self.completed_tasks) + + @property + def programs(self): + programs = "" + if self.code == "": + return programs # TODO: maybe fix 10054 now, a better way is isolating env.step() like voyager + for skill_name, entry in self.skills.items(): + programs += f"{entry['code']}\n\n" + for primitives in load_mc_skills_code(): # TODO add skills_dir + programs += f"{primitives}\n\n" + return programs + + def set_mc_port(self, mc_port): + super().set_mc_port(mc_port) + self.set_mc_resume() + + def set_mc_resume(self): + self.qa_cache_questions_vectordb = ChromaVectorStore( + collection_name="qa_cache_questions_vectordb", + persist_dir=f"{MC_CKPT_DIR}/curriculum/vectordb", + ) + + self.vectordb = ChromaVectorStore( + collection_name="skill_vectordb", + persist_dir=f"{MC_CKPT_DIR}/skill/vectordb", + ) + + if Config.default().resume: + logger.info(f"Loading Action Developer from {MC_CKPT_DIR}/action") + self.chest_memory = read_json_file(f"{MC_CKPT_DIR}/action/chest_memory.json") + + logger.info(f"Loading Curriculum Agent from {MC_CKPT_DIR}/curriculum") + self.completed_tasks = read_json_file(f"{MC_CKPT_DIR}/curriculum/completed_tasks.json") + self.failed_tasks = read_json_file(f"{MC_CKPT_DIR}/curriculum/failed_tasks.json") + + logger.info(f"Loading Skill Manager from {MC_CKPT_DIR}/skill\033[0m") + self.skills = read_json_file(f"{MC_CKPT_DIR}/skill/skills.json") + + logger.info(f"Loading Qa Cache from {MC_CKPT_DIR}/curriculum\033[0m") + self.qa_cache = read_json_file(f"{MC_CKPT_DIR}/curriculum/qa_cache.json") + + if self.vectordb._collection.count() == 0: + logger.info(self.vectordb._collection.count()) + # Set vdvs for skills & qa_cache + skill_desps = [skill["description"] for program_name, skill in self.skills.items()] + program_names = [program_name for program_name, skill in self.skills.items()] + metadatas = [{"name": program_name} for program_name in program_names] + # add vectordb from file + self.vectordb.add_texts( + texts=skill_desps, + ids=program_names, + metadatas=metadatas, + ) + self.vectordb.persist() + + logger.info(self.qa_cache_questions_vectordb._collection.count()) + if self.qa_cache_questions_vectordb._collection.count() == 0: + questions = [question for question, answer in self.qa_cache.items()] + + self.qa_cache_questions_vectordb.add_texts(texts=questions) + + self.qa_cache_questions_vectordb.persist() + + logger.info( + f"INIT_CHECK: There are {self.vectordb._collection.count()} skills in vectordb and {len(self.skills)} skills in skills.json." + ) + # Check if Skill Manager's vectordb right using + assert self.vectordb._collection.count() == len(self.skills), ( + f"Skill Manager's vectordb is not synced with skills.json.\n" + f"There are {self.vectordb._collection.count()} skills in vectordb but {len(self.skills)} skills in skills.json.\n" + f"Did you set resume=False when initializing the manager?\n" + f"You may need to manually delete the vectordb directory for running from scratch." + ) + + logger.info( + f"INIT_CHECK: There are {self.qa_cache_questions_vectordb._collection.count()} qa_cache in vectordb and {len(self.qa_cache)} questions in qa_cache.json." + ) + assert self.qa_cache_questions_vectordb._collection.count() == len(self.qa_cache), ( + f"Curriculum Agent's qa cache question vectordb is not synced with qa_cache.json.\n" + f"There are {self.qa_cache_questions_vectordb._collection.count()} questions in vectordb " + f"but {len(self.qa_cache)} questions in qa_cache.json.\n" + f"Did you set resume=False when initializing the agent?\n" + f"You may need to manually delete the qa cache question vectordb directory for running from scratch.\n" + ) + + def register_roles(self, roles: Iterable["Minecraft"]): + for role in roles: + role.set_memory(self) + + def update_event(self, event: dict): + if self.event == event: + return + self.event = event + self.update_chest_memory(event) + self.update_chest_observation() + # self.event_summary = self.summarize_chatlog(event) + + def update_task(self, task: str): + self.current_task = task + + def update_context(self, context: str): + self.context = context + + def update_program_code(self, program_code: str): + self.program_code = program_code + + def update_code(self, code: str): + self.code = code # action_developer.gen_action_code to HERE + + def update_program_name(self, program_name: str): + self.program_name = program_name + + def update_critique(self, critique: str): + self.critique = critique # critic_agent.check_task_success to HERE + + def append_skill(self, skill: dict): + self.skills[self.program_name] = skill # skill_manager.retrieve_skills to HERE + + def update_retrieve_skills(self, retrieve_skills: list): + self.retrieve_skills = retrieve_skills + + def update_skill_desp(self, skill_desp: str): + self.skill_desp = skill_desp + + async def update_qa_cache(self, qa_cache: dict): + self.qa_cache = qa_cache + + def update_chest_memory(self, events: dict): + """ + Input: events: Dict + Result: self.chest_memory update & save to json + """ + nearbyChests = events[-1][1]["nearbyChests"] + for position, chest in nearbyChests.items(): + if position in self.chest_memory: + if isinstance(chest, dict): + self.chest_memory[position] = chest + if chest == "Invalid": + logger.info(f"Action Developer removing chest {position}: {chest}") + self.chest_memory.pop(position) + else: + if chest != "Invalid": + logger.info(f"Action Developer saving chest {position}: {chest}") + self.chest_memory[position] = chest + + write_json_file(f"{MC_CKPT_DIR}/action/chest_memory.json", self.chest_memory) + + def update_chest_observation(self): + """ + update chest_memory to chest_observation. + Refer to @ https://github.com/MineDojo/Voyager/blob/main/voyager/agents/action.py + """ + + chests = [] + for chest_position, chest in self.chest_memory.items(): + if isinstance(chest, dict) and len(chest) > 0: + chests.append(f"{chest_position}: {chest}") + for chest_position, chest in self.chest_memory.items(): + if isinstance(chest, dict) and len(chest) == 0: + chests.append(f"{chest_position}: Empty") + for chest_position, chest in self.chest_memory.items(): + if isinstance(chest, str): + assert chest == "Unknown" + chests.append(f"{chest_position}: Unknown items inside") + assert len(chests) == len(self.chest_memory) + if chests: + chests = "\n".join(chests) + self.chest_observation = f"Chests:\n{chests}\n\n" + else: + self.chest_observation = "Chests: None\n\n" + + def summarize_chatlog(self, events): + def filter_item(message: str): + craft_pattern = r"I cannot make \w+ because I need: (.*)" + craft_pattern2 = r"I cannot make \w+ because there is no crafting table nearby" + mine_pattern = r"I need at least a (.*) to mine \w+!" + if re.match(craft_pattern, message): + self.event_summary = re.match(craft_pattern, message).groups()[0] + elif re.match(craft_pattern2, message): + self.event_summary = "a nearby crafting table" + elif re.match(mine_pattern, message): + self.event_summary = re.match(mine_pattern, message).groups()[0] + else: + self.event_summary = "" + return self.event_summary + + chatlog = set() + for event_type, event in events: + if event_type == "onChat": + item = filter_item(event["onChat"]) + if item: + chatlog.add(item) + self.event_summary = "I also need " + ", ".join(chatlog) + "." if chatlog else "" + + def reset_block_info(self): + # revert all the placing event in the last step + pass + + def update_exploration_progress(self, success: bool): + """ + Split task into completed_tasks or failed_tasks + Args: info = { + "task": self.task, + "success": success, + "conversations": self.conversations, + } + """ + self.runtime_status = success + task = self.current_task + if task.startswith("Deposit useless items into the chest at"): + return + if success: + logger.info(f"Completed task {task}.") + self.completed_tasks.append(task) + else: + logger.info(f"Failed to complete task {task}. Skipping to next task.") + self.failed_tasks.append(task) + # when not success, below to update event! + # revert all the placing event in the last step + blocks = [] + positions = [] + for event_type, event in self.event: + if event_type == "onSave" and event["onSave"].endswith("_placed"): + block = event["onSave"].split("_placed")[0] + position = event["status"]["position"] + blocks.append(block) + positions.append(position) + new_events = self._step( + f"await givePlacedItemBack(bot, {json.dumps(blocks)}, {json.dumps(positions)})", + programs=self.programs, + ) + self.event[-1][1]["inventory"] = new_events[-1][1]["inventory"] + self.event[-1][1]["voxels"] = new_events[-1][1]["voxels"] + + self.save_sorted_tasks() + + def save_sorted_tasks(self): + updated_completed_tasks = [] + # record repeated failed tasks + updated_failed_tasks = self.failed_tasks + # dedup but keep order + for task in self.completed_tasks: + if task not in updated_completed_tasks: + updated_completed_tasks.append(task) + + # remove completed tasks from failed tasks + for task in updated_completed_tasks: + while task in updated_failed_tasks: + updated_failed_tasks.remove(task) + + self.completed_tasks = updated_completed_tasks + self.failed_tasks = updated_failed_tasks + + # dump to json + write_json_file(f"{MC_CKPT_DIR}/curriculum/completed_tasks.json", self.completed_tasks) + write_json_file(f"{MC_CKPT_DIR}/curriculum/failed_tasks.json", self.failed_tasks) + + async def on_event_retrieve(self, *args): + """ + Retrieve Minecraft events. + + Returns: + list: A list of Minecraft events. + + Raises: + Exception: If there is an issue retrieving events. + """ + try: + self._reset( + options={ + "mode": "soft", + "wait_ticks": 20, + } + ) + # difficulty = "easy" if len(self.completed_tasks) > 15 else "peaceful" + difficulty = "peaceful" + + events = self._step("bot.chat(`/time set ${getNextTime()}`);\n" + f"bot.chat('/difficulty {difficulty}');") + self.update_event(events) + return events + except Exception as e: + time.sleep(3) # wait for mineflayer to exit + # reset bot status here + events = self._reset( + options={ + "mode": "hard", + "wait_ticks": 20, + "inventory": self.event[-1][1]["inventory"], + "equipment": self.event[-1][1]["status"]["equipment"], + "position": self.event[-1][1]["status"]["position"], + } + ) + self.update_event(events) + logger.error(f"Failed to retrieve Minecraft events: {str(e)}") + return events + + async def on_event_execute(self, *args): + """ + Execute Minecraft events. + + This function is used to obtain events from the Minecraft environment. Check the implementation in + the 'voyager/env/bridge.py step()' function to capture events generated within the game. + + Returns: + list: A list of Minecraft events. + + Raises: + Exception: If there is an issue retrieving events. + """ + try: + events = self._step( + code=self.code, + programs=self.programs, + ) + self.update_event(events) + return events + except Exception as e: + time.sleep(3) # wait for mineflayer to exit + # reset bot status here + events = self._reset( + options={ + "mode": "hard", + "wait_ticks": 20, + "inventory": self.event[-1][1]["inventory"], + "equipment": self.event[-1][1]["status"]["equipment"], + "position": self.event[-1][1]["status"]["position"], + } + ) + self.update_event(events) + logger.error(f"Failed to execute Minecraft events: {str(e)}") + return events diff --git a/metagpt/environment/minecraft/minecraft_ext_env.py b/metagpt/environment/minecraft/minecraft_ext_env.py new file mode 100644 index 0000000000000000000000000000000000000000..fb43e97c9ec1c276a0cd39d92072ead23fd89cc0 --- /dev/null +++ b/metagpt/environment/minecraft/minecraft_ext_env.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : The Minecraft external environment to integrate with Minecraft game +# refs to `voyager bridge.py` + +import json +import time +from typing import Any, Optional + +import requests +from pydantic import ConfigDict, Field, model_validator + +from metagpt.base.base_env_space import BaseEnvAction, BaseEnvObsParams +from metagpt.environment.base_env import ExtEnv, mark_as_writeable +from metagpt.environment.minecraft.const import ( + MC_CKPT_DIR, + MC_CORE_INVENTORY_ITEMS, + MC_CURRICULUM_OB, + MC_DEFAULT_WARMUP, + METAGPT_ROOT, +) +from metagpt.environment.minecraft.process_monitor import SubprocessMonitor +from metagpt.logs import logger + + +class MinecraftExtEnv(ExtEnv): + model_config = ConfigDict(arbitrary_types_allowed=True) + + mc_port: Optional[int] = Field(default=None) + server_host: str = Field(default="http://127.0.0.1") + server_port: str = Field(default=3000) + request_timeout: int = Field(default=600) + + mineflayer: Optional[SubprocessMonitor] = Field(default=None, validate_default=True) + + has_reset: bool = Field(default=False) + reset_options: Optional[dict] = Field(default=None) + connected: bool = Field(default=False) + server_paused: bool = Field(default=False) + warm_up: dict = Field(default=dict()) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + pass + + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + pass + + def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + pass + + @property + def server(self) -> str: + return f"{self.server_host}:{self.server_port}" + + @model_validator(mode="after") + def _post_init_ext_env(self): + if not self.mineflayer: + self.mineflayer = SubprocessMonitor( + commands=[ + "node", + METAGPT_ROOT.joinpath("metagpt", "environment", "minecraft", "mineflayer", "index.js"), + str(self.server_port), + ], + name="mineflayer", + ready_match=r"Server started on port (\d+)", + ) + if not self.warm_up: + warm_up = MC_DEFAULT_WARMUP + if "optional_inventory_items" in warm_up: + assert MC_CORE_INVENTORY_ITEMS is not None + # self.core_inv_items_regex = re.compile(MC_CORE_INVENTORY_ITEMS) + self.warm_up["optional_inventory_items"] = warm_up["optional_inventory_items"] + else: + self.warm_up["optional_inventory_items"] = 0 + for key in MC_CURRICULUM_OB: + self.warm_up[key] = warm_up.get(key, MC_DEFAULT_WARMUP[key]) + self.warm_up["nearby_blocks"] = 0 + self.warm_up["inventory"] = 0 + self.warm_up["completed_tasks"] = 0 + self.warm_up["failed_tasks"] = 0 + + # init ckpt sub-forders + MC_CKPT_DIR.joinpath("curriculum/vectordb").mkdir(parents=True, exist_ok=True) + MC_CKPT_DIR.joinpath("action").mkdir(exist_ok=True) + MC_CKPT_DIR.joinpath("skill/code").mkdir(parents=True, exist_ok=True) + MC_CKPT_DIR.joinpath("skill/description").mkdir(exist_ok=True) + MC_CKPT_DIR.joinpath("skill/vectordb").mkdir(exist_ok=True) + + def set_mc_port(self, mc_port: int): + self.mc_port = mc_port + + @mark_as_writeable + def close(self) -> bool: + self.unpause() + if self.connected: + res = requests.post(f"{self.server}/stop") + if res.status_code == 200: + self.connected = False + self.mineflayer.stop() + return not self.connected + + @mark_as_writeable + def check_process(self) -> dict: + retry = 0 + while not self.mineflayer.is_running: + logger.info("Mineflayer process has exited, restarting") + self.mineflayer.run() + if not self.mineflayer.is_running: + if retry > 3: + logger.error("Mineflayer process failed to start") + raise {} + else: + retry += 1 + continue + logger.info(self.mineflayer.ready_line) + res = requests.post( + f"{self.server}/start", + json=self.reset_options, + timeout=self.request_timeout, + ) + if res.status_code != 200: + self.mineflayer.stop() + logger.error(f"Minecraft server reply with code {res.status_code}") + raise {} + return res.json() + + @mark_as_writeable + def _reset(self, *, seed=None, options=None) -> dict: + if options is None: + options = {} + if options.get("inventory", {}) and options.get("mode", "hard") != "hard": + logger.error("inventory can only be set when options is hard") + raise {} + + self.reset_options = { + "port": self.mc_port, + "reset": options.get("mode", "hard"), + "inventory": options.get("inventory", {}), + "equipment": options.get("equipment", []), + "spread": options.get("spread", False), + "waitTicks": options.get("wait_ticks", 5), + "position": options.get("position", None), + } + + self.unpause() + self.mineflayer.stop() + time.sleep(1) # wait for mineflayer to exit + + returned_data = self.check_process() + self.has_reset = True + self.connected = True + # All the reset in step will be soft + self.reset_options["reset"] = "soft" + self.pause() + return json.loads(returned_data) + + @mark_as_writeable + def _step(self, code: str, programs: str = "") -> dict: + if not self.has_reset: + raise RuntimeError("Environment has not been reset yet") + self.check_process() + self.unpause() + data = { + "code": code, + "programs": programs, + } + res = requests.post(f"{self.server}/step", json=data, timeout=self.request_timeout) + if res.status_code != 200: + raise RuntimeError("Failed to step Minecraft server") + returned_data = res.json() + self.pause() + return json.loads(returned_data) + + @mark_as_writeable + def pause(self) -> bool: + if self.mineflayer.is_running and not self.server_paused: + res = requests.post(f"{self.server}/pause") + if res.status_code == 200: + self.server_paused = True + return self.server_paused + + @mark_as_writeable + def unpause(self) -> bool: + if self.mineflayer.is_running and self.server_paused: + res = requests.post(f"{self.server}/pause") + if res.status_code == 200: + self.server_paused = False + else: + logger.info(f"mineflayer pause result: {res.json()}") + return self.server_paused diff --git a/metagpt/environment/minecraft/mineflayer/.gitignore b/metagpt/environment/minecraft/mineflayer/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0fd46841034a3366d29b838bd070304a78e31337 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/.gitignore @@ -0,0 +1 @@ +!/lib \ No newline at end of file diff --git a/metagpt/environment/minecraft/mineflayer/.prettierignore b/metagpt/environment/minecraft/mineflayer/.prettierignore new file mode 100644 index 0000000000000000000000000000000000000000..1b07c39e9b4cf3756f6e3ea23f7ab6ea22a87f15 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/.prettierignore @@ -0,0 +1,3 @@ +# Ignore artifacts: +build +coverage \ No newline at end of file diff --git a/metagpt/environment/minecraft/mineflayer/.prettierrc.json b/metagpt/environment/minecraft/mineflayer/.prettierrc.json new file mode 100644 index 0000000000000000000000000000000000000000..0a02bcefdab2e1654666e9d5effedc14501e98db --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/.prettierrc.json @@ -0,0 +1,3 @@ +{ + "tabWidth": 4 +} diff --git a/metagpt/environment/minecraft/mineflayer/index.js b/metagpt/environment/minecraft/mineflayer/index.js new file mode 100644 index 0000000000000000000000000000000000000000..7fb0a8787f87596b9be31818d022c8f0eb0d5951 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/index.js @@ -0,0 +1,425 @@ +const fs = require("fs"); +const express = require("express"); +const bodyParser = require("body-parser"); +const mineflayer = require("mineflayer"); + +const skills = require("./lib/skillLoader"); +const { initCounter, getNextTime } = require("./lib/utils"); +const obs = require("./lib/observation/base"); +const OnChat = require("./lib/observation/onChat"); +const OnError = require("./lib/observation/onError"); +const { Voxels, BlockRecords } = require("./lib/observation/voxels"); +const Status = require("./lib/observation/status"); +const Inventory = require("./lib/observation/inventory"); +const OnSave = require("./lib/observation/onSave"); +const Chests = require("./lib/observation/chests"); +const { plugin: tool } = require("mineflayer-tool"); + +let bot = null; + +const app = express(); + +app.use(bodyParser.json({ limit: "50mb" })); +app.use(bodyParser.urlencoded({ limit: "50mb", extended: false })); + +app.post("/start", (req, res) => { + if (bot) onDisconnect("Restarting bot"); + bot = null; + console.log(req.body); + bot = mineflayer.createBot({ + host: "localhost", // minecraft server ip + port: req.body.port, // minecraft server port + username: "bot", + disableChatSigning: true, + checkTimeoutInterval: 60 * 60 * 1000, + }); + bot.once("error", onConnectionFailed); + + // Event subscriptions + bot.waitTicks = req.body.waitTicks; + bot.globalTickCounter = 0; + bot.stuckTickCounter = 0; + bot.stuckPosList = []; + bot.iron_pickaxe = false; + + bot.on("kicked", onDisconnect); + + // mounting will cause physicsTick to stop + bot.on("mount", () => { + bot.dismount(); + }); + + bot.once("spawn", async () => { + bot.removeListener("error", onConnectionFailed); + let itemTicks = 1; + if (req.body.reset === "hard") { + bot.chat("/clear @s"); + bot.chat("/kill @s"); + const inventory = req.body.inventory ? req.body.inventory : {}; + const equipment = req.body.equipment + ? req.body.equipment + : [null, null, null, null, null, null]; + for (let key in inventory) { + bot.chat(`/give @s minecraft:${key} ${inventory[key]}`); + itemTicks += 1; + } + const equipmentNames = [ + "armor.head", + "armor.chest", + "armor.legs", + "armor.feet", + "weapon.mainhand", + "weapon.offhand", + ]; + for (let i = 0; i < 6; i++) { + if (i === 4) continue; + if (equipment[i]) { + bot.chat( + `/item replace entity @s ${equipmentNames[i]} with minecraft:${equipment[i]}` + ); + itemTicks += 1; + } + } + } + + if (req.body.position) { + bot.chat( + `/tp @s ${req.body.position.x} ${req.body.position.y} ${req.body.position.z}` + ); + } + + // if iron_pickaxe is in bot's inventory + if ( + bot.inventory.items().find((item) => item.name === "iron_pickaxe") + ) { + bot.iron_pickaxe = true; + } + + const { pathfinder } = require("mineflayer-pathfinder"); + const tool = require("mineflayer-tool").plugin; + const collectBlock = require("mineflayer-collectblock").plugin; + const pvp = require("mineflayer-pvp").plugin; + const minecraftHawkEye = require("minecrafthawkeye"); + bot.loadPlugin(pathfinder); + bot.loadPlugin(tool); + bot.loadPlugin(collectBlock); + bot.loadPlugin(pvp); + bot.loadPlugin(minecraftHawkEye); + + // bot.collectBlock.movements.digCost = 0; + // bot.collectBlock.movements.placeCost = 0; + + obs.inject(bot, [ + OnChat, + OnError, + Voxels, + Status, + Inventory, + OnSave, + Chests, + BlockRecords, + ]); + skills.inject(bot); + + if (req.body.spread) { + bot.chat(`/spreadplayers ~ ~ 0 300 under 80 false @s`); + await bot.waitForTicks(bot.waitTicks); + } + + await bot.waitForTicks(bot.waitTicks * itemTicks); + res.json(bot.observe()); + + initCounter(bot); + bot.chat("/gamerule keepInventory true"); + bot.chat("/gamerule doDaylightCycle false"); + }); + + function onConnectionFailed(e) { + console.log(e); + bot = null; + res.status(400).json({ error: e }); + } + function onDisconnect(message) { + if (bot.viewer) { + bot.viewer.close(); + } + bot.end(); + console.log(message); + bot = null; + } +}); + +app.post("/step", async (req, res) => { + // import useful package + let response_sent = false; + function otherError(err) { + console.log("Uncaught Error"); + bot.emit("error", handleError(err)); + bot.waitForTicks(bot.waitTicks).then(() => { + if (!response_sent) { + response_sent = true; + res.json(bot.observe()); + } + }); + } + + process.on("uncaughtException", otherError); + + const mcData = require("minecraft-data")(bot.version); + mcData.itemsByName["leather_cap"] = mcData.itemsByName["leather_helmet"]; + mcData.itemsByName["leather_tunic"] = + mcData.itemsByName["leather_chestplate"]; + mcData.itemsByName["leather_pants"] = + mcData.itemsByName["leather_leggings"]; + mcData.itemsByName["leather_boots"] = mcData.itemsByName["leather_boots"]; + mcData.itemsByName["lapis_lazuli_ore"] = mcData.itemsByName["lapis_ore"]; + mcData.blocksByName["lapis_lazuli_ore"] = mcData.blocksByName["lapis_ore"]; + const { + Movements, + goals: { + Goal, + GoalBlock, + GoalNear, + GoalXZ, + GoalNearXZ, + GoalY, + GoalGetToBlock, + GoalLookAtBlock, + GoalBreakBlock, + GoalCompositeAny, + GoalCompositeAll, + GoalInvert, + GoalFollow, + GoalPlaceBlock, + }, + pathfinder, + Move, + ComputedPath, + PartiallyComputedPath, + XZCoordinates, + XYZCoordinates, + SafeBlock, + GoalPlaceBlockOptions, + } = require("mineflayer-pathfinder"); + const { Vec3 } = require("vec3"); + + // Set up pathfinder + const movements = new Movements(bot, mcData); + bot.pathfinder.setMovements(movements); + + bot.globalTickCounter = 0; + bot.stuckTickCounter = 0; + bot.stuckPosList = []; + + function onTick() { + bot.globalTickCounter++; + if (bot.pathfinder.isMoving()) { + bot.stuckTickCounter++; + if (bot.stuckTickCounter >= 100) { + onStuck(1.5); + bot.stuckTickCounter = 0; + } + } + } + + bot.on("physicTick", onTick); + + // initialize fail count + let _craftItemFailCount = 0; + let _killMobFailCount = 0; + let _mineBlockFailCount = 0; + let _placeItemFailCount = 0; + let _smeltItemFailCount = 0; + + // Retrieve array form post bod + const code = req.body.code; + const programs = req.body.programs; + bot.cumulativeObs = []; + await bot.waitForTicks(bot.waitTicks); + const r = await evaluateCode(code, programs); + process.off("uncaughtException", otherError); + if (r !== "success") { + bot.emit("error", handleError(r)); + } + await returnItems(); + // wait for last message + await bot.waitForTicks(bot.waitTicks); + if (!response_sent) { + response_sent = true; + res.json(bot.observe()); + } + bot.removeListener("physicTick", onTick); + + async function evaluateCode(code, programs) { + // Echo the code produced for players to see it. Don't echo when the bot code is already producing dialog or it will double echo + try { + await eval("(async () => {" + programs + "\n" + code + "})()"); + return "success"; + } catch (err) { + return err; + } + } + + function onStuck(posThreshold) { + const currentPos = bot.entity.position; + bot.stuckPosList.push(currentPos); + + // Check if the list is full + if (bot.stuckPosList.length === 5) { + const oldestPos = bot.stuckPosList[0]; + const posDifference = currentPos.distanceTo(oldestPos); + + if (posDifference < posThreshold) { + teleportBot(); // execute the function + } + + // Remove the oldest time from the list + bot.stuckPosList.shift(); + } + } + + function teleportBot() { + const blocks = bot.findBlocks({ + matching: (block) => { + return block.type === 0; + }, + maxDistance: 1, + count: 27, + }); + + if (blocks) { + // console.log(blocks.length); + const randomIndex = Math.floor(Math.random() * blocks.length); + const block = blocks[randomIndex]; + bot.chat(`/tp @s ${block.x} ${block.y} ${block.z}`); + } else { + bot.chat("/tp @s ~ ~1.25 ~"); + } + } + + function returnItems() { + bot.chat("/gamerule doTileDrops false"); + const crafting_table = bot.findBlock({ + matching: mcData.blocksByName.crafting_table.id, + maxDistance: 128, + }); + if (crafting_table) { + bot.chat( + `/setblock ${crafting_table.position.x} ${crafting_table.position.y} ${crafting_table.position.z} air destroy` + ); + bot.chat("/give @s crafting_table"); + } + const furnace = bot.findBlock({ + matching: mcData.blocksByName.furnace.id, + maxDistance: 128, + }); + if (furnace) { + bot.chat( + `/setblock ${furnace.position.x} ${furnace.position.y} ${furnace.position.z} air destroy` + ); + bot.chat("/give @s furnace"); + } + if (bot.inventoryUsed() >= 32) { + // if chest is not in bot's inventory + if (!bot.inventory.items().find((item) => item.name === "chest")) { + bot.chat("/give @s chest"); + } + } + // if iron_pickaxe not in bot's inventory and bot.iron_pickaxe + if ( + bot.iron_pickaxe && + !bot.inventory.items().find((item) => item.name === "iron_pickaxe") + ) { + bot.chat("/give @s iron_pickaxe"); + } + bot.chat("/gamerule doTileDrops true"); + } + + function handleError(err) { + let stack = err.stack; + if (!stack) { + return err; + } + console.log(stack); + const final_line = stack.split("\n")[1]; + const regex = /:(\d+):\d+\)/; + + const programs_length = programs.split("\n").length; + let match_line = null; + for (const line of stack.split("\n")) { + const match = regex.exec(line); + if (match) { + const line_num = parseInt(match[1]); + if (line_num >= programs_length) { + match_line = line_num - programs_length; + break; + } + } + } + if (!match_line) { + return err.message; + } + let f_line = final_line.match( + /\((?.*):(?\d+):(?\d+)\)/ + ); + if (f_line && f_line.groups && fs.existsSync(f_line.groups.file)) { + const { file, line, pos } = f_line.groups; + const f = fs.readFileSync(file, "utf8").split("\n"); + // let filename = file.match(/(?<=node_modules\\)(.*)/)[1]; + let source = file + `:${line}\n${f[line - 1].trim()}\n `; + + const code_source = + "at " + + code.split("\n")[match_line - 1].trim() + + " in your code"; + return source + err.message + "\n" + code_source; + } else if ( + f_line && + f_line.groups && + f_line.groups.file.includes("") + ) { + const { file, line, pos } = f_line.groups; + let source = + "Your code" + + `:${match_line}\n${code.split("\n")[match_line - 1].trim()}\n `; + let code_source = ""; + if (line < programs_length) { + source = + "In your program code: " + + programs.split("\n")[line - 1].trim() + + "\n"; + code_source = `at line ${match_line}:${code + .split("\n") + [match_line - 1].trim()} in your code`; + } + return source + err.message + "\n" + code_source; + } + return err.message; + } +}); + +app.post("/stop", (req, res) => { + bot.end(); + res.json({ + message: "Bot stopped", + }); +}); + +app.post("/pause", (req, res) => { + if (!bot) { + res.status(400).json({ error: "Bot not spawned" }); + return; + } + bot.chat("/pause"); + bot.waitForTicks(bot.waitTicks).then(() => { + res.json({ message: "Success" }); + }); +}); + +// Server listening to PORT 3000 + +const DEFAULT_PORT = 3000; +const PORT = process.argv[2] || DEFAULT_PORT; +app.listen(PORT, () => { + console.log(`Server started on port ${PORT}`); +}); diff --git a/metagpt/environment/minecraft/mineflayer/lib/observation/base.js b/metagpt/environment/minecraft/mineflayer/lib/observation/base.js new file mode 100644 index 0000000000000000000000000000000000000000..b661a24b57c1a61b9ff09b9254ce72002212f5d3 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/lib/observation/base.js @@ -0,0 +1,45 @@ +class Observation { + constructor(bot) { + if (new.target === Observation) { + throw new TypeError( + "Cannot instantiate abstract class Observation" + ); + } + + this.bot = bot; + this.name = "Observation"; + } + + observe() { + throw new TypeError("Method 'observe()' must be implemented."); + } + + reset() {} +} + +function inject(bot, obs_list) { + bot.obsList = []; + bot.cumulativeObs = []; + bot.eventMemory = {}; + obs_list.forEach((obs) => { + bot.obsList.push(new obs(bot)); + }); + bot.event = function (event_name) { + let result = {}; + bot.obsList.forEach((obs) => { + if (obs.name.startsWith("on") && obs.name !== event_name) { + return; + } + result[obs.name] = obs.observe(); + }); + bot.cumulativeObs.push([event_name, result]); + }; + bot.observe = function () { + bot.event("observe"); + const result = bot.cumulativeObs; + bot.cumulativeObs = []; + return JSON.stringify(result); + }; +} + +module.exports = { Observation, inject }; diff --git a/metagpt/environment/minecraft/mineflayer/lib/observation/chests.js b/metagpt/environment/minecraft/mineflayer/lib/observation/chests.js new file mode 100644 index 0000000000000000000000000000000000000000..842bd171d579d77a328615787e0309d0b40eb1fe --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/lib/observation/chests.js @@ -0,0 +1,31 @@ +const { Observation } = require("./base"); + +class Chests extends Observation { + constructor(bot) { + super(bot); + this.name = "nearbyChests"; + this.chestsItems = {}; + bot.on("closeChest", (chestItems, position) => { + this.chestsItems[position] = chestItems; + }); + bot.on("removeChest", (chestPosition) => { + this.chestsItems[chestPosition] = "Invalid"; + }); + } + + observe() { + const chests = this.bot.findBlocks({ + matching: this.bot.registry.blocksByName.chest.id, + maxDistance: 16, + count: 999, + }); + chests.forEach((chest) => { + if (!this.chestsItems.hasOwnProperty(chest)) { + this.chestsItems[chest] = "Unknown"; + } + }); + return this.chestsItems; + } +} + +module.exports = Chests; diff --git a/metagpt/environment/minecraft/mineflayer/lib/observation/inventory.js b/metagpt/environment/minecraft/mineflayer/lib/observation/inventory.js new file mode 100644 index 0000000000000000000000000000000000000000..0645d1bfa0803e155e3987d3d526f2b43d8f5936 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/lib/observation/inventory.js @@ -0,0 +1,39 @@ +const { Observation } = require("./base"); + +class Inventory extends Observation { + constructor(bot) { + super(bot); + this.name = "inventory"; + } + + observe() { + return listItems(this.bot); + } +} + +function listItems(bot) { + const items = getInventoryItems(bot); + return items.reduce(itemToDict, {}); +} + +function getInventoryItems(bot) { + const inventory = bot.currentWindow || bot.inventory; + return inventory.items(); +} + +function itemToDict(acc, cur) { + if (cur.name && cur.count) { + //if both name and count property are defined + if (acc[cur.name]) { + //if the item is already in the dict + acc[cur.name] += cur.count; + } else { + //if the item is not in the dict + acc[cur.name] = cur.count; + } + } + return acc; +} + +//export modules +module.exports = Inventory; diff --git a/metagpt/environment/minecraft/mineflayer/lib/observation/onChat.js b/metagpt/environment/minecraft/mineflayer/lib/observation/onChat.js new file mode 100644 index 0000000000000000000000000000000000000000..54b411e2ad903ca54e4cdbf2b9d8732df82a55f8 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/lib/observation/onChat.js @@ -0,0 +1,26 @@ +const Observation = require("./base.js").Observation; + +class onChat extends Observation { + constructor(bot) { + super(bot); + this.name = "onChat"; + this.obs = ""; + bot.on("chatEvent", (username, message) => { + // Save entity status to local variable + if (message.startsWith("/")) { + return; + } + + this.obs += message; + this.bot.event(this.name); + }); + } + + observe() { + const result = this.obs; + this.obs = ""; + return result; + } +} + +module.exports = onChat; diff --git a/metagpt/environment/minecraft/mineflayer/lib/observation/onError.js b/metagpt/environment/minecraft/mineflayer/lib/observation/onError.js new file mode 100644 index 0000000000000000000000000000000000000000..ac8fed9e51937c33105068e2c45800fe1c022c89 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/lib/observation/onError.js @@ -0,0 +1,22 @@ +const Observation = require("./base.js").Observation; + +class onError extends Observation { + constructor(bot) { + super(bot); + this.name = "onError"; + this.obs = null; + bot.on("error", (err) => { + // Save entity status to local variable + this.obs = err; + this.bot.event(this.name); + }); + } + + observe() { + const result = this.obs; + this.obs = null; + return result; + } +} + +module.exports = onError; diff --git a/metagpt/environment/minecraft/mineflayer/lib/observation/onSave.js b/metagpt/environment/minecraft/mineflayer/lib/observation/onSave.js new file mode 100644 index 0000000000000000000000000000000000000000..e5983590ff7b5829b7a9679fee7a11f04f3cc5a7 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/lib/observation/onSave.js @@ -0,0 +1,22 @@ +const Observation = require("./base.js").Observation; + +class onSave extends Observation { + constructor(bot) { + super(bot); + this.name = "onSave"; + this.obs = null; + bot.on("save", (eventName) => { + // Save entity status to local variable + this.obs = eventName; + this.bot.event(this.name); + }); + } + + observe() { + const result = this.obs; + this.obs = null; + return result; + } +} + +module.exports = onSave; diff --git a/metagpt/environment/minecraft/mineflayer/lib/observation/status.js b/metagpt/environment/minecraft/mineflayer/lib/observation/status.js new file mode 100644 index 0000000000000000000000000000000000000000..b031fbcf20d307bdd7895de1b29e589b10d33b40 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/lib/observation/status.js @@ -0,0 +1,103 @@ +const Observation = require("./base.js").Observation; + +class Status extends Observation { + constructor(bot) { + super(bot); + this.name = "status"; + } + + observe() { + return { + health: this.bot.health, + food: this.bot.food, + saturation: this.bot.foodSaturation, + oxygen: this.bot.oxygenLevel, + position: this.bot.entity.position, + velocity: this.bot.entity.velocity, + yaw: this.bot.entity.yaw, + pitch: this.bot.entity.pitch, + onGround: this.bot.entity.onGround, + equipment: this.getEquipment(), + name: this.bot.entity.username, + timeSinceOnGround: this.bot.entity.timeSinceOnGround, + isInWater: this.bot.entity.isInWater, + isInLava: this.bot.entity.isInLava, + isInWeb: this.bot.entity.isInWeb, + isCollidedHorizontally: this.bot.entity.isCollidedHorizontally, + isCollidedVertically: this.bot.entity.isCollidedVertically, + biome: this.bot.blockAt(this.bot.entity.position) + ? this.bot.blockAt(this.bot.entity.position).biome.name + : "None", + entities: this.getEntities(), + timeOfDay: this.getTime(), + inventoryUsed: this.bot.inventoryUsed(), + elapsedTime: this.bot.globalTickCounter, + }; + } + + itemToObs(item) { + if (!item) return null; + return item.name; + } + + getTime() { + const timeOfDay = this.bot.time.timeOfDay; + let time = ""; + if (timeOfDay < 1000) { + time = "sunrise"; + } else if (timeOfDay < 6000) { + time = "day"; + } else if (timeOfDay < 12000) { + time = "noon"; + } else if (timeOfDay < 13000) { + time = "sunset"; + } else if (timeOfDay < 18000) { + time = "night"; + } else if (timeOfDay < 22000) { + time = "midnight"; + } else { + time = "sunrise"; + } + return time; + } + + // For each item in equipment, if it exists, return the name of the item + // otherwise return null + getEquipment() { + const slots = this.bot.inventory.slots; + const mainHand = this.bot.heldItem; + return slots + .slice(5, 9) + .concat(mainHand, slots[45]) + .map(this.itemToObs); + } + + getEntities() { + const entities = this.bot.entities; + if (!entities) return {}; + // keep all monsters in one list, keep other mobs in another list + const mobs = {}; + for (const id in entities) { + const entity = entities[id]; + if (!entity.displayName) continue; + if (entity.name === "player" || entity.name === "item") continue; + if (entity.position.distanceTo(this.bot.entity.position) < 32) { + if (!mobs[entity.name]) { + mobs[entity.name] = entity.position.distanceTo( + this.bot.entity.position + ); + } else if ( + mobs[entity.name] > + entity.position.distanceTo(this.bot.entity.position) + ) { + mobs[entity.name] = entity.position.distanceTo( + this.bot.entity.position + ); + } + } + } + return mobs; + } +} + +module.exports = Status; diff --git a/metagpt/environment/minecraft/mineflayer/lib/observation/voxels.js b/metagpt/environment/minecraft/mineflayer/lib/observation/voxels.js new file mode 100644 index 0000000000000000000000000000000000000000..ecb0c14b70d4b48034fd4af452bb7572073db878 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/lib/observation/voxels.js @@ -0,0 +1,67 @@ +// Blocks = require("./blocks") +const { Observation } = require("./base"); + +class Voxels extends Observation { + constructor(bot) { + super(bot); + this.name = "voxels"; + } + + observe() { + return Array.from(getSurroundingBlocks(this.bot, 8, 2, 8)); + } +} + +class BlockRecords extends Observation { + constructor(bot) { + super(bot); + this.name = "blockRecords"; + this.records = new Set(); + this.tick = 0; + bot.on("physicsTick", () => { + this.tick++; + if (this.tick >= 100) { + const items = getInventoryItems(this.bot); + getSurroundingBlocks(this.bot, 8, 2, 8).forEach((block) => { + if (!items.has(block)) this.records.add(block); + }); + this.tick = 0; + } + }); + } + + observe() { + return Array.from(this.records); + } + + reset() { + this.records = new Set(); + } +} + +function getSurroundingBlocks(bot, x_distance, y_distance, z_distance) { + const surroundingBlocks = new Set(); + + for (let x = -x_distance; x <= x_distance; x++) { + for (let y = -y_distance; y <= y_distance; y++) { + for (let z = -z_distance; z <= z_distance; z++) { + const block = bot.blockAt(bot.entity.position.offset(x, y, z)); + if (block && block.type !== 0) { + surroundingBlocks.add(block.name); + } + } + } + } + // console.log(surroundingBlocks); + return surroundingBlocks; +} + +function getInventoryItems(bot) { + const items = new Set(); + bot.inventory.items().forEach((item) => { + if (item) items.add(item.name); + }); + return items; +} + +module.exports = { Voxels, BlockRecords }; diff --git a/metagpt/environment/minecraft/mineflayer/lib/skillLoader.js b/metagpt/environment/minecraft/mineflayer/lib/skillLoader.js new file mode 100644 index 0000000000000000000000000000000000000000..d78cf782093b213b35d3d4c719490e3a86a7878b --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/lib/skillLoader.js @@ -0,0 +1,79 @@ +function inject(bot) { + bot._sleep = bot.sleep; + bot.sleep = async (bedBlock) => { + await bot.waitForTicks(20); + await bot._sleep(bedBlock); + await bot.waitForTicks(135); + }; + + bot._fish = bot.fish; + bot.fish = async () => { + if (bot.heldItem?.name !== "fishing_rod") { + bot.chat("I'm not holding a fishing rod!"); + return; + } + let timeout = null; + await Promise.race([ + bot._fish(), + new Promise( + (resolve, reject) => + (timeout = setTimeout(() => { + bot.activateItem(); + reject( + new Error( + "Finishing timeout, make sure you get to and look at a water block!" + ) + ); + }, 60000)) + ), + ]); + clearTimeout(timeout); + await bot.waitForTicks(20); + }; + + bot._consume = bot.consume; + bot.consume = async () => { + // action_count.activateItem++; + await bot._consume(); + await bot.waitForTicks(20); + }; + + bot._useOn = bot.useOn; + bot.useOn = async (entity) => { + if (entity.position.distanceTo(bot.entity.position) > 6) { + bot.chat("Please goto a place near the entity first!"); + return; + } + await bot._useOn(entity); + await bot.waitForTicks(20); + }; + + bot._activateBlock = bot.activateBlock; + bot.activateBlock = async (block) => { + if (block.position.distanceTo(bot.entity.position) > 6) { + bot.chat("Please goto a place near the block first!"); + return; + } + // action_count.activateBlock++; + await bot._activateBlock(block); + }; + + bot._chat = bot.chat; + bot.chat = (message) => { + // action_count.chat++; + bot.emit("chatEvent", "bot", message); + bot._chat(message); + }; + + bot.inventoryUsed = () => { + return bot.inventory.slots.slice(9, 45).filter((item) => item !== null) + .length; + }; + + bot.save = function (eventName) { + bot.emit("save", eventName); + }; +} + +// export all control_primitives +module.exports = { inject }; diff --git a/metagpt/environment/minecraft/mineflayer/lib/utils.js b/metagpt/environment/minecraft/mineflayer/lib/utils.js new file mode 100644 index 0000000000000000000000000000000000000000..68af3079602ab8d88059a9c8c4055140dda32f1d --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/lib/utils.js @@ -0,0 +1,31 @@ +let gameTimeCounter = 0; +let gameTimeList = []; +const initCounter = (bot) => { + gameTimeList = []; + for (let i = 0; i < 13000; i += 1000) { + gameTimeList.push(i); + } + for (let i = 13000; i < 24000; i += 2000) { + gameTimeList.push(i); + } + const timeOfDay = bot.time.timeOfDay; + for (let i = 0; i < gameTimeList.length; i++) { + if (gameTimeList[i] > timeOfDay) { + gameTimeCounter = i - 1; + break; + } + } +}; + +const getNextTime = () => { + gameTimeCounter++; + if (gameTimeCounter >= gameTimeList.length) { + gameTimeCounter = 0; + } + return gameTimeList[gameTimeCounter]; +}; + +module.exports = { + initCounter, + getNextTime, +}; diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/.gitignore b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0578fdca3844dbbdfdabfa5c927de3a1144d7d5a --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/.gitignore @@ -0,0 +1,107 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +lerna-debug.log* + +# Diagnostic reports (https://nodejs.org/api/report.html) +report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage +*.lcov + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ +jspm_packages/ + +# TypeScript v1 declaration files +typings/ + +# TypeScript cache +*.tsbuildinfo + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Microbundle cache +.rpt2_cache/ +.rts2_cache_cjs/ +.rts2_cache_es/ +.rts2_cache_umd/ + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variables file +.env +.env.test + +# parcel-bundler cache (https://parceljs.org/) +.cache + +# Next.js build output +.next + +# Nuxt.js build / generate output +.nuxt +dist + +# Gatsby files +.cache/ +# Comment in the public line in if your project uses Gatsby and *not* Next.js +# https://nextjs.org/blog/next-9-1#public-directory-support +# public + +# vuepress build output +.vuepress/dist + +# Serverless directories +.serverless/ + +# FuseBox cache +.fusebox/ + +# DynamoDB Local files +.dynamodb/ + +# TernJS port file +.tern-port + +lib/ +package-lock.json diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/LICENSE b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f2896b56e45adc3d54cd6f98764d4b155b571217 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 TheDudeFromCI + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/README.md b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/README.md new file mode 100644 index 0000000000000000000000000000000000000000..555acb761e51efff08f372cb2525c8da2a230e57 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/README.md @@ -0,0 +1,89 @@ +

mineflayer-collectblock

+

A small utility plugin for allowing users to collect blocks using a higher level API.

+ +

+ + + + + + +

+ +--- +## This is a modified version to better support Voyager + +## Showcase + +You can see a video of the plugin in action, [here.](https://youtu.be/5T_rcCnNnf4) +The source code of the bot in the video can be seen in the examples folder, [here.](https://github.com/TheDudeFromCI/mineflayer-collectblock/blob/master/examples/collector.js) + +### Description + +This plugin is a wrapper for mineflayer that allows for easier API usage when collecting blocks or item drops. This plugin is designed to reduce some of the boilerplate code based around the act of pathfinding to a block _(handled by_ ***mineflayer-pathfinder***_)_, selecting the best tool to mine that block _(handled by_ ***mineflayer-tool***_)_, actually mining it, then moving to collect the item drops from that block. This plugin allows for all of that basic concept to be wrapped up into a single API function. + +In addition to the usage above, some additional quality of life features are available in this plugin. These include the ability to automatically deposit items into a chest when the bot's inventory is full, collecting new tools from a chest if the bot doesn't currently have a required tool _(also handled by_ ***mineflayer-tool***_)_, and allowing for queueing of multiple blocks or item drops to the collection task, so they can be processed later. + +### Getting Started + +This plugin is built using Node and can be installed using: +```bash +npm install --save mineflayer-collectblock +``` + +### Simple Bot + +The brief description goes here. + +```js +// Create your bot +const mineflayer = require("mineflayer") +const bot = mineflayer.createBot({ + host: 'localhost', + username: 'Player', +}) +let mcData + +// Load collect block +bot.loadPlugin(require('mineflayer-collectblock').plugin) + +async function collectGrass() { + // Find a nearby grass block + const grass = bot.findBlock({ + matching: mcData.blocksByName.grass_block.id, + maxDistance: 64 + }) + + if (grass) { + // If we found one, collect it. + try { + await bot.collectBlock.collect(grass) + collectGrass() // Collect another grass block + } catch (err) { + console.log(err) // Handle errors, if any + } + } +} + +// On spawn, start collecting all nearby grass +bot.once('spawn', () => { + mcData = require('minecraft-data')(bot.version) + collectGrass() +}) +``` + +### Documentation + +[API](https://github.com/TheDudeFromCI/mineflayer-collectblock/blob/master/docs/api.md) + +[Examples](https://github.com/TheDudeFromCI/mineflayer-collectblock/tree/master/examples) + +### License + +This project uses the [MIT](https://github.com/TheDudeFromCI/mineflayer-collectblock/blob/master/LICENSE) license. + +### Contributions + +This project is accepting PRs and Issues. See something you think can be improved? Go for it! Any and all help is highly appreciated! + +For larger changes, it is recommended to discuss these changes in the issues tab before writing any code. It's also preferred to make many smaller PRs than one large one, where applicable. diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/_config.yml b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/_config.yml new file mode 100644 index 0000000000000000000000000000000000000000..c4192631f25b34d77a7f159aa0da0e3ae99c4ef4 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/_config.yml @@ -0,0 +1 @@ +theme: jekyll-theme-cayman \ No newline at end of file diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/docs/api.md b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/docs/api.md new file mode 100644 index 0000000000000000000000000000000000000000..66d8a3ecc4a441ff3e989412fc1520e5ffdc1e17 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/docs/api.md @@ -0,0 +1,52 @@ +# API + +Welcome to the *mineflayer-collectblock* API documentation page. + +## Table of Contents + +- [1. Summary](#1-summary) +- [Properties](#properties) + - [`bot.collectblock.movements: Movements`](#botcollectblockmovements-movements) +- [Functions](#functions) + - [collect](#collect) + - [Options:](#options) + +## 1. Summary + +The collect block plugin is a utility plugin that can be used to help make collecting blocks and item drops very easy, using only a single API call. No need to worry about pathfinding to the block, selecting the right tool, or moving to pick up the item drop after mining. + +## Properties + +### `bot.collectblock.movements: Movements` + +The movements object used by the pathfinder plugin to define the movement configuration. This object is passed to the pathfinder plugin when any API from this plugin is called in order to control how pathfinding should work when collecting the given blocks or item. + +If set to null, the pathfinder plugin movements is not updated. + +Defaults to a new movements object instance. + +## Functions + +### collect + +Usage: `bot.collectblock.collect(target: Collectable | Collectable[], options?: CollectOptions, cb: (err?: Error) => void): void` + +Causes the bot to collect the given block, item drop, or list of those. If the target is a block, the bot will move to the block, mine it, and pick up the item drop. If the target is an item drop, the bot will move to the item drop and pick it up. If the target is a list of collectables, the bot will move from target to target in order of closest to furthest and collect each target in turn. + +#### Options: + + * `append: boolean` + + If true, the target(s) will be appended to the existing target list instead of starting a new task. Defaults to false. + + * `ignoreNoPath: boolean` + + If true, errors will not be thrown when a path to the target block cannot be found. The bot will attempt to choose the best available position it can find, instead. Errors are still thrown if the bot cannot interact with the block from it's final location. Defaults to false. + + * `chestLocations: Vec3[]` + + Gets the list of chest locations to use when storing items after the bot's inventory becomes full. If undefined, it defaults to the chest location list on the bot.collectBlock plugin. + + * `itemFilter: ItemFilter` + + When transferring items to a chest, this filter is used to determine what items are allowed to be moved, and what items aren't allowed to be moved. Defaults to the item filter specified on the bot.collectBlock plugin. \ No newline at end of file diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/examples/collector.js b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/examples/collector.js new file mode 100644 index 0000000000000000000000000000000000000000..b9bb8faf9e73762856eed9d41f0da027728e82b3 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/examples/collector.js @@ -0,0 +1,70 @@ +/** + * This bot example show how to direct a bot to collect a specific block type + * or a group of nearby blocks of that type. + */ + +const mineflayer = require('mineflayer') +const collectBlock = require('mineflayer-collectblock').plugin + +if (process.argv.length < 4 || process.argv.length > 6) { + console.log('Usage : node collector.js [] []') + process.exit(1) +} + +const bot = mineflayer.createBot({ + host: process.argv[2], + port: process.argv[3], + username: process.argv[4] || 'collector', + password: process.argv[5] +}) + +bot.loadPlugin(collectBlock) + +let mcData +bot.once('spawn', () => { + mcData = require('minecraft-data')(bot.version) +}) + +bot.on('chat', async (username, message) => { + const args = message.split(' ') + if (args[0] !== 'collect') return + + let count = 1 + if (args.length === 3) count = parseInt(args[1]) + + let type = args[1] + if (args.length === 3) type = args[2] + + const blockType = mcData.blocksByName[type] + if (!blockType) { + return + } + + const blocks = bot.findBlocks({ + matching: blockType.id, + maxDistance: 64, + count: count + }) + + if (blocks.length === 0) { + bot.chat("I don't see that block nearby.") + return + } + + const targets = [] + for (let i = 0; i < Math.min(blocks.length, count); i++) { + targets.push(bot.blockAt(blocks[i])) + } + + bot.chat(`Found ${targets.length} ${type}(s)`) + + try { + await bot.collectBlock.collect(targets) + // All blocks have been collected. + bot.chat('Done') + } catch (err) { + // An error occurred, report it. + bot.chat(err.message) + console.log(err) + } +}) diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/examples/oreMiner.js b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/examples/oreMiner.js new file mode 100644 index 0000000000000000000000000000000000000000..6accac88fd3c3e29ac431c497d618d2f27f23c67 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/examples/oreMiner.js @@ -0,0 +1,59 @@ +/** + * This bot example shows how to collect a vein of ores quickly after only finding a single block. + * This makes it easy to collect a vein of ores or mine a tree without looking for every block in the + * area. + */ + +const mineflayer = require('mineflayer') +const collectBlock = require('mineflayer-collectblock').plugin + +if (process.argv.length < 4 || process.argv.length > 6) { + console.log('Usage : node oreMiner.js [] []') + process.exit(1) +} + +const bot = mineflayer.createBot({ + host: process.argv[2], + port: process.argv[3], + username: process.argv[4] || 'oreMiner', + password: process.argv[5] +}) + +bot.loadPlugin(collectBlock) + +let mcData +bot.once('spawn', () => { + mcData = require('minecraft-data')(bot.version) +}) + +bot.on('chat', async (username, message) => { + const args = message.split(' ') + if (args[0] !== 'collect') return + + const blockType = mcData.blocksByName[args[1]] + if (!blockType) { + bot.chat(`I don't know any blocks named ${args[1]}.`) + return + } + + const block = bot.findBlock({ + matching: blockType.id, + maxDistance: 64 + }) + + if (!block) { + bot.chat("I don't see that block nearby.") + return + } + + const targets = bot.collectBlock.findFromVein(block) + try { + await bot.collectBlock.collect(targets) + // All blocks have been collected. + bot.chat('Done') + } catch (err) { + // An error occurred, report it. + bot.chat(err.message) + console.log(err) + } +}) diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/examples/storageBot.js b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/examples/storageBot.js new file mode 100644 index 0000000000000000000000000000000000000000..b6f9971f25103612d6dd529fc4f4b42a710f1b1f --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/examples/storageBot.js @@ -0,0 +1,107 @@ +/** + * This bot example shows how to use the chest filling mechanic of the plugin. + * Simply provide a given storage chest, and the bot will automatically try and + * store it's inventory in that chest when the bot's inventory becomes full. + */ + +if (process.argv.length < 4 || process.argv.length > 6) { + console.log('Usage : node storageBot.js [] []') + process.exit(1) +} + +// Load your libraries +const mineflayer = require('mineflayer') +const collectBlock = require('mineflayer-collectblock').plugin + +// Create your bot +const bot = mineflayer.createBot({ + host: process.argv[2], + port: parseInt(process.argv[3]), + username: process.argv[4] ? process.argv[4] : 'storageBot', + password: process.argv[5] +}) + +// Load the collect block plugin +bot.loadPlugin(collectBlock) + +// Load mcData on login +let mcData +bot.once('login', () => { + mcData = require('minecraft-data')(bot.version) +}) + +// On spawn, try to find any nearby chests and save those as storage locations. +// When the bot's inventory becomes too full, it will empty it's inventory into +// these chests before collecting more resources. If a chest gets full, it moves +// to the next one in order until it's inventory is empty or it runs out of chests. +bot.once('spawn', () => { + bot.collectBlock.chestLocations = bot.findBlocks({ + matching: mcData.blocksByName.chest.id, + maxDistance: 16, + count: 999999 // Get as many chests as we can + }) + + if (bot.collectBlock.chestLocations.length === 0) { + bot.chat("I don't see any chests nearby.") + } else { + for (const chestPos of bot.collectBlock.chestLocations) { + bot.chat(`I found a chest at ${chestPos}`) + } + } +}) + +// Wait for someone to say something +bot.on('chat', async (username, message) => { + // If the player says something start starts with "collect" + // Otherwise, do nothing + const args = message.split(' ') + if (args[0] !== 'collect') return + + // If the player specifies a number, collect that many. Otherwise, default to 1. + let count = 1 + if (args.length === 3) count = parseInt(args[1]) + + // If a number was given the item number is the 3rd arg, not the 2nd. + let type = args[1] + if (args.length === 3) type = args[2] + + // Get the id of that block type for this version of Minecraft. + const blockType = mcData.blocksByName[type] + if (!blockType) { + bot.chat(`I don't know any blocks named ${type}.`) + return + } + + // Find all nearby blocks of that type, up to the given count, within 64 blocks. + const blocks = bot.findBlocks({ + matching: blockType.id, + maxDistance: 64, + count: count + }) + + // Complain if we can't find any nearby blocks of that type. + if (blocks.length === 0) { + bot.chat("I don't see that block nearby.") + return + } + + // Convert the block position array into a block array to pass to collect block. + const targets = [] + for (let i = 0; i < Math.min(blocks.length, count); i++) { + targets.push(bot.blockAt(blocks[i])) + } + + // Announce what we found. + bot.chat(`Found ${targets.length} ${type}(s)`) + + // Tell the bot to collect all of the given blocks in the block list. + try { + await bot.collectBlock.collect(targets) + // All blocks have been collected. + bot.chat('Done') + } catch (err) { + // An error occurred, report it. + bot.chat(err.message) + console.log(err) + } +}) diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/package.json b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/package.json new file mode 100644 index 0000000000000000000000000000000000000000..0f59e7aa6a1d38ed4c43923f910846d6c7998ec8 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/package.json @@ -0,0 +1,44 @@ +{ + "name": "mineflayer-collectblock", + "version": "1.4.1", + "description": "A simple utility plugin for Mineflayer that add a higher level API for collecting blocks.", + "main": "lib/index.js", + "types": "lib/index.d.ts", + "scripts": { + "build": "ts-standard && tsc && require-self", + "clean": "rm -rf lib", + "test": "test" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/TheDudeFromCI/mineflayer-collectblock.git" + }, + "keywords": [ + "mineflayer", + "plugin", + "api", + "utility", + "helper", + "collect" + ], + "author": "TheDudeFromCI", + "license": "MIT", + "bugs": { + "url": "https://github.com/TheDudeFromCI/mineflayer-collectblock/issues" + }, + "homepage": "https://github.com/TheDudeFromCI/mineflayer-collectblock#readme", + "dependencies": { + "mineflayer": "^4.0.0", + "mineflayer-pathfinder": "^2.1.1", + "mineflayer-tool": "^1.1.0" + }, + "devDependencies": { + "@types/node": "^18.6.4", + "require-self": "^0.2.3", + "ts-standard": "^11.0.0", + "typescript": "^4.1.3" + }, + "files": [ + "lib/**/*" + ] +} diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/BlockVeins.ts b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/BlockVeins.ts new file mode 100644 index 0000000000000000000000000000000000000000..ae5542ce3a693d75262bea010b72766e3042fd0b --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/BlockVeins.ts @@ -0,0 +1,35 @@ +import { Bot } from 'mineflayer' +import { Block } from 'prismarine-block' + +export function findFromVein (bot: Bot, block: Block, maxBlocks: number, maxDistance: number, floodRadius: number): Block[] { + const targets: Block[] = [] + const open: Block[] = [block] + const type = block.type + const center = block.position + + for (let i = 0; i < maxBlocks; i++) { + const next = open.pop() + if (next == null) break + + targets.push(next) + + for (let x = -floodRadius; x <= floodRadius; x++) { + for (let y = -floodRadius; y <= floodRadius; y++) { + for (let z = -floodRadius; z <= floodRadius; z++) { + const neighborPos = next.position.offset(x, y, z) + if (neighborPos.manhattanDistanceTo(center) > maxDistance) continue + + const neighbor = bot.blockAt(neighborPos) + if (neighbor == null || neighbor.type !== type) continue + + if (targets.includes(neighbor)) continue + if (open.includes(neighbor)) continue + + open.push(neighbor) + } + } + } + } + + return targets +} diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/CollectBlock.ts b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/CollectBlock.ts new file mode 100644 index 0000000000000000000000000000000000000000..d2be87822f9ab6fffe64aeae777933a3f0e61d29 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/CollectBlock.ts @@ -0,0 +1,451 @@ +import { Bot } from "mineflayer"; +import { Block } from "prismarine-block"; +import { Movements, goals } from "mineflayer-pathfinder"; +import { TemporarySubscriber } from "./TemporarySubscriber"; +import { Entity } from "prismarine-entity"; +import { error } from "./Util"; +import { Vec3 } from "vec3"; +import { emptyInventoryIfFull, ItemFilter } from "./Inventory"; +import { findFromVein } from "./BlockVeins"; +import { Collectable, Targets } from "./Targets"; +import { Item } from "prismarine-item"; +import mcDataLoader from "minecraft-data"; +import { once } from "events"; +import { callbackify } from "util"; + +export type Callback = (err?: Error) => void; + +async function collectAll( + bot: Bot, + options: CollectOptionsFull +): Promise { + let success_count = 0; + while (!options.targets.empty) { + await emptyInventoryIfFull( + bot, + options.chestLocations, + options.itemFilter + ); + const closest = options.targets.getClosest(); + if (closest == null) break; + switch (closest.constructor.name) { + case "Block": { + try { + if (success_count >= options.count) { + break; + } + await bot.tool.equipForBlock( + closest as Block, + equipToolOptions + ); + const goal = new goals.GoalLookAtBlock( + closest.position, + bot.world + ); + await bot.pathfinder.goto(goal); + await mineBlock(bot, closest as Block, options); + success_count++; + // TODO: options.ignoreNoPath + } catch (err) { + // @ts-ignore + // console.log(err.stack) + // bot.pathfinder.stop() + // bot.waitForTicks(10) + try { + bot.pathfinder.setGoal(null); + } catch (err) {} + if (options.ignoreNoPath) { + // @ts-ignore + if (err.name === "Invalid block") { + console.log( + `Block ${closest.name} at ${closest.position} is not valid! Skip it!` + ); + } // @ts-ignore + else if (err.name === "Unsafe block") { + console.log( + `${closest.name} at ${closest.position} is not safe to break! Skip it!` + ); + // @ts-ignore + } else if (err.name === "NoItem") { + const properties = + bot.registry.blocksByName[closest.name]; + const leastTool = Object.keys( + properties.harvestTools + )[0]; + const item = bot.registry.items[leastTool]; + bot.chat( + `I need at least a ${item.name} to mine ${closest.name}! Skip it!` + ); + return; + } else if ( + // @ts-ignore + err.name === "NoPath" || + // @ts-ignore + err.name === "Timeout" + ) { + if ( + bot.entity.position.distanceTo( + closest.position + ) < 0.5 + ) { + await mineBlock(bot, closest as Block, options); + break; + } + console.log( + `No path to ${closest.name} at ${closest.position}! Skip it!` + ); + // @ts-ignore + } else if (err.message === "Digging aborted") { + console.log(`Digging aborted! Skip it!`); + } else { + // @ts-ignore + bot.chat(`Error: ${err.message}`); + } + break; + } + throw err; + } + break; + } + case "Entity": { + // Don't collect any entities that are marked as 'invalid' + if (!(closest as Entity).isValid) break; + try { + const tempEvents = new TemporarySubscriber(bot); + const waitForPickup = new Promise( + (resolve, reject) => { + const timeout = setTimeout(() => { + // After 10 seconds, reject the promise + clearTimeout(timeout); + tempEvents.cleanup(); + reject(new Error("Failed to pickup item")); + }, 10000); + tempEvents.subscribeTo( + "entityGone", + (entity: Entity) => { + if (entity === closest) { + clearTimeout(timeout); + tempEvents.cleanup(); + resolve(); + } + } + ); + } + ); + bot.pathfinder.setGoal( + new goals.GoalFollow(closest as Entity, 0) + ); + // await bot.pathfinder.goto(new goals.GoalBlock(closest.position.x, closest.position.y, closest.position.z)) + await waitForPickup; + } catch (err) { + // @ts-ignore + console.log(err.stack); + try { + bot.pathfinder.setGoal(null); + } catch (err) {} + if (options.ignoreNoPath) { + // @ts-ignore + if (err.message === "Failed to pickup item") { + bot.chat(`Failed to pickup item! Skip it!`); + } + break; + } + throw err; + } + break; + } + default: { + throw error( + "UnknownType", + `Target ${closest.constructor.name} is not a Block or Entity!` + ); + } + } + options.targets.removeTarget(closest); + } + bot.chat(`Collect finish!`); +} + +const equipToolOptions = { + requireHarvest: true, + getFromChest: false, + maxTools: 2, +}; + +async function mineBlock( + bot: Bot, + block: Block, + options: CollectOptionsFull +): Promise { + if ( + bot.blockAt(block.position)?.type !== block.type || + bot.blockAt(block.position)?.type === 0 + ) { + options.targets.removeTarget(block); + throw error("Invalid block", "Block is not valid!"); + // @ts-expect-error + } else if (!bot.pathfinder.movements.safeToBreak(block)) { + options.targets.removeTarget(block); + throw error("Unsafe block", "Block is not safe to break!"); + } + + await bot.tool.equipForBlock(block, equipToolOptions); + + if (!block.canHarvest(bot.heldItem ? bot.heldItem.type : bot.heldItem)) { + options.targets.removeTarget(block); + throw error("NoItem", "Bot does not have a harvestable tool!"); + } + + const tempEvents = new TemporarySubscriber(bot); + tempEvents.subscribeTo("itemDrop", (entity: Entity) => { + if ( + entity.position.distanceTo(block.position.offset(0.5, 0.5, 0.5)) <= + 0.5 + ) { + options.targets.appendTarget(entity); + } + }); + try { + await bot.dig(block); + // Waiting for items to drop + await new Promise((resolve) => { + let remainingTicks = 10; + tempEvents.subscribeTo("physicTick", () => { + remainingTicks--; + if (remainingTicks <= 0) { + tempEvents.cleanup(); + resolve(); + } + }); + }); + } finally { + tempEvents.cleanup(); + } +} + +/** + * A set of options to apply when collecting the given targets. + */ +export interface CollectOptions { + /** + * If true, the target(s) will be appended to the existing target list instead of + * starting a new task. Defaults to false. + */ + append?: boolean; + + /** + * If true, errors will not be thrown when a path to the target block cannot + * be found. The bot will attempt to choose the best available position it + * can find, instead. Errors are still thrown if the bot cannot interact with + * the block from it's final location. Defaults to false. + */ + ignoreNoPath?: boolean; + + /** + * Gets the list of chest locations to use when storing items after the bot's + * inventory becomes full. If undefined, it defaults to the chest location + * list on the bot.collectBlock plugin. + */ + chestLocations?: Vec3[]; + + /** + * When transferring items to a chest, this filter is used to determine what + * items are allowed to be moved, and what items aren't allowed to be moved. + * Defaults to the item filter specified on the bot.collectBlock plugin. + */ + itemFilter?: ItemFilter; + + /** + * The total number of items to collect + */ + count?: number; +} + +/** + * A version of collect options where all values are assigned. + */ +interface CollectOptionsFull { + append: boolean; + ignoreNoPath: boolean; + chestLocations: Vec3[]; + itemFilter: ItemFilter; + targets: Targets; + count: number; +} + +/** + * The collect block plugin. + */ +export class CollectBlock { + /** + * The bot. + */ + private readonly bot: Bot; + + /** + * The list of active targets being collected. + */ + private readonly targets: Targets; + + /** + * The movements configuration to be sent to the pathfinder plugin. + */ + movements?: Movements; + + /** + * A list of chest locations which the bot is allowed to empty their inventory into + * if it becomes full while the bot is collecting resources. + */ + chestLocations: Vec3[] = []; + + /** + * When collecting items, this filter is used to determine what items should be placed + * into a chest if the bot's inventory becomes full. By default, returns true for all + * items except for tools, weapons, and armor. + * + * @param item - The item stack in the bot's inventory to check. + * + * @returns True if the item should be moved into the chest. False otherwise. + */ + itemFilter: ItemFilter = (item: Item) => { + if (item.name.includes("helmet")) return false; + if (item.name.includes("chestplate")) return false; + if (item.name.includes("leggings")) return false; + if (item.name.includes("boots")) return false; + if (item.name.includes("shield")) return false; + if (item.name.includes("sword")) return false; + if (item.name.includes("pickaxe")) return false; + if (item.name.includes("axe")) return false; + if (item.name.includes("shovel")) return false; + if (item.name.includes("hoe")) return false; + return true; + }; + + /** + * Creates a new instance of the create block plugin. + * + * @param bot - The bot this plugin is acting on. + */ + constructor(bot: Bot) { + this.bot = bot; + this.targets = new Targets(bot); + // @ts-ignore + this.movements = new Movements(bot, mcDataLoader(bot.version)); + } + + /** + * If target is a block: + * Causes the bot to break and collect the target block. + * + * If target is an item drop: + * Causes the bot to collect the item drop. + * + * If target is an array containing items or blocks, preforms the correct action for + * all targets in that array sorting dynamically by distance. + * + * @param target - The block(s) or item(s) to collect. + * @param options - The set of options to use when handling these targets + * @param cb - The callback that is called finished. + */ + async collect( + target: Collectable | Collectable[], + options: CollectOptions | Callback = {}, + cb?: Callback + ): Promise { + if (typeof options === "function") { + cb = options; + options = {}; + } + // @ts-expect-error + if (cb != null) return callbackify(this.collect)(target, options, cb); + + const optionsFull: CollectOptionsFull = { + append: options.append ?? false, + ignoreNoPath: options.ignoreNoPath ?? false, + chestLocations: options.chestLocations ?? this.chestLocations, + itemFilter: options.itemFilter ?? this.itemFilter, + targets: this.targets, + count: options.count ?? Infinity, + }; + + if (this.bot.pathfinder == null) { + throw error( + "UnresolvedDependency", + "The mineflayer-collectblock plugin relies on the mineflayer-pathfinder plugin to run!" + ); + } + + if (this.bot.tool == null) { + throw error( + "UnresolvedDependency", + "The mineflayer-collectblock plugin relies on the mineflayer-tool plugin to run!" + ); + } + + if (this.movements != null) { + this.bot.pathfinder.setMovements(this.movements); + } + + if (!optionsFull.append) await this.cancelTask(); + if (Array.isArray(target)) { + this.targets.appendTargets(target); + } else { + this.targets.appendTarget(target); + } + + try { + await collectAll(this.bot, optionsFull); + this.targets.clear(); + } catch (err) { + this.targets.clear(); + // Ignore path stopped error for cancelTask to work properly (imo we shouldn't throw any pathing errors) + // @ts-expect-error + if (err.name !== "PathStopped") throw err; + } finally { + // @ts-expect-error + this.bot.emit("collectBlock_finished"); + } + } + + /** + * Loads all touching blocks of the same type to the given block and returns them as an array. + * This effectively acts as a flood fill algorithm to retrieve blocks in the same ore vein and similar. + * + * @param block - The starting block. + * @param maxBlocks - The maximum number of blocks to look for before stopping. + * @param maxDistance - The max distance from the starting block to look. + * @param floodRadius - The max distance distance from block A to block B to be considered "touching" + */ + findFromVein( + block: Block, + maxBlocks = 100, + maxDistance = 16, + floodRadius = 1 + ): Block[] { + return findFromVein( + this.bot, + block, + maxBlocks, + maxDistance, + floodRadius + ); + } + + /** + * Cancels the current collection task, if still active. + * + * @param cb - The callback to use when the task is stopped. + */ + async cancelTask(cb?: Callback): Promise { + if (this.targets.empty) { + if (cb != null) cb(); + return await Promise.resolve(); + } + this.bot.pathfinder.stop(); + if (cb != null) { + // @ts-expect-error + this.bot.once("collectBlock_finished", cb); + } + await once(this.bot, "collectBlock_finished"); + } +} diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/Inventory.ts b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/Inventory.ts new file mode 100644 index 0000000000000000000000000000000000000000..6a17d0cc525966d26e948d627febd567abf3dbc6 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/Inventory.ts @@ -0,0 +1,87 @@ +import { Bot } from 'mineflayer' +import { Callback } from './CollectBlock' +import { Vec3 } from 'vec3' +import { error } from './Util' +import { Item } from 'prismarine-item' +import { goals } from 'mineflayer-pathfinder' +import { callbackify } from 'util' + +export type ItemFilter = (item: Item) => boolean + +function getClosestChest (bot: Bot, chestLocations: Vec3[]): Vec3 | null { + let chest = null + let distance = 0 + + for (const c of chestLocations) { + const dist = c.distanceTo(bot.entity.position) + if (chest == null || dist < distance) { + chest = c + distance = dist + } + } + + if (chest != null) { + chestLocations.splice(chestLocations.indexOf(chest), 1) + } + + return chest +} + +export async function emptyInventoryIfFull (bot: Bot, chestLocations: Vec3[], itemFilter: ItemFilter, cb?: Callback): Promise { + // @ts-expect-error + if (cb != null) return callbackify(emptyInventoryIfFull)(bot, chestLocations, cb) + if (bot.inventory.emptySlotCount() > 0) return + return await emptyInventory(bot, chestLocations, itemFilter) +} + +export async function emptyInventory (bot: Bot, chestLocations: Vec3[], itemFilter: ItemFilter, cb?: Callback): Promise { + // @ts-expect-error + if (cb != null) return callbackify(emptyInventory)(bot, chestLocations, cb) + if (chestLocations.length === 0) { + throw error('NoChests', 'There are no defined chest locations!') + } + + // Shallow clone so we can safely remove chests from the list that are full. + chestLocations = [...chestLocations] + + while (true) { + const chest = getClosestChest(bot, chestLocations) + if (chest == null) { + throw error('NoChests', 'All chests are full.') + } + const hasRemaining = await tryEmptyInventory(bot, chest, itemFilter) + if (!hasRemaining) return + } +} + +async function tryEmptyInventory (bot: Bot, chestLocation: Vec3, itemFilter: ItemFilter, cb?: (err: Error | undefined, hasRemaining: boolean) => void): Promise { + // @ts-expect-error + if (cb != null) return callbackify(tryEmptyInventory)(bot, chestLocation, itemFilter, cb) + await gotoChest(bot, chestLocation) + return await placeItems(bot, chestLocation, itemFilter) +} + +async function gotoChest (bot: Bot, location: Vec3, cb?: Callback): Promise { + // @ts-expect-error + if (cb != null) return callbackify(gotoChest)(bot, location) + await bot.pathfinder.goto(new goals.GoalGetToBlock(location.x, location.y, location.z)) +} + +async function placeItems (bot: Bot, chestPos: Vec3, itemFilter: ItemFilter, cb?: (err: Error | undefined, hasRemaining: boolean) => void): Promise { + // @ts-expect-error + if (cb != null) return callbackify(placeItems)(bot, chestPos, itemFilter, cb) + const chestBlock = bot.blockAt(chestPos) + if (chestBlock == null) { + throw error('UnloadedChunk', 'Chest is in an unloaded chunk!') + } + const chest = await bot.openChest(chestBlock) + for (const item of bot.inventory.items()) { + if (!itemFilter(item)) continue + if (chest.firstEmptyContainerSlot() === null) { + // We have items that didn't fit. + return true + } + await chest.deposit(item.type, item.metadata, item.count) + } + return false +} diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/Targets.ts b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/Targets.ts new file mode 100644 index 0000000000000000000000000000000000000000..568d07ad98ac8b4140344ed50515ad9e6a246899 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/Targets.ts @@ -0,0 +1,60 @@ +import { Bot } from 'mineflayer' +import { Block } from 'prismarine-block' +import { Entity } from 'prismarine-entity' + +export type Collectable = Block | Entity + +export class Targets { + private readonly bot: Bot + private targets: Collectable[] = [] + + constructor (bot: Bot) { + this.bot = bot + } + + appendTargets (targets: Collectable[]): void { + for (const target of targets) { + this.appendTarget(target) + } + } + + appendTarget (target: Collectable): void { + if (this.targets.includes(target)) return + this.targets.push(target) + } + + /** + * Gets the closest target to the bot in this list. + * + * @returns The closest target, or null if there are no targets. + */ + getClosest (): Collectable | null { + let closest: Collectable | null = null + let distance: number = 0 + + for (const target of this.targets) { + const dist = target.position.distanceTo(this.bot.entity.position) + + if (closest == null || dist < distance) { + closest = target + distance = dist + } + } + + return closest + } + + get empty (): boolean { + return this.targets.length === 0 + } + + clear (): void { + this.targets.length = 0 + } + + removeTarget (target: Collectable): void { + const index = this.targets.indexOf(target) + if (index < 0) return + this.targets.splice(index, 1) + } +} diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/TaskQueue.ts b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/TaskQueue.ts new file mode 100644 index 0000000000000000000000000000000000000000..81fe3bc5ae05d9f15eedbc4e9f307176ed819040 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/TaskQueue.ts @@ -0,0 +1,77 @@ +import type { Callback } from './index' +export type Task = (cb: Callback) => void +export type SyncTask = () => void + +/** + * A simple utility class for queuing up a series of async tasks to execute. + */ +export class TaskQueue { + private tasks: Task[] = [] + + /** + * If true, the task list will stop executing if one of the tasks throws an error. + */ + readonly stopOnError: boolean = true + + /** + * Adds a new async task to this queue. The provided callback should be executed when + * the async task is complete. + * + * @param task - The async task to add. + */ + add (task: Task): void { + this.tasks.push(task) + } + + /** + * Adds a synchronous task toi this queue. + * + * @param task - The sync task to add. + */ + addSync (task: SyncTask): void { + this.add((cb) => { + try { + task() + cb() + } catch (err: any) { + cb(err) + } + }) + } + + /** + * Runs all tasks currently in this queue and empties the queue. + * + * @param cb - The optional callback to be executed when all tasks in this queue have + * finished executing. + */ + runAll (cb?: Callback): void { + const taskList = this.tasks + this.tasks = [] + + let index = -1 + const runNext: () => void = () => { + index++ + if (index >= taskList.length) { + if (cb !== undefined) cb() + return + } + + try { + taskList[index]((err) => { + if (err !== undefined) { + if (cb !== undefined) cb(err) + + if (this.stopOnError) return + } + + runNext() + }) + } catch (err: any) { + if (cb !== undefined) cb(err) + } + } + + runNext() + } +} diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts new file mode 100644 index 0000000000000000000000000000000000000000..3f14a607da52bc42332ee1f6ba0999c2db76a679 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/TemporarySubscriber.ts @@ -0,0 +1,34 @@ +import { Bot } from 'mineflayer' + +class Subscription { + constructor (readonly eventName: string, readonly callback: Function) {} +} + +export class TemporarySubscriber { + private readonly subscriptions: Subscription[] = [] + + constructor (readonly bot: Bot) {} + + /** + * Adds a new temporary event listener to the bot. + * + * @param event - The event to subscribe to. + * @param callback - The function to execute. + */ + subscribeTo (event: string, callback: Function): void { + this.subscriptions.push(new Subscription(event, callback)) + + // @ts-expect-error + this.bot.on(event, callback) + } + + /** + * Removes all attached event listeners from the bot. + */ + cleanup (): void { + for (const sub of this.subscriptions) { + // @ts-expect-error + this.bot.removeListener(sub.eventName, sub.callback) + } + } +} diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/Util.ts b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/Util.ts new file mode 100644 index 0000000000000000000000000000000000000000..ee0f29e0cb1034e1dd96593b73119382050b722b --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/Util.ts @@ -0,0 +1,13 @@ +/** + * Creates a new error object with the given type and message. + * + * @param type - The error type. + * @param message - The error message. + * + * @returns The error object. + */ +export function error (type: string, message: string): Error { + const e = new Error(message) + e.name = type + return e +} diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/index.ts b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/index.ts new file mode 100644 index 0000000000000000000000000000000000000000..45c9a85087f56b5bd771477a6fe5b1a02d986b9f --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/src/index.ts @@ -0,0 +1,25 @@ +import { Bot } from 'mineflayer' +import { CollectBlock } from './CollectBlock' +import { pathfinder as pathfinderPlugin } from 'mineflayer-pathfinder' +import { plugin as toolPlugin } from 'mineflayer-tool' + +export function plugin (bot: Bot): void { + // @ts-expect-error + bot.collectBlock = new CollectBlock(bot) + + // Load plugins if not loaded manually. + setTimeout(() => loadPathfinderPlugin(bot), 0) + setTimeout(() => loadToolPlugin(bot), 0) +} + +function loadPathfinderPlugin (bot: Bot): void { + if (bot.pathfinder != null) return + bot.loadPlugin(pathfinderPlugin) +} + +function loadToolPlugin (bot: Bot): void { + if (bot.tool != null) return + bot.loadPlugin(toolPlugin) +} + +export { CollectBlock, Callback, CollectOptions } from './CollectBlock' diff --git a/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/tsconfig.json b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/tsconfig.json new file mode 100644 index 0000000000000000000000000000000000000000..a6076bc0c72a5ed65fd375450a97b2feefb28045 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/mineflayer-collectblock/tsconfig.json @@ -0,0 +1,69 @@ +{ + "compilerOptions": { + /* Visit https://aka.ms/tsconfig.json to read more about this file */ + /* Basic Options */ + // "incremental": true, /* Enable incremental compilation */ + "target": "ES2015", /* Specify ECMAScript target version: 'ES3' (default), 'ES5', 'ES2015', 'ES2016', 'ES2017', 'ES2018', 'ES2019', 'ES2020', or 'ESNEXT'. */ + "module": "commonjs", /* Specify module code generation: 'none', 'commonjs', 'amd', 'system', 'umd', 'es2015', 'es2020', or 'ESNext'. */ + // "lib": [], /* Specify library files to be included in the compilation. */ + "allowJs": true, /* Allow javascript files to be compiled. */ + "checkJs": true, /* Report errors in .js files. */ + // "jsx": "preserve", /* Specify JSX code generation: 'preserve', 'react-native', or 'react'. */ + "declaration": true, + // "declarationMap": true, /* Generates a sourcemap for each corresponding '.d.ts' file. */ + // "sourceMap": true, /* Generates corresponding '.map' file. */ + // "outFile": "./", /* Concatenate and emit output to single file. */ + "outDir": "./lib", + // "rootDir": "./", /* Specify the root directory of input files. Use to control the output directory structure with --outDir. */ + // "composite": true, /* Enable project compilation */ + // "tsBuildInfoFile": "./", /* Specify file to store incremental compilation information */ + // "removeComments": true, /* Do not emit comments to output. */ + // "noEmit": true, /* Do not emit outputs. */ + // "importHelpers": true, /* Import emit helpers from 'tslib'. */ + // "downlevelIteration": true, /* Provide full support for iterables in 'for-of', spread, and destructuring when targeting 'ES5' or 'ES3'. */ + // "isolatedModules": true, /* Transpile each file as a separate module (similar to 'ts.transpileModule'). */ + /* Strict Type-Checking Options */ + "strict": true, /* Enable all strict type-checking options. */ + // "noImplicitAny": true, /* Raise error on expressions and declarations with an implied 'any' type. */ + "strictNullChecks": true, /* Enable strict null checks. */ + // "strictFunctionTypes": true, /* Enable strict checking of function types. */ + // "strictBindCallApply": true, /* Enable strict 'bind', 'call', and 'apply' methods on functions. */ + // "strictPropertyInitialization": true, /* Enable strict checking of property initialization in classes. */ + // "noImplicitThis": true, /* Raise error on 'this' expressions with an implied 'any' type. */ + "alwaysStrict": true, /* Parse in strict mode and emit "use strict" for each source file. */ + /* Additional Checks */ + "noUnusedLocals": true, /* Report errors on unused locals. */ + // "noUnusedParameters": true, /* Report errors on unused parameters. */ + "noImplicitReturns": true, /* Report error when not all code paths in function return a value. */ + // "noFallthroughCasesInSwitch": true, /* Report errors for fallthrough cases in switch statement. */ + /* Module Resolution Options */ + // "moduleResolution": "node", /* Specify module resolution strategy: 'node' (Node.js) or 'classic' (TypeScript pre-1.6). */ + // "baseUrl": "./", /* Base directory to resolve non-absolute module names. */ + // "paths": {}, /* A series of entries which re-map imports to lookup locations relative to the 'baseUrl'. */ + // "rootDirs": [], /* List of root folders whose combined content represents the structure of the project at runtime. */ + // "typeRoots": [], /* List of folders to include type definitions from. */ + // "types": [], /* Type declaration files to be included in compilation. */ + // "allowSyntheticDefaultImports": true, /* Allow default imports from modules with no default export. This does not affect code emit, just typechecking. */ + "esModuleInterop": true, /* Enables emit interoperability between CommonJS and ES Modules via creation of namespace objects for all imports. Implies 'allowSyntheticDefaultImports'. */ + // "preserveSymlinks": true, /* Do not resolve the real path of symlinks. */ + // "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */ + /* Source Map Options */ + // "sourceRoot": "", /* Specify the location where debugger should locate TypeScript files instead of source locations. */ + // "mapRoot": "", /* Specify the location where debugger should locate map files instead of generated locations. */ + // "inlineSourceMap": true, /* Emit a single file with source maps instead of having a separate file. */ + // "inlineSources": true, /* Emit the source alongside the sourcemaps within a single file; requires '--inlineSourceMap' or '--sourceMap' to be set. */ + /* Experimental Options */ + // "experimentalDecorators": true, /* Enables experimental support for ES7 decorators. */ + // "emitDecoratorMetadata": true, /* Enables experimental support for emitting type metadata for decorators. */ + /* Advanced Options */ + "skipLibCheck": true, /* Skip type checking of declaration files. */ + "forceConsistentCasingInFileNames": true /* Disallow inconsistently-cased references to the same file. */ + }, + "include": [ + "src" + ], + "exclude": [ + "node_modules", + "**/__tests__/*" + ] +} \ No newline at end of file diff --git a/metagpt/environment/minecraft/mineflayer/package.json b/metagpt/environment/minecraft/mineflayer/package.json new file mode 100644 index 0000000000000000000000000000000000000000..9e389d268c3e7e09d3ad36a9668fbac7ac587397 --- /dev/null +++ b/metagpt/environment/minecraft/mineflayer/package.json @@ -0,0 +1,38 @@ +{ + "name": "voyager", + "version": "1.0.0", + "description": "", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "keywords": [], + "author": "", + "license": "ISC", + "dependencies": { + "body-parser": "^1.20.2", + "express": "^4.18.2", + "magic-string": "^0.30.0", + "minecraft-data": "^3.31.0", + "minecrafthawkeye": "^1.3.6", + "mineflayer": "^4.8.1", + "mineflayer-collectblock": "file:mineflayer-collectblock", + "mineflayer-pathfinder": "^2.4.2", + "mineflayer-pvp": "^1.3.2", + "mineflayer-tool": "^1.2.0", + "mocha": "^10.2.0", + "prismarine-biome": "^1.3.0", + "prismarine-block": "=1.16.3", + "prismarine-entity": "^2.2.0", + "prismarine-item": "^1.12.1", + "prismarine-nbt": "^2.2.1", + "prismarine-recipe": "^1.3.1", + "prismarine-viewer": "^1.24.0", + "typescript": "^4.9.5", + "vec3": "^0.1.8", + "graceful-fs": "^4.2.11" + }, + "devDependencies": { + "prettier": "2.8.5" + } +} diff --git a/metagpt/environment/minecraft/process_monitor.py b/metagpt/environment/minecraft/process_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..b62aa60050d0b20ca01c1ab558531279e2dee904 --- /dev/null +++ b/metagpt/environment/minecraft/process_monitor.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# refs to `voyager process_monitor.py` + +import re +import subprocess +import threading +import warnings +from typing import List + +import psutil + +from metagpt.logs import define_log_level + + +class SubprocessMonitor: + def __init__( + self, + commands: List[str], + name: str, + ready_match: str = r".*", + callback_match: str = r"^(?!x)x$", # regex that will never match + callback: callable = None, + finished_callback: callable = None, + ): + self.commands = commands + self.name = name + self.logger = define_log_level(name=name) + self.process = None + self.ready_match = ready_match + self.ready_event = None + self.ready_line = None + self.callback_match = callback_match + self.callback = callback + self.finished_callback = finished_callback + self.thread = None + + def _start(self): + self.logger.info(f"Starting subprocess with commands: {self.commands}") + + self.process = psutil.Popen( + self.commands, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + ) + self.logger.info(f"Subprocess {self.name} started with PID {self.process.pid}.") + for line in iter(self.process.stdout.readline, ""): + self.logger.info(line.strip()) + if re.search(self.ready_match, line): + self.ready_line = line + self.logger.info("Subprocess is ready.") + self.ready_event.set() + if re.search(self.callback_match, line): + self.callback() + if not self.ready_event.is_set(): + self.ready_event.set() + warnings.warn(f"Subprocess {self.name} failed to start.") + if self.finished_callback: + self.finished_callback() + + def run(self): + self.ready_event = threading.Event() + self.ready_line = None + self.thread = threading.Thread(target=self._start) + self.thread.start() + self.ready_event.wait() + + def stop(self): + self.logger.info("Stopping subprocess.") + if self.process and self.process.is_running(): + self.process.terminate() + self.process.wait() + + @property + def is_running(self): + if self.process is None: + return False + return self.process.is_running() diff --git a/metagpt/environment/software/__init__.py b/metagpt/environment/software/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/environment/software/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/software/software_env.py b/metagpt/environment/software/software_env.py new file mode 100644 index 0000000000000000000000000000000000000000..94bc116590c2142aabdd1dd2c898d0df288c661e --- /dev/null +++ b/metagpt/environment/software/software_env.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG Software Env + + +from metagpt.environment.base_env import Environment + + +class SoftwareEnv(Environment): + """a specific alias name""" + + pass diff --git a/metagpt/environment/stanford_town/__init__.py b/metagpt/environment/stanford_town/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/environment/stanford_town/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/stanford_town/env_space.py b/metagpt/environment/stanford_town/env_space.py new file mode 100644 index 0000000000000000000000000000000000000000..1741cccfe833dabb2c275936856f98159014515f --- /dev/null +++ b/metagpt/environment/stanford_town/env_space.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from typing import Any, Optional, Union + +import numpy as np +import numpy.typing as npt +from gymnasium import spaces +from pydantic import ConfigDict, Field, field_validator + +from metagpt.base.base_env_space import ( + BaseEnvAction, + BaseEnvActionType, + BaseEnvObsParams, + BaseEnvObsType, +) + + +class EnvActionType(BaseEnvActionType): + NONE = 0 # no action to run, just get observation + + ADD_TILE_EVENT = 1 # Add an event triple to a tile + RM_TILE_EVENT = 2 # Remove an event triple from a tile + TURN_TILE_EVENT_IDLE = 3 # Turn an event triple from a tile into idle + RM_TITLE_SUB_EVENT = 4 # Remove an event triple that has the input subject from a tile + + +class EnvAction(BaseEnvAction): + """env action type and its related params of action functions/apis""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_type: int = Field(default=EnvActionType.NONE, description="action type") + coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate" + ) + subject: str = Field(default="", description="subject name of first element in event") + event: tuple[str, Optional[str], Optional[str], Optional[str]] = Field( + default=["", None, None, None], description="tile event" + ) + + @field_validator("coord", mode="before") + @classmethod + def check_coord(cls, coord) -> npt.NDArray[np.int64]: + if not isinstance(coord, np.ndarray): + return np.array(coord) + + +class EnvObsType(BaseEnvObsType): + """get part observation with specific params""" + + NONE = 0 # get whole observation from env + + GET_TITLE = 1 # get the tile detail dictionary with given tile coord + TILE_PATH = 2 # get the tile address with given tile coord + TILE_NBR = 3 # get the neighbors of given tile coord and its vision radius + + +class EnvObsParams(BaseEnvObsParams): + """observation params for different EnvObsType""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + obs_type: int = Field(default=EnvObsType.NONE, description="observation type") + coord: npt.NDArray[np.int64] = Field( + default_factory=lambda: np.zeros(2, dtype=np.int64), description="tile coordinate" + ) + level: str = Field(default="", description="different level of title") + vision_radius: int = Field(default=0, description="the vision radius of current tile") + + @field_validator("coord", mode="before") + @classmethod + def check_coord(cls, coord) -> npt.NDArray[np.int64]: + if not isinstance(coord, np.ndarray): + return np.array(coord) + + +EnvObsValType = Union[list[list[str]], dict[str, set[tuple[int, int]]], list[list[dict[str, Any]]]] + + +def get_observation_space() -> spaces.Dict: + # it's a + space = spaces.Dict( + {"collision_maze": spaces.Discrete(2), "tiles": spaces.Discrete(2), "address_tiles": spaces.Discrete(2)} + ) + + return space + + +def get_action_space(maze_shape: tuple[int, int]) -> spaces.Dict: + """The fields defined by the space correspond to the input parameters of the action except `action_type`""" + space = spaces.Dict( + { + "action_type": spaces.Discrete(len(EnvActionType)), + "coord": spaces.Box( + np.array([0, 0], dtype=np.int64), np.array([maze_shape[0], maze_shape[1]], dtype=np.int64) + ), # coord of the tile + "subject": spaces.Text(256), # the first element of an tile event + "event": spaces.Tuple( + (spaces.Text(256), spaces.Text(256), spaces.Text(256), spaces.Text(256)) + ), # event is a tuple of four str + } + ) + return space diff --git a/metagpt/environment/stanford_town/stanford_town_env.py b/metagpt/environment/stanford_town/stanford_town_env.py new file mode 100644 index 0000000000000000000000000000000000000000..af8a882b2df27f90b8255f6addbde51bed26878a --- /dev/null +++ b/metagpt/environment/stanford_town/stanford_town_env.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG StanfordTown Env + +from metagpt.environment.base_env import Environment +from metagpt.environment.stanford_town.stanford_town_ext_env import StanfordTownExtEnv + + +class StanfordTownEnv(StanfordTownExtEnv, Environment): + pass diff --git a/metagpt/environment/stanford_town/stanford_town_ext_env.py b/metagpt/environment/stanford_town/stanford_town_ext_env.py new file mode 100644 index 0000000000000000000000000000000000000000..30a02d4dbed922e45d0d201e59b1c59550c320c9 --- /dev/null +++ b/metagpt/environment/stanford_town/stanford_town_ext_env.py @@ -0,0 +1,451 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : The StanfordTown external environment to interate with the web interface +# refs to `generative_agents maze.py` + +import math +from pathlib import Path +from typing import Any, Optional + +from pydantic import ConfigDict, Field, model_validator + +from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable +from metagpt.environment.stanford_town.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, + EnvObsValType, + get_action_space, + get_observation_space, +) +from metagpt.utils.common import read_csv_to_list, read_json_file + + +class StanfordTownExtEnv(ExtEnv): + model_config = ConfigDict(arbitrary_types_allowed=True) + + maze_asset_path: Optional[Path] = Field(default=None, description="the path to store maze assets") + maze_width: int = Field(default=140, description="maze map width") + maze_height: int = Field(default=100, description="maze map height") + sq_tile_size: int = Field(default=32, description="the pixel height/width of a tile") + special_constraint: str = Field( + default="", description="a string description of any relevant special constraints " "the world might have" + ) + tiles: list[list[dict]] = Field(default=[]) + address_tiles: dict[str, set] = Field(default=dict()) + collision_maze: list[list] = Field(default=[]) + + @model_validator(mode="before") + @classmethod + def _init_maze(cls, values): + maze_asset_path = values["maze_asset_path"] + assert maze_asset_path + maze_asset_path = Path(maze_asset_path) + + maze_matrix_path = maze_asset_path.joinpath("matrix") + meta_info = read_json_file(maze_matrix_path.joinpath("maze_meta_info.json")) + + maze_width = int(meta_info["maze_width"]) + maze_height = int(meta_info["maze_height"]) + values["maze_width"] = maze_width + values["maze_height"] = maze_height + values["sq_tile_size"] = int(meta_info["sq_tile_size"]) + values["special_constraint"] = meta_info["special_constraint"] + + # READING IN SPECIAL BLOCKS + # Special blocks are those that are colored in the Tiled map. + # Here is an example row for the arena block file: + # e.g, "25331, Double Studio, Studio, Bedroom 2, Painting" + + blocks_folder = maze_matrix_path.joinpath("special_blocks") + + _wb = blocks_folder.joinpath("world_blocks.csv") + wb_rows = read_csv_to_list(_wb, header=False) + wb = wb_rows[0][-1] + + _sb = blocks_folder.joinpath("sector_blocks.csv") + sb_rows = read_csv_to_list(_sb, header=False) + sb_dict = dict() + for i in sb_rows: + sb_dict[i[0]] = i[-1] + + _ab = blocks_folder.joinpath("arena_blocks.csv") + ab_rows = read_csv_to_list(_ab, header=False) + ab_dict = dict() + for i in ab_rows: + ab_dict[i[0]] = i[-1] + + _gob = blocks_folder.joinpath("game_object_blocks.csv") + gob_rows = read_csv_to_list(_gob, header=False) + gob_dict = dict() + for i in gob_rows: + gob_dict[i[0]] = i[-1] + + _slb = blocks_folder.joinpath("spawning_location_blocks.csv") + slb_rows = read_csv_to_list(_slb, header=False) + slb_dict = dict() + for i in slb_rows: + slb_dict[i[0]] = i[-1] + + # [SECTION 3] Reading in the matrices + # This is your typical two dimensional matrices. It's made up of 0s and + # the number that represents the color block from the blocks folder. + maze_folder = maze_matrix_path.joinpath("maze") + + _cm = maze_folder.joinpath("collision_maze.csv") + collision_maze_raw = read_csv_to_list(_cm, header=False)[0] + _sm = maze_folder.joinpath("sector_maze.csv") + sector_maze_raw = read_csv_to_list(_sm, header=False)[0] + _am = maze_folder.joinpath("arena_maze.csv") + arena_maze_raw = read_csv_to_list(_am, header=False)[0] + _gom = maze_folder.joinpath("game_object_maze.csv") + game_object_maze_raw = read_csv_to_list(_gom, header=False)[0] + _slm = maze_folder.joinpath("spawning_location_maze.csv") + spawning_location_maze_raw = read_csv_to_list(_slm, header=False)[0] + + # Loading the maze. The mazes are taken directly from the json exports of + # Tiled maps. They should be in csv format. + # Importantly, they are "not" in a 2-d matrix format -- they are single + # row matrices with the length of width x height of the maze. So we need + # to convert here. + # example format: [['0', '0', ... '25309', '0',...], ['0',...]...] + # 25309 is the collision bar number right now. + collision_maze = [] + sector_maze = [] + arena_maze = [] + game_object_maze = [] + spawning_location_maze = [] + for i in range(0, len(collision_maze_raw), maze_width): + tw = maze_width + collision_maze += [collision_maze_raw[i : i + tw]] + sector_maze += [sector_maze_raw[i : i + tw]] + arena_maze += [arena_maze_raw[i : i + tw]] + game_object_maze += [game_object_maze_raw[i : i + tw]] + spawning_location_maze += [spawning_location_maze_raw[i : i + tw]] + values["collision_maze"] = collision_maze + + tiles = [] + for i in range(maze_height): + row = [] + for j in range(maze_width): + tile_details = dict() + tile_details["world"] = wb + + tile_details["sector"] = "" + if sector_maze[i][j] in sb_dict: + tile_details["sector"] = sb_dict[sector_maze[i][j]] + + tile_details["arena"] = "" + if arena_maze[i][j] in ab_dict: + tile_details["arena"] = ab_dict[arena_maze[i][j]] + + tile_details["game_object"] = "" + if game_object_maze[i][j] in gob_dict: + tile_details["game_object"] = gob_dict[game_object_maze[i][j]] + + tile_details["spawning_location"] = "" + if spawning_location_maze[i][j] in slb_dict: + tile_details["spawning_location"] = slb_dict[spawning_location_maze[i][j]] + + tile_details["collision"] = False + if collision_maze[i][j] != "0": + tile_details["collision"] = True + + tile_details["events"] = set() + + row += [tile_details] + tiles += [row] + values["tiles"] = tiles + + # Each game object occupies an event in the tile. We are setting up the + # default event value here. + for i in range(maze_height): + for j in range(maze_width): + if tiles[i][j]["game_object"]: + object_name = ":".join( + [tiles[i][j]["world"], tiles[i][j]["sector"], tiles[i][j]["arena"], tiles[i][j]["game_object"]] + ) + go_event = (object_name, None, None, None) + tiles[i][j]["events"].add(go_event) + + # Reverse tile access. + # -- given a string address, we return a set of all + # tile coordinates belonging to that address (this is opposite of + # tiles that give you the string address given a coordinate). This is + # an optimization component for finding paths for the personas' movement. + # address_tiles['bedroom-2-a'] == {(58, 9)} + # address_tiles['double studio:recreation:pool table'] + # == {(29, 14), (31, 11), (30, 14), (32, 11), ...}, + address_tiles = dict() + for i in range(maze_height): + for j in range(maze_width): + addresses = [] + if tiles[i][j]["sector"]: + add = f'{tiles[i][j]["world"]}:' + add += f'{tiles[i][j]["sector"]}' + addresses += [add] + if tiles[i][j]["arena"]: + add = f'{tiles[i][j]["world"]}:' + add += f'{tiles[i][j]["sector"]}:' + add += f'{tiles[i][j]["arena"]}' + addresses += [add] + if tiles[i][j]["game_object"]: + add = f'{tiles[i][j]["world"]}:' + add += f'{tiles[i][j]["sector"]}:' + add += f'{tiles[i][j]["arena"]}:' + add += f'{tiles[i][j]["game_object"]}' + addresses += [add] + if tiles[i][j]["spawning_location"]: + add = f'{tiles[i][j]["spawning_location"]}' + addresses += [add] + + for add in addresses: + if add in address_tiles: + address_tiles[add].add((j, i)) + else: + address_tiles[add] = set([(j, i)]) + values["address_tiles"] = address_tiles + + values["action_space"] = get_action_space((maze_width, maze_height)) + values["observation_space"] = get_observation_space() + return values + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, EnvObsValType], dict[str, Any]]: + """reset env and get the init observation + Return results corresponding to `observation, info` + """ + super().reset(seed=seed, options=options) + + obs = self._get_obs() + + return obs, {} + + def _get_obs(self) -> dict[str, EnvObsValType]: + """Get observation""" + return { + "collision_maze": self.get_collision_maze(), + "tiles": self.tiles, + "address_tiles": self.get_address_tiles(), + } + + def observe(self, obs_params: Optional[EnvObsParams] = None) -> Any: + """Get partial or full observation from the env""" + obs_type = obs_params.obs_type if obs_params else EnvObsType.NONE + if obs_type == EnvObsType.NONE: + obs = self._get_obs() + elif obs_type == EnvObsType.GET_TITLE: + obs = self.access_tile(tile=obs_params.coord) + elif obs_type == EnvObsType.TILE_PATH: + obs = self.get_tile_path(tile=obs_params.coord, level=obs_params.level) + elif obs_type == EnvObsType.TILE_NBR: + obs = self.get_nearby_tiles(tile=obs_params.coord, vision_r=obs_params.vision_radius) + return obs + + def step(self, action: EnvAction) -> tuple[dict[str, EnvObsValType], float, bool, bool, dict[str, Any]]: + """Execute action and then return observation + Return results corresponding to `observation, reward, terminated, truncated, info` + """ + terminated = False + try: + self._execute_env_action(action) + except Exception: + terminated = True + + obs = self._get_obs() + + ret = (obs, 1.0, terminated, False, {}) + return ret + + def _execute_env_action(self, action: EnvAction): + action_type = action.action_type + if action_type == EnvActionType.NONE: + pass + elif action_type == EnvActionType.ADD_TILE_EVENT: + self.add_event_from_tile(curr_event=action.event, tile=action.coord) + elif action_type == EnvActionType.RM_TILE_EVENT: + self.remove_event_from_tile(curr_event=action.event, tile=action.coord) + elif action_type == EnvActionType.TURN_TILE_EVENT_IDLE: + self.turn_event_from_tile_idle(curr_event=action.event, tile=action.coord) + elif action_type == EnvActionType.RM_TITLE_SUB_EVENT: + self.remove_subject_events_from_tile(subject=action.subject, tile=action.coord) + + def turn_coordinate_to_tile(self, px_coordinate: tuple[int, int]) -> tuple[int, int]: + """ + Turns a pixel coordinate to a tile coordinate. + """ + x = math.ceil(px_coordinate[0] / self.sq_tile_size) + y = math.ceil(px_coordinate[1] / self.sq_tile_size) + return x, y + + @mark_as_readable + def get_collision_maze(self) -> list: + return self.collision_maze + + @mark_as_readable + def get_address_tiles(self) -> dict: + return self.address_tiles + + @mark_as_readable + def access_tile(self, tile: tuple[int, int]) -> dict: + """ + Returns the tiles details dictionary that is stored in self.tiles of the + designated x, y location. + + INPUT + tile: The tile coordinate of our interest in (x, y) form. + OUTPUT + The tile detail dictionary for the designated tile. + EXAMPLE OUTPUT + Given (58, 9), + self.tiles[9][58] = {'world': 'double studio', + 'sector': 'double studio', 'arena': 'bedroom 2', + 'game_object': 'bed', 'spawning_location': 'bedroom-2-a', + 'collision': False, + 'events': {('double studio:double studio:bedroom 2:bed', + None, None)}} + """ + x = tile[0] + y = tile[1] + return self.tiles[y][x] + + @mark_as_readable + def get_tile_path(self, tile: tuple[int, int], level: str) -> str: + """ + Get the tile string address given its coordinate. You designate the level + by giving it a string level description. + + INPUT: + tile: The tile coordinate of our interest in (x, y) form. + level: world, sector, arena, or game object + OUTPUT + The string address for the tile. + EXAMPLE OUTPUT + Given tile=(58, 9), and level=arena, + "double studio:double studio:bedroom 2" + """ + x = tile[0] + y = tile[1] + tile = self.tiles[y][x] + + path = f"{tile['world']}" + if level == "world": + return path + else: + path += f":{tile['sector']}" + + if level == "sector": + return path + else: + path += f":{tile['arena']}" + + if level == "arena": + return path + else: + path += f":{tile['game_object']}" + + return path + + @mark_as_readable + def get_nearby_tiles(self, tile: tuple[int, int], vision_r: int) -> list[tuple[int, int]]: + """ + Given the current tile and vision_r, return a list of tiles that are + within the radius. Note that this implementation looks at a square + boundary when determining what is within the radius. + i.e., for vision_r, returns x's. + x x x x x + x x x x x + x x P x x + x x x x x + x x x x x + + INPUT: + tile: The tile coordinate of our interest in (x, y) form. + vision_r: The radius of the persona's vision. + OUTPUT: + nearby_tiles: a list of tiles that are within the radius. + """ + left_end = 0 + if tile[0] - vision_r > left_end: + left_end = tile[0] - vision_r + + right_end = self.maze_width - 1 + if tile[0] + vision_r + 1 < right_end: + right_end = tile[0] + vision_r + 1 + + bottom_end = self.maze_height - 1 + if tile[1] + vision_r + 1 < bottom_end: + bottom_end = tile[1] + vision_r + 1 + + top_end = 0 + if tile[1] - vision_r > top_end: + top_end = tile[1] - vision_r + + nearby_tiles = [] + for i in range(left_end, right_end): + for j in range(top_end, bottom_end): + nearby_tiles += [(i, j)] + return nearby_tiles + + @mark_as_writeable + def add_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + """ + Add an event triple to a tile. + + INPUT: + curr_event: Current event triple. + e.g., ('double studio:double studio:bedroom 2:bed', None, + None) + tile: The tile coordinate of our interest in (x, y) form. + OUPUT: + None + """ + self.tiles[tile[1]][tile[0]]["events"].add(curr_event) + + @mark_as_writeable + def remove_event_from_tile(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + """dswaq + Remove an event triple from a tile. + + INPUT: + curr_event: Current event triple. + e.g., ('double studio:double studio:bedroom 2:bed', None, + None) + tile: The tile coordinate of our interest in (x, y) form. + OUPUT: + None + """ + curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() + for event in curr_tile_ev_cp: + if event == curr_event: + self.tiles[tile[1]][tile[0]]["events"].remove(event) + + @mark_as_writeable + def turn_event_from_tile_idle(self, curr_event: tuple[str], tile: tuple[int, int]) -> None: + curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() + for event in curr_tile_ev_cp: + if event == curr_event: + self.tiles[tile[1]][tile[0]]["events"].remove(event) + new_event = (event[0], None, None, None) + self.tiles[tile[1]][tile[0]]["events"].add(new_event) + + @mark_as_writeable + def remove_subject_events_from_tile(self, subject: str, tile: tuple[int, int]) -> None: + """ + Remove an event triple that has the input subject from a tile. + + INPUT: + subject: "Isabella Rodriguez" + tile: The tile coordinate of our interest in (x, y) form. + OUPUT: + None + """ + curr_tile_ev_cp = self.tiles[tile[1]][tile[0]]["events"].copy() + for event in curr_tile_ev_cp: + if event[0] == subject: + self.tiles[tile[1]][tile[0]]["events"].remove(event) diff --git a/metagpt/environment/werewolf/__init__.py b/metagpt/environment/werewolf/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/environment/werewolf/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/environment/werewolf/const.py b/metagpt/environment/werewolf/const.py new file mode 100644 index 0000000000000000000000000000000000000000..7f810389da323406512097b2a82193f97502e254 --- /dev/null +++ b/metagpt/environment/werewolf/const.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from enum import Enum + +from metagpt.const import MESSAGE_ROUTE_TO_ALL + + +class RoleType(Enum): + VILLAGER = "Villager" + WEREWOLF = "Werewolf" + GUARD = "Guard" + SEER = "Seer" + WITCH = "Witch" + MODERATOR = "Moderator" + + +class RoleState(Enum): + ALIVE = "alive" # the role is alive + DEAD = "dead" # killed or poisoned + KILLED = "killed" # killed by werewolf or voting + POISONED = "poisoned" # killed by poison + SAVED = "saved" # saved by antidote + PROTECTED = "projected" # projected by guard + + +class RoleActionRes(Enum): + SAVE = "save" + PASS = "pass" # ignore current action output + + +empty_set = set() + +# the ordered rules by the moderator to announce to everyone each step +STEP_INSTRUCTIONS = { + 0: { + "content": "It’s dark, everyone close your eyes. I will talk with you/your team secretly at night.", + "send_to": {RoleType.MODERATOR.value}, # for moderator to continue speaking + "restricted_to": empty_set, + }, + 1: { + "content": "Guard, please open your eyes!", + "send_to": {RoleType.MODERATOR.value}, # for moderator to continue speaking + "restricted_to": empty_set, + }, + 2: { + "content": """Guard, now tell me who you protect tonight? +You only choose one from the following living options please: {living_players}. +Or you can pass. For example: Protect ...""", + "send_to": {RoleType.GUARD.value}, + "restricted_to": {RoleType.MODERATOR.value, RoleType.GUARD.value}, + }, + 3: {"content": "Guard, close your eyes", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set}, + 4: { + "content": "Werewolves, please open your eyes!", + "send_to": {RoleType.MODERATOR.value}, + "restricted_to": empty_set, + }, + 5: { + "content": """Werewolves, I secretly tell you that {werewolf_players} are +all of the {werewolf_num} werewolves! Keep in mind you are teammates. The rest players are not werewolves. +choose one from the following living options please: +{living_players}. For example: Kill ...""", + "send_to": {RoleType.WEREWOLF.value}, + "restricted_to": {RoleType.MODERATOR.value, RoleType.WEREWOLF.value}, + }, + 6: {"content": "Werewolves, close your eyes", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set}, + 7: {"content": "Witch, please open your eyes!", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set}, + 8: { + "content": """Witch, tonight {player_hunted} has been killed by the werewolves. +You have a bottle of antidote, would you like to save him/her? If so, say "Save", else, say "Pass".""", + "send_to": {RoleType.WITCH.value}, + "restricted_to": {RoleType.MODERATOR.value, RoleType.WITCH.value}, + }, # 要先判断女巫是否有解药,再去询问女巫是否使用解药救人 + 9: { + "content": """Witch, you also have a bottle of poison, would you like to use it to kill one of the living players? +Choose one from the following living options: {living_players}. +If so, say ONLY "Poison PlayerX", replace PlayerX with the actual player name, else, say "Pass".""", + "send_to": {RoleType.WITCH.value}, + "restricted_to": {RoleType.MODERATOR.value, RoleType.WITCH.value}, + }, # + 10: {"content": "Witch, close your eyes", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set}, + 11: {"content": "Seer, please open your eyes!", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set}, + 12: { + "content": """Seer, you can check one player's identity. Who are you going to verify its identity tonight? +Choose only one from the following living options:{living_players}.""", + "send_to": {RoleType.SEER.value}, + "restricted_to": {RoleType.MODERATOR.value, RoleType.SEER.value}, + }, + 13: {"content": "Seer, close your eyes", "send_to": {RoleType.MODERATOR.value}, "restricted_to": empty_set}, + # The 1-st daytime + 14: { + "content": """It's daytime. Everyone woke up except those who had been killed.""", + "send_to": {RoleType.MODERATOR.value}, + "restricted_to": empty_set, + }, + 15: { + "content": "{player_current_dead} was killed last night!", + "send_to": {RoleType.MODERATOR.value}, + "restricted_to": empty_set, + }, + 16: { + "content": """Living players: {living_players}, now freely talk about the current situation based on your observation and +reflection with a few sentences. Decide whether to reveal your identity based on your reflection.""", + "send_to": {MESSAGE_ROUTE_TO_ALL}, # send to all to speak in daytime + "restricted_to": empty_set, + }, + 17: { + "content": """Now vote and tell me who you think is the werewolf. Don’t mention your role. +You only choose one from the following living options please: +{living_players}. Say ONLY: I vote to eliminate ...""", + "send_to": {MESSAGE_ROUTE_TO_ALL}, + "restricted_to": empty_set, + }, + 18: { + "content": """{player_current_dead} was eliminated.""", + "send_to": {RoleType.MODERATOR.value}, + "restricted_to": empty_set, + }, +} diff --git a/metagpt/environment/werewolf/env_space.py b/metagpt/environment/werewolf/env_space.py new file mode 100644 index 0000000000000000000000000000000000000000..dd6ceeabed944f17bc24a4108a9840367eb9a496 --- /dev/null +++ b/metagpt/environment/werewolf/env_space.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : werewolf observation/action space and its action definition + +from gymnasium import spaces +from pydantic import ConfigDict, Field + +from metagpt.base.base_env_space import BaseEnvAction, BaseEnvActionType +from metagpt.environment.werewolf.const import STEP_INSTRUCTIONS + + +class EnvActionType(BaseEnvActionType): + NONE = 0 # no action to run, just get observation + WOLF_KILL = 1 # wolf kill someone + VOTE_KILL = 2 # vote kill someone + WITCH_POISON = 3 # witch poison someone + WITCH_SAVE = 4 # witch save someone + GUARD_PROTECT = 5 # guard protect someone + PROGRESS_STEP = 6 # step increment + + +class EnvAction(BaseEnvAction): + model_config = ConfigDict(arbitrary_types_allowed=True) + + action_type: int = Field(default=EnvActionType.NONE, description="action type") + player_name: str = Field(default="", description="the name of the player to do the action") + target_player_name: str = Field(default="", description="the name of the player who take the action") + + +def get_observation_space() -> spaces.Dict: + space = spaces.Dict( + { + "game_setup": spaces.Text(256), + "step_idx": spaces.Discrete(len(STEP_INSTRUCTIONS)), + "living_players": spaces.Tuple( + (spaces.Text(16), spaces.Text(16)) + ), # TODO should be tuple of variable length + "werewolf_players": spaces.Tuple( + (spaces.Text(16), spaces.Text(16)) + ), # TODO should be tuple of variable length + "player_hunted": spaces.Text(16), + "player_current_dead": spaces.Tuple( + (spaces.Text(16), spaces.Text(16)) + ), # TODO should be tuple of variable length + "witch_poison_left": spaces.Discrete(2), + "witch_antidote_left": spaces.Discrete(2), + "winner": spaces.Text(16), + "win_reason": spaces.Text(64), + } + ) + return space + + +def get_action_space() -> spaces.Dict: + space = spaces.Dict( + { + "action_type": spaces.Discrete(len(EnvActionType)), + "player_name": spaces.Text(16), # the player to do the action + "target_player_name": spaces.Text(16), # the target player who take the action + } + ) + return space diff --git a/metagpt/environment/werewolf/werewolf_env.py b/metagpt/environment/werewolf/werewolf_env.py new file mode 100644 index 0000000000000000000000000000000000000000..999ff63a1cf6cd77b931d787266504a5fef312c0 --- /dev/null +++ b/metagpt/environment/werewolf/werewolf_env.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : MG Werewolf Env + +from typing import Iterable + +from pydantic import Field + +from metagpt.environment.base_env import Environment +from metagpt.environment.werewolf.werewolf_ext_env import WerewolfExtEnv +from metagpt.schema import Message + + +class WerewolfEnv(WerewolfExtEnv, Environment): + round_cnt: int = Field(default=0) + + def add_roles(self, roles: Iterable["Role"]): + """增加一批在当前环境的角色 + Add a batch of characters in the current environment + """ + for role in roles: + self.roles[role.name] = role # use name as key here, due to multi-player can have same profile + + for role in roles: # setup system message with roles + role.context = self.context + role.set_env(self) + + def publish_message(self, message: Message, add_timestamp: bool = True): + """Post information to the current environment""" + if add_timestamp: + # Because the content of the message may be repeated, for example, killing the same person in two nights + # Therefore, a unique round_cnt prefix needs to be added so that the same message will not be automatically deduplicated when added to the memory. + message.content = f"{self.round_cnt} | " + message.content + super().publish_message(message) + + async def run(self, k=1): + """Process all Role runs by order""" + for _ in range(k): + for role in self.roles.values(): + await role.run() + self.round_cnt += 1 diff --git a/metagpt/environment/werewolf/werewolf_ext_env.py b/metagpt/environment/werewolf/werewolf_ext_env.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ecb8b345563be31628fbd7c0edda4ca12a913d --- /dev/null +++ b/metagpt/environment/werewolf/werewolf_ext_env.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : The werewolf game external environment to integrate with + +import random +from collections import Counter +from typing import Any, Callable, Optional + +from pydantic import ConfigDict, Field + +from metagpt.base.base_env_space import BaseEnvObsParams +from metagpt.environment.base_env import ExtEnv, mark_as_readable, mark_as_writeable +from metagpt.environment.werewolf.const import STEP_INSTRUCTIONS, RoleState, RoleType +from metagpt.environment.werewolf.env_space import EnvAction, EnvActionType +from metagpt.logs import logger + + +class WerewolfExtEnv(ExtEnv): + model_config = ConfigDict(arbitrary_types_allowed=True) + + players_state: dict[str, tuple[str, RoleState]] = Field( + default_factory=dict, description="the player's role type and state by player_name" + ) + + round_idx: int = Field(default=0) # the current round + step_idx: int = Field(default=0) # the current step of current round + eval_step_idx: list[int] = Field(default=[]) + per_round_steps: int = Field(default=len(STEP_INSTRUCTIONS)) + + # game global states + game_setup: str = Field(default="", description="game setup including role and its num") + special_role_players: list[str] = Field(default=[]) + winner: Optional[str] = Field(default=None) + win_reason: Optional[str] = Field(default=None) + witch_poison_left: int = Field(default=1, description="should be 1 or 0") + witch_antidote_left: int = Field(default=1, description="should be 1 or 0") + + # game current round states, a round is from closing your eyes to the next time you close your eyes + round_hunts: dict[str, str] = Field(default_factory=dict, description="nighttime wolf hunt result") + round_votes: dict[str, str] = Field( + default_factory=dict, description="daytime all players vote result, key=voter, value=voted one" + ) + player_hunted: Optional[str] = Field(default=None) + player_protected: Optional[str] = Field(default=None) + is_hunted_player_saved: bool = Field(default=False) + player_poisoned: Optional[str] = Field(default=None) + player_current_dead: list[str] = Field(default=[]) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict[str, Any]] = None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """currently unused""" + pass + + def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any: + """currently unused""" + pass + + def _get_obs(self): + return { + "game_setup": self.game_setup, + "step_idx": self.step_idx, + "living_players": self.living_players, + "werewolf_players": self.werewolf_players, # currently, lack observation isolation + "player_hunted": self.player_hunted, + "player_current_dead": self.player_current_dead, + "witch_poison_left": self.witch_poison_left, + "witch_antidote_left": self.witch_antidote_left, + "winner": self.winner, + "win_reason": self.win_reason, + } + + def step(self, action: EnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + action_type = action.action_type + player_name = action.player_name + target_player_name = action.target_player_name + if action_type == EnvActionType.WOLF_KILL: + self.wolf_kill_someone(wolf_name=player_name, player_name=target_player_name) + elif action_type == EnvActionType.VOTE_KILL: + self.vote_kill_someone(voter_name=player_name, player_name=target_player_name) + elif action_type == EnvActionType.WITCH_POISON: + self.witch_poison_someone(witch_name=player_name, player_name=target_player_name) + elif action_type == EnvActionType.WITCH_SAVE: + self.witch_save_someone(witch_name=player_name, player_name=target_player_name) + elif action_type == EnvActionType.GUARD_PROTECT: + self.guard_protect_someone(guard_name=player_name, player_name=target_player_name) + elif action_type == EnvActionType.PROGRESS_STEP: + self.progress_step() + elif action_type == EnvActionType.NONE: + pass + else: + raise ValueError(f"not supported action_type: {action_type}") + + self.update_game_states() + terminated = self._check_game_finish() + obs = self._get_obs() + return obs, 1.0, terminated, False, {} + + def _check_game_finish(self) -> bool: + """return True if game finished else False""" + # game's termination condition + terminated = False + living_werewolf = [p for p in self.werewolf_players if p in self.living_players] + living_villagers = [p for p in self.villager_players if p in self.living_players] + living_special_roles = [p for p in self.special_role_players if p in self.living_players] + if not living_werewolf: + self.winner = "good guys" + self.win_reason = "werewolves all dead" + terminated = True + elif not living_villagers or not living_special_roles: + self.winner = "werewolf" + self.win_reason = "villagers all dead" if not living_villagers else "special roles all dead" + terminated = True + return terminated + + @property + def living_players(self) -> list[str]: + player_names = [] + for name, roletype_state in self.players_state.items(): + if roletype_state[1] in [RoleState.ALIVE, RoleState.SAVED]: + player_names.append(name) + return player_names + + def _role_type_players(self, role_type: str) -> list[str]: + """return player name of particular role type""" + player_names = [] + for name, roletype_state in self.players_state.items(): + if role_type in roletype_state[0]: + player_names.append(name) + return player_names + + @property + def werewolf_players(self) -> list[str]: + player_names = self._role_type_players(role_type=RoleType.WEREWOLF.value) + return player_names + + @property + def villager_players(self) -> list[str]: + player_names = self._role_type_players(role_type=RoleType.VILLAGER.value) + return player_names + + def _init_players_state(self, players: list["Role"]): + for play in players: + self.players_state[play.name] = (play.profile, RoleState.ALIVE) + + self.special_role_players = [ + p for p in self.living_players if p not in self.werewolf_players + self.villager_players + ] + + def init_game_setup( + self, + role_uniq_objs: list[object], + num_villager: int = 2, + num_werewolf: int = 2, + shuffle=True, + add_human=False, + use_reflection=True, + use_experience=False, + use_memory_selection=False, + new_experience_version="", + prepare_human_player=Callable, + ) -> tuple[str, list]: + """init players using different roles' num""" + role_objs = [] + for role_obj in role_uniq_objs: + if RoleType.VILLAGER.value in str(role_obj): + role_objs.extend([role_obj] * num_villager) + elif RoleType.WEREWOLF.value in str(role_obj): + role_objs.extend([role_obj] * num_werewolf) + else: + role_objs.append(role_obj) + if shuffle: + random.shuffle(role_objs) + if add_human: + assigned_role_idx = random.randint(0, len(role_objs) - 1) + assigned_role = role_objs[assigned_role_idx] + role_objs[assigned_role_idx] = prepare_human_player(assigned_role) # TODO + + players = [ + role( + name=f"Player{i + 1}", + use_reflection=use_reflection, + use_experience=use_experience, + use_memory_selection=use_memory_selection, + new_experience_version=new_experience_version, + ) + for i, role in enumerate(role_objs) + ] + + if add_human: + logger.info(f"You are assigned {players[assigned_role_idx].name}({players[assigned_role_idx].profile})") + + game_setup = ["Game setup:"] + [f"{player.name}: {player.profile}," for player in players] + self.game_setup = "\n".join(game_setup) + + self._init_players_state(players) # init players state + + return self.game_setup, players + + def _update_players_state(self, player_names: list[str], state: RoleState = RoleState.KILLED): + for player_name in player_names: + if player_name in self.players_state: + roletype_state = self.players_state[player_name] + self.players_state[player_name] = (roletype_state[0], state) + + def _check_valid_role(self, player_name: str, role_type: str) -> bool: + roletype_state = self.players_state.get(player_name) + return True if roletype_state and role_type in roletype_state[0] else False + + def _check_player_continue(self, player_name: str, particular_step: int = -1) -> bool: + """to check if can do the operation to the player""" + step_idx = self.step_idx % self.per_round_steps + if particular_step > 0 and step_idx != particular_step: # step no + # particular_step = 18, not daytime vote time, ignore + # particular_step = 15, not nighttime hunt time, ignore + return False + if player_name not in self.living_players: + return False + return True + + @mark_as_readable + def curr_step_instruction(self) -> dict: + step_idx = self.step_idx % len(STEP_INSTRUCTIONS) + instruction = STEP_INSTRUCTIONS[step_idx] + self.step_idx += 1 + return instruction + + @mark_as_writeable + def progress_step(self): + self.step_idx += 1 + + @mark_as_readable + def get_players_state(self, player_names: list[str]) -> dict[str, RoleState]: + players_state = { + player_name: self.players_state[player_name][1] # only return role state + for player_name in player_names + if player_name in self.players_state + } + return players_state + + @mark_as_writeable + def vote_kill_someone(self, voter_name: str, player_name: str = None): + """player vote result at daytime + player_name: if it's None, regard as abstaining from voting + """ + if not self._check_player_continue(voter_name, particular_step=18): # 18=step no + return + + self.round_votes[voter_name] = player_name + # check if all living players finish voting, then get the dead one + if list(self.round_votes.keys()) == self.living_players: + voted_all = list(self.round_votes.values()) # TODO in case of tie vote, check who was voted first + voted_all = [item for item in voted_all if item] + self.player_current_dead = [Counter(voted_all).most_common()[0][0]] + self._update_players_state(self.player_current_dead) + + @mark_as_writeable + def wolf_kill_someone(self, wolf_name: str, player_name: str): + if not self._check_valid_role(wolf_name, RoleType.WEREWOLF.value): + return + if not self._check_player_continue(wolf_name, particular_step=6): # 5=step no + return + + self.round_hunts[wolf_name] = player_name + # living_werewolf = [p for p in self.werewolf_players if p in self.living_players] + # check if all living wolfs finish hunting, then get the hunted one + # if list(self.round_hunts.keys()) == living_werewolf: + # hunted_all = list(self.round_hunts.values()) + # self.player_hunted = Counter(hunted_all).most_common()[0][0] + self.player_hunted = player_name + + def _witch_poison_or_save_someone( + self, witch_name: str, player_name: str = None, state: RoleState = RoleState.POISONED + ): + if not self._check_valid_role(witch_name, RoleType.WITCH.value): + return + if not self._check_player_continue(player_name): + return + + assert state in [RoleState.POISONED, RoleState.SAVED] + self._update_players_state([player_name], state) + if state == RoleState.POISONED: + self.player_poisoned = player_name + self.witch_poison_left -= 1 + else: + # self.player_protected = player_name + self.is_hunted_player_saved = True + self.witch_antidote_left -= 1 + + @mark_as_writeable + def witch_poison_someone(self, witch_name: str, player_name: str = None): + self._witch_poison_or_save_someone(witch_name, player_name, RoleState.POISONED) + + @mark_as_writeable + def witch_save_someone(self, witch_name: str, player_name: str = None): + self._witch_poison_or_save_someone(witch_name, player_name, RoleState.SAVED) + + @mark_as_writeable + def guard_protect_someone(self, guard_name: str, player_name: str = None): + if not self._check_valid_role(guard_name, RoleType.GUARD.value): + return + if not self._check_player_continue(player_name): + return + self.player_protected = player_name + + @mark_as_writeable + def update_game_states(self): + step_idx = self.step_idx % self.per_round_steps + if step_idx not in [15, 18] or self.step_idx in self.eval_step_idx: + return + else: + self.eval_step_idx.append(self.step_idx) # record evaluation, avoid repetitive evaluation at the same step + + if step_idx == 15: # step no + # night ends: after all special roles acted, process the whole night + self.player_current_dead = [] # reset + + if self.player_hunted != self.player_protected and not self.is_hunted_player_saved: + self.player_current_dead.append(self.player_hunted) + if self.player_poisoned: + self.player_current_dead.append(self.player_poisoned) + + self._update_players_state(self.player_current_dead) + # reset + self.player_hunted = None + self.player_protected = None + self.is_hunted_player_saved = False + self.player_poisoned = None + elif step_idx == 18: + # updated use vote_kill_someone + pass diff --git a/metagpt/exp_pool/.DS_Store b/metagpt/exp_pool/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8cf43084c2765bf4ab919f31322bdc369c3ee1c8 Binary files /dev/null and b/metagpt/exp_pool/.DS_Store differ diff --git a/metagpt/exp_pool/__init__.py b/metagpt/exp_pool/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97d45a278b5fe08cb79c28fc9c5fdf1546087494 --- /dev/null +++ b/metagpt/exp_pool/__init__.py @@ -0,0 +1,6 @@ +"""Experience pool init.""" + +from metagpt.exp_pool.manager import get_exp_manager +from metagpt.exp_pool.decorator import exp_cache + +__all__ = ["get_exp_manager", "exp_cache"] diff --git a/metagpt/exp_pool/__pycache__/__init__.cpython-310.pyc b/metagpt/exp_pool/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b33f994dce76920c3c03d4c11c11efd70132a3d6 Binary files /dev/null and b/metagpt/exp_pool/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/exp_pool/__pycache__/__init__.cpython-39.pyc b/metagpt/exp_pool/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa75fd23c184191fe990c5585624230b32c76fe5 Binary files /dev/null and b/metagpt/exp_pool/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/exp_pool/__pycache__/decorator.cpython-310.pyc b/metagpt/exp_pool/__pycache__/decorator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0442d3138a9776e9f421d107bb93e5efdda34dbd Binary files /dev/null and b/metagpt/exp_pool/__pycache__/decorator.cpython-310.pyc differ diff --git a/metagpt/exp_pool/__pycache__/decorator.cpython-39.pyc b/metagpt/exp_pool/__pycache__/decorator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..beed446550a1156c6e622f4998f7ff35d64af0ab Binary files /dev/null and b/metagpt/exp_pool/__pycache__/decorator.cpython-39.pyc differ diff --git a/metagpt/exp_pool/__pycache__/manager.cpython-310.pyc b/metagpt/exp_pool/__pycache__/manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f4022025c6db975b5ad8b5f94f73513503b901b Binary files /dev/null and b/metagpt/exp_pool/__pycache__/manager.cpython-310.pyc differ diff --git a/metagpt/exp_pool/__pycache__/manager.cpython-39.pyc b/metagpt/exp_pool/__pycache__/manager.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f48f01354afd92f2538ac65b502056fdffed738 Binary files /dev/null and b/metagpt/exp_pool/__pycache__/manager.cpython-39.pyc differ diff --git a/metagpt/exp_pool/__pycache__/schema.cpython-310.pyc b/metagpt/exp_pool/__pycache__/schema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaf743a55c47499877cc90aa3986877ece49cf10 Binary files /dev/null and b/metagpt/exp_pool/__pycache__/schema.cpython-310.pyc differ diff --git a/metagpt/exp_pool/__pycache__/schema.cpython-39.pyc b/metagpt/exp_pool/__pycache__/schema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33e35c9a6b6e397db629548496b45497c236345a Binary files /dev/null and b/metagpt/exp_pool/__pycache__/schema.cpython-39.pyc differ diff --git a/metagpt/exp_pool/context_builders/__init__.py b/metagpt/exp_pool/context_builders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..047558be036f43ed120a7534e8cb9bd39bbc6ecf --- /dev/null +++ b/metagpt/exp_pool/context_builders/__init__.py @@ -0,0 +1,7 @@ +"""Context builders init.""" + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder +from metagpt.exp_pool.context_builders.simple import SimpleContextBuilder +from metagpt.exp_pool.context_builders.role_zero import RoleZeroContextBuilder + +__all__ = ["BaseContextBuilder", "SimpleContextBuilder", "RoleZeroContextBuilder"] diff --git a/metagpt/exp_pool/context_builders/__pycache__/__init__.cpython-310.pyc b/metagpt/exp_pool/context_builders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1acde12772f099a1462ac85f318ea7b0fbf18b6d Binary files /dev/null and b/metagpt/exp_pool/context_builders/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/exp_pool/context_builders/__pycache__/__init__.cpython-39.pyc b/metagpt/exp_pool/context_builders/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dafd6869ca2c866b11af6091d7c6b67e57a5bda Binary files /dev/null and b/metagpt/exp_pool/context_builders/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/exp_pool/context_builders/__pycache__/base.cpython-310.pyc b/metagpt/exp_pool/context_builders/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fc5a7e9bc014d4291fb63e9de0942a35c400e6c Binary files /dev/null and b/metagpt/exp_pool/context_builders/__pycache__/base.cpython-310.pyc differ diff --git a/metagpt/exp_pool/context_builders/__pycache__/base.cpython-39.pyc b/metagpt/exp_pool/context_builders/__pycache__/base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71f4cf13447d0e24cdd8ad0c19d588b606ffc0fc Binary files /dev/null and b/metagpt/exp_pool/context_builders/__pycache__/base.cpython-39.pyc differ diff --git a/metagpt/exp_pool/context_builders/__pycache__/role_zero.cpython-310.pyc b/metagpt/exp_pool/context_builders/__pycache__/role_zero.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41fe3d54caa5e5f4557d8b0702bf83d31b09f986 Binary files /dev/null and b/metagpt/exp_pool/context_builders/__pycache__/role_zero.cpython-310.pyc differ diff --git a/metagpt/exp_pool/context_builders/__pycache__/role_zero.cpython-39.pyc b/metagpt/exp_pool/context_builders/__pycache__/role_zero.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2cbc962a8c87776a2e0dd663b1381c9456363b8 Binary files /dev/null and b/metagpt/exp_pool/context_builders/__pycache__/role_zero.cpython-39.pyc differ diff --git a/metagpt/exp_pool/context_builders/__pycache__/simple.cpython-310.pyc b/metagpt/exp_pool/context_builders/__pycache__/simple.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..975bb58016976eeed5082be4434b54dddb02a5df Binary files /dev/null and b/metagpt/exp_pool/context_builders/__pycache__/simple.cpython-310.pyc differ diff --git a/metagpt/exp_pool/context_builders/__pycache__/simple.cpython-39.pyc b/metagpt/exp_pool/context_builders/__pycache__/simple.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecbe494729627777249d81dd1ff81daf15f5f966 Binary files /dev/null and b/metagpt/exp_pool/context_builders/__pycache__/simple.cpython-39.pyc differ diff --git a/metagpt/exp_pool/context_builders/action_node.py b/metagpt/exp_pool/context_builders/action_node.py new file mode 100644 index 0000000000000000000000000000000000000000..891b898be35bb5b08305565ce70e3995770e1027 --- /dev/null +++ b/metagpt/exp_pool/context_builders/action_node.py @@ -0,0 +1,30 @@ +"""Action Node context builder.""" + +from typing import Any + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder + +ACTION_NODE_CONTEXT_TEMPLATE = """ +{req} + +### Experiences +----- +{exps} +----- + +## Instruction +Consider **Experiences** to generate a better answer. +""" + + +class ActionNodeContextBuilder(BaseContextBuilder): + async def build(self, req: Any) -> str: + """Builds the action node context string. + + If there are no experiences, returns the original `req`; + otherwise returns context with `req` and formatted experiences. + """ + + exps = self.format_exps() + + return ACTION_NODE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req diff --git a/metagpt/exp_pool/context_builders/base.py b/metagpt/exp_pool/context_builders/base.py new file mode 100644 index 0000000000000000000000000000000000000000..691d51c8c515de40e0e6f92ff6bbbd4fc8912d71 --- /dev/null +++ b/metagpt/exp_pool/context_builders/base.py @@ -0,0 +1,41 @@ +"""Base context builder.""" + +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, ConfigDict + +from metagpt.exp_pool.schema import Experience + +EXP_TEMPLATE = """Given the request: {req}, We can get the response: {resp}, which scored: {score}.""" + + +class BaseContextBuilder(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + + exps: list[Experience] = [] + + @abstractmethod + async def build(self, req: Any) -> Any: + """Build context from req. + + Do not modify `req`. If modification is necessary, use copy.deepcopy to create a copy first. + """ + + def format_exps(self) -> str: + """Format experiences into a numbered list of strings. + + Example: + 1. Given the request: req1, We can get the response: resp1, which scored: 8. + 2. Given the request: req2, We can get the response: resp2, which scored: 9. + + Returns: + str: The formatted experiences as a string. + """ + + result = [] + for i, exp in enumerate(self.exps, start=1): + score_val = exp.metric.score.val if exp.metric and exp.metric.score else "N/A" + result.append(f"{i}. " + EXP_TEMPLATE.format(req=exp.req, resp=exp.resp, score=score_val)) + + return "\n".join(result) diff --git a/metagpt/exp_pool/context_builders/role_zero.py b/metagpt/exp_pool/context_builders/role_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..cbda72fc58f0e213150363a2c3ff0f48a35100b9 --- /dev/null +++ b/metagpt/exp_pool/context_builders/role_zero.py @@ -0,0 +1,39 @@ +"""RoleZero context builder.""" + +import copy +from typing import Any + +from metagpt.const import EXPERIENCE_MASK +from metagpt.exp_pool.context_builders.base import BaseContextBuilder + + +class RoleZeroContextBuilder(BaseContextBuilder): + async def build(self, req: Any) -> list[dict]: + """Builds the role zero context string. + + Note: + 1. The expected format for `req`, e.g., [{...}, {"role": "user", "content": "context"}]. + 2. Returns the original `req` if it is empty. + 3. Creates a copy of req and replaces the example content in the copied req with actual experiences. + """ + + if not req: + return req + + exps = self.format_exps() + if not exps: + return req + + req_copy = copy.deepcopy(req) + + req_copy[-1]["content"] = self.replace_example_content(req_copy[-1].get("content", ""), exps) + + return req_copy + + def replace_example_content(self, text: str, new_example_content: str) -> str: + return self.fill_experience(text, new_example_content) + + @staticmethod + def fill_experience(text: str, new_example_content: str) -> str: + replaced_text = text.replace(EXPERIENCE_MASK, new_example_content) + return replaced_text diff --git a/metagpt/exp_pool/context_builders/simple.py b/metagpt/exp_pool/context_builders/simple.py new file mode 100644 index 0000000000000000000000000000000000000000..d7b8d0be9afa8c1c8c8b2d8c58941440f2dc55c6 --- /dev/null +++ b/metagpt/exp_pool/context_builders/simple.py @@ -0,0 +1,26 @@ +"""Simple context builder.""" + + +from typing import Any + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder + +SIMPLE_CONTEXT_TEMPLATE = """ +## Context + +### Experiences +----- +{exps} +----- + +## User Requirement +{req} + +## Instruction +Consider **Experiences** to generate a better answer. +""" + + +class SimpleContextBuilder(BaseContextBuilder): + async def build(self, req: Any) -> str: + return SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=self.format_exps()) diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..bb285d31cd4327609f2bfd7bfdefe9a35196d51c --- /dev/null +++ b/metagpt/exp_pool/decorator.py @@ -0,0 +1,227 @@ +"""Experience Decorator.""" + +import asyncio +import functools +from typing import Any, Callable, Optional, TypeVar + +from pydantic import BaseModel, ConfigDict, model_validator + +from metagpt.config2 import config +from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder +from metagpt.exp_pool.manager import ExperienceManager, get_exp_manager +from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge +from metagpt.exp_pool.schema import ( + LOG_NEW_EXPERIENCE_PREFIX, + Experience, + Metric, + QueryType, + Score, +) +from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer +from metagpt.exp_pool.serializers import BaseSerializer, SimpleSerializer +from metagpt.logs import logger +from metagpt.utils.async_helper import NestAsyncio +from metagpt.utils.exceptions import handle_exception + +ReturnType = TypeVar("ReturnType") + + +def exp_cache( + _func: Optional[Callable[..., ReturnType]] = None, + query_type: QueryType = QueryType.SEMANTIC, + manager: Optional[ExperienceManager] = None, + scorer: Optional[BaseScorer] = None, + perfect_judge: Optional[BasePerfectJudge] = None, + context_builder: Optional[BaseContextBuilder] = None, + serializer: Optional[BaseSerializer] = None, + tag: Optional[str] = None, +): + """Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience. + + Note: + 1. This can be applied to both synchronous and asynchronous functions. + 2. The function must have a `req` parameter, and it must be provided as a keyword argument. + 3. If `config.exp_pool.enabled` is False, the decorator will just directly execute the function. + 4. If `config.exp_pool.enable_write` is False, the decorator will skip evaluating and saving the experience. + 5. If `config.exp_pool.enable_read` is False, the decorator will skip reading from the experience pool. + + + Args: + _func: Just to make the decorator more flexible, for example, it can be used directly with @exp_cache by default, without the need for @exp_cache(). + query_type: The type of query to be used when fetching experiences. + manager: How to fetch, evaluate and save experience, etc. Default to `exp_manager`. + scorer: Evaluate experience. Default to `SimpleScorer()`. + perfect_judge: Determines if an experience is perfect. Defaults to `SimplePerfectJudge()`. + context_builder: Build the context from exps and the function parameters. Default to `SimpleContextBuilder()`. + serializer: Serializes the request and the function's return value for storage, deserializes the stored response back to the function's return value. Defaults to `SimpleSerializer()`. + tag: An optional tag for the experience. Default to `ClassName.method_name` or `function_name`. + """ + + def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]: + @functools.wraps(func) + async def get_or_create(args: Any, kwargs: Any) -> ReturnType: + if not config.exp_pool.enabled: + rsp = func(*args, **kwargs) + return await rsp if asyncio.iscoroutine(rsp) else rsp + + handler = ExpCacheHandler( + func=func, + args=args, + kwargs=kwargs, + query_type=query_type, + exp_manager=manager, + exp_scorer=scorer, + exp_perfect_judge=perfect_judge, + context_builder=context_builder, + serializer=serializer, + tag=tag, + ) + + await handler.fetch_experiences() + + if exp := await handler.get_one_perfect_exp(): + return exp + + await handler.execute_function() + + if config.exp_pool.enable_write: + await handler.process_experience() + + return handler._raw_resp + + return ExpCacheHandler.choose_wrapper(func, get_or_create) + + return decorator(_func) if _func else decorator + + +class ExpCacheHandler(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + func: Callable + args: Any + kwargs: Any + query_type: QueryType = QueryType.SEMANTIC + exp_manager: Optional[ExperienceManager] = None + exp_scorer: Optional[BaseScorer] = None + exp_perfect_judge: Optional[BasePerfectJudge] = None + context_builder: Optional[BaseContextBuilder] = None + serializer: Optional[BaseSerializer] = None + tag: Optional[str] = None + + _exps: list[Experience] = None + _req: str = "" + _resp: str = "" + _raw_resp: Any = None + _score: Score = None + + @model_validator(mode="after") + def initialize(self): + """Initialize default values for optional parameters if they are None. + + This is necessary because the decorator might pass None, which would override the default values set by Field. + """ + + self._validate_params() + + self.exp_manager = self.exp_manager or get_exp_manager() + self.exp_scorer = self.exp_scorer or SimpleScorer() + self.exp_perfect_judge = self.exp_perfect_judge or SimplePerfectJudge() + self.context_builder = self.context_builder or SimpleContextBuilder() + self.serializer = self.serializer or SimpleSerializer() + self.tag = self.tag or self._generate_tag() + + self._req = self.serializer.serialize_req(**self.kwargs) + + return self + + async def fetch_experiences(self): + """Fetch experiences by query_type.""" + + self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag) + logger.info(f"Found {len(self._exps)} experiences for tag '{self.tag}'") + + async def get_one_perfect_exp(self) -> Optional[Any]: + """Get a potentially perfect experience, and resolve resp.""" + + for exp in self._exps: + if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs): + logger.info(f"Got one perfect experience for req '{exp.req[:20]}...'") + return self.serializer.deserialize_resp(exp.resp) + + return None + + async def execute_function(self): + """Execute the function, and save resp.""" + + self._raw_resp = await self._execute_function() + self._resp = self.serializer.serialize_resp(self._raw_resp) + + @handle_exception + async def process_experience(self): + """Process experience. + + Evaluates and saves experience. + Use `handle_exception` to ensure robustness, do not stop subsequent operations. + """ + + await self.evaluate_experience() + self.save_experience() + + async def evaluate_experience(self): + """Evaluate the experience, and save the score.""" + + self._score = await self.exp_scorer.evaluate(self._req, self._resp) + + def save_experience(self): + """Save the new experience.""" + + exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score)) + self.exp_manager.create_exp(exp) + self._log_exp(exp) + + @staticmethod + def choose_wrapper(func, wrapped_func): + """Choose how to run wrapped_func based on whether the function is asynchronous.""" + + async def async_wrapper(*args, **kwargs): + return await wrapped_func(args, kwargs) + + def sync_wrapper(*args, **kwargs): + NestAsyncio.apply_once() + return asyncio.get_event_loop().run_until_complete(wrapped_func(args, kwargs)) + + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper + + def _validate_params(self): + if "req" not in self.kwargs: + raise ValueError("`req` must be provided as a keyword argument.") + + def _generate_tag(self) -> str: + """Generates a tag for the self.func. + + "ClassName.method_name" if the first argument is a class instance, otherwise just "function_name". + """ + + if self.args and hasattr(self.args[0], "__class__"): + cls_name = type(self.args[0]).__name__ + return f"{cls_name}.{self.func.__name__}" + + return self.func.__name__ + + async def _build_context(self) -> str: + self.context_builder.exps = self._exps + + return await self.context_builder.build(self.kwargs["req"]) + + async def _execute_function(self): + self.kwargs["req"] = await self._build_context() + + if asyncio.iscoroutinefunction(self.func): + return await self.func(*self.args, **self.kwargs) + + return self.func(*self.args, **self.kwargs) + + def _log_exp(self, exp: Experience): + log_entry = exp.model_dump_json(include={"uuid", "req", "resp", "tag"}) + + logger.debug(f"{LOG_NEW_EXPERIENCE_PREFIX}{log_entry}") diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..35de17079bc09d7d457a91072544c03a4b3ff514 --- /dev/null +++ b/metagpt/exp_pool/manager.py @@ -0,0 +1,242 @@ +"""Experience Manager.""" + +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from pydantic import BaseModel, ConfigDict, Field + +from metagpt.config2 import Config +from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType +from metagpt.exp_pool.schema import DEFAULT_SIMILARITY_TOP_K, Experience, QueryType +from metagpt.logs import logger +from metagpt.utils.exceptions import handle_exception + +if TYPE_CHECKING: + from metagpt.rag.engines import SimpleEngine + + +class ExperienceManager(BaseModel): + """ExperienceManager manages the lifecycle of experiences, including CRUD and optimization. + + Args: + config (Config): Configuration for managing experiences. + _storage (SimpleEngine): Engine to handle the storage and retrieval of experiences. + _vector_store (ChromaVectorStore): The actual place where vectors are stored. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + config: Config = Field(default_factory=Config.default) + + _storage: Any = None + + @property + def storage(self) -> "SimpleEngine": + if self._storage is None: + logger.info(f"exp_pool config: {self.config.exp_pool}") + + self._storage = self._resolve_storage() + + return self._storage + + @storage.setter + def storage(self, value): + self._storage = value + + @property + def is_readable(self) -> bool: + return self.config.exp_pool.enabled and self.config.exp_pool.enable_read + + @is_readable.setter + def is_readable(self, value: bool): + self.config.exp_pool.enable_read = value + + # If set to True, ensure that enabled is also True. + if value: + self.config.exp_pool.enabled = True + + @property + def is_writable(self) -> bool: + return self.config.exp_pool.enabled and self.config.exp_pool.enable_write + + @is_writable.setter + def is_writable(self, value: bool): + self.config.exp_pool.enable_write = value + + # If set to True, ensure that enabled is also True. + if value: + self.config.exp_pool.enabled = True + + @handle_exception + def create_exp(self, exp: Experience): + """Adds an experience to the storage if writing is enabled. + + Args: + exp (Experience): The experience to add. + """ + + self.create_exps([exp]) + + @handle_exception + def create_exps(self, exps: list[Experience]): + """Adds multiple experiences to the storage if writing is enabled. + + Args: + exps (list[Experience]): A list of experiences to add. + """ + if not self.is_writable: + return + + self.storage.add_objs(exps) + self.storage.persist(self.config.exp_pool.persist_path) + + @handle_exception(default_return=[]) + async def query_exps(self, req: str, tag: str = "", query_type: QueryType = QueryType.SEMANTIC) -> list[Experience]: + """Retrieves and filters experiences. + + Args: + req (str): The query string to retrieve experiences. + tag (str): Optional tag to filter the experiences by. + query_type (QueryType): Default semantic to vector matching. exact to same matching. + + Returns: + list[Experience]: A list of experiences that match the args. + """ + + if not self.is_readable: + return [] + + nodes = await self.storage.aretrieve(req) + exps: list[Experience] = [node.metadata["obj"] for node in nodes] + + # TODO: filter by metadata + if tag: + exps = [exp for exp in exps if exp.tag == tag] + + if query_type == QueryType.EXACT: + exps = [exp for exp in exps if exp.req == req] + + return exps + + @handle_exception + def delete_all_exps(self): + """Delete the all experiences.""" + + if not self.is_writable: + return + + self.storage.clear(persist_dir=self.config.exp_pool.persist_path) + + def get_exps_count(self) -> int: + """Get the total number of experiences.""" + + return self.storage.count() + + def _resolve_storage(self) -> "SimpleEngine": + """Selects the appropriate storage creation method based on the configured retrieval type.""" + + storage_creators = { + ExperiencePoolRetrievalType.BM25: self._create_bm25_storage, + ExperiencePoolRetrievalType.CHROMA: self._create_chroma_storage, + } + + return storage_creators[self.config.exp_pool.retrieval_type]() + + def _create_bm25_storage(self) -> "SimpleEngine": + """Creates or loads BM25 storage. + + This function attempts to create a new BM25 storage if the specified + document store path does not exist. If the path exists, it loads the + existing BM25 storage. + + Returns: + SimpleEngine: An instance of SimpleEngine configured with BM25 storage. + + Raises: + ImportError: If required modules are not installed. + """ + + try: + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import BM25IndexConfig, BM25RetrieverConfig + except ImportError: + raise ImportError("To use the experience pool, you need to install the rag module.") + + persist_path = Path(self.config.exp_pool.persist_path) + docstore_path = persist_path / "docstore.json" + + ranker_configs = self._get_ranker_configs() + + if not docstore_path.exists(): + logger.debug(f"Path `{docstore_path}` not exists, try to create a new bm25 storage.") + exps = [Experience(req="req", resp="resp")] + + retriever_configs = [BM25RetrieverConfig(create_index=True, similarity_top_k=DEFAULT_SIMILARITY_TOP_K)] + + storage = SimpleEngine.from_objs( + objs=exps, retriever_configs=retriever_configs, ranker_configs=ranker_configs + ) + return storage + + logger.debug(f"Path `{docstore_path}` exists, try to load bm25 storage.") + retriever_configs = [BM25RetrieverConfig(similarity_top_k=DEFAULT_SIMILARITY_TOP_K)] + storage = SimpleEngine.from_index( + BM25IndexConfig(persist_path=persist_path), + retriever_configs=retriever_configs, + ranker_configs=ranker_configs, + ) + + return storage + + def _create_chroma_storage(self) -> "SimpleEngine": + """Creates Chroma storage. + + Returns: + SimpleEngine: An instance of SimpleEngine configured with Chroma storage. + + Raises: + ImportError: If required modules are not installed. + """ + + try: + from metagpt.rag.engines import SimpleEngine + from metagpt.rag.schema import ChromaRetrieverConfig + except ImportError: + raise ImportError("To use the experience pool, you need to install the rag module.") + + retriever_configs = [ + ChromaRetrieverConfig( + persist_path=self.config.exp_pool.persist_path, + collection_name=self.config.exp_pool.collection_name, + similarity_top_k=DEFAULT_SIMILARITY_TOP_K, + ) + ] + ranker_configs = self._get_ranker_configs() + + storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs) + + return storage + + def _get_ranker_configs(self): + """Returns ranker configurations based on the configuration. + + If `use_llm_ranker` is True, returns a list with one `LLMRankerConfig` + instance. Otherwise, returns an empty list. + + Returns: + list: A list of `LLMRankerConfig` instances or an empty list. + """ + + from metagpt.rag.schema import LLMRankerConfig + + return [LLMRankerConfig(top_n=DEFAULT_SIMILARITY_TOP_K)] if self.config.exp_pool.use_llm_ranker else [] + + +_exp_manager = None + + +def get_exp_manager() -> ExperienceManager: + global _exp_manager + if _exp_manager is None: + _exp_manager = ExperienceManager() + return _exp_manager diff --git a/metagpt/exp_pool/perfect_judges/__init__.py b/metagpt/exp_pool/perfect_judges/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d8796c7c8591762796c0d45ee567a0d046316f9d --- /dev/null +++ b/metagpt/exp_pool/perfect_judges/__init__.py @@ -0,0 +1,6 @@ +"""Perfect judges init.""" + +from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge +from metagpt.exp_pool.perfect_judges.simple import SimplePerfectJudge + +__all__ = ["BasePerfectJudge", "SimplePerfectJudge"] diff --git a/metagpt/exp_pool/perfect_judges/__pycache__/__init__.cpython-310.pyc b/metagpt/exp_pool/perfect_judges/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..818ace858723912ee105e67d96fec7412e3cbd2e Binary files /dev/null and b/metagpt/exp_pool/perfect_judges/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/exp_pool/perfect_judges/__pycache__/__init__.cpython-39.pyc b/metagpt/exp_pool/perfect_judges/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55412172e01c0c462b67ccb87d48a7dd39f14450 Binary files /dev/null and b/metagpt/exp_pool/perfect_judges/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/exp_pool/perfect_judges/__pycache__/base.cpython-310.pyc b/metagpt/exp_pool/perfect_judges/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99385b7457122cf14bdad857181e922838f54ee5 Binary files /dev/null and b/metagpt/exp_pool/perfect_judges/__pycache__/base.cpython-310.pyc differ diff --git a/metagpt/exp_pool/perfect_judges/__pycache__/base.cpython-39.pyc b/metagpt/exp_pool/perfect_judges/__pycache__/base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0a579bb71b6d3cb6b68f4449d934c4b9cc7d482 Binary files /dev/null and b/metagpt/exp_pool/perfect_judges/__pycache__/base.cpython-39.pyc differ diff --git a/metagpt/exp_pool/perfect_judges/__pycache__/simple.cpython-310.pyc b/metagpt/exp_pool/perfect_judges/__pycache__/simple.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffc9f03f33507412a419cdadff064d0fb3e96060 Binary files /dev/null and b/metagpt/exp_pool/perfect_judges/__pycache__/simple.cpython-310.pyc differ diff --git a/metagpt/exp_pool/perfect_judges/__pycache__/simple.cpython-39.pyc b/metagpt/exp_pool/perfect_judges/__pycache__/simple.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8454686bc032e7c13a20cb577d2169824f4c17f Binary files /dev/null and b/metagpt/exp_pool/perfect_judges/__pycache__/simple.cpython-39.pyc differ diff --git a/metagpt/exp_pool/perfect_judges/base.py b/metagpt/exp_pool/perfect_judges/base.py new file mode 100644 index 0000000000000000000000000000000000000000..2935229931f74f9c21b224b1ff2f4bb538e6f802 --- /dev/null +++ b/metagpt/exp_pool/perfect_judges/base.py @@ -0,0 +1,20 @@ +"""Base perfect judge.""" + +from abc import ABC, abstractmethod + +from pydantic import BaseModel, ConfigDict + +from metagpt.exp_pool.schema import Experience + + +class BasePerfectJudge(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + + @abstractmethod + async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool: + """Determine whether the experience is perfect. + + Args: + exp (Experience): The experience to evaluate. + serialized_req (str): The serialized request to compare against the experience's request. + """ diff --git a/metagpt/exp_pool/perfect_judges/simple.py b/metagpt/exp_pool/perfect_judges/simple.py new file mode 100644 index 0000000000000000000000000000000000000000..37ede95c395c73c6c7aeb16858166cab19540c4b --- /dev/null +++ b/metagpt/exp_pool/perfect_judges/simple.py @@ -0,0 +1,27 @@ +"""Simple perfect judge.""" + + +from pydantic import ConfigDict + +from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge +from metagpt.exp_pool.schema import MAX_SCORE, Experience + + +class SimplePerfectJudge(BasePerfectJudge): + model_config = ConfigDict(arbitrary_types_allowed=True) + + async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool: + """Determine whether the experience is perfect. + + Args: + exp (Experience): The experience to evaluate. + serialized_req (str): The serialized request to compare against the experience's request. + + Returns: + bool: True if the serialized request matches the experience's request and the experience's score is perfect, False otherwise. + """ + + if not exp.metric or not exp.metric.score: + return False + + return serialized_req == exp.req and exp.metric.score.val == MAX_SCORE diff --git a/metagpt/exp_pool/schema.py b/metagpt/exp_pool/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..fea48a7f7d189afde8f9f430caa262a9171131f5 --- /dev/null +++ b/metagpt/exp_pool/schema.py @@ -0,0 +1,76 @@ +"""Experience schema.""" +import time +from enum import Enum +from typing import Optional +from uuid import UUID, uuid4 + +from pydantic import BaseModel, Field + +MAX_SCORE = 10 + +DEFAULT_SIMILARITY_TOP_K = 2 + +LOG_NEW_EXPERIENCE_PREFIX = "New experience: " + + +class QueryType(str, Enum): + """Type of query experiences.""" + + EXACT = "exact" + SEMANTIC = "semantic" + + +class ExperienceType(str, Enum): + """Experience Type.""" + + SUCCESS = "success" + FAILURE = "failure" + INSIGHT = "insight" + + +class EntryType(Enum): + """Experience Entry Type.""" + + AUTOMATIC = "Automatic" + MANUAL = "Manual" + + +class Score(BaseModel): + """Score in Metric.""" + + val: int = Field(default=1, description="Value of the score, Between 1 and 10, higher is better.") + reason: str = Field(default="", description="Reason for the value.") + + +class Metric(BaseModel): + """Experience Metric.""" + + time_cost: float = Field(default=0.000, description="Time cost, the unit is milliseconds.") + money_cost: float = Field(default=0.000, description="Money cost, the unit is US dollars.") + score: Score = Field(default=None, description="Score, with value and reason.") + + +class Trajectory(BaseModel): + """Experience Trajectory.""" + + plan: str = Field(default="", description="The plan.") + action: str = Field(default="", description="Action for the plan.") + observation: str = Field(default="", description="Output of the action.") + reward: int = Field(default=0, description="Measure the action.") + + +class Experience(BaseModel): + """Experience.""" + + req: str = Field(..., description="") + resp: str = Field(..., description="The type is string/json/code.") + metric: Optional[Metric] = Field(default=None, description="Metric.") + exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.") + entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.") + tag: str = Field(default="", description="Tagging experience.") + traj: Optional[Trajectory] = Field(default=None, description="Trajectory.") + timestamp: Optional[float] = Field(default_factory=time.time) + uuid: Optional[UUID] = Field(default_factory=uuid4) + + def rag_key(self): + return self.req diff --git a/metagpt/exp_pool/scorers/__init__.py b/metagpt/exp_pool/scorers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..caa845b143f0ec91ec73c1cdfd23f07a331dd3aa --- /dev/null +++ b/metagpt/exp_pool/scorers/__init__.py @@ -0,0 +1,6 @@ +"""Scorers init.""" + +from metagpt.exp_pool.scorers.base import BaseScorer +from metagpt.exp_pool.scorers.simple import SimpleScorer + +__all__ = ["BaseScorer", "SimpleScorer"] diff --git a/metagpt/exp_pool/scorers/__pycache__/__init__.cpython-310.pyc b/metagpt/exp_pool/scorers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89e415aafddea91db01d31b1f12de5a6e19e11ff Binary files /dev/null and b/metagpt/exp_pool/scorers/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/exp_pool/scorers/__pycache__/__init__.cpython-39.pyc b/metagpt/exp_pool/scorers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17373034559fcffd986a9e2a82df410188019f97 Binary files /dev/null and b/metagpt/exp_pool/scorers/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/exp_pool/scorers/__pycache__/base.cpython-310.pyc b/metagpt/exp_pool/scorers/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf8f6b8a2e9a52fe4cde8ba9464bd48b7326e011 Binary files /dev/null and b/metagpt/exp_pool/scorers/__pycache__/base.cpython-310.pyc differ diff --git a/metagpt/exp_pool/scorers/__pycache__/base.cpython-39.pyc b/metagpt/exp_pool/scorers/__pycache__/base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..367b1164d291dacfbe32adb91efc6524ca6e2f9e Binary files /dev/null and b/metagpt/exp_pool/scorers/__pycache__/base.cpython-39.pyc differ diff --git a/metagpt/exp_pool/scorers/__pycache__/simple.cpython-310.pyc b/metagpt/exp_pool/scorers/__pycache__/simple.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..823e3567b59ca6843f5d2a52a5e95ea65bbcd36f Binary files /dev/null and b/metagpt/exp_pool/scorers/__pycache__/simple.cpython-310.pyc differ diff --git a/metagpt/exp_pool/scorers/__pycache__/simple.cpython-39.pyc b/metagpt/exp_pool/scorers/__pycache__/simple.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc92bf751fc621eb5f3cec270b211cf4e7230b8f Binary files /dev/null and b/metagpt/exp_pool/scorers/__pycache__/simple.cpython-39.pyc differ diff --git a/metagpt/exp_pool/scorers/base.py b/metagpt/exp_pool/scorers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..97cac49925fb81b7d417e10e8056187678cdc598 --- /dev/null +++ b/metagpt/exp_pool/scorers/base.py @@ -0,0 +1,15 @@ +"""Base scorer.""" + +from abc import ABC, abstractmethod + +from pydantic import BaseModel, ConfigDict + +from metagpt.exp_pool.schema import Score + + +class BaseScorer(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + + @abstractmethod + async def evaluate(self, req: str, resp: str) -> Score: + """Evaluates the quality of a response relative to a given request.""" diff --git a/metagpt/exp_pool/scorers/simple.py b/metagpt/exp_pool/scorers/simple.py new file mode 100644 index 0000000000000000000000000000000000000000..4b060aac4f0e60c4d06272404c5afe633812c8de --- /dev/null +++ b/metagpt/exp_pool/scorers/simple.py @@ -0,0 +1,65 @@ +"""Simple scorer.""" + +import json + +from pydantic import Field + +from metagpt.exp_pool.schema import Score +from metagpt.exp_pool.scorers.base import BaseScorer +from metagpt.llm import LLM +from metagpt.provider.base_llm import BaseLLM +from metagpt.utils.common import CodeParser + +SIMPLE_SCORER_TEMPLATE = """ +Role: You are a highly efficient assistant, tasked with evaluating a response to a given request. The response is generated by a large language model (LLM). + +I will provide you with a request and a corresponding response. Your task is to assess this response and provide a score from a human perspective. + +## Context +### Request +{req} + +### Response +{resp} + +## Format Example +```json +{{ + "val": "the value of the score, int from 1 to 10, higher is better.", + "reason": "an explanation supporting the score." +}} +``` + +## Instructions +- Understand the request and response given by the user. +- Evaluate the response based on its quality relative to the given request. +- Provide a score from 1 to 10, where 10 is the best. +- Provide a reason supporting your score. + +## Constraint +Format: Just print the result in json format like **Format Example**. + +## Action +Follow instructions, generate output and make sure it follows the **Constraint**. +""" + + +class SimpleScorer(BaseScorer): + llm: BaseLLM = Field(default_factory=LLM) + + async def evaluate(self, req: str, resp: str) -> Score: + """Evaluates the quality of a response relative to a given request, as scored by an LLM. + + Args: + req (str): The request. + resp (str): The response. + + Returns: + Score: An object containing the score (1-10) and the reasoning. + """ + + prompt = SIMPLE_SCORER_TEMPLATE.format(req=req, resp=resp) + resp = await self.llm.aask(prompt) + resp_json = json.loads(CodeParser.parse_code(resp, lang="json")) + + return Score(**resp_json) diff --git a/metagpt/exp_pool/serializers/__init__.py b/metagpt/exp_pool/serializers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e1045588ef021c4de94a62cc51a342f108a6266 --- /dev/null +++ b/metagpt/exp_pool/serializers/__init__.py @@ -0,0 +1,9 @@ +"""Serializers init.""" + +from metagpt.exp_pool.serializers.base import BaseSerializer +from metagpt.exp_pool.serializers.simple import SimpleSerializer +from metagpt.exp_pool.serializers.action_node import ActionNodeSerializer +from metagpt.exp_pool.serializers.role_zero import RoleZeroSerializer + + +__all__ = ["BaseSerializer", "SimpleSerializer", "ActionNodeSerializer", "RoleZeroSerializer"] diff --git a/metagpt/exp_pool/serializers/__pycache__/__init__.cpython-310.pyc b/metagpt/exp_pool/serializers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65c961da4e8c04102daa7263ca94359acd643551 Binary files /dev/null and b/metagpt/exp_pool/serializers/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/exp_pool/serializers/__pycache__/__init__.cpython-39.pyc b/metagpt/exp_pool/serializers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ebebc832cbd87e1c2f3db4351796e9505fa9994 Binary files /dev/null and b/metagpt/exp_pool/serializers/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/exp_pool/serializers/__pycache__/action_node.cpython-310.pyc b/metagpt/exp_pool/serializers/__pycache__/action_node.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2dd7e4e50e45ebddf659fdb4be8805ecae78b55b Binary files /dev/null and b/metagpt/exp_pool/serializers/__pycache__/action_node.cpython-310.pyc differ diff --git a/metagpt/exp_pool/serializers/__pycache__/action_node.cpython-39.pyc b/metagpt/exp_pool/serializers/__pycache__/action_node.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3139201e89769ca85507b384a39cd0f12523e8c0 Binary files /dev/null and b/metagpt/exp_pool/serializers/__pycache__/action_node.cpython-39.pyc differ diff --git a/metagpt/exp_pool/serializers/__pycache__/base.cpython-310.pyc b/metagpt/exp_pool/serializers/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bc67db3a23c4851080cfa7156d7cf61329e6853 Binary files /dev/null and b/metagpt/exp_pool/serializers/__pycache__/base.cpython-310.pyc differ diff --git a/metagpt/exp_pool/serializers/__pycache__/base.cpython-39.pyc b/metagpt/exp_pool/serializers/__pycache__/base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..970fc6e369585f94583e6b4fa55320f0c70f225b Binary files /dev/null and b/metagpt/exp_pool/serializers/__pycache__/base.cpython-39.pyc differ diff --git a/metagpt/exp_pool/serializers/__pycache__/role_zero.cpython-310.pyc b/metagpt/exp_pool/serializers/__pycache__/role_zero.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34ba87d13d9f9f59f0d1decbc7a8c9935b88a77e Binary files /dev/null and b/metagpt/exp_pool/serializers/__pycache__/role_zero.cpython-310.pyc differ diff --git a/metagpt/exp_pool/serializers/__pycache__/role_zero.cpython-39.pyc b/metagpt/exp_pool/serializers/__pycache__/role_zero.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..536f18fccce627b6882c19cd958eb134c28cb58e Binary files /dev/null and b/metagpt/exp_pool/serializers/__pycache__/role_zero.cpython-39.pyc differ diff --git a/metagpt/exp_pool/serializers/__pycache__/simple.cpython-310.pyc b/metagpt/exp_pool/serializers/__pycache__/simple.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d83ec06b3d27e8d4dbad38de9ee0a40e7f51d6ad Binary files /dev/null and b/metagpt/exp_pool/serializers/__pycache__/simple.cpython-310.pyc differ diff --git a/metagpt/exp_pool/serializers/__pycache__/simple.cpython-39.pyc b/metagpt/exp_pool/serializers/__pycache__/simple.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bfc777dae95f498829cc61018a5c4f8928a18de Binary files /dev/null and b/metagpt/exp_pool/serializers/__pycache__/simple.cpython-39.pyc differ diff --git a/metagpt/exp_pool/serializers/action_node.py b/metagpt/exp_pool/serializers/action_node.py new file mode 100644 index 0000000000000000000000000000000000000000..7746d6be47da531d4fc6f0f760377941e0730e55 --- /dev/null +++ b/metagpt/exp_pool/serializers/action_node.py @@ -0,0 +1,36 @@ +"""ActionNode Serializer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Type + +# Import ActionNode only for type checking to avoid circular imports +if TYPE_CHECKING: + from metagpt.actions.action_node import ActionNode + +from metagpt.exp_pool.serializers.simple import SimpleSerializer + + +class ActionNodeSerializer(SimpleSerializer): + def serialize_resp(self, resp: ActionNode) -> str: + return resp.instruct_content.model_dump_json() + + def deserialize_resp(self, resp: str) -> ActionNode: + """Customized deserialization, it will be triggered when a perfect experience is found. + + ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'. + """ + + class InstructContent: + def __init__(self, json_data): + self.json_data = json_data + + def model_dump_json(self): + return self.json_data + + from metagpt.actions.action_node import ActionNode + + action_node = ActionNode(key="", expected_type=Type[str], instruction="", example="") + action_node.instruct_content = InstructContent(resp) + + return action_node diff --git a/metagpt/exp_pool/serializers/base.py b/metagpt/exp_pool/serializers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c09488e1211554ec9c7d8e09a512ef0f5b69bbfc --- /dev/null +++ b/metagpt/exp_pool/serializers/base.py @@ -0,0 +1,29 @@ +"""Base serializer.""" + +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class BaseSerializer(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + + @abstractmethod + def serialize_req(self, **kwargs) -> str: + """Serializes the request for storage. + + Do not modify kwargs. If modification is necessary, use copy.deepcopy to create a copy first. + Note that copy.deepcopy may raise errors, such as TypeError: cannot pickle '_thread.RLock' object. + """ + + @abstractmethod + def serialize_resp(self, resp: Any) -> str: + """Serializes the function's return value for storage. + + Do not modify resp. The rest is the same as `serialize_req`. + """ + + @abstractmethod + def deserialize_resp(self, resp: str) -> Any: + """Deserializes the stored response back to the function's return value""" diff --git a/metagpt/exp_pool/serializers/role_zero.py b/metagpt/exp_pool/serializers/role_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..89dd73f391c0969a334c5ebdc1f637604dc04ea2 --- /dev/null +++ b/metagpt/exp_pool/serializers/role_zero.py @@ -0,0 +1,58 @@ +"""RoleZero Serializer.""" + +import copy +import json + +from metagpt.exp_pool.serializers.simple import SimpleSerializer + + +class RoleZeroSerializer(SimpleSerializer): + def serialize_req(self, **kwargs) -> str: + """Serialize the request for database storage, ensuring it is a string. + + Only extracts the necessary content from `req` because `req` may be very lengthy and could cause embedding errors. + + Args: + req (list[dict]): The request to be serialized. Example: + [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."}, + {"role": "user", "content": "context"}, + ] + + Returns: + str: The serialized request as a JSON string. + """ + req = kwargs.get("req", []) + + if not req: + return "" + + filtered_req = self._filter_req(req) + + if state_data := kwargs.get("state_data"): + filtered_req.append({"role": "user", "content": state_data}) + + return json.dumps(filtered_req) + + def _filter_req(self, req: list[dict]) -> list[dict]: + """Filter the `req` to include only necessary items. + + Args: + req (list[dict]): The original request. + + Returns: + list[dict]: The filtered request. + """ + + filtered_req = [copy.deepcopy(item) for item in req if self._is_useful_content(item["content"])] + + return filtered_req + + def _is_useful_content(self, content: str) -> bool: + """Currently, only the content of the file is considered, and more judgments can be added later.""" + + if "Command Editor.read executed: file_path" in content: + return True + + return False diff --git a/metagpt/exp_pool/serializers/simple.py b/metagpt/exp_pool/serializers/simple.py new file mode 100644 index 0000000000000000000000000000000000000000..ebd06e0e0cceb580f2432b919add0668e2960b78 --- /dev/null +++ b/metagpt/exp_pool/serializers/simple.py @@ -0,0 +1,22 @@ +"""Simple Serializer.""" + +from typing import Any + +from metagpt.exp_pool.serializers.base import BaseSerializer + + +class SimpleSerializer(BaseSerializer): + def serialize_req(self, **kwargs) -> str: + """Just use `str` to convert the request object into a string.""" + + return str(kwargs.get("req", "")) + + def serialize_resp(self, resp: Any) -> str: + """Just use `str` to convert the response object into a string.""" + + return str(resp) + + def deserialize_resp(self, resp: str) -> Any: + """Just return the string response as it is.""" + + return resp diff --git a/metagpt/ext/.DS_Store b/metagpt/ext/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..7284b0d449934b20405c0cfa618f44c06b340f82 Binary files /dev/null and b/metagpt/ext/.DS_Store differ diff --git a/metagpt/ext/__init__.py b/metagpt/ext/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/ext/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/ext/__pycache__/__init__.cpython-310.pyc b/metagpt/ext/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fb0fce30d89877a234020237f204cb77f58e9e0 Binary files /dev/null and b/metagpt/ext/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/ext/__pycache__/__init__.cpython-39.pyc b/metagpt/ext/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dc916cb9d12d80892ea3d8253e1b15abdd8dc92 Binary files /dev/null and b/metagpt/ext/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/ext/aflow/.DS_Store b/metagpt/ext/aflow/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8f4cb5693d8da66f2d28ad67b42e07f307fc618d Binary files /dev/null and b/metagpt/ext/aflow/.DS_Store differ diff --git a/metagpt/ext/aflow/benchmark/README.md b/metagpt/ext/aflow/benchmark/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4a2464fd1214cfe7b3072ff6b8004b3e1969e774 --- /dev/null +++ b/metagpt/ext/aflow/benchmark/README.md @@ -0,0 +1,29 @@ +# Custom Evaluation Function via Benchmark Class + +## How to Use + +To create a benchmark for a new dataset, follow these steps: + +1. Create a new Python file, e.g., `my_dataset_benchmark.py` +2. Import the base class: + ```python + from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark + ``` +3. Create a new class that inherits from `BaseBenchmark`: + ```python + class MyDatasetBenchmark(BaseBenchmark): + def __init__(self, name: str, file_path: str, log_path: str): + super().__init__(name, file_path, log_path) + ``` +4. Implement the required abstract methods: + - `evaluate_problem`: Evaluate a single problem + - `calculate_score`: Calculate the score for a prediction + - `get_result_columns`: Define column names for the results CSV file + +5. Override other methods as needed, such as `load_data` or `save_results_to_csv` + +## Example + +Refer to the `DROPBenchmark` class in the `drop.py` file for an example of how to implement a benchmark for a specific dataset. + +By following these guidelines, you can easily create custom benchmark evaluations for new datasets. diff --git a/metagpt/ext/aflow/benchmark/benchmark.py b/metagpt/ext/aflow/benchmark/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..b5692f01e6a4f5a517ee58074f9fcb55c88c9ddb --- /dev/null +++ b/metagpt/ext/aflow/benchmark/benchmark.py @@ -0,0 +1,104 @@ +import asyncio +import json +import os +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, List, Tuple + +import aiofiles +import pandas as pd +from tqdm.asyncio import tqdm_asyncio + +from metagpt.logs import logger +from metagpt.utils.common import write_json_file + + +class BaseBenchmark(ABC): + def __init__(self, name: str, file_path: str, log_path: str): + self.name = name + self.file_path = file_path + self.log_path = log_path + + PASS = "PASS" + FAIL = "FAIL" + + async def load_data(self, specific_indices: List[int] = None) -> List[dict]: + data = [] + async with aiofiles.open(self.file_path, mode="r", encoding="utf-8") as file: + async for line in file: + data.append(json.loads(line)) + if specific_indices is not None: + filtered_data = [data[i] for i in specific_indices if i < len(data)] + return filtered_data + return data + + def save_results_to_csv(self, results: List[Tuple[Any, ...]], columns: List[str]): + df = pd.DataFrame(results, columns=columns) + avg_score = df["score"].mean() + t_cost = df["cost"].max() + a_cost = t_cost / len(df) if len(df) > 0 else 0 + current_time = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"{avg_score:.5f}_{current_time}.csv" + output_file = os.path.join(self.log_path, filename) + df.to_csv(output_file, index=False) + logger.info(f"Results saved to {output_file}") + return avg_score, a_cost, t_cost + + def log_mismatch( + self, + problem: str, + expected_output: Any, + prediction: str, + extracted_output: Any, + extract_answer_code: str = "None", + ): + log_data = { + "question": problem, + "right_answer": expected_output, + "model_output": prediction, + "extracted_output": extracted_output, + "extract_answer_code": extract_answer_code, + } + log_file = Path(self.log_path) / "log.json" + if log_file.exists(): + with log_file.open("r", encoding="utf-8") as f: + try: + data = json.load(f) + except json.JSONDecodeError: + data = [] + else: + data = [] + data.append(log_data) + write_json_file(log_file, data, encoding="utf-8", indent=4) + + @abstractmethod + async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[Any, ...]: + pass + + @abstractmethod + def calculate_score(self, expected_output: Any, prediction: Any) -> Tuple[float, Any]: + pass + + @abstractmethod + def get_result_columns(self) -> List[str]: + pass + + async def evaluate_all_problems(self, data: List[dict], graph: Callable, max_concurrent_tasks: int = 50): + semaphore = asyncio.Semaphore(max_concurrent_tasks) + + async def sem_evaluate(problem): + async with semaphore: + return await self.evaluate_problem(problem, graph) + + tasks = [sem_evaluate(problem) for problem in data] + return await tqdm_asyncio.gather(*tasks, desc=f"Evaluating {self.name} problems", total=len(data)) + + async def run_evaluation(self, graph: Callable, va_list: List[int], max_concurrent_tasks: int = 50): + data = await self.load_data(va_list) + results = await self.evaluate_all_problems(data, graph, max_concurrent_tasks) + columns = self.get_result_columns() + average_score, average_cost, total_cost = self.save_results_to_csv(results, columns) + logger.info(f"Average score on {self.name} dataset: {average_score:.5f}") + logger.info(f"Total Cost: {total_cost:.5f}") + return average_score, average_cost, total_cost diff --git a/metagpt/ext/aflow/benchmark/drop.py b/metagpt/ext/aflow/benchmark/drop.py new file mode 100644 index 0000000000000000000000000000000000000000..3cec5795fb4a1c8f4a31020fe3930ac39433db78 --- /dev/null +++ b/metagpt/ext/aflow/benchmark/drop.py @@ -0,0 +1,83 @@ +import re +import string +from collections import Counter +from typing import Callable, List, Tuple + +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed + +from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark +from metagpt.logs import logger + + +class DROPBenchmark(BaseBenchmark): + def __init__(self, name: str, file_path: str, log_path: str): + super().__init__(name, file_path, log_path) + + def normalize_answer(self, s: str) -> List[str]: + """ + Normalize answers for evaluation. + """ + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + def calculate_score(self, ground_truth: str, prediction: str) -> Tuple[float, str]: + """ + Compute the F1 score between prediction and ground truth answers. + """ + prediction_tokens = self.normalize_answer(prediction).split() + ground_truth_tokens = self.normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0, prediction + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1, prediction + + @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), retry=retry_if_exception_type(Exception), reraise=True) + async def _generate_output(self, graph, input_text): + return await graph(input_text) + + async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[str, str, str, float, float]: + input_text = problem["context"] + expected_output = problem["ref_text"] + answers = expected_output.split("|") + + try: + output, cost = await self._generate_output(graph, input_text) + f1_scores = [] + + for answer in answers: + if answer.strip() != "": + output_parts = output.split("|") + for output_part in output_parts: + f1_score, _ = self.calculate_score(answer, output_part) + f1_scores.append(f1_score) + + uni_score = max(f1_scores) + + if uni_score < 0.3: + self.log_mismatch(input_text, expected_output, output, output) + + return input_text, output, expected_output, uni_score, cost + + except Exception as e: + logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}") + return input_text, str(e), expected_output, 0.0, 0.0 + + def get_result_columns(self) -> List[str]: + return ["inputs", "prediction", "expected_output", "score", "cost"] diff --git a/metagpt/ext/aflow/benchmark/gsm8k.py b/metagpt/ext/aflow/benchmark/gsm8k.py new file mode 100644 index 0000000000000000000000000000000000000000..51979c0c578626fa48629c9b515f849f5e92fdd7 --- /dev/null +++ b/metagpt/ext/aflow/benchmark/gsm8k.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# @Date : +# @Author : all +# @Desc : test on gsm8k +import re +from typing import Callable, List, Optional, Tuple + +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed + +from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark +from metagpt.logs import logger + + +class GSM8KBenchmark(BaseBenchmark): + def __init__(self, name: str, file_path: str, log_path: str): + super().__init__(name, file_path, log_path) + + def extract_number(self, text: str) -> Optional[float]: + matches = re.findall(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?|\d+\.\d+", str(text)) + if matches: + last_number = matches[-1].replace(",", "") + try: + return float(last_number) + except ValueError: + return None + else: + return None + + def calculate_score(self, expected_output: float, prediction: float) -> Tuple[float, float]: + if prediction is None: + return 0.0, prediction + return 1.0 if abs(expected_output - prediction) <= 1e-6 else 0.0, prediction + + @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), retry=retry_if_exception_type(Exception), reraise=True) + async def _generate_output(self, graph, input_text): + return await graph(input_text) + + async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[str, str, float, float, float]: + input_text = problem["question"] + expected_output = self.extract_number(problem["answer"]) + + try: + output, cost = await self._generate_output(graph, input_text) + predicted_number = self.extract_number(output) + score, extracted_output = self.calculate_score(expected_output, predicted_number) + + if score == 0: + self.log_mismatch(input_text, expected_output, output, extracted_output) + + return input_text, output, expected_output, score, cost + + except Exception as e: + logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}") + return input_text, str(e), expected_output, 0.0, 0.0 + + def get_result_columns(self) -> List[str]: + return ["question", "prediction", "expected_output", "score", "cost"] diff --git a/metagpt/ext/aflow/benchmark/hotpotqa.py b/metagpt/ext/aflow/benchmark/hotpotqa.py new file mode 100644 index 0000000000000000000000000000000000000000..b3bafe22b89278836c07a68a584b088d9068f775 --- /dev/null +++ b/metagpt/ext/aflow/benchmark/hotpotqa.py @@ -0,0 +1,71 @@ +import re +import string +from collections import Counter +from typing import Callable, List, Tuple + +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed + +from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark +from metagpt.logs import logger + + +class HotpotQABenchmark(BaseBenchmark): + def __init__(self, name: str, file_path: str, log_path: str): + super().__init__(name, file_path, log_path) + + def normalize_answer(self, s: str) -> str: + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + def calculate_score(self, ground_truth: str, prediction: str) -> Tuple[float, str]: + prediction_tokens = self.normalize_answer(prediction).split() + ground_truth_tokens = self.normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0, prediction + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1, prediction + + @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), retry=retry_if_exception_type(Exception), reraise=True) + async def _generate_output(self, graph, input_text): + return await graph(input_text) + + async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[str, str, str, str, float, float]: + input_text = problem["question"] + expected_output = problem["answer"] + paragraphs = [item[1] for item in problem["context"] if isinstance(item[1], list)] + context_str = "\n".join(" ".join(paragraph) for paragraph in paragraphs) + inputs = f"Context: {context_str}\n\nQuestion: {input_text}\n\nAnswer:" + + try: + output, cost = await self._generate_output(graph, inputs) + score, extracted_output = self.calculate_score(expected_output, output) + + if ( + score < 0.3 + ): # We set the threshold for collecting incorrect questions to 0.3, as F1 Score cannot be simply judged using 0-1 + self.log_mismatch(input_text, expected_output, output, extracted_output) + + return input_text, context_str, output, expected_output, score, cost + + except Exception as e: + logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}") + return input_text, context_str, str(e), expected_output, 0.0, 0.0 + + def get_result_columns(self) -> List[str]: + return ["question", "context", "prediction", "expected_output", "score", "cost"] diff --git a/metagpt/ext/aflow/benchmark/humaneval.py b/metagpt/ext/aflow/benchmark/humaneval.py new file mode 100644 index 0000000000000000000000000000000000000000..b54add260f1784371a4ffc8f2e06e6589124fb40 --- /dev/null +++ b/metagpt/ext/aflow/benchmark/humaneval.py @@ -0,0 +1,151 @@ +import asyncio +import threading +import time +from typing import Any, Callable, Dict, List, Optional, Tuple + +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed + +from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark +from metagpt.logs import logger +from metagpt.utils.sanitize import sanitize + + +class HumanEvalBenchmark(BaseBenchmark): + def __init__(self, name: str, file_path: str, log_path: str): + super().__init__(name, file_path, log_path) + + class TimeoutError(Exception): + pass + + def run_with_timeout(self, func, args, timeout): + result = [] + stop_event = threading.Event() + + def target(): + try: + result.append(func(*args)) + except Exception as e: + result.append(e) + finally: + stop_event.set() + + thread = threading.Thread(target=target) + thread.start() + is_timeout = not stop_event.wait(timeout) + + if is_timeout: + raise self.TimeoutError("Function execution timed out") + + if not result: + return None + if isinstance(result[0], Exception): + raise result[0] + return result[0] + + def check_solution(self, solution, test, entry_point): + solution = sanitize(code=solution, entrypoint=entry_point) + try: + global_dict = { + "math": __import__("math"), + "hashlib": __import__("hashlib"), + "re": __import__("re"), + "List": List, + "Dict": Dict, + "Tuple": Tuple, + "Optional": Optional, + "Any": Any, + } + + # Add handling for special cases + if entry_point == "decode_cyclic": + solution = ( + '\n\ndef encode_cyclic(s: str):\n """\n returns encoded string by cycling groups of three characters.\n """\n # split string to groups. Each of length 3.\n groups = [s[(3 * i):min((3 * i + 3), len(s))] for i in range((len(s) + 2) // 3)]\n # cycle elements in each group. Unless group has fewer elements than 3.\n groups = [(group[1:] + group[0]) if len(group) == 3 else group for group in groups]\n return "".join(groups)' + + "\n\n" + + solution + ) + elif entry_point == "decode_shift": + solution = ( + '\n\ndef encode_shift(s: str):\n """\n returns encoded string by shifting every character by 5 in the alphabet.\n """\n return "".join([chr(((ord(ch) + 5 - ord("a")) % 26) + ord("a")) for ch in s])\n\n\n' + + solution + ) + elif entry_point == "find_zero": + solution = ( + "\n\ndef poly(xs: list, x: float):\n return sum(coeff * (x ** i) for i, coeff in enumerate(xs))\n\n" + + solution + ) + + exec(solution, global_dict) + + if entry_point not in global_dict: + raise ValueError(f"Function {entry_point} is not defined in the solution.") + + exec(test, global_dict) + + check = global_dict["check"] + + result = self.run_with_timeout(check, (global_dict[entry_point],), 15) + + if result is None: + result = (self.PASS, "The solution passed all test cases.") + + except self.TimeoutError: + result = ( + self.FAIL, + "Execution timed out. Please check if your solution contains infinite loops or overly time-consuming operations.", + ) + except Exception as e: + error_message = f"Error: {str(e)}.\n Solution: {solution}.\n Test: {test}" + result = (self.FAIL, error_message) + + with open("error.log", "a", encoding="utf-8") as log_file: + log_file.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {error_message}\n") + + return result + + @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), retry=retry_if_exception_type(Exception), reraise=True) + async def _generate_output(self, graph, prompt, entry_point): + # Generate output with a timeout of 60 seconds + return await asyncio.wait_for(graph(prompt, entry_point), timeout=60) + + async def evaluate_problem(self, data: dict, graph: Callable) -> Tuple[str, str, str, float, float]: + input_text = data["prompt"] + expected_output = ( + "\nCorrect Solution:\ndef " + + data["entry_point"] + + "(params you should put here):" + + "\n\n" + + data["canonical_solution"] + ) + + try: + # Generate prediction using the graph function + prediction, cost = await self._generate_output(graph, input_text, data["entry_point"]) + + # Check the solution + ret = self.check_solution(prediction, data["test"], data["entry_point"]) + test_case_details = ret[1] + expected_output = test_case_details + expected_output + + # Calculate score based on the check result + score = 1.0 if ret[0] == self.PASS else 0.0 + + # Log mismatch if the score is 0 + if score == 0: + self.log_mismatch(input_text, expected_output, prediction, score) + + return input_text, prediction, expected_output, score, cost + + except asyncio.TimeoutError: + logger.info("Timeout error. Skipping this sample.") + return input_text, "Timeout", expected_output, 0.0, 0.0 + + except Exception as e: + logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}") + return input_text, str(e), expected_output, 0.0, 0.0 + + def calculate_score(self, expected_output: str, prediction: str) -> Tuple[float, str]: + # The scoring logic for HumanEval is already implemented in evaluate_problem, this is just to conform to the interface + return 0.0, prediction + + def get_result_columns(self) -> List[str]: + return ["inputs", "prediction", "expected_output", "score", "cost"] diff --git a/metagpt/ext/aflow/benchmark/math.py b/metagpt/ext/aflow/benchmark/math.py new file mode 100644 index 0000000000000000000000000000000000000000..07b0612d06050e495d9fa385410fd7902994cbfb --- /dev/null +++ b/metagpt/ext/aflow/benchmark/math.py @@ -0,0 +1,137 @@ +import inspect +import re +from math import isclose +from typing import Any, Callable, List, Tuple + +import regex +from sympy import N, simplify +from sympy.parsing.latex import parse_latex +from sympy.parsing.sympy_parser import parse_expr +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed + +from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark +from metagpt.logs import logger + + +class MATHBenchmark(BaseBenchmark): + def __init__(self, name: str, file_path: str, log_path: str): + super().__init__(name, file_path, log_path) + + def extract_model_answer(self, text: str) -> str: + pattern = r"\\boxed{((?:[^{}]|{[^{}]*})*)}" + boxed_matches = re.findall(pattern, text, re.DOTALL) + if boxed_matches: + return boxed_matches[-1].strip() + + sentence_end_pattern = r"(? Tuple[int, str]: + expected_answer = self.extract_model_answer(expected_output) + predicted_answer = self.extract_model_answer(prediction) + + if self.math_equal(predicted_answer, expected_answer): + return 1, predicted_answer + else: + return 0, predicted_answer + + def math_equal(self, prediction: Any, reference: Any) -> bool: + if str(prediction) == str(reference): + return True + + try: + if self.is_digit(prediction) and self.is_digit(reference): + prediction = self.parse_digits(prediction) + reference = self.parse_digits(reference) + return isclose(prediction, reference, abs_tol=1e-3) + except: + pass + + try: + return self.symbolic_equal(prediction, reference) + except: + pass + + return False + + def is_digit(self, num): + return self.parse_digits(num) is not None + + def parse_digits(self, num): + num = regex.sub(",", "", str(num)) + try: + return float(num) + except: + if num.endswith("%"): + num = num[:-1] + if num.endswith("\\"): + num = num[:-1] + try: + return float(num) / 100 + except: + pass + return None + + def symbolic_equal(self, a, b): + def _parse(s): + for f in [parse_latex, parse_expr]: + try: + return f(s) + except: + pass + return s + + a = _parse(a) + b = _parse(b) + + try: + if simplify(a - b) == 0: + return True + except: + pass + + try: + if isclose(N(a), N(b), abs_tol=1e-3): + return True + except: + pass + return False + + def get_function_code(self, func): + try: + source_code = inspect.getsource(func) + return source_code + except OSError: + return "no code" + + @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), retry=retry_if_exception_type(Exception), reraise=True) + async def _generate_output(self, graph, input_text): + return await graph(input_text) + + async def evaluate_problem(self, problem: dict, graph: Callable) -> Tuple[str, str, str, int, float]: + input_text = problem["problem"] + expected_output = problem["solution"] + + try: + output, cost = await self._generate_output(graph, input_text) + uni_score, extracted_output = self.calculate_score(expected_output, output) + + if uni_score == 0: + self.log_mismatch( + input_text, + expected_output, + output, + extracted_output, + extract_answer_code=self.get_function_code(self.extract_model_answer), + ) + + return input_text, output, expected_output, uni_score, cost + + except Exception as e: + logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}") + return input_text, str(e), expected_output, 0.0, 0.0 + + def get_result_columns(self) -> List[str]: + return ["question", "prediction", "expected_output", "score", "cost"] diff --git a/metagpt/ext/aflow/benchmark/mbpp.py b/metagpt/ext/aflow/benchmark/mbpp.py new file mode 100644 index 0000000000000000000000000000000000000000..c3628b02401d5e4d6043a58abc40dd4f0908307d --- /dev/null +++ b/metagpt/ext/aflow/benchmark/mbpp.py @@ -0,0 +1,121 @@ +import threading +import time +from typing import Any, Callable, Dict, List, Optional, Tuple + +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed + +from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark +from metagpt.logs import logger +from metagpt.utils.sanitize import sanitize + + +class MBPPBenchmark(BaseBenchmark): + def __init__(self, name: str, file_path: str, log_path: str): + super().__init__(name, file_path, log_path) + + class TimeoutError(Exception): + pass + + def run_with_timeout(self, func, timeout): + result = [] + stop_event = threading.Event() + + def target(): + try: + result.append(func()) + except Exception as e: + result.append(e) + finally: + stop_event.set() + + thread = threading.Thread(target=target) + thread.start() + is_timeout = not stop_event.wait(timeout) + + if is_timeout: + raise self.TimeoutError("Function execution timed out") + + if not result: + return None + if isinstance(result[0], Exception): + raise result[0] + return result[0] + + def check_solution(self, solution, test, entry_point): + solution = sanitize(code=solution, entrypoint=entry_point) + try: + global_dict = { + "math": __import__("math"), + "hashlib": __import__("hashlib"), + "re": __import__("re"), + "List": List, + "Dict": Dict, + "Tuple": Tuple, + "Optional": Optional, + "Any": Any, + } + + exec(solution, global_dict) + + if entry_point not in global_dict: + raise ValueError(f"Function {entry_point} is not defined in the solution.") + + exec(test, global_dict) + + check = global_dict["check"] + + result = self.run_with_timeout(check, 15) + + if result is None: + result = (self.PASS, "The solution passed all test cases.") + + except self.TimeoutError: + result = ( + self.FAIL, + "Execution timed out. Please check if your solution contains infinite loops or overly time-consuming operations.", + ) + except Exception as e: + error_message = f"Error: {str(e)}.\n Solution: {solution}.\n Test: {test}" + result = (self.FAIL, error_message) + + with open("error.log", "a", encoding="utf-8") as log_file: + log_file.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {error_message}\n") + + return result + + @retry(stop=stop_after_attempt(5), wait=wait_fixed(1), retry=retry_if_exception_type(Exception), reraise=True) + async def _generate_output(self, graph, prompt, entry_point): + return await graph(prompt, entry_point) + + async def evaluate_problem(self, data: dict, graph: Callable) -> Tuple[str, str, str, float, float]: + input_text = data["prompt"] + expected_output = "\nCorrect Solution:\ndef " + data["code"] + + try: + # Generate prediction using the graph function + prediction, cost = await self._generate_output(graph, input_text, data["entry_point"]) + + # Check the solution + ret = self.check_solution(prediction, data["test"], data["entry_point"]) + test_case_details = ret[1] + expected_output = test_case_details + "\nCorrect Solution:" + data["code"] + + # Calculate score based on the check result + score = 1.0 if ret[0] == self.PASS else 0.0 + + # Log mismatch if the score is 0 + if score == 0: + self.log_mismatch(input_text, expected_output, prediction, score) + + return input_text, prediction, expected_output, score, cost + + except Exception as e: + logger.info(f"Maximum retries reached. Skipping this sample. Error: {e}") + return input_text, str(e), expected_output, 0.0, 0.0 + + def calculate_score(self, expected_output: str, prediction: str) -> Tuple[float, str]: + # The scoring logic for MBPP is already implemented in evaluate_problem, this is just to conform to the interface + return 0.0, prediction + + def get_result_columns(self) -> List[str]: + return ["inputs", "prediction", "expected_output", "score", "cost"] diff --git a/metagpt/ext/aflow/benchmark/utils.py b/metagpt/ext/aflow/benchmark/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..846101bc0cafab8a18fc020df7403519eee9f410 --- /dev/null +++ b/metagpt/ext/aflow/benchmark/utils.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/7/24 16:37 +@Author : didi +@File : utils.py +""" + +import json +import os + +import numpy as np + +from metagpt.utils.common import read_json_file, write_json_file + + +def generate_random_indices(n, n_samples, test=False): + """ + Generate random indices + """ + + def _set_seed(seed=42): + np.random.seed(seed) + + _set_seed() + indices = np.arange(n) + np.random.shuffle(indices) + if test: + return indices[n_samples:] + else: + return indices[:n_samples] + + +def split_data_set(file_path, samples, test=False): + data = [] + + with open(file_path, "r") as file: + for line in file: + data.append(json.loads(line)) + random_indices = generate_random_indices(len(data), samples, test) + data = [data[i] for i in random_indices] + return data + + +def log_mismatch(problem, expected_output, prediction, predicted_number, path): + log_data = { + "question": problem, + "right_answer": expected_output, + "model_output": prediction, + "extracted_output": predicted_number, + } + + log_file = os.path.join(path, "log.json") + + # Check if the log file already exists + if os.path.exists(log_file): + # If it exists, load the existing log data + data = read_json_file(log_file) + else: + # If it does not exist, create a new log list + data = [] + + # Add the new log entry + data.append(log_data) + + # Write the data back to log.json file + write_json_file(log_file, data, encoding="utf-8", indent=4) diff --git a/metagpt/ext/aflow/data/download_data.py b/metagpt/ext/aflow/data/download_data.py new file mode 100644 index 0000000000000000000000000000000000000000..a3aa2774ca89ae4eec424a6a0436b7659e5b68a7 --- /dev/null +++ b/metagpt/ext/aflow/data/download_data.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# @Date : 2024-10-20 +# @Author : MoshiQAQ & didi +# @Desc : Download and extract dataset files + +import os +import tarfile +from typing import Dict + +import requests +from tqdm import tqdm + +from metagpt.logs import logger + + +def download_file(url: str, filename: str) -> None: + """Download a file from the given URL and show progress.""" + response = requests.get(url, stream=True) + total_size = int(response.headers.get("content-length", 0)) + block_size = 1024 + progress_bar = tqdm(total=total_size, unit="iB", unit_scale=True) + + with open(filename, "wb") as file: + for data in response.iter_content(block_size): + size = file.write(data) + progress_bar.update(size) + progress_bar.close() + + +def extract_tar_gz(filename: str, extract_path: str) -> None: + """Extract a tar.gz file to the specified path.""" + with tarfile.open(filename, "r:gz") as tar: + tar.extractall(path=extract_path) + + +def process_dataset(url: str, filename: str, extract_path: str) -> None: + """Download, extract, and clean up a dataset.""" + logger.info(f"Downloading {filename}...") + download_file(url, filename) + + logger.info(f"Extracting {filename}...") + extract_tar_gz(filename, extract_path) + + logger.info(f"{filename} download and extraction completed.") + + os.remove(filename) + logger.info(f"Removed {filename}") + + +# Define the datasets to be downloaded +# Users can modify this list to choose which datasets to download +datasets_to_download: Dict[str, Dict[str, str]] = { + "datasets": { + "url": "https://drive.google.com/uc?export=download&id=1DNoegtZiUhWtvkd2xoIuElmIi4ah7k8e", + "filename": "aflow_data.tar.gz", + "extract_path": "metagpt/ext/aflow/data", + }, + "results": { + "url": "https://drive.google.com/uc?export=download&id=1Sr5wjgKf3bN8OC7G6cO3ynzJqD4w6_Dv", + "filename": "result.tar.gz", + "extract_path": "metagpt/ext/aflow/data/results", + }, + "initial_rounds": { + "url": "https://drive.google.com/uc?export=download&id=1UBoW4WBWjX2gs4I_jq3ALdXeLdwDJMdP", + "filename": "initial_rounds.tar.gz", + "extract_path": "metagpt/ext/aflow/scripts/optimized", + }, +} + + +def download(required_datasets, if_first_download: bool = True): + """Main function to process all selected datasets""" + if if_first_download: + for dataset_name in required_datasets: + dataset = datasets_to_download[dataset_name] + extract_path = dataset["extract_path"] + process_dataset(dataset["url"], dataset["filename"], extract_path) + else: + logger.info("Skip downloading datasets") diff --git a/metagpt/ext/aflow/scripts/evaluator.py b/metagpt/ext/aflow/scripts/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..34bdcd9fc1c5866c8e1ea5702ca949097fb037ff --- /dev/null +++ b/metagpt/ext/aflow/scripts/evaluator.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# @Date : 8/23/2024 10:00 AM +# @Author : all +# @Desc : Evaluation for different datasets + +from typing import Dict, Literal, Tuple + +from metagpt.ext.aflow.benchmark.benchmark import BaseBenchmark +from metagpt.ext.aflow.benchmark.drop import DROPBenchmark +from metagpt.ext.aflow.benchmark.gsm8k import GSM8KBenchmark +from metagpt.ext.aflow.benchmark.hotpotqa import HotpotQABenchmark +from metagpt.ext.aflow.benchmark.humaneval import HumanEvalBenchmark +from metagpt.ext.aflow.benchmark.math import MATHBenchmark +from metagpt.ext.aflow.benchmark.mbpp import MBPPBenchmark + +# If you want to customize tasks, add task types here and provide evaluation functions, just like the ones given above +DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"] + + +class Evaluator: + """ + Complete the evaluation for different datasets here + """ + + def __init__(self, eval_path: str): + self.eval_path = eval_path + self.dataset_configs: Dict[DatasetType, BaseBenchmark] = { + "GSM8K": GSM8KBenchmark, + "MATH": MATHBenchmark, + "HumanEval": HumanEvalBenchmark, + "HotpotQA": HotpotQABenchmark, + "MBPP": MBPPBenchmark, + "DROP": DROPBenchmark, + } + + async def graph_evaluate( + self, dataset: DatasetType, graph, params: dict, path: str, is_test: bool = False + ) -> Tuple[float, float, float]: + if dataset not in self.dataset_configs: + raise ValueError(f"Unsupported dataset: {dataset}") + + data_path = self._get_data_path(dataset, is_test) + benchmark_class = self.dataset_configs[dataset] + benchmark = benchmark_class(name=dataset, file_path=data_path, log_path=path) + + # Use params to configure the graph and benchmark + configured_graph = await self._configure_graph(dataset, graph, params) + if is_test: + va_list = None # For test data, generally use None to test all + else: + va_list = None # Use None to test all Validation data, or set va_list (e.g., [1, 2, 3]) to use partial data + return await benchmark.run_evaluation(configured_graph, va_list) + + async def _configure_graph(self, dataset, graph, params: dict): + # Here you can configure the graph based on params + # For example: set LLM configuration, dataset configuration, etc. + dataset_config = params.get("dataset", {}) + llm_config = params.get("llm_config", {}) + return graph(name=dataset, llm_config=llm_config, dataset=dataset_config) + + def _get_data_path(self, dataset: DatasetType, test: bool) -> str: + base_path = f"metagpt/ext/aflow/data/{dataset.lower()}" + return f"{base_path}_test.jsonl" if test else f"{base_path}_validate.jsonl" diff --git a/metagpt/ext/aflow/scripts/interface.py b/metagpt/ext/aflow/scripts/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..46cdbdabfad6aa252508352f043f6f941ffa56c1 --- /dev/null +++ b/metagpt/ext/aflow/scripts/interface.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# @Date : 2024-03-21 +# @Author : Your Name +# @Desc : Interface for AFLOW + +import asyncio +import importlib.util +import sys +from pathlib import Path +from typing import Optional, Tuple + +from metagpt.configs.models_config import ModelsConfig +from metagpt.ext.aflow.scripts.evaluator import DatasetType +from metagpt.ext.aflow.scripts.optimizer_utils.data_utils import DataUtils +from metagpt.logs import logger + + +def load_best_round(dataset: str, optimized_path: str = "metagpt/ext/aflow/scripts/optimized") -> int: + """加载最佳表现的轮次""" + data_utils = DataUtils(f"{optimized_path}/{dataset}") + + # 使用get_top_rounds获取得分最高的轮次 + top_rounds = data_utils.get_top_rounds(sample=2, mode="Graph") + if not top_rounds[1]: + return 1 + + return top_rounds[1]["round"] + + +def load_workflow_class(graph_path: str): + """动态加载工作流类""" + spec = importlib.util.spec_from_file_location("workflow_module", graph_path) + module = importlib.util.module_from_spec(spec) + sys.modules["workflow_module"] = module + spec.loader.exec_module(module) + return module.Workflow + + +async def aflow_inference( + dataset: DatasetType, + question: str, + entry_point: Optional[str] = None, + round: Optional[int] = None, + llm_name: str = "gpt-4o-mini", + optimized_path: str = "metagpt/ext/aflow/scripts/optimized", +) -> Tuple[str, float]: + """AFLOW推理接口 + + Args: + dataset: 数据集名称 + question: 输入问题 + round: 指定使用的轮次,如果为None则使用最佳轮次 + llm_name: 使用的LLM模型名称 + optimized_path: 优化结果保存路径 + + Returns: + (答案, 成本)的元组 + """ + # 如果没有指定轮次,使用最佳轮次 + if round is None: + round = load_best_round(dataset, optimized_path) + + logger.info(f"Using round {round} for inference") + + # 构建工作流路径并加载 + graph_path = Path(optimized_path) / dataset / "workflows" / f"round_{round}" / "graph.py" + if not graph_path.exists(): + raise FileNotFoundError(f"Workflow file not found: {graph_path}") + + # 动态加载工作流类 + WorkflowClass = load_workflow_class(str(graph_path)) + + # 创建工作流实例 + llm_config = ModelsConfig.default().get(llm_name) + workflow = WorkflowClass( + name=f"{dataset}_workflow", + llm_config=llm_config, + dataset=dataset, + ) + + # 执行推理 + if dataset in ["MBPP", "HumanEval"]: + # 代码类任务需要额外的entry_point参数 + answer, cost = await workflow(question, entry_point=entry_point) + else: + answer, cost = await workflow(question) + + return answer, cost + + +if __name__ == "__main__": + asyncio.run( + aflow_inference( + dataset="MBPP", + question="write a function named add_two_numbers to calculate the sum of two numbers", + entry_point="add_two_numbers", + ) + ) diff --git a/metagpt/ext/aflow/scripts/operator.py b/metagpt/ext/aflow/scripts/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..903a962e02e0552067fcae2f74129b099530d152 --- /dev/null +++ b/metagpt/ext/aflow/scripts/operator.py @@ -0,0 +1,360 @@ +# -*- coding: utf-8 -*- +# @Date : 6/27/2024 17:36 PM +# @Author : didi +# @Desc : operator demo of aflow +import asyncio +import concurrent.futures +import random +import sys +import traceback +from collections import Counter +from typing import Dict, List, Tuple + +from tenacity import retry, stop_after_attempt, wait_fixed + +from metagpt.actions.action_node import ActionNode +from metagpt.ext.aflow.scripts.operator_an import ( + AnswerGenerateOp, + CodeGenerateOp, + FormatOp, + GenerateOp, + MdEnsembleOp, + ReflectionTestOp, + ReviewOp, + ReviseOp, + ScEnsembleOp, +) +from metagpt.ext.aflow.scripts.prompts.prompt import ( + ANSWER_GENERATION_PROMPT, + FORMAT_PROMPT, + MD_ENSEMBLE_PROMPT, + PYTHON_CODE_VERIFIER_PROMPT, + REFLECTION_ON_PUBLIC_TEST_PROMPT, + REVIEW_PROMPT, + REVISE_PROMPT, + SC_ENSEMBLE_PROMPT, +) +from metagpt.ext.aflow.scripts.utils import ( + extract_test_cases_from_jsonl, + test_case_2_test_function, +) +from metagpt.llm import LLM +from metagpt.logs import logger + + +class Operator: + def __init__(self, llm: LLM, name: str): + self.name = name + self.llm = llm + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + async def _fill_node(self, op_class, prompt, mode=None, **extra_kwargs): + fill_kwargs = {"context": prompt, "llm": self.llm} + if mode: + fill_kwargs["mode"] = mode + fill_kwargs.update(extra_kwargs) + node = await ActionNode.from_pydantic(op_class).fill(**fill_kwargs) + return node.instruct_content.model_dump() + + +class Custom(Operator): + def __init__(self, llm: LLM, name: str = "Custom"): + super().__init__(llm, name) + + async def __call__(self, input, instruction): + prompt = instruction + input + response = await self._fill_node(GenerateOp, prompt, mode="single_fill") + return response + + +class AnswerGenerate(Operator): + def __init__(self, llm: LLM, name: str = "AnswerGenerate"): + super().__init__(llm, name) + + async def __call__(self, input: str, mode: str = None) -> Tuple[str, str]: + prompt = ANSWER_GENERATION_PROMPT.format(input=input) + response = await self._fill_node(AnswerGenerateOp, prompt, mode="xml_fill") + return response + + +class CustomCodeGenerate(Operator): + def __init__(self, llm: LLM, name: str = "CustomCodeGenerate"): + super().__init__(llm, name) + + async def __call__(self, problem, entry_point, instruction): + prompt = instruction + problem + response = await self._fill_node(GenerateOp, prompt, mode="code_fill", function_name=entry_point) + return response + + +class ScEnsemble(Operator): + """ + Paper: Self-Consistency Improves Chain of Thought Reasoning in Language Models + Link: https://arxiv.org/abs/2203.11171 + Paper: Universal Self-Consistency for Large Language Model Generation + Link: https://arxiv.org/abs/2311.17311 + """ + + def __init__(self, llm: LLM, name: str = "ScEnsemble"): + super().__init__(llm, name) + + async def __call__(self, solutions: List[str], problem: str): + answer_mapping = {} + solution_text = "" + for index, solution in enumerate(solutions): + answer_mapping[chr(65 + index)] = index + solution_text += f"{chr(65 + index)}: \n{str(solution)}\n\n\n" + + prompt = SC_ENSEMBLE_PROMPT.format(question=problem, solutions=solution_text) + response = await self._fill_node(ScEnsembleOp, prompt, mode="xml_fill") + + answer = response.get("solution_letter", "") + answer = answer.strip().upper() + + return {"response": solutions[answer_mapping[answer]]} + + +def run_code(code): + try: + # Create a new global namespace + global_namespace = {} + + disallowed_imports = [ + "os", + "sys", + "subprocess", + "multiprocessing", + "matplotlib", + "seaborn", + "plotly", + "bokeh", + "ggplot", + "pylab", + "tkinter", + "PyQt5", + "wx", + "pyglet", + ] + + # Check for prohibited imports + for lib in disallowed_imports: + if f"import {lib}" in code or f"from {lib}" in code: + logger.info("Detected prohibited import: %s", lib) + return "Error", f"Prohibited import: {lib} and graphing functionalities" + + # Use exec to execute the code + exec(code, global_namespace) + # Assume the code defines a function named 'solve' + if "solve" in global_namespace and callable(global_namespace["solve"]): + result = global_namespace["solve"]() + return "Success", str(result) + else: + return "Error", "Function 'solve' not found" + except Exception as e: + exc_type, exc_value, exc_traceback = sys.exc_info() + tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback) + return "Error", f"Execution error: {str(e)}\n{''.join(tb_str)}" + + +class Programmer(Operator): + def __init__(self, llm: LLM, name: str = "Programmer"): + super().__init__(llm, name) + + async def exec_code(self, code, timeout=30): + """ + Asynchronously execute code and return an error if timeout occurs. + """ + loop = asyncio.get_running_loop() + with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: + try: + # Submit run_code task to the process pool + future = loop.run_in_executor(executor, run_code, code) + # Wait for the task to complete or timeout + result = await asyncio.wait_for(future, timeout=timeout) + return result + except asyncio.TimeoutError: + # Timeout, attempt to shut down the process pool + executor.shutdown(wait=False, cancel_futures=True) + return "Error", "Code execution timed out" + except Exception as e: + return "Error", f"Unknown error: {str(e)}" + + async def code_generate(self, problem, analysis, feedback, mode): + """ + Asynchronous method to generate code. + """ + prompt = PYTHON_CODE_VERIFIER_PROMPT.format(problem=problem, analysis=analysis, feedback=feedback) + response = await self._fill_node(CodeGenerateOp, prompt, mode, function_name="solve") + return response + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(2)) + async def __call__(self, problem: str, analysis: str = "None"): + """ + Call method, generate code and execute, retry up to 3 times. + """ + code = None + output = None + feedback = "" + for i in range(3): + code_response = await self.code_generate(problem, analysis, feedback, mode="code_fill") + code = code_response.get("code") + if not code: + return {"code": code, "output": "No code generated"} + status, output = await self.exec_code(code) + if status == "Success": + return {"code": code, "output": output} + else: + logger.info(f"Execution error on attempt {i + 1}, error message: {output}") + feedback = ( + f"\nThe result of the error from the code you wrote in the previous round:\n" + f"Code: {code}\n\nStatus: {status}, {output}" + ) + return {"code": code, "output": output} + + +class Test(Operator): + def __init__(self, llm: LLM, name: str = "Test"): + super().__init__(llm, name) + + def exec_code(self, solution, entry_point): + test_cases = extract_test_cases_from_jsonl(entry_point) + + fail_cases = [] + for test_case in test_cases: + test_code = test_case_2_test_function(solution, test_case, entry_point) + try: + exec(test_code, globals()) + except AssertionError as e: + exc_type, exc_value, exc_traceback = sys.exc_info() + tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback) + with open("tester.txt", "a") as f: + f.write("test_error of " + entry_point + "\n") + error_infomation = { + "test_fail_case": { + "test_case": test_case, + "error_type": "AssertionError", + "error_message": str(e), + "traceback": tb_str, + } + } + fail_cases.append(error_infomation) + except Exception as e: + with open("tester.txt", "a") as f: + f.write(entry_point + " " + str(e) + "\n") + return {"exec_fail_case": str(e)} + if fail_cases != []: + return fail_cases + else: + return "no error" + + async def __call__(self, problem, solution, entry_point, test_loop: int = 3): + """ + "Test": { + "description": "Test the solution with test cases, if the solution is correct, return 'no error', if the solution is incorrect, return reflect on the soluion and the error information", + "interface": "test(problem: str, solution: str, entry_point: str) -> str" + } + """ + for _ in range(test_loop): + result = self.exec_code(solution, entry_point) + if result == "no error": + return {"result": True, "solution": solution} + elif "exec_fail_case" in result: + result = result["exec_fail_case"] + prompt = REFLECTION_ON_PUBLIC_TEST_PROMPT.format( + problem=problem, + solution=solution, + exec_pass=f"executed unsuccessfully, error: \n {result}", + test_fail="executed unsucessfully", + ) + response = await self._fill_node(ReflectionTestOp, prompt, mode="code_fill") + solution = response["reflection_and_solution"] + else: + prompt = REFLECTION_ON_PUBLIC_TEST_PROMPT.format( + problem=problem, + solution=solution, + exec_pass="executed successfully", + test_fail=result, + ) + response = await self._fill_node(ReflectionTestOp, prompt, mode="code_fill") + solution = response["reflection_and_solution"] + + result = self.exec_code(solution, entry_point) + if result == "no error": + return {"result": True, "solution": solution} + else: + return {"result": False, "solution": solution} + + +class Format(Operator): + def __init__(self, llm: LLM, name: str = "Format"): + super().__init__(llm, name) + + async def __call__(self, problem, solution, mode: str = None): + prompt = FORMAT_PROMPT.format(problem_description=problem, solution=solution) + response = await self._fill_node(FormatOp, prompt, mode) + return response + + +class Review(Operator): + def __init__(self, llm: LLM, name: str = "Review"): + super().__init__(llm, name) + + async def __call__(self, problem, solution, mode: str = None): + prompt = REVIEW_PROMPT.format(problem=problem, solution=solution) + response = await self._fill_node(ReviewOp, prompt, mode="xml_fill") + return response + + +class Revise(Operator): + def __init__(self, llm: LLM, name: str = "Revise"): + super().__init__(llm, name) + + async def __call__(self, problem, solution, feedback, mode: str = None): + prompt = REVISE_PROMPT.format(problem=problem, solution=solution, feedback=feedback) + response = await self._fill_node(ReviseOp, prompt, mode="xml_fill") + return response + + +class MdEnsemble(Operator): + """ + Paper: Can Generalist Foundation Models Outcompete Special-Purpose Tuning? Case Study in Medicine + Link: https://arxiv.org/abs/2311.16452 + """ + + def __init__(self, llm: LLM, name: str = "MdEnsemble", vote_count: int = 5): + super().__init__(llm, name) + self.vote_count = vote_count + + @staticmethod + def shuffle_answers(solutions: List[str]) -> Tuple[List[str], Dict[str, str]]: + shuffled_solutions = solutions.copy() + random.shuffle(shuffled_solutions) + answer_mapping = {chr(65 + i): solutions.index(solution) for i, solution in enumerate(shuffled_solutions)} + return shuffled_solutions, answer_mapping + + async def __call__(self, solutions: List[str], problem: str, mode: str = None): + logger.info(f"solution count: {len(solutions)}") + all_responses = [] + + for _ in range(self.vote_count): + shuffled_solutions, answer_mapping = self.shuffle_answers(solutions) + + solution_text = "" + for index, solution in enumerate(shuffled_solutions): + solution_text += f"{chr(65 + index)}: \n{str(solution)}\n\n\n" + + prompt = MD_ENSEMBLE_PROMPT.format(solutions=solution_text, question=problem) + response = await self._fill_node(MdEnsembleOp, prompt, mode="xml_fill") + + answer = response.get("solution_letter", "A") + answer = answer.strip().upper() + + if answer in answer_mapping: + original_index = answer_mapping[answer] + all_responses.append(original_index) + + most_frequent_index = Counter(all_responses).most_common(1)[0][0] + final_answer = solutions[most_frequent_index] + return {"solution": final_answer} diff --git a/metagpt/ext/aflow/scripts/operator_an.py b/metagpt/ext/aflow/scripts/operator_an.py new file mode 100644 index 0000000000000000000000000000000000000000..d0201dea2e7abc68ced453b8f9eafae43ee8b920 --- /dev/null +++ b/metagpt/ext/aflow/scripts/operator_an.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# @Date : 6/27/2024 19:46 PM +# @Author : didi +# @Desc : action nodes for operator + +from pydantic import BaseModel, Field + + +class GenerateOp(BaseModel): + response: str = Field(default="", description="Your solution for this problem") + + +class CodeGenerateOp(BaseModel): + code: str = Field(default="", description="Your complete code solution for this problem") + + +class AnswerGenerateOp(BaseModel): + thought: str = Field(default="", description="The step by step thinking process") + answer: str = Field(default="", description="The final answer to the question") + + +class FormatOp(BaseModel): + solution: str = Field(default="", description="Your formatted answer for this problem") + + +class ScEnsembleOp(BaseModel): + thought: str = Field(default="", description="The thought of the most consistent solution.") + solution_letter: str = Field(default="", description="The letter of most consistent solution.") + + +class ReflectionTestOp(BaseModel): + reflection_and_solution: str = Field( + default="", description="Corrective solution for code execution errors or test case failures" + ) + + +class MdEnsembleOp(BaseModel): + thought: str = Field(default="", description="Step-by-step analysis of the solutions to determine the best one.") + solution_letter: str = Field(default="", description="The letter of the chosen best solution (only one letter).") + + +class ReviewOp(BaseModel): + review_result: bool = Field( + default=False, + description="The Review Result (Bool). If you think this solution looks good for you, return 'true'; If not, return 'false'", + ) + feedback: str = Field( + default="", + description="Your FeedBack for this problem based on the criteria. If the review result is true, you can put it 'nothing here'.", + ) + + +class ReviseOp(BaseModel): + solution: str = Field(default="", description="Based on the feedback, revised solution for this problem") diff --git a/metagpt/ext/aflow/scripts/optimized/__init__.py b/metagpt/ext/aflow/scripts/optimized/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/metagpt/ext/aflow/scripts/optimizer.py b/metagpt/ext/aflow/scripts/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0ac4827e7195ef11d5203a33a42df165a81d50c6 --- /dev/null +++ b/metagpt/ext/aflow/scripts/optimizer.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- +# @Date : 8/12/2024 22:00 PM +# @Author : issac +# @Desc : optimizer for graph + +import asyncio +import time +from typing import List, Literal + +from pydantic import BaseModel, Field + +from metagpt.actions.action_node import ActionNode +from metagpt.ext.aflow.scripts.evaluator import DatasetType +from metagpt.ext.aflow.scripts.optimizer_utils.convergence_utils import ConvergenceUtils +from metagpt.ext.aflow.scripts.optimizer_utils.data_utils import DataUtils +from metagpt.ext.aflow.scripts.optimizer_utils.evaluation_utils import EvaluationUtils +from metagpt.ext.aflow.scripts.optimizer_utils.experience_utils import ExperienceUtils +from metagpt.ext.aflow.scripts.optimizer_utils.graph_utils import GraphUtils +from metagpt.logs import logger +from metagpt.provider.llm_provider_registry import create_llm_instance + +QuestionType = Literal["math", "code", "qa"] +OptimizerType = Literal["Graph", "Test"] + + +class GraphOptimize(BaseModel): + modification: str = Field(default="", description="modification") + graph: str = Field(default="", description="graph") + prompt: str = Field(default="", description="prompt") + + +class Optimizer: + def __init__( + self, + dataset: DatasetType, + question_type: QuestionType, + opt_llm_config, + exec_llm_config, + operators: List, + sample: int, + check_convergence: bool = False, + optimized_path: str = None, + initial_round: int = 1, + max_rounds: int = 20, + validation_rounds: int = 5, + ) -> None: + self.optimize_llm_config = opt_llm_config + self.optimize_llm = create_llm_instance(self.optimize_llm_config) + self.execute_llm_config = exec_llm_config + + self.dataset = dataset + self.type = question_type + self.check_convergence = check_convergence + + self.graph = None + self.operators = operators + + self.root_path = f"{optimized_path}/{self.dataset}" + self.sample = sample + self.top_scores = [] + self.round = initial_round + self.max_rounds = max_rounds + self.validation_rounds = validation_rounds + + self.graph_utils = GraphUtils(self.root_path) + self.data_utils = DataUtils(self.root_path) + self.experience_utils = ExperienceUtils(self.root_path) + self.evaluation_utils = EvaluationUtils(self.root_path) + self.convergence_utils = ConvergenceUtils(self.root_path) + + def optimize(self, mode: OptimizerType = "Graph"): + if mode == "Test": + test_n = 3 # validation datasets's execution number + for i in range(test_n): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + score = loop.run_until_complete(self.test()) + return None + + for opt_round in range(self.max_rounds): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + retry_count = 0 + max_retries = 1 + + while retry_count < max_retries: + try: + score = loop.run_until_complete(self._optimize_graph()) + break + except Exception as e: + retry_count += 1 + logger.info(f"Error occurred: {e}. Retrying... (Attempt {retry_count}/{max_retries})") + if retry_count == max_retries: + logger.info("Max retries reached. Moving to next round.") + score = None + + wait_time = 5 * retry_count + time.sleep(wait_time) + + if retry_count < max_retries: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + self.round += 1 + logger.info(f"Score for round {self.round}: {score}") + + converged, convergence_round, final_round = self.convergence_utils.check_convergence(top_k=3) + + if converged and self.check_convergence: + logger.info( + f"Convergence detected, occurred in round {convergence_round}, final round is {final_round}" + ) + # Print average scores and standard deviations for each round + self.convergence_utils.print_results() + break + + time.sleep(5) + + async def _optimize_graph(self): + validation_n = self.validation_rounds # validation datasets's execution number + graph_path = f"{self.root_path}/workflows" + data = self.data_utils.load_results(graph_path) + + if self.round == 1: + directory = self.graph_utils.create_round_directory(graph_path, self.round) + # Load graph using graph_utils + self.graph = self.graph_utils.load_graph(self.round, graph_path) + avg_score = await self.evaluation_utils.evaluate_graph(self, directory, validation_n, data, initial=True) + + # Create a loop until the generated graph meets the check conditions + while True: + directory = self.graph_utils.create_round_directory(graph_path, self.round + 1) + + top_rounds = self.data_utils.get_top_rounds(self.sample) + sample = self.data_utils.select_round(top_rounds) + + prompt, graph_load = self.graph_utils.read_graph_files(sample["round"], graph_path) + graph = self.graph_utils.extract_solve_graph(graph_load) + + processed_experience = self.experience_utils.load_experience() + experience = self.experience_utils.format_experience(processed_experience, sample["round"]) + + operator_description = self.graph_utils.load_operators_description(self.operators) + log_data = self.data_utils.load_log(sample["round"]) + + graph_optimize_prompt = self.graph_utils.create_graph_optimize_prompt( + experience, sample["score"], graph[0], prompt, operator_description, self.type, log_data + ) + + graph_optimize_node = await ActionNode.from_pydantic(GraphOptimize).fill( + context=graph_optimize_prompt, mode="xml_fill", llm=self.optimize_llm + ) + + response = await self.graph_utils.get_graph_optimize_response(graph_optimize_node) + + # Check if the modification meets the conditions + check = self.experience_utils.check_modification( + processed_experience, response["modification"], sample["round"] + ) + + # If `check` is True, break the loop; otherwise, regenerate the graph + if check: + break + + # Save the graph and evaluate + self.graph_utils.write_graph_files(directory, response, self.round + 1, self.dataset) + + experience = self.experience_utils.create_experience_data(sample, response["modification"]) + + self.graph = self.graph_utils.load_graph(self.round + 1, graph_path) + + logger.info(directory) + + avg_score = await self.evaluation_utils.evaluate_graph(self, directory, validation_n, data, initial=False) + + self.experience_utils.update_experience(directory, experience, avg_score) + + return avg_score + + async def test(self): + rounds = [5] # You can choose the rounds you want to test here. + data = [] + + graph_path = f"{self.root_path}/workflows_test" + json_file_path = self.data_utils.get_results_file_path(graph_path) + + data = self.data_utils.load_results(graph_path) + + for round in rounds: + directory = self.graph_utils.create_round_directory(graph_path, round) + self.graph = self.graph_utils.load_graph(round, graph_path) + + score, avg_cost, total_cost = await self.evaluation_utils.evaluate_graph_test(self, directory, is_test=True) + + new_data = self.data_utils.create_result_data(round, score, avg_cost, total_cost) + data.append(new_data) + + self.data_utils.save_results(json_file_path, data) diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/convergence_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/convergence_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0e275f49657aed8d4b71913db92bef9967e1faec --- /dev/null +++ b/metagpt/ext/aflow/scripts/optimizer_utils/convergence_utils.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +# @Date : 9/23/2024 10:00 AM +# @Author : Issac +# @Desc : + +import json +import os + +import numpy as np + +from metagpt.logs import logger + + +class ConvergenceUtils: + def __init__(self, root_path): + self.root_path = root_path + self.data = None + self.rounds = None + self.avg_scores, self.stds = None, None + + def load_data(self, root_path): + """ + Read JSON file, create a new file if it doesn't exist, then return the data. + """ + rounds_dir = os.path.join(root_path, "workflows") + result_file = os.path.join(rounds_dir, "results.json") + + # Ensure directory exists + os.makedirs(rounds_dir, exist_ok=True) + + # If file doesn't exist, create a new one with an empty list + if not os.path.exists(result_file): + with open(result_file, "w") as file: + json.dump([], file) + + # Read file and return data + with open(result_file, "r") as file: + return json.load(file) + + def process_rounds(self): + """ + Organize data by round, return a dictionary of scores by round. + """ + self.data = self.load_data(root_path=self.root_path) + rounds = {} + for entry in self.data: + round_number = entry["round"] + score = entry["score"] + if round_number not in rounds: + rounds[round_number] = [] + rounds[round_number].append(score) + return rounds + + def calculate_avg_and_std(self): + """ + Calculate average score and standard deviation for each round, return two lists: average scores and standard deviations. + """ + self.rounds = self.process_rounds() + + sorted_rounds = sorted(self.rounds.items(), key=lambda x: x[0]) + avg_scores = [] + stds = [] + for round_number, scores in sorted_rounds: + avg_scores.append(np.mean(scores)) + stds.append(np.std(scores)) + return avg_scores, stds + + def check_convergence(self, top_k=3, z=0, consecutive_rounds=5): + """ + Check for convergence. z is the z-score corresponding to the confidence level. + consecutive_rounds is the number of consecutive rounds that must meet the stop condition. + """ + # Calculate average score and standard deviation for each round + self.avg_scores, self.stds = self.calculate_avg_and_std() + # If total rounds are not enough to calculate top_k+1 rounds, return not converged + if len(self.avg_scores) < top_k + 1: + return False, None, None + convergence_count = 0 # Convergence counter + previous_y = None # Y value of the previous round (average of top_k scores) + sigma_y_previous = None # Standard error of Y value from previous round + for i in range(len(self.avg_scores)): + # Dynamically select top_k from current round and all previous rounds + top_k_indices = np.argsort(self.avg_scores[: i + 1])[::-1][ + :top_k + ] # Select top k indices by descending average score + top_k_scores = [self.avg_scores[j] for j in top_k_indices] # Get list of top k scores + top_k_stds = [ + self.stds[j] for j in top_k_indices + ] # Get list of standard deviations corresponding to top k scores + # Calculate mean of top k scores for current round, i.e., y_current + y_current = np.mean(top_k_scores) + # Calculate standard error of y_current (sigma_y_current), representing score dispersion + sigma_y_current = np.sqrt(np.sum([s**2 for s in top_k_stds]) / (top_k**2)) + # If not the first round, calculate change in Y (Delta_Y) and corresponding standard error + if previous_y is not None: + # Calculate Y difference between current round and previous round + delta_y = y_current - previous_y + # Calculate standard error of Y difference (sigma_Delta_Y) + sigma_delta_y = np.sqrt(sigma_y_current**2 + sigma_y_previous**2) + # Check if Y change is within acceptable confidence interval, i.e., convergence condition + if abs(delta_y) <= z * sigma_delta_y: + convergence_count += 1 + # If consecutive converged rounds reach set value, return convergence information + if convergence_count >= consecutive_rounds: + return True, i - consecutive_rounds + 1, i + else: + # If change is large, reset convergence counter + convergence_count = 0 + # Update Y value and standard error for previous round + previous_y = y_current + sigma_y_previous = sigma_y_current + # If convergence condition not met, return not converged + return False, None, None + + def print_results(self): + """ + Print average score and standard deviation for all rounds. + """ + self.avg_scores, self.stds = self.calculate_avg_and_std() + for i, (avg_score, std) in enumerate(zip(self.avg_scores, self.stds), 1): + logger.info(f"Round {i}: Average Score = {avg_score:.4f}, Standard Deviation = {std:.4f}") + + +if __name__ == "__main__": + # Use this class and specify top_k + checker = ConvergenceUtils("path") # For example, set top_k=5 + converged, convergence_round, final_round = checker.check_convergence() + + if converged: + logger.info(f"Convergence detected, occurred at round {convergence_round}, final round is {final_round}") + else: + logger.info("No convergence detected within all rounds") + + # Print average score and standard deviation for each round + checker.print_results() diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2a09e0820153611baf432bfb21e055906ee444b8 --- /dev/null +++ b/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py @@ -0,0 +1,149 @@ +import datetime +import json +import os +import random + +import numpy as np +import pandas as pd + +from metagpt.logs import logger +from metagpt.utils.common import read_json_file, write_json_file + + +class DataUtils: + def __init__(self, root_path: str): + self.root_path = root_path + self.top_scores = [] + + def load_results(self, path: str) -> list: + result_path = os.path.join(path, "results.json") + if os.path.exists(result_path): + with open(result_path, "r") as json_file: + try: + return json.load(json_file) + except json.JSONDecodeError: + return [] + return [] + + def get_top_rounds(self, sample: int, path=None, mode="Graph"): + self._load_scores(path, mode) + unique_rounds = set() + unique_top_scores = [] + + first_round = next((item for item in self.top_scores if item["round"] == 1), None) + if first_round: + unique_top_scores.append(first_round) + unique_rounds.add(1) + + for item in self.top_scores: + if item["round"] not in unique_rounds: + unique_top_scores.append(item) + unique_rounds.add(item["round"]) + + if len(unique_top_scores) >= sample: + break + + return unique_top_scores + + def select_round(self, items): + if not items: + raise ValueError("Item list is empty.") + + sorted_items = sorted(items, key=lambda x: x["score"], reverse=True) + scores = [item["score"] * 100 for item in sorted_items] + + probabilities = self._compute_probabilities(scores) + logger.info(f"\nMixed probability distribution: {probabilities}") + logger.info(f"\nSorted rounds: {sorted_items}") + + selected_index = np.random.choice(len(sorted_items), p=probabilities) + logger.info(f"\nSelected index: {selected_index}, Selected item: {sorted_items[selected_index]}") + + return sorted_items[selected_index] + + def _compute_probabilities(self, scores, alpha=0.2, lambda_=0.3): + scores = np.array(scores, dtype=np.float64) + n = len(scores) + + if n == 0: + raise ValueError("Score list is empty.") + + uniform_prob = np.full(n, 1.0 / n, dtype=np.float64) + + max_score = np.max(scores) + shifted_scores = scores - max_score + exp_weights = np.exp(alpha * shifted_scores) + + sum_exp_weights = np.sum(exp_weights) + if sum_exp_weights == 0: + raise ValueError("Sum of exponential weights is 0, cannot normalize.") + + score_prob = exp_weights / sum_exp_weights + + mixed_prob = lambda_ * uniform_prob + (1 - lambda_) * score_prob + + total_prob = np.sum(mixed_prob) + if not np.isclose(total_prob, 1.0): + mixed_prob = mixed_prob / total_prob + + return mixed_prob + + def load_log(self, cur_round, path=None, mode: str = "Graph"): + if mode == "Graph": + log_dir = os.path.join(self.root_path, "workflows", f"round_{cur_round}", "log.json") + else: + log_dir = path + + # 检查文件是否存在 + if not os.path.exists(log_dir): + return "" # 如果文件不存在,返回空字符串 + logger.info(log_dir) + data = read_json_file(log_dir, encoding="utf-8") + + if isinstance(data, dict): + data = [data] + elif not isinstance(data, list): + data = list(data) + + if not data: + return "" + + sample_size = min(3, len(data)) + random_samples = random.sample(data, sample_size) + + log = "" + for sample in random_samples: + log += json.dumps(sample, indent=4, ensure_ascii=False) + "\n\n" + + return log + + def get_results_file_path(self, graph_path: str) -> str: + return os.path.join(graph_path, "results.json") + + def create_result_data(self, round: int, score: float, avg_cost: float, total_cost: float) -> dict: + now = datetime.datetime.now() + return {"round": round, "score": score, "avg_cost": avg_cost, "total_cost": total_cost, "time": now} + + def save_results(self, json_file_path: str, data: list): + write_json_file(json_file_path, data, encoding="utf-8", indent=4) + + def _load_scores(self, path=None, mode="Graph"): + if mode == "Graph": + rounds_dir = os.path.join(self.root_path, "workflows") + else: + rounds_dir = path + + result_file = os.path.join(rounds_dir, "results.json") + self.top_scores = [] + + data = read_json_file(result_file, encoding="utf-8") + df = pd.DataFrame(data) + + scores_per_round = df.groupby("round")["score"].mean().to_dict() + + for round_number, average_score in scores_per_round.items(): + self.top_scores.append({"round": round_number, "score": average_score}) + + self.top_scores.sort(key=lambda x: x["score"], reverse=True) + + return self.top_scores diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/evaluation_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/evaluation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..77683017ee376acff88d98a8f3c5d909069758c4 --- /dev/null +++ b/metagpt/ext/aflow/scripts/optimizer_utils/evaluation_utils.py @@ -0,0 +1,63 @@ +from metagpt.ext.aflow.scripts.evaluator import Evaluator + + +class EvaluationUtils: + def __init__(self, root_path: str): + self.root_path = root_path + + async def evaluate_initial_round(self, optimizer, graph_path, directory, validation_n, data): + # 使用 optimizer 的 graph_utils 来加载图 + optimizer.graph = optimizer.graph_utils.load_graph(optimizer.round, graph_path) + evaluator = Evaluator(eval_path=directory) + + for i in range(validation_n): + score, avg_cost, total_cost = await evaluator.graph_evaluate( + optimizer.dataset, + optimizer.graph, + {"dataset": optimizer.dataset, "llm_config": optimizer.execute_llm_config}, + directory, + is_test=False, + ) + + new_data = optimizer.data_utils.create_result_data(optimizer.round, score, avg_cost, total_cost) + data.append(new_data) + + result_path = optimizer.data_utils.get_results_file_path(graph_path) + optimizer.data_utils.save_results(result_path, data) + + return data + + async def evaluate_graph(self, optimizer, directory, validation_n, data, initial=False): + evaluator = Evaluator(eval_path=directory) + sum_score = 0 + + for i in range(validation_n): + score, avg_cost, total_cost = await evaluator.graph_evaluate( + optimizer.dataset, + optimizer.graph, + {"dataset": optimizer.dataset, "llm_config": optimizer.execute_llm_config}, + directory, + is_test=False, + ) + + cur_round = optimizer.round + 1 if initial is False else optimizer.round + + new_data = optimizer.data_utils.create_result_data(cur_round, score, avg_cost, total_cost) + data.append(new_data) + + result_path = optimizer.data_utils.get_results_file_path(f"{optimizer.root_path}/workflows") + optimizer.data_utils.save_results(result_path, data) + + sum_score += score + + return sum_score / validation_n + + async def evaluate_graph_test(self, optimizer, directory, is_test=True): + evaluator = Evaluator(eval_path=directory) + return await evaluator.graph_evaluate( + optimizer.dataset, + optimizer.graph, + {"dataset": optimizer.dataset, "llm_config": optimizer.execute_llm_config}, + directory, + is_test=is_test, + ) diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/experience_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/experience_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..43f9eb1d5c7e9353736c865ca54d2d811fa96409 --- /dev/null +++ b/metagpt/ext/aflow/scripts/optimizer_utils/experience_utils.py @@ -0,0 +1,96 @@ +import json +import os +from collections import defaultdict + +from metagpt.logs import logger +from metagpt.utils.common import read_json_file, write_json_file + + +class ExperienceUtils: + def __init__(self, root_path: str): + self.root_path = root_path + + def load_experience(self, path=None, mode: str = "Graph"): + if mode == "Graph": + rounds_dir = os.path.join(self.root_path, "workflows") + else: + rounds_dir = path + + experience_data = defaultdict(lambda: {"score": None, "success": {}, "failure": {}}) + + for round_dir in os.listdir(rounds_dir): + if os.path.isdir(os.path.join(rounds_dir, round_dir)) and round_dir.startswith("round_"): + round_path = os.path.join(rounds_dir, round_dir) + try: + round_number = int(round_dir.split("_")[1]) + json_file_path = os.path.join(round_path, "experience.json") + if os.path.exists(json_file_path): + data = read_json_file(json_file_path, encoding="utf-8") + father_node = data["father node"] + + if experience_data[father_node]["score"] is None: + experience_data[father_node]["score"] = data["before"] + + if data["succeed"]: + experience_data[father_node]["success"][round_number] = { + "modification": data["modification"], + "score": data["after"], + } + else: + experience_data[father_node]["failure"][round_number] = { + "modification": data["modification"], + "score": data["after"], + } + except Exception as e: + logger.info(f"Error processing {round_dir}: {str(e)}") + + experience_data = dict(experience_data) + + output_path = os.path.join(rounds_dir, "processed_experience.json") + with open(output_path, "w", encoding="utf-8") as outfile: + json.dump(experience_data, outfile, indent=4, ensure_ascii=False) + + logger.info(f"Processed experience data saved to {output_path}") + return experience_data + + def format_experience(self, processed_experience, sample_round): + experience_data = processed_experience.get(sample_round) + if experience_data: + experience = f"Original Score: {experience_data['score']}\n" + experience += "These are some conclusions drawn from experience:\n\n" + for key, value in experience_data["failure"].items(): + experience += f"-Absolutely prohibit {value['modification']} (Score: {value['score']})\n" + for key, value in experience_data["success"].items(): + experience += f"-Absolutely prohibit {value['modification']} \n" + experience += "\n\nNote: Take into account past failures and avoid repeating the same mistakes, as these failures indicate that these approaches are ineffective. You must fundamentally change your way of thinking, rather than simply using more advanced Python syntax like for, if, else, etc., or modifying the prompt." + else: + experience = f"No experience data found for round {sample_round}." + return experience + + def check_modification(self, processed_experience, modification, sample_round): + experience_data = processed_experience.get(sample_round) + if experience_data: + for key, value in experience_data["failure"].items(): + if value["modification"] == modification: + return False + for key, value in experience_data["success"].items(): + if value["modification"] == modification: + return False + return True + else: + return True # 如果 experience_data 为空,也返回 True + + def create_experience_data(self, sample, modification): + return { + "father node": sample["round"], + "modification": modification, + "before": sample["score"], + "after": None, + "succeed": None, + } + + def update_experience(self, directory, experience, avg_score): + experience["after"] = avg_score + experience["succeed"] = bool(avg_score > experience["before"]) + + write_json_file(os.path.join(directory, "experience.json"), experience, encoding="utf-8", indent=4) diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/graph_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/graph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ebe9b263982abf7b617bab7c9e94ae8c5ca51d --- /dev/null +++ b/metagpt/ext/aflow/scripts/optimizer_utils/graph_utils.py @@ -0,0 +1,125 @@ +import json +import os +import re +import time +import traceback +from typing import List + +from metagpt.ext.aflow.scripts.prompts.optimize_prompt import ( + WORKFLOW_CUSTOM_USE, + WORKFLOW_INPUT, + WORKFLOW_OPTIMIZE_PROMPT, + WORKFLOW_TEMPLATE, +) +from metagpt.logs import logger + + +class GraphUtils: + def __init__(self, root_path: str): + self.root_path = root_path + + def create_round_directory(self, graph_path: str, round_number: int) -> str: + directory = os.path.join(graph_path, f"round_{round_number}") + os.makedirs(directory, exist_ok=True) + return directory + + def load_graph(self, round_number: int, workflows_path: str): + workflows_path = workflows_path.replace("\\", ".").replace("/", ".") + graph_module_name = f"{workflows_path}.round_{round_number}.graph" + + try: + graph_module = __import__(graph_module_name, fromlist=[""]) + graph_class = getattr(graph_module, "Workflow") + return graph_class + except ImportError as e: + logger.info(f"Error loading graph for round {round_number}: {e}") + raise + + def read_graph_files(self, round_number: int, workflows_path: str): + prompt_file_path = os.path.join(workflows_path, f"round_{round_number}", "prompt.py") + graph_file_path = os.path.join(workflows_path, f"round_{round_number}", "graph.py") + + try: + with open(prompt_file_path, "r", encoding="utf-8") as file: + prompt_content = file.read() + with open(graph_file_path, "r", encoding="utf-8") as file: + graph_content = file.read() + except FileNotFoundError as e: + logger.info(f"Error: File not found for round {round_number}: {e}") + raise + except Exception as e: + logger.info(f"Error loading prompt for round {round_number}: {e}") + raise + return prompt_content, graph_content + + def extract_solve_graph(self, graph_load: str) -> List[str]: + pattern = r"class Workflow:.+" + return re.findall(pattern, graph_load, re.DOTALL) + + def load_operators_description(self, operators: List[str]) -> str: + path = f"{self.root_path}/workflows/template/operator.json" + operators_description = "" + for id, operator in enumerate(operators): + operator_description = self._load_operator_description(id + 1, operator, path) + operators_description += f"{operator_description}\n" + return operators_description + + def _load_operator_description(self, id: int, operator_name: str, file_path: str) -> str: + with open(file_path, "r") as f: + operator_data = json.load(f) + matched_data = operator_data[operator_name] + desc = matched_data["description"] + interface = matched_data["interface"] + return f"{id}. {operator_name}: {desc}, with interface {interface})." + + def create_graph_optimize_prompt( + self, + experience: str, + score: float, + graph: str, + prompt: str, + operator_description: str, + type: str, + log_data: str, + ) -> str: + graph_input = WORKFLOW_INPUT.format( + experience=experience, + score=score, + graph=graph, + prompt=prompt, + operator_description=operator_description, + type=type, + log=log_data, + ) + graph_system = WORKFLOW_OPTIMIZE_PROMPT.format(type=type) + return graph_input + WORKFLOW_CUSTOM_USE + graph_system + + async def get_graph_optimize_response(self, graph_optimize_node): + max_retries = 5 + retries = 0 + + while retries < max_retries: + try: + response = graph_optimize_node.instruct_content.model_dump() + return response + except Exception as e: + retries += 1 + logger.info(f"Error generating prediction: {e}. Retrying... ({retries}/{max_retries})") + if retries == max_retries: + logger.info("Maximum retries reached. Skipping this sample.") + break + traceback.print_exc() + time.sleep(5) + return None + + def write_graph_files(self, directory: str, response: dict, round_number: int, dataset: str): + graph = WORKFLOW_TEMPLATE.format(graph=response["graph"], round=round_number, dataset=dataset) + + with open(os.path.join(directory, "graph.py"), "w", encoding="utf-8") as file: + file.write(graph) + + with open(os.path.join(directory, "prompt.py"), "w", encoding="utf-8") as file: + file.write(response["prompt"]) + + with open(os.path.join(directory, "__init__.py"), "w", encoding="utf-8") as file: + file.write("") diff --git a/metagpt/ext/aflow/scripts/prompts/optimize_prompt.py b/metagpt/ext/aflow/scripts/prompts/optimize_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..a2e862ec2907ed8d37e642bb1d49e283a02a3bd3 --- /dev/null +++ b/metagpt/ext/aflow/scripts/prompts/optimize_prompt.py @@ -0,0 +1,59 @@ +WORKFLOW_OPTIMIZE_PROMPT = """You are building a Graph and corresponding Prompt to jointly solve {type} problems. +Referring to the given graph and prompt, which forms a basic example of a {type} solution approach, +please reconstruct and optimize them. You can add, modify, or delete nodes, parameters, or prompts. Include your +single modification in XML tags in your reply. Ensure they are complete and correct to avoid runtime failures. When +optimizing, you can incorporate critical thinking methods like review, revise, ensemble (generating multiple answers through different/similar prompts, then voting/integrating/checking the majority to obtain a final answer), selfAsk, etc. Consider +Python's loops (for, while, list comprehensions), conditional statements (if-elif-else, ternary operators), +or machine learning techniques (e.g., linear regression, decision trees, neural networks, clustering). The graph +complexity should not exceed 10. Use logical and control flow (IF-ELSE, loops) for a more enhanced graphical +representation.Ensure that all the prompts required by the current graph from prompt_custom are included.Exclude any other prompts. +Output the modified graph and all the necessary Prompts in prompt_custom (if needed). +The prompt you need to generate is only the one used in `prompt_custom.XXX` within Custom. Other methods already have built-in prompts and are prohibited from being generated. Only generate those needed for use in `prompt_custom`; please remove any unused prompts in prompt_custom. +the generated prompt must not contain any placeholders. +Considering information loss, complex graphs may yield better results, but insufficient information transmission can omit the solution. It's crucial to include necessary context during the process.""" + + +WORKFLOW_INPUT = """ +Here is a graph and the corresponding prompt (prompt only related to the custom method) that performed excellently in a previous iteration (maximum score is 1). You must make further optimizations and improvements based on this graph. The modified graph must differ from the provided example, and the specific differences should be noted within the xxx section.\n + + {experience} + (such as:add /delete /modify/ ...) + {score} + {graph} + {prompt}(only prompt_custom) + {operator_description} + +Below are the logs of some results with the aforementioned Graph that performed well but encountered errors, which can be used as references for optimization: +{log} + +First, provide optimization ideas. **Only one detail point can be modified at a time**, and no more than 5 lines of code may be changed per modification—extensive modifications are strictly prohibited to maintain project focus! +When introducing new functionalities in the graph, please make sure to import the necessary libraries or modules yourself, except for operator, prompt_custom, create_llm_instance, and CostManage, which have already been automatically imported. +**Under no circumstances should Graph output None for any field.** +Use custom methods to restrict your output format, rather than using code (outside of the code, the system will extract answers based on certain rules and score them). +It is very important to format the Graph output answers, you can refer to the standard answer format in the log. +""" + +WORKFLOW_CUSTOM_USE = """\nHere's an example of using the `custom` method in graph: +``` +# You can write your own prompt in prompt_custom and then use it in the Custom method in the graph +response = await self.custom(input=problem, instruction=prompt_custom.XXX_PROMPT) +# You can also concatenate previously generated string results in the input to provide more comprehensive contextual information. +# response = await self.custom(input=problem+f"xxx:{xxx}, xxx:{xxx}", instruction=prompt_custom.XXX_PROMPT) +# The output from the Custom method can be placed anywhere you need it, as shown in the example below +solution = await self.generate(problem=f"question:{problem}, xxx:{response['response']}") +``` +Note: In custom, the input and instruction are directly concatenated(instruction+input), and placeholders are not supported. Please ensure to add comments and handle the concatenation externally.\n + +**Introducing multiple operators at appropriate points can enhance performance. If you find that some provided operators are not yet used in the graph, try incorporating them.** +""" + +WORKFLOW_TEMPLATE = """from typing import Literal +import metagpt.ext.aflow.scripts.optimized.{dataset}.workflows.template.operator as operator +import metagpt.ext.aflow.scripts.optimized.{dataset}.workflows.round_{round}.prompt as prompt_custom +from metagpt.provider.llm_provider_registry import create_llm_instance +from metagpt.utils.cost_manager import CostManager + +DatasetType = Literal["HumanEval", "MBPP", "GSM8K", "MATH", "HotpotQA", "DROP"] + +{graph} +""" diff --git a/metagpt/ext/aflow/scripts/prompts/prompt.py b/metagpt/ext/aflow/scripts/prompts/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..16bf78af87b727d8da0646c4655f04f24b806687 --- /dev/null +++ b/metagpt/ext/aflow/scripts/prompts/prompt.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# @Date : 6/26/2024 17:07 PM +# @Author : didi +# @Desc : prompts of operators + +ANSWER_GENERATION_PROMPT = """ +Think step by step and solve the problem. +1. In the "thought" field, explain your thinking process in detail. +2. In the "answer" field, provide the final answer concisely and clearly. The answer should be a direct response to the question, without including explanations or reasoning. +Your task: {input} +""" + +FORMAT_PROMPT = """ +For the question described as {problem_description}, +please extract a short and concise answer contains only one word/few words from the following solution: {solution}. +Make sure there are no additional comments or explanations in your response. +""" + +SC_ENSEMBLE_PROMPT = """ +Given the question described as follows: {question} +Several solutions have been generated to address the given question. They are as follows: +{solutions} + +Carefully evaluate these solutions and identify the answer that appears most frequently across them. This consistency in answers is crucial for determining the most reliable solution. + +In the "thought" field, provide a detailed explanation of your thought process. In the "solution_letter" field, output only the single letter ID (A, B, C, etc.) corresponding to the most consistent solution. Do not include any additional text or explanation in the "solution_letter" field. +""" + +PYTHON_CODE_VERIFIER_PROMPT = """ +You are a professional Python programmer. Your task is to write complete, self-contained code based on a given mathematical problem and output the answer. The code should include all necessary imports and dependencies, and be ready to run without additional setup or environment configuration. + +Problem description: {problem} +Other analysis: {analysis} +{feedback} + +Your code should: +1. Implement the calculation steps described in the problem. +2. Define a function named `solve` that performs the calculation and returns the result. The `solve` function should not require any input parameters; instead, it should obtain all necessary inputs from within the function or from globally defined variables. +3. `solve` function return the final calculation result. + +Please ensure your code is efficient, well-commented, and follows Python best practices. The output should be limited to basic data types such as strings, integers, and floats. It is prohibited to transmit images or other file formats. The code output is intended for a text-based language model. +""" + + +REFLECTION_ON_PUBLIC_TEST_PROMPT = """ +Given a code problem and a python code solution which failed to pass test or execute, you need to analyze the reason for the failure and propose a better code solution.: +### problem +{problem} + +### Code Solution +{solution} + +### Execution Result +{exec_pass} + +#### Failed Test Case +{test_fail} + +Please provide a reflection on the failed test cases and code solution, followed by a better code solution without any additional text or test cases. +""" + +MD_ENSEMBLE_PROMPT = """ +Given the question described as follows: {question} +Several solutions have been generated to address the given question. They are as follows: +{solutions} + +Carefully evaluate these solutions and identify the solution that is more capable of solving the problem compared to other solutions, as this is crucial for problem-solving. + +In the "thought" field, provide a detailed explanation of your thought process. In the "solution_letter" field, output only the single letter ID (A, B, C, etc.) corresponding to the solution. Do not include any additional text or explanation in the "solution_letter" field. +""" + +REVIEW_PROMPT = """ +Given a problem and a thoughtful solution, your task is to using critical thinking (questioning) to review the solution's correctness and provide a review result in boolean format. + +problem: {problem} +solution: {solution} + +If you are more than 95 percent confident that the final answer is incorrect, please return False and give a feedback for the error. Otherwise, please return True and give a explanation for the correctness. +""" + +REVISE_PROMPT = """ +Given a problem and a thoughtful solution which is just reviewed as incorrect, your task is to revise the solution to solve the question and ensure the final code solution is wrapped with ```python```. + +problem: {problem} +solution: {solution} +feedback: {feedback} + +Ensure the output code is self-contained, and without any additional text or test cases. +""" diff --git a/metagpt/ext/aflow/scripts/utils.py b/metagpt/ext/aflow/scripts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6222dc492c0a1d7f86a8e96ed367d45e212146 --- /dev/null +++ b/metagpt/ext/aflow/scripts/utils.py @@ -0,0 +1,125 @@ +""" +@Time : 2024/7/24 16:37 +@Author : didi +@File : utils.py +""" + +import json +import re +from enum import Enum +from typing import Any, List, Tuple + + +class CodeDataset(Enum): + HUMAN_EVAL = "HumanEval" + MBPP = "MBPP" + + +def extract_test_cases_from_jsonl(entry_point: str, dataset: CodeDataset = CodeDataset.HUMAN_EVAL): + if dataset == CodeDataset.HUMAN_EVAL.value: + file_path = "metagpt/ext/aflow/data/humaneval_public_test.jsonl" + # Retain the original hardcoded test cases + hardcoded_cases = { + "find_zero": "", + "decode_cyclic": "", + "decode_shift": "", + "by_length": "", + "add": "", + "triangle_area": "", + "correct_bracketing": "", + "solve": "", + "sum_squares": "", + "starts_one_ends": "", + } + elif dataset == CodeDataset.MBPP.value: + file_path = "metagpt/ext/aflow/data/mbpp_public_test.jsonl" + hardcoded_cases = { + "remove_odd": "", + "replace_spaces": "", + "snake_to_camel": "", + "Split": "", + "swap_List": "", + "square_Sum": "", + "sort_sublists": "", + "unique_sublists": "", + } + # Check if there are hardcoded test cases + if entry_point in hardcoded_cases: + return hardcoded_cases[entry_point] + + # If there are no hardcoded test cases, read from the file + with open(file_path, "r") as file: + for line in file: + data = json.loads(line) + if data.get("entry_point") == entry_point: + return data.get("test") + + return None + + +def extract_test_cases(docstring: str) -> List[Tuple[str, List[Any], Any]]: + # Use regular expressions to match test cases, now capturing function names and any output + pattern = r">>> (\w+)\((.*?)\)\n\s*(.*?)(?=\n|$)" + matches = re.findall(pattern, docstring, re.DOTALL) + + test_cases = [] + for match in matches: + func_name, input_str, expected_output = match + + # Process input + input_list = [] + for item in input_str.split(","): + item = item.strip() + try: + # Try to convert input to numeric type + if "." in item: + input_list.append(float(item)) + else: + input_list.append(int(item)) + except ValueError: + # If unable to convert to numeric, keep as string + input_list.append(item.strip("'\"")) + + # Process output + try: + # Try to convert output to numeric or boolean value + if expected_output.lower() == "true": + expected_output = True + elif expected_output.lower() == "false": + expected_output = False + elif "." in expected_output: + expected_output = float(expected_output) + else: + expected_output = int(expected_output) + except ValueError: + # If unable to convert, keep as string + expected_output = expected_output.strip("'\"") + + test_cases.append([func_name, input_list, expected_output]) + + return test_cases + + +def test_cases_2_test_functions(solution: str, test_cases: str): + tester_function = f""" +{solution} + +{test_cases} +""" + return tester_function + + +def test_case_2_test_function(solution: str, test_case: str, entry_point: str): + tester_function = f""" +{solution} + + +def check(candidate): + {test_case} + +def test_check(): + check({entry_point}) + +test_check() +""" + return tester_function diff --git a/metagpt/ext/aflow/scripts/workflow.py b/metagpt/ext/aflow/scripts/workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..47b54021b280c93d76f131cff1af7e08ab1a95e7 --- /dev/null +++ b/metagpt/ext/aflow/scripts/workflow.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# @Date : 6/27/2024 22:07 PM +# @Author : didi +# @Desc : Basic Graph Class + + +from metagpt.ext.aflow.scripts.evaluator import DatasetType +from metagpt.provider.llm_provider_registry import create_llm_instance +from metagpt.utils.cost_manager import CostManager + + +class Workflow: + def __init__( + self, + name: str, + llm_config, + dataset: DatasetType, + ) -> None: + self.name = name + self.dataset = dataset + self.llm = create_llm_instance(llm_config) + self.llm.cost_manager = CostManager() + + async def __call__(self, problem: str): + """ + Implementation of the workflow + """ + raise NotImplementedError("This method should be implemented by the subclass") diff --git a/metagpt/ext/android_assistant/.DS_Store b/metagpt/ext/android_assistant/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..9b2c5291a9477658442018c1c6ecf2bc26f6c7ec Binary files /dev/null and b/metagpt/ext/android_assistant/.DS_Store differ diff --git a/metagpt/ext/android_assistant/README.md b/metagpt/ext/android_assistant/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fe8b4b3e32c9dded2fc82edd22792ff1d1ab5a4b --- /dev/null +++ b/metagpt/ext/android_assistant/README.md @@ -0,0 +1,118 @@ +# MetaGPT Android Assistant + +The MetaGPT Android Assistant is an intelligent assistance tool driven by a multi-modal large language model based on the advanced MetaGPT framework. It has the ability to self-learn, mastering users' daily usage patterns through learning, and can automatically complete various application operations according to user instructions, achieving comprehensive liberation of users' hands. +Next, we will introduce the functions of the MetaGPT Android Assistant and how to use it. + +## Features + +The operation of the MetaGPT Android Assistant mainly includes two stages: learning and automatic execution. Below, we introduce the specific features of the MetaGPT Android Assistant from these two stages. + +### Learning Stage + +By learning from human demonstrations or exploring apps based on human instructions, the MetaGPT Android Assistant can learn the functionality of apps, generate corresponding operation documents for use in the subsequent "automatic execution" stage. Approximately 20 rounds of exploration for any given task objective can significantly improve performance. + +By setting the `stage` to `learn`, you can ask the Android Assistant to enter the learning stage. By setting the `mode` to `auto`, you can instruct the Android Assistant to learn through automatic exploration; by setting the mode to manual, you can instruct the Android Assistant to learn through human manual demonstration. In the usage section, we provide detailed explanations of the script parameters. You can try experimenting with automatic exploration and manual demonstration modes on the "Messenger" app with the following commands: + +```bash +cd examples/android_assistant +python run_assistant.py "Send 'When will we release this feature?' to +86 8888888" --stage "learn" --mode "auto or manual" --app-name "Messenger" +``` + +#### Learning Based on Human Demonstration +When asking the Android Assistant to perform self-exploration during the learning stage, you can free your hands. However, when instructing it to learn according to your commands, you need to follow the instructions in the terminal for the Android Assistant to accurately learn your operation methods. +A possible example is as follows: + +```bash +cd examples/android_assistant +python run_assistant.py "Send 'When will we release this feature?' to +86 8888888" --stage "learn" --mode "manual" --app-name "Messenger" +``` + +After running this command, you will first see a screenshot of an Android screen that has been marked at various interactive locations, as shown in the figure below: + + + +After remembering the location where you want to operate, a request similar to the one below will be output in the terminal. Reply to it and thereby direct the Android assistant to learn your demonstration action: + +```bash +| INFO | examples.android_assistant.actions.manual_record:run:96 - Which element do you want to tap? Choose a numeric tag from 1 to 11: +user_input: 8 +| INFO | examples.android_assistant.actions.manual_record:run:81 - Choose one of the following actions you want to perform on the current screen: +tap, text, long_press, swipe, stop +user_input: tap +``` + +### Automatic Execution Stage +After the Android Assistant completes the learning stage, you can command it to complete tasks on the phone through text descriptions. By configuring the operation documents from the self-learning stage, the Android Assistant has richer prior knowledge, and its execution capabilities are further enhanced. +You can instruct the Android Assistant to send messages in the "Messenger" app with the following command: +```bash +python run_assistant.py "Send 'When will we release this feature?' to +86 8888888" --stage "act" --mode "auto or manual" --app-name "Messenger" +``` +Specifically, by selecting `auto` for `mode`, the Android assistant will employ the operational records compiled through self-exploration. Alternatively, if `manual` is chosen as the `mode`, the Android assistant will leverage the operation manuals accrued from learning via human demonstration. + +## Installation +To use the Android Assistant, you first need to meet the following conditions: +1. Complete the installation of the MetaGPT environment. +2. Install [Android Debug Bridge (ADB)](https://developer.android.com/tools/adb?hl=zh-cn) on your PC, which enables interaction between your PC and Android devices. +3. Install Android Studio and within it, install the Android emulator to provide an environment for the Android Assistant to learn and execute. For information on how to install the Android emulator, refer to [Quick Installation of Android Studio & Emulator](https://docs.expo.dev/workflow/android-studio-emulator/). +4. (Optional) Connect your Android device to the USB port of your PC, which can also provide an environment for the Android Assistant to learn and execute. + +Note ⚠️: When operating with the Android emulator, the emulator model we use is Medium Phone, which is recommended for first-time users to complete the operation. + +After completing these operations, you can enter the following command to check if ADB is installed successfully and if the Android device is connected: +```bash +adb devices +``` + +## Usage +The MetaGPT Android Assistant is designed within the MetaGPT framework as a collection of Roles and multiple Actions. You can run it by executing the `run_assistant.py` script. The specific parameter description of this script is as follows: +```text +Usage: run_assistant.py [OPTIONS] TASK_DESC + + Run a Android Assistant + +Arguments: + TASK_DESC the task description you want the android assistant to learn or + act [required] + +Options: + --n-round INTEGER The max round to do an app operation task. + [default: 20] + --stage TEXT stage: learn / act [default: learn] + --mode TEXT mode: auto / manual , when state=learn + [default: auto] + --app-name TEXT the name of app you want to run [default: + demo] + --investment FLOAT Dollar amount to invest in the AI company. + [default: 5.0] + --refine-doc / --no-refine-doc Refine existing operation docs based on the + latest observation if True. [default: no- + refine-doc] + --min-dist INTEGER The minimum distance between elements to + prevent overlapping during the labeling + process. [default: 30] + --android-screenshot-dir TEXT The path to store screenshots on android + device. Make sure it exists. [default: + /sdcard/Pictures/Screenshots] + --android-xml-dir TEXT The path to store xml files for determining + UI elements localtion. Make sure it exists. + [default: /sdcard] + --device-id TEXT The Android device_id [default: + emulator-5554] + --help Show this message and exit. +``` + +## Acknowledgements +The MetaGPT Android Assistant has referenced some ideas and code from the [AppAgent](https://github.com/mnotgod96/AppAgent) project. We thank the developers of the Appagent project. + +### Citation + +```bib +@misc{yang2023appagent, + title={AppAgent: Multimodal Agents as Smartphone Users}, + author={Chi Zhang and Zhao Yang and Jiaxuan Liu and Yucheng Han and Xin Chen and Zebiao Huang and Bin Fu and Gang Yu}, + year={2023}, + eprint={2312.13771}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` \ No newline at end of file diff --git a/metagpt/ext/android_assistant/README_CN.md b/metagpt/ext/android_assistant/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..a1abbe3b0bfa3b61bd76e15a11d71f9c43281190 --- /dev/null +++ b/metagpt/ext/android_assistant/README_CN.md @@ -0,0 +1,113 @@ +# MetaGPT 安卓助理 + +MetaGPT安卓助理是一款依托于先进的MetaGPT框架构建的多模态大语言模型驱动的智能辅助工具。 +它具备自我学习的能力,能够通过学习掌握用户的日常使用方式,同时能够根据用户的指令自动完成各类应用程序的操作任务,实现了用户双手的全面解放。 +接下来,我们将介绍MetaGPT安卓助理的功能以及如何使用它。 + +## 功能 + +MetaGPT 安卓助理的执行主要包含两个阶段,分别为自我学习与自动执行。下面,我们将从这两个阶段介绍MetaGPT 安卓助理的具体功能。 + +### 自我学习阶段 + +通过学习人类演示或基于人类指令对app进行探索,MetaGPT安卓助理可以对app的功能进行学习,生成相应的操作文档,为后续的“自动执行”阶段使用。对于任何给定的任务目标,进行约20轮的探索可以显著提高性能。 + +通过设定`stage`为`learn`可要求安卓助理进入自我学习阶段。通过设定`mode`为`auto`,可要求安卓助理通过自动探索学习,通过设定`mode`为`manual`,可要求安卓助理通过人类手动演示学习。在使用章节,我们对脚本的参数进行了详细的说明。 +您可以尝试对“Messenger”应用程序进行自动探索和手动演示模式的实验,具体命令如下: + +```bash +cd examples/android_assistant +python run_assistant.py "Send 'When will we release this feature? to +86 8888888'" --stage "learn" --mode "auto or manual" --app-name "Messenger" +``` + +#### 基于人类演示的学习 +在要求安卓助理在自我学习阶段执行自我探索时,您可以解放您的双手,但在要求他根据您的指令进行学习时,你需要根据终端中的指令进行输入,以便安卓助理能够准确地学习您的操作方式。 +一个可能的例子如下: + +```bash +cd examples/android_assistant +python run_assistant.py "Send 'When will we release this feature? to +86 8888888'" --stage "learn" --mode "manual" --app-name "Messenger" +``` + +在运行这一指令后,你将首先看到一个在各个可交互的位置进行了标记的安卓屏幕的截图,如下图: + + + +在记住你要操作的位置之后,终端中将会输出与下面类似的要求,回复它,进而指挥安卓助理学习你的演示行为: + +```bash +| INFO | examples.android_assistant.actions.manual_record:run:96 - Which element do you want to tap? Choose a numeric tag from 1 to 11: +user_input: 8 +| INFO | examples.android_assistant.actions.manual_record:run:81 - Choose one of the following actions you want to perform on the current screen: +tap, text, long_press, swipe, stop +user_input: tap +``` +### 自动执行阶段 +在安卓助理完成了自我学习阶段之后,您可以通过文本描述的方式,指挥安卓助理在手机中完成任务。通过为其配置自我学习阶段的操作文档,安卓助理具备了更丰富的前置知识,执行能力进一步得到提升。 +你可以通过以下指令,指挥安卓助理在“Messenger”应用中发送信息: +```bash +python run_assistant.py "Send 'When will we release this feature? to +86 8888888'" --stage "act" --mode "auto or manual" --app-name "Messenger" +``` +其中,`mode`选择`auto`,安卓助理将使用自我探索中积累的操作文档;`mode`选择`manual`,安卓助理将使用人类演示学习中积累的操作文档。 + +## 安装 +为了使用安卓助理,你首先需要满足以下条件: +1. 完成MetaGPT环境的安装 +2. 在你的PC上安装[Android Debug Bridge(ADB)](https://developer.android.com/tools/adb?hl=zh-cn),ADB可以使你的PC与安卓设备进行交互。 +3. 安装Android Studio,在其中安装Android模拟器,以为安卓助手提供学习与执行的环境。关于如何安装Android模拟器,可以参考[快速安装Android Studio & Emulator](https://dev.weixin.qq.com/docs/framework/dev/framework/env/android-simulator.html)。 +4. (Optional) 将你的安卓设备连接到PC的USB端口上,这同样可以为安卓助手提供学习与执行的环境。 + +注意 ⚠️:在使用Android模拟器进行操作时,我们使用的模拟器型号为Medium Phone,建议第一次尝试此类应用的用户使用这一型号完成操作。 + +在完成这一系列操作之后,你可以输入以下命令检查ADB是否安装成功,以及安卓设备是否连接 +```bash +adb devices +``` +## 使用 +MetaGPT 安卓助理在MetaGPT框架中被设计为一个`Role`与多个`Action`的集合,你可以通过运行`run_assistant.py`脚本来运行它。这一脚本具体的参数说明如下: +```text +用法:run_assistant.py [选项] 任务描述 + + 运行一个安卓助手 + +参数: + TASK_DESC 你希望安卓助手学习或执行的任务描述 + [必需] + +选项: + --n-round 整数 执行应用程序操作任务的最大轮数。 + [默认值:20] + --stage 文本 阶段:learn/act [默认值:learn] + --mode 文本 模式:auto/manual,当状态=learn时 [默认值:auto] + --app-name 文本 你想要运行的应用程序名称 [默认值: + 演示] + --investment 浮点数 投资于人工智能公司的美元金额。 + [默认值:5.0] + --refine-doc / --no-refine-doc 如果为真,则根据最新的观察结果优化现有操作文档。 + [默认值:--no-refine-doc] + --min-dist 整数 在标记过程中防止元素重叠的最小元素间距。 + [默认值:30] + --android-screenshot-dir 文本 在安卓设备上存储截图的路径。确保其存在。 + [默认值:/sdcard/Pictures/Screenshots] + --android-xml-dir 文本 存储用于确定UI元素位置的XML文件的路径。 + 确保其存在。[默认值:/sdcard] + --device-id 文本 安卓device_id [默认值: + 模拟器-5554] + --help 显示此信息并退出。 +``` + +## 致谢 +MetaGPT 安卓助理参考了 [AppAgent](https://github.com/mnotgod96/AppAgent) 项目的部分思路与代码,感谢 Appagent 项目的开发者们。 + +### 引用 + +```bib +@misc{yang2023appagent, + title={AppAgent: Multimodal Agents as Smartphone Users}, + author={Chi Zhang and Zhao Yang and Jiaxuan Liu and Yucheng Han and Xin Chen and Zebiao Huang and Bin Fu and Gang Yu}, + year={2023}, + eprint={2312.13771}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` \ No newline at end of file diff --git a/metagpt/ext/android_assistant/__init__.py b/metagpt/ext/android_assistant/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/ext/android_assistant/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/ext/android_assistant/actions/__init__.py b/metagpt/ext/android_assistant/actions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/ext/android_assistant/actions/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/ext/android_assistant/actions/manual_record.py b/metagpt/ext/android_assistant/actions/manual_record.py new file mode 100644 index 0000000000000000000000000000000000000000..bcfb2ed893ae259b3401c890912414461f3cff5e --- /dev/null +++ b/metagpt/ext/android_assistant/actions/manual_record.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : manual record user interaction in stage=learn & mode=manual, LIKE scripts/step_recorder.py +import time +from pathlib import Path + +import cv2 + +from metagpt.actions.action import Action +from metagpt.config2 import config +from metagpt.environment.android.android_env import AndroidEnv +from metagpt.environment.android.const import ADB_EXEC_FAIL +from metagpt.environment.android.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, +) +from metagpt.ext.android_assistant.utils.schema import ( + ActionOp, + AndroidActionOutput, + RunState, + SwipeOp, +) +from metagpt.ext.android_assistant.utils.utils import ( + draw_bbox_multi, + elem_list_from_xml_tree, +) +from metagpt.logs import logger + + +class ManualRecord(Action): + """do a human operation on the screen with human input""" + + name: str = "ManualRecord" + + useless_list: list[str] = [] # store useless elements uid + record_path: Path = "" + task_desc_path: Path = "" + screenshot_before_path: Path = "" + screenshot_after_path: Path = "" + xml_path: Path = "" + + async def run(self, task_desc: str, task_dir: Path, env: AndroidEnv): + self.record_path = Path(task_dir) / "record.txt" + self.task_desc_path = Path(task_dir) / "task_desc.txt" + self.screenshot_before_path = Path(task_dir) / "raw_screenshots" + self.screenshot_after_path = Path(task_dir) / "labeled_screenshots" + self.xml_path = Path(task_dir) / "xml" + for path in [self.screenshot_before_path, self.screenshot_after_path, self.xml_path]: + path.mkdir(parents=True, exist_ok=True) + + self.record_path.write_text("") + record_file = open(self.record_path, "w") + self.task_desc_path.write_text(task_desc) + + step = 0 + extra_config = config.extra + while True: + step += 1 + screenshot_path: Path = env.observe( + EnvObsParams( + obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{step}", local_save_dir=self.screenshot_before_path + ) + ) + xml_path: Path = env.observe( + EnvObsParams(obs_type=EnvObsType.GET_XML, xml_name=f"{step}", local_save_dir=self.xml_path) + ) + if not screenshot_path.exists() or not xml_path.exists(): + return AndroidActionOutput(action_state=RunState.FAIL) + + elem_list = elem_list_from_xml_tree(xml_path, self.useless_list, extra_config.get("min_dist", 30)) + + screenshot_labeled_path = Path(self.screenshot_after_path).joinpath(f"{step}_labeled.png") + labeled_img = draw_bbox_multi(screenshot_path, screenshot_labeled_path, elem_list) + + cv2.namedWindow("image", cv2.WINDOW_NORMAL) + cv2.imshow("image", labeled_img) + cv2.waitKey(0) + cv2.destroyAllWindows() + + user_input = "xxx" + logger.info( + "Choose one of the following actions you want to perform on the current screen:\n" + "tap, text, long_press, swipe, stop" + ) + + while ( + user_input.lower() != ActionOp.TAP.value + and user_input.lower() != ActionOp.TEXT.value + and user_input.lower() != ActionOp.LONG_PRESS.value + and user_input.lower() != ActionOp.SWIPE.value + and user_input.lower() != ActionOp.STOP.value + ): + user_input = input("user_input: ") + + if user_input.lower() == ActionOp.TAP.value: + logger.info(f"Which element do you want to tap? Choose a numeric tag from 1 to {len(elem_list)}:") + user_input = "xxx" + while not user_input.isnumeric() or int(user_input) > len(elem_list) or int(user_input) < 1: + user_input = input("user_input: ") + tl, br = elem_list[int(user_input) - 1].bbox + x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2 + action = EnvAction(action_type=EnvActionType.SYSTEM_TAP, coord=(x, y)) + log_str = f"tap({int(user_input)}):::{elem_list[int(user_input) - 1].uid}\n" + elif user_input.lower() == ActionOp.TEXT.value: + logger.info( + f"Which element do you want to input the text string? Choose a numeric tag from 1 to " + f"{len(elem_list)}:" + ) + input_area = "xxx" + while not input_area.isnumeric() or int(input_area) > len(elem_list) or int(input_area) < 1: + input_area = input("user_input: ") + logger.info("Enter your input text below:") + user_input = "" + while not user_input: + user_input = input("user_input: ") + action = EnvAction(action_type=EnvActionType.USER_INPUT, input_txt=user_input) + log_str = f"text({input_area}:sep:'{user_input}'):::{elem_list[int(input_area) - 1].uid}\n" + elif user_input.lower() == ActionOp.LONG_PRESS.value: + logger.info( + f"Which element do you want to long press? Choose a numeric tag from 1 to {len(elem_list)}:" + ) + user_input = "xxx" + while not user_input.isnumeric() or int(user_input) > len(elem_list) or int(user_input) < 1: + user_input = input("user_input: ") + tl, br = elem_list[int(user_input) - 1].bbox + x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2 + action = EnvAction(action_type=EnvActionType.USER_LONGPRESS, coord=(x, y)) + log_str = f"long_press({int(user_input)}):::{elem_list[int(user_input) - 1].uid}\n" + elif user_input.lower() == ActionOp.SWIPE.value: + logger.info( + "What is the direction of your swipe? Choose one from the following options:\n" + "up, down, left, right" + ) + user_input = "" + while ( + user_input != SwipeOp.UP.value + and user_input != SwipeOp.DOWN.value + and user_input != SwipeOp.LEFT.value + and user_input != SwipeOp.RIGHT.value + ): + user_input = input("user_input: ") + swipe_dir = user_input + logger.info(f"Which element do you want to swipe? Choose a numeric tag from 1 to {len(elem_list)}:") + while not user_input.isnumeric() or int(user_input) > len(elem_list) or int(user_input) < 1: + user_input = input("user_input: ") + tl, br = elem_list[int(user_input) - 1].bbox + x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2 + + action = EnvAction(action_type=EnvActionType.USER_SWIPE, coord=(x, y), orient=swipe_dir) + log_str = f"swipe({int(user_input)}:sep:{swipe_dir}):::{elem_list[int(user_input) - 1].uid}\n" + elif user_input.lower() == ActionOp.STOP.value: + record_file.write("stop\n") + record_file.close() + break + else: + break + + obs, _, _, _, info = env.step(action) + action_res = info["res"] + if action_res == ADB_EXEC_FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + record_file.write(log_str) + + time.sleep(1) + + return AndroidActionOutput(action_state=RunState.SUCCESS) diff --git a/metagpt/ext/android_assistant/actions/parse_record.py b/metagpt/ext/android_assistant/actions/parse_record.py new file mode 100644 index 0000000000000000000000000000000000000000..304daf65563281d45af85864b6a252be7947a4b9 --- /dev/null +++ b/metagpt/ext/android_assistant/actions/parse_record.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : parse record to generate learned standard operations in stage=learn & mode=manual, +# LIKE scripts/document_generation.py + +import ast +import re +from pathlib import Path + +from metagpt.actions.action import Action +from metagpt.config2 import config +from metagpt.ext.android_assistant.actions.parse_record_an import RECORD_PARSE_NODE +from metagpt.ext.android_assistant.prompts.operation_prompt import ( + long_press_doc_template, + refine_doc_suffix, + swipe_doc_template, + tap_doc_template, + text_doc_template, +) +from metagpt.ext.android_assistant.utils.schema import ( + ActionOp, + AndroidActionOutput, + RecordLogItem, + RunState, + SwipeOp, +) +from metagpt.logs import logger +from metagpt.utils.common import encode_image + + +class ParseRecord(Action): + name: str = "ParseRecord" + record_path: Path = "" + task_desc_path: Path = "" + screenshot_before_path: Path = "" + screenshot_after_path: Path = "" + + async def run(self, task_dir: Path, docs_dir: Path): + doc_count = 0 + self.record_path = Path(task_dir) / "record.txt" + self.task_desc_path = Path(task_dir) / "task_desc.txt" + self.screenshot_before_path = Path(task_dir) / "raw_screenshots" + self.screenshot_after_path = Path(task_dir) / "labeled_screenshots" + for path in [self.screenshot_before_path, self.screenshot_after_path]: + path.mkdir(parents=True, exist_ok=True) + + task_desc = self.task_desc_path.read_text() + extra_config = config.extra + + with open(self.record_path, "r") as record_file: + record_step_count = len(record_file.readlines()) - 1 + record_file.seek(0) + for step in range(1, record_step_count + 1): + img_before_base64 = encode_image(self.screenshot_after_path.joinpath(f"{step}_labeled.png")) + img_after_base64 = encode_image(self.screenshot_after_path.joinpath(f"{step + 1}_labeled.png")) + rec = record_file.readline().strip() + action, resource_id = rec.split(":::") + action_type = action.split("(")[0] + # 构建Prompt + action_param = re.findall(r"\((.*?)\)", action)[0] + if action_type == ActionOp.TAP.value: + prompt_template = tap_doc_template + context = prompt_template.format(ui_element=action_param) + elif action_type == ActionOp.TEXT.value: + input_area, input_text = action_param.split(":sep:") + prompt_template = text_doc_template + context = prompt_template.format(ui_element=input_area) + elif action_type == ActionOp.LONG_PRESS.value: + prompt_template = long_press_doc_template + context = prompt_template.format(ui_element=action_param) + elif action_type == ActionOp.SWIPE.value: + swipe_area, swipe_dir = action_param.split(":sep:") + if swipe_dir == SwipeOp.UP.value or swipe_dir == SwipeOp.DOWN.value: + action_type = ActionOp.VERTICAL_SWIPE.value + elif swipe_dir == SwipeOp.LEFT.value or swipe_dir == SwipeOp.RIGHT.value: + action_type = ActionOp.HORIZONTAL_SWIPE.value + prompt_template = swipe_doc_template + context = prompt_template.format(swipe_dir=swipe_dir, ui_element=swipe_area) + else: + break + context = context.format(task_desc=task_desc) + + doc_name = resource_id + ".txt" + doc_path = docs_dir.joinpath(doc_name) + + if doc_path.exists(): + try: + doc_content = ast.literal_eval(doc_path.read_text()) + except Exception as exp: + logger.error(f"ast parse doc: {doc_path} failed, exp: {exp}") + continue + + if doc_content[action_type]: + if extra_config.get("doc_refine", False): + refine_context = refine_doc_suffix.format(old_doc=doc_content[action_type]) + context += refine_context + logger.info( + f"Documentation for the element {resource_id} already exists. The doc will be " + f"refined based on the latest demo." + ) + else: + logger.info( + f"Documentation for the element {resource_id} already exists. Turn on DOC_REFINE " + f"in the config file if needed." + ) + continue + else: + doc_content = {"tap": "", "text": "", "v_swipe": "", "h_swipe": "", "long_press": ""} + + logger.info(f"Waiting for GPT-4V to generate documentation for the element {resource_id}") + node = await RECORD_PARSE_NODE.fill( + context=context, llm=self.llm, images=[img_before_base64, img_after_base64] + ) + if "error" in node.content: + return AndroidActionOutput(action_state=RunState.FAIL) + log_path = task_dir.joinpath("log_parse_record.txt") + prompt = node.compile(context=context, schema="json", mode="auto") + msg = node.content + doc_content[action_type] = msg + + with open(log_path, "a") as logfile: + log_item = RecordLogItem( + step=step, + prompt=prompt, + image_before=img_before_base64, + image_after=img_after_base64, + response=node.content, + ) + logfile.write(log_item.model_dump_json() + "\n") + with open(doc_path, "w") as outfile: + outfile.write(str(doc_content)) + doc_count += 1 + logger.info(f"Documentation generated and saved to {doc_path}") + + logger.info(f"Documentation generation phase completed. {doc_count} docs generated.") + + return AndroidActionOutput(action_state=RunState.FINISH) diff --git a/metagpt/ext/android_assistant/actions/parse_record_an.py b/metagpt/ext/android_assistant/actions/parse_record_an.py new file mode 100644 index 0000000000000000000000000000000000000000..210c93e236db761163dd8b5788ae46f209d58c0c --- /dev/null +++ b/metagpt/ext/android_assistant/actions/parse_record_an.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the ActionNode to parse record + +from metagpt.actions.action_node import ActionNode + +OBSERVATION = ActionNode( + key="Observation", + expected_type=str, + instruction="Provide a description of your observations of the two images. " + "Subsequently, delineate the distinctions between the first image and the second one.", + example="", +) + +THOUGHT = ActionNode( + key="Thought", + expected_type=str, + instruction="Consider the impact of Action acting on UI elements.", + example="", +) + +DESCRIPTION = ActionNode( + key="Description", + expected_type=str, + instruction="Describe the functionality of the UI element concisely in one or two sentences Do not include " + "the numeric tag in your description", + example="", +) + +NODES = [OBSERVATION, THOUGHT, DESCRIPTION] + +RECORD_PARSE_NODE = ActionNode.from_children("RecordParse", NODES) diff --git a/metagpt/ext/android_assistant/actions/screenshot_parse.py b/metagpt/ext/android_assistant/actions/screenshot_parse.py new file mode 100644 index 0000000000000000000000000000000000000000..4d8bb0e1eb993cd285dd9e807da81638baa1a3fb --- /dev/null +++ b/metagpt/ext/android_assistant/actions/screenshot_parse.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : LIKE scripts/task_executor.py in stage=act + +import ast +from pathlib import Path + +from metagpt.actions.action import Action +from metagpt.config2 import config +from metagpt.environment.android.android_env import AndroidEnv +from metagpt.environment.android.const import ADB_EXEC_FAIL +from metagpt.environment.android.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, +) +from metagpt.ext.android_assistant.actions.screenshot_parse_an import ( + SCREENSHOT_PARSE_NODE, +) +from metagpt.ext.android_assistant.prompts.assistant_prompt import ( + screenshot_parse_template, + screenshot_parse_with_grid_template, +) +from metagpt.ext.android_assistant.utils.schema import ( + AndroidActionOutput, + AndroidElement, + GridOpParam, + LongPressGridOpParam, + LongPressOpParam, + OpLogItem, + RunState, + SwipeGridOpParam, + SwipeOpParam, + TapGridOpParam, + TapOpParam, + TextOpParam, +) +from metagpt.ext.android_assistant.utils.utils import ( + area_to_xy, + draw_bbox_multi, + draw_grid, + elem_bbox_to_xy, + screenshot_parse_extract, + traverse_xml_tree, +) +from metagpt.logs import logger +from metagpt.utils.common import encode_image + + +class ScreenshotParse(Action): + name: str = "ScreenshotParse" + + def _makeup_ui_document(self, elem_list: list[AndroidElement], docs_idr: Path, use_exist_doc: bool = True) -> str: + if not use_exist_doc: + return "" + + ui_doc = """ +You also have access to the following documentations that describes the functionalities of UI +elements you can interact on the screen. These docs are crucial for you to determine the target of your +next action. You should always prioritize these documented elements for interaction: """ + for i, elem in enumerate(elem_list): + doc_path = docs_idr.joinpath(f"{elem.uid}.txt") + if not doc_path.exists(): + continue + try: + doc_content = ast.literal_eval(doc_path.read_text()) + except Exception as exp: + logger.error(f"ast parse doc: {doc_path} failed, exp: {exp}") + continue + + ui_doc += f"Documentation of UI element labeled with the numeric tag '{i + 1}':\n" + if doc_content["tap"]: + ui_doc += f"This UI element is clickable. {doc_content['tap']}\n\n" + if doc_content["text"]: + ui_doc += ( + f"This UI element can receive text input. The text input is used for the following " + f"purposes: {doc_content['text']}\n\n" + ) + if doc_content["long_press"]: + ui_doc += f"This UI element is long clickable. {doc_content['long_press']}\n\n" + if doc_content["v_swipe"]: + ui_doc += ( + f"This element can be swiped directly without tapping. You can swipe vertically on " + f"this UI element. {doc_content['v_swipe']}\n\n" + ) + if doc_content["h_swipe"]: + ui_doc += ( + f"This element can be swiped directly without tapping. You can swipe horizontally on " + f"this UI element. {doc_content['h_swipe']}\n\n" + ) + return ui_doc + + async def run( + self, + round_count: int, + task_desc: str, + last_act: str, + task_dir: Path, + docs_dir: Path, + grid_on: bool, + env: AndroidEnv, + ): + extra_config = config.extra + for path in [task_dir, docs_dir]: + path.mkdir(parents=True, exist_ok=True) + screenshot_path: Path = env.observe( + EnvObsParams(obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{round_count}_before", local_save_dir=task_dir) + ) + xml_path: Path = env.observe( + EnvObsParams(obs_type=EnvObsType.GET_XML, xml_name=f"{round_count}", local_save_dir=task_dir) + ) + if not screenshot_path.exists() or not xml_path.exists(): + return AndroidActionOutput(action_state=RunState.FAIL) + + clickable_list = [] + focusable_list = [] + traverse_xml_tree(xml_path, clickable_list, "clickable", True) + traverse_xml_tree(xml_path, focusable_list, "focusable", True) + elem_list: list[AndroidElement] = clickable_list.copy() + for elem in focusable_list: + bbox = elem.bbox + center = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + close = False + for e in clickable_list: + bbox = e.bbox + center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5 + if dist <= extra_config.get("min_dist", 30): + close = True + break + if not close: + elem_list.append(elem) + + screenshot_labeled_path = task_dir.joinpath(f"{round_count}_labeled.png") + draw_bbox_multi(screenshot_path, screenshot_labeled_path, elem_list) + img_base64 = encode_image(screenshot_labeled_path) + + parse_template = screenshot_parse_with_grid_template if grid_on else screenshot_parse_template + + if grid_on: + env.rows, env.cols = draw_grid(screenshot_path, task_dir / f"{round_count}_grid.png") + + ui_doc = self._makeup_ui_document(elem_list, docs_dir) + context = parse_template.format(ui_document=ui_doc, task_description=task_desc, last_act=last_act) + node = await SCREENSHOT_PARSE_NODE.fill(context=context, llm=self.llm, images=[img_base64]) + + if "error" in node.content: + return AndroidActionOutput(action_state=RunState.FAIL) + + prompt = node.compile(context=context, schema="json", mode="auto") + OpLogItem(step=round_count, prompt=prompt, image=str(screenshot_labeled_path), response=node.content) + + op_param = screenshot_parse_extract(node.instruct_content.model_dump(), grid_on) + if op_param.param_state == RunState.FINISH: + logger.info(f"op_param: {op_param}") + return AndroidActionOutput(action_state=RunState.FINISH) + if op_param.param_state == RunState.FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + + last_act = op_param.last_act + if isinstance(op_param, TapOpParam): + x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) + action = EnvAction(action_type=EnvActionType.SYSTEM_TAP, coord=(x, y)) + elif isinstance(op_param, TextOpParam): + action = EnvAction(action_type=EnvActionType.USER_INPUT, input_txt=op_param.input_str) + elif isinstance(op_param, LongPressOpParam): + x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) + action = EnvAction(action_type=EnvActionType.USER_LONGPRESS, coord=(x, y)) + elif isinstance(op_param, SwipeOpParam): + x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) + action = EnvAction( + action_type=EnvActionType.USER_SWIPE, coord=(x, y), orient=op_param.swipe_orient, dist=op_param.dist + ) + elif isinstance(op_param, GridOpParam): + grid_on = True + elif isinstance(op_param, TapGridOpParam) or isinstance(op_param, LongPressGridOpParam): + x, y = area_to_xy(op_param.area, op_param.subarea, env.width, env.height, env.rows, env.cols) + if isinstance(op_param, TapGridOpParam): + action = EnvAction(action_type=EnvActionType.SYSTEM_TAP, coord=(x, y)) + else: + # LongPressGridOpParam + action = EnvAction(action_type=EnvActionType.USER_LONGPRESS, coord=(x, y)) + elif isinstance(op_param, SwipeGridOpParam): + start_x, start_y = area_to_xy( + op_param.start_area, op_param.start_subarea, env.width, env.height, env.rows, env.cols + ) + end_x, end_y = area_to_xy( + op_param.end_area, op_param.end_subarea, env.width, env.height, env.rows, env.cols + ) + action = EnvAction( + action_type=EnvActionType.USER_SWIPE_TO, coord=(start_x, start_y), tgt_coord=(end_x, end_y) + ) + + if not grid_on: + obs, _, _, _, info = env.step(action) + action_res = info["res"] + if action_res == ADB_EXEC_FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + + if op_param.act_name != "grid": + grid_on = False + + return AndroidActionOutput(data={"grid_on": grid_on, "last_act": last_act}) diff --git a/metagpt/ext/android_assistant/actions/screenshot_parse_an.py b/metagpt/ext/android_assistant/actions/screenshot_parse_an.py new file mode 100644 index 0000000000000000000000000000000000000000..eb23ba93445c112d0e469b5628ed3e752f03fd3f --- /dev/null +++ b/metagpt/ext/android_assistant/actions/screenshot_parse_an.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the ActionNode to parse screenshot + +from metagpt.actions.action_node import ActionNode + +OBSERVATION = ActionNode( + key="Observation", expected_type=str, instruction="Describe what you observe in the image", example="" +) + +THOUGHT = ActionNode( + key="Thought", + expected_type=str, + instruction="To complete the given task, what is the next step I should do", + example="", +) + +ACTION = ActionNode( + key="Action", + expected_type=str, + instruction="The function call with the correct parameters to proceed with the task. If you believe the task is " + "completed or there is nothing to be done, you should output FINISH. You cannot output anything else " + "except a function call or FINISH in this field.", + example="", +) + +SUMMARY = ActionNode( + key="Summary", + expected_type=str, + instruction="Summarize your past actions along with your latest action in one or two sentences. Do not include " + "the numeric tag in your summary", + example="", +) + +SUMMARY_GRID = ActionNode( + key="Summary", + expected_type=str, + instruction="Summarize your past actions along with your latest action in one or two sentences. Do not include " + "the grid area number in your summary", + example="", +) + +NODES = [OBSERVATION, THOUGHT, ACTION, SUMMARY] + +NODES_GRID = [OBSERVATION, THOUGHT, ACTION, SUMMARY_GRID] + +SCREENSHOT_PARSE_NODE = ActionNode.from_children("ScreenshotParse", NODES) +SCREENSHOT_PARSE_GRID_NODE = ActionNode.from_children("ScreenshotParseGrid", NODES_GRID) diff --git a/metagpt/ext/android_assistant/actions/self_learn_and_reflect.py b/metagpt/ext/android_assistant/actions/self_learn_and_reflect.py new file mode 100644 index 0000000000000000000000000000000000000000..5e9cfbb4547599d388598ac53d0c187542645197 --- /dev/null +++ b/metagpt/ext/android_assistant/actions/self_learn_and_reflect.py @@ -0,0 +1,231 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : LIKE scripts/self_explorer.py in stage=learn & mode=auto self_explore_task stage + +import ast +from pathlib import Path + +from metagpt.actions.action import Action +from metagpt.config2 import config +from metagpt.environment.android.android_env import AndroidEnv +from metagpt.environment.android.const import ADB_EXEC_FAIL +from metagpt.environment.android.env_space import ( + EnvAction, + EnvActionType, + EnvObsParams, + EnvObsType, +) +from metagpt.ext.android_assistant.actions.screenshot_parse_an import ( + SCREENSHOT_PARSE_NODE, +) +from metagpt.ext.android_assistant.actions.self_learn_reflect_an import ( + SELF_LEARN_REFLECT_NODE, +) +from metagpt.ext.android_assistant.prompts.assistant_prompt import ( + screenshot_parse_self_explore_reflect_template as reflect_template, +) +from metagpt.ext.android_assistant.prompts.assistant_prompt import ( + screenshot_parse_self_explore_template, +) +from metagpt.ext.android_assistant.utils.schema import ( + ActionOp, + AndroidActionOutput, + AndroidElement, + Decision, + DocContent, + LongPressOpParam, + OpLogItem, + ReflectLogItem, + RunState, + SwipeOp, + SwipeOpParam, + TapOpParam, + TextOpParam, +) +from metagpt.ext.android_assistant.utils.utils import ( + draw_bbox_multi, + elem_bbox_to_xy, + elem_list_from_xml_tree, + reflect_parse_extarct, + screenshot_parse_extract, +) +from metagpt.logs import logger +from metagpt.utils.common import encode_image + + +class SelfLearnAndReflect(Action): + name: str = "SelfLearnAndReflect" + + useless_list: list[str] = [] # store useless elements uid + + screenshot_before_path: str = "" + screenshot_before_base64: str = "" + elem_list: list[AndroidElement] = [] + swipe_orient: str = "up" + act_name: str = "" + ui_area: int = -1 + + async def run( + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv + ) -> AndroidActionOutput: + for path in [task_dir, docs_dir]: + path.mkdir(parents=True, exist_ok=True) + resp = await self.run_self_learn(round_count, task_desc, last_act, task_dir, env) + if resp.action_state != RunState.SUCCESS: + return resp + + resp = await self.run_reflect(round_count, task_desc, last_act, task_dir, docs_dir, env) + return resp + + async def run_self_learn( + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv + ) -> AndroidActionOutput: + extra_config = config.extra + screenshot_path: Path = env.observe( + EnvObsParams(obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{round_count}_before", local_save_dir=task_dir) + ) + xml_path: Path = env.observe( + EnvObsParams(obs_type=EnvObsType.GET_XML, xml_name=f"{round_count}", local_save_dir=task_dir) + ) + if not screenshot_path.exists() or not xml_path.exists(): + return AndroidActionOutput(action_state=RunState.FAIL) + + elem_list = elem_list_from_xml_tree(xml_path, self.useless_list, extra_config.get("min_dist", 30)) + + screenshot_before_labeled_path = task_dir.joinpath(f"{round_count}_before_labeled.png") + draw_bbox_multi(screenshot_path, screenshot_before_labeled_path, elem_list) + img_base64 = encode_image(screenshot_before_labeled_path) + self.screenshot_before_base64 = img_base64 + self.screenshot_before_path = screenshot_before_labeled_path + + self_explore_template = screenshot_parse_self_explore_template + context = self_explore_template.format(task_description=task_desc, last_act=last_act) + + node = await SCREENSHOT_PARSE_NODE.fill(context=context, llm=self.llm, images=[img_base64]) + logger.debug(f"fill result:{node}") + if "error" in node.content: + return AndroidActionOutput(action_state=RunState.FAIL) + prompt = node.compile(context=context, schema="json", mode="auto") + # Modify WindowsPath to Str + OpLogItem(step=round_count, prompt=prompt, image=str(screenshot_before_labeled_path), response=node.content) + op_param = screenshot_parse_extract(node.instruct_content.model_dump(), grid_on=False) + # TODO Modify Op_param. When op_param.action is FINISH, how to solve this ? + if op_param.param_state == RunState.FINISH: + return AndroidActionOutput(action_state=RunState.FINISH) + if op_param.param_state == RunState.FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + + if isinstance(op_param, TapOpParam): + self.ui_area = op_param.area + x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) + action = EnvAction(action_type=EnvActionType.SYSTEM_TAP, coord=(x, y)) + elif isinstance(op_param, TextOpParam): + action = EnvAction(action_type=EnvActionType.USER_INPUT, input_txt=op_param.input_str) + elif isinstance(op_param, LongPressOpParam): + self.ui_area = op_param.area + x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) + action = EnvAction(action_type=EnvActionType.USER_LONGPRESS, coord=(x, y)) + elif isinstance(op_param, SwipeOpParam): + self.ui_area = op_param.area + self.swipe_orient = op_param.swipe_orient + x, y = elem_bbox_to_xy(elem_list[op_param.area - 1].bbox) + action = EnvAction( + action_type=EnvActionType.USER_SWIPE, coord=(x, y), orient=op_param.swipe_orient, dist=op_param.dist + ) + + obs, _, _, _, info = env.step(action) + action_res = info["res"] + if action_res == ADB_EXEC_FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + + self.elem_list = elem_list + self.act_name = op_param.act_name + return AndroidActionOutput() + + async def run_reflect( + self, round_count: int, task_desc: str, last_act: str, task_dir: Path, docs_dir: Path, env: AndroidEnv + ) -> AndroidActionOutput: + screenshot_path: Path = env.observe( + EnvObsParams(obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{round_count}_after", local_save_dir=task_dir) + ) + if not screenshot_path.exists(): + return AndroidActionOutput(action_state=RunState.FAIL) + + screenshot_after_labeled_path = task_dir.joinpath(f"{round_count}_after_labeled.png") + draw_bbox_multi(screenshot_path, screenshot_after_labeled_path, elem_list=self.elem_list) + img_base64 = encode_image(screenshot_after_labeled_path) + if self.act_name == ActionOp.TAP.value: + action = "tapping" + elif self.act_name == ActionOp.LONG_PRESS.value: + action = "long pressing" + elif self.act_name == ActionOp.SWIPE.value: + action = "swiping" + if self.swipe_orient == SwipeOp.UP.value or self.swipe_orient == SwipeOp.DOWN.value: + action = "v_swipe" + elif self.swipe_orient == SwipeOp.LEFT.value or self.swipe_orient == SwipeOp.RIGHT.value: + action = "h_swipe" + else: + # TODO Test for assignment, This error is eupiped with the next. + logger.warning(f"Current action name parse failed, it's `{self.act_name}`") + action = None + context = reflect_template.format( + action=action, ui_element=str(self.ui_area), task_desc=task_desc, last_act=last_act + ) + node = await SELF_LEARN_REFLECT_NODE.fill( + context=context, llm=self.llm, images=[self.screenshot_before_base64, img_base64] + ) + + if "error" in node.content: + return AndroidActionOutput(action_state=RunState.FAIL) + + prompt = node.compile(context=context, schema="json", mode="auto") + ReflectLogItem( + step=round_count, + prompt=prompt, + image_before=str(self.screenshot_before_path), + image_after=str(screenshot_after_labeled_path), + response=node.content, + ) + + op_param = reflect_parse_extarct(node.instruct_content.model_dump()) + if op_param.param_state == RunState.FINISH: + return AndroidActionOutput(action_state=RunState.FINISH) + if op_param.param_state == RunState.FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + + logger.info( + f"reflect_parse_extarct decision: {op_param.decision}, " + f"elem_list size: {len(self.elem_list)}, ui_area: {self.ui_area}" + ) + # TODO here will cause `IndexError: list index out of range`. + # Maybe you should clink back to the desktop in the simulator + resource_id = self.elem_list[int(self.ui_area) - 1].uid + if op_param.decision == Decision.INEFFECTIVE.value: + self.useless_list.append(resource_id) + last_act = "NONE" # TODO global + elif op_param.decision in [Decision.BACK.value, Decision.CONTINUE.value, Decision.SUCCESS.value]: + if op_param.decision in [Decision.BACK.value, Decision.CONTINUE.value]: + self.useless_list.append(resource_id) + last_act = "NONE" + if op_param.decision == Decision.BACK.value: + action = EnvAction(action_type=EnvActionType.SYSTEM_BACK) + obs, _, _, _, info = env.step(action) + if info["res"] == ADB_EXEC_FAIL: + return AndroidActionOutput(action_state=RunState.FAIL) + doc = op_param.documentation + doc_path = docs_dir.joinpath(f"{resource_id}.txt") + if doc_path.exists(): + try: + doc_content = ast.literal_eval(doc_path.read_text()) + except Exception as exp: + logger.error(f"ast parse doc: {doc_path} failed, exp: {exp}") + return AndroidActionOutput(action_state=RunState.FAIL) + + if doc_content[self.act_name]: + logger.info(f"Documentation for the element {resource_id} already exists.") + return AndroidActionOutput(action_state=RunState.FAIL) + else: + doc_content = DocContent() + setattr(doc_content, self.act_name, doc) + doc_path.write_text(str(doc_content)) + return AndroidActionOutput(data={"last_act": last_act}) diff --git a/metagpt/ext/android_assistant/actions/self_learn_reflect_an.py b/metagpt/ext/android_assistant/actions/self_learn_reflect_an.py new file mode 100644 index 0000000000000000000000000000000000000000..305b7376af469fd3d03bbf04907a3d486fbd173b --- /dev/null +++ b/metagpt/ext/android_assistant/actions/self_learn_reflect_an.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the ActionNode to parse Reflection + +from metagpt.actions.action_node import ActionNode + +DECISION = ActionNode( + key="Decision", expected_type=str, instruction="explain why you made this decision", example="BACK" +) + + +THOUGHT = ActionNode(key="Thought", expected_type=str, instruction="explain why you made this decision", example="") + + +DOCUMENTATION = ActionNode( + key="Documentation", expected_type=str, instruction="describe the function of the UI element", example="" +) + + +NODES = [DECISION, THOUGHT, DOCUMENTATION] +SELF_LEARN_REFLECT_NODE = ActionNode.from_children("SelfLearnReflect", NODES) diff --git a/metagpt/ext/android_assistant/prompts/__init__.py b/metagpt/ext/android_assistant/prompts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/ext/android_assistant/prompts/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/ext/android_assistant/prompts/assistant_prompt.py b/metagpt/ext/android_assistant/prompts/assistant_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..34baf58417ca1e0206d2d5730ce29fe11e78530a --- /dev/null +++ b/metagpt/ext/android_assistant/prompts/assistant_prompt.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the prompt templates of assistant learning and acting + +screenshot_parse_template = """You are an agent that is trained to perform some basic tasks on a smartphone. You will be given a +smartphone screenshot. The interactive UI elements on the screenshot are labeled with numeric tags starting from 1. The +numeric tag of each interactive element is located in the center of the element. + +You can call the following functions to control the smartphone: + +1. tap(element: int) +This function is used to tap an UI element shown on the smartphone screen. +"element" is a numeric tag assigned to an UI element shown on the smartphone screen. +A simple use case can be tap(5), which taps the UI element labeled with the number 5. + +2. text(text_input: str) +This function is used to insert text input in an input field/box. text_input is the string you want to insert and must +be wrapped with double quotation marks. A simple use case can be text("Hello, world!"), which inserts the string +"Hello, world!" into the input area on the smartphone screen. This function is usually callable when you see a keyboard +showing in the lower half of the screen. + +3. long_press(element: int) +This function is used to long press an UI element shown on the smartphone screen. +"element" is a numeric tag assigned to an UI element shown on the smartphone screen. +A simple use case can be long_press(5), which long presses the UI element labeled with the number 5. + +4. swipe(element: int, direction: str, dist: str) +This function is used to swipe an UI element shown on the smartphone screen, usually a scroll view or a slide bar. +"element" is a numeric tag assigned to an UI element shown on the smartphone screen. "direction" is a string that +represents one of the four directions: up, down, left, right. "direction" must be wrapped with double quotation +marks. "dist" determines the distance of the swipe and can be one of the three options: short, medium, long. You should +choose the appropriate distance option according to your need. +A simple use case can be swipe(21, "up", "medium"), which swipes up the UI element labeled with the number 21 for a +medium distance. + +5. grid() +You should call this function when you find the element you want to interact with is not labeled with a numeric tag and +other elements with numeric tags cannot help with the task. The function will bring up a grid overlay to divide the +smartphone screen into small areas and this will give you more freedom to choose any part of the screen to tap, long +press, or swipe. +{ui_document} +The task you need to complete is to: {task_description}. Your past actions to proceed with this task are summarized as +follows: {last_act} +Now, given the documentation and the following labeled screenshot, you need to think and call the function needed to +proceed with the task. Your output should include three parts in the given format: + +You can only take one action at a time, so please directly call the function.""" + +screenshot_parse_with_grid_template = """You are an agent that is trained to perform some basic tasks on a smartphone. You will be given +a smartphone screenshot overlaid by a grid. The grid divides the screenshot into small square areas. Each area is +labeled with an integer in the top-left corner. + +You can call the following functions to control the smartphone: + +1. tap(area: int, subarea: str) +This function is used to tap a grid area shown on the smartphone screen. "area" is the integer label assigned to a grid +area shown on the smartphone screen. "subarea" is a string representing the exact location to tap within the grid area. +It can take one of the nine values: center, top-left, top, top-right, left, right, bottom-left, bottom, and +bottom-right. +A simple use case can be tap(5, "center"), which taps the exact center of the grid area labeled with the number 5. + +2. long_press(area: int, subarea: str) +This function is used to long press a grid area shown on the smartphone screen. "area" is the integer label assigned to +a grid area shown on the smartphone screen. "subarea" is a string representing the exact location to long press within +the grid area. It can take one of the nine values: center, top-left, top, top-right, left, right, bottom-left, bottom, +and bottom-right. +A simple use case can be long_press(7, "top-left"), which long presses the top left part of the grid area labeled with +the number 7. + +3. swipe(start_area: int, start_subarea: str, end_area: int, end_subarea: str) +This function is used to perform a swipe action on the smartphone screen, especially when you want to interact with a +scroll view or a slide bar. "start_area" is the integer label assigned to the grid area which marks the starting +location of the swipe. "start_subarea" is a string representing the exact location to begin the swipe within the grid +area. "end_area" is the integer label assigned to the grid area which marks the ending location of the swipe. +"end_subarea" is a string representing the exact location to end the swipe within the grid area. +The two subarea parameters can take one of the nine values: center, top-left, top, top-right, left, right, bottom-left, +bottom, and bottom-right. +A simple use case can be swipe(21, "center", 25, "right"), which performs a swipe starting from the center of grid area +21 to the right part of grid area 25. + +The task you need to complete is to: {task_description}. Your past actions to proceed with this task are summarized as +follows: {last_act} +Now, given the following labeled screenshot, you need to think and call the function needed to proceed with the task. +Your output should include three parts in the given format: + +You can only take one action at a time, so please directly call the function.""" + +screenshot_parse_self_explore_template = """You are an agent that is trained to complete certain tasks on a smartphone. You will be +given a screenshot of a smartphone app. The interactive UI elements on the screenshot are labeled with numeric tags +starting from 1. + +You can call the following functions to interact with those labeled elements to control the smartphone: + +1. tap(element: int) +This function is used to tap an UI element shown on the smartphone screen. +"element" is a numeric tag assigned to an UI element shown on the smartphone screen. +A simple use case can be tap(5), which taps the UI element labeled with the number 5. + +2. text(text_input: str) +This function is used to insert text input in an input field/box. text_input is the string you want to insert and must +be wrapped with double quotation marks. A simple use case can be text("Hello, world!"), which inserts the string +"Hello, world!" into the input area on the smartphone screen. This function is only callable when you see a keyboard +showing in the lower half of the screen. + +3. long_press(element: int) +This function is used to long press an UI element shown on the smartphone screen. +"element" is a numeric tag assigned to an UI element shown on the smartphone screen. +A simple use case can be long_press(5), which long presses the UI element labeled with the number 5. + +4. swipe(element: int, direction: str, dist: str) +This function is used to swipe an UI element shown on the smartphone screen, usually a scroll view or a slide bar. +"element" is a numeric tag assigned to an UI element shown on the smartphone screen. "direction" is a string that +represents one of the four directions: up, down, left, right. "direction" must be wrapped with double quotation +marks. "dist" determines the distance of the swipe and can be one of the three options: short, medium, long. You should +choose the appropriate distance option according to your need. +A simple use case can be swipe(21, "up", "medium"), which swipes up the UI element labeled with the number 21 for a +medium distance. + +The task you need to complete is to {task_description}. Your past actions to proceed with this task are summarized as +follows: {last_act} +Now, given the following labeled screenshot, you need to think and call the function needed to proceed with the task. +Your output should include three parts in the given format: + +You can only take one action at a time, so please directly call the function.""" + +screenshot_parse_self_explore_reflect_template = """I will give you screenshots of a mobile app before and after {action} the UI +element labeled with the number '{ui_element}' on the first screenshot. The numeric tag of each element is located at +the center of the element. The action of {action} this UI element was described as follows: +{last_act} +The action was also an attempt to proceed with a larger task, which is to {task_desc}. Your job is to carefully analyze +the difference between the two screenshots to determine if the action is in accord with the description above and at +the same time effectively moved the task forward. Your output should be determined based on the following situations: +1. BACK +If you think the action navigated you to a page where you cannot proceed with the given task, you should go back to the +previous interface. At the same time, describe the functionality of the UI element concisely in one or two sentences by +observing the difference between the two screenshots. Notice that your description of the UI element should focus on +the general function. Never include the numeric tag of the UI element in your description. You can use pronouns such as +"the UI element" to refer to the element. Your output should be in the following format: +Decision: BACK +Thought: +Documentation: +2. INEFFECTIVE +If you find the action changed nothing on the screen (screenshots before and after the action are identical), you +should continue to interact with other elements on the screen. Notice that if you find the location of the cursor +changed between the two screenshots, then they are not identical. Your output should be in the following format: +Decision: INEFFECTIVE +Thought: +Documentation: +3. CONTINUE +If you find the action changed something on the screen but does not reflect the action description above and did not +move the given task forward, you should continue to interact with other elements on the screen. At the same time, +describe the functionality of the UI element concisely in one or two sentences by observing the difference between the +two screenshots. Notice that your description of the UI element should focus on the general function. Never include the +numeric tag of the UI element in your description. You can use pronouns such as "the UI element" to refer to the +element. Your output should be in the following format: +Decision: CONTINUE +Thought: +Documentation: +4. SUCCESS +If you think the action successfully moved the task forward (even though it did not completed the task), you should +describe the functionality of the UI element concisely in one or two sentences. Notice that your description of the UI +element should focus on the general function. Never include the numeric tag of the UI element in your description. You +can use pronouns such as "the UI element" to refer to the element. Your output should be in the following format: +Decision: SUCCESS +Thought: +Documentation: +""" diff --git a/metagpt/ext/android_assistant/prompts/operation_prompt.py b/metagpt/ext/android_assistant/prompts/operation_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..1bde53f04197b50e75d6caf3ce1847402b4a3a9d --- /dev/null +++ b/metagpt/ext/android_assistant/prompts/operation_prompt.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the prompt templates of phone operation + +tap_doc_template = """I will give you the screenshot of a mobile app before and after tapping the UI element labeled +with the number {ui_element} on the screen. The numeric tag of each element is located at the center of the element. +Tapping this UI element is a necessary part of proceeding with a larger task, which is to . Your task is to +describe the functionality of the UI element concisely in one or two sentences. Notice that your description of the UI +element should focus on the general function. For example, if the UI element is used to navigate to the chat window +with John, your description should not include the name of the specific person. Just say: "Tapping this area will +navigate the user to the chat window". Never include the numeric tag of the UI element in your description. You can use +pronouns such as "the UI element" to refer to the element.""" + +text_doc_template = """I will give you the screenshot of a mobile app before and after typing in the input area labeled +with the number {ui_element} on the screen. The numeric tag of each element is located at the center of the element. +Typing in this UI element is a necessary part of proceeding with a larger task, which is to . Your task is +to describe the functionality of the UI element concisely in one or two sentences. Notice that your description of the +UI element should focus on the general function. For example, if the change of the screenshot shows that the user typed +"How are you?" in the chat box, you do not need to mention the actual text. Just say: "This input area is used for the +user to type a message to send to the chat window.". Never include the numeric tag of the UI element in your +description. You can use pronouns such as "the UI element" to refer to the element.""" + +long_press_doc_template = """I will give you the screenshot of a mobile app before and after long pressing the UI +element labeled with the number {ui_element} on the screen. The numeric tag of each element is located at the center of +the element. Long pressing this UI element is a necessary part of proceeding with a larger task, which is to +. Your task is to describe the functionality of the UI element concisely in one or two sentences. Notice +that your description of the UI element should focus on the general function. For example, if long pressing the UI +element redirects the user to the chat window with John, your description should not include the name of the specific +person. Just say: "Long pressing this area will redirect the user to the chat window". Never include the numeric tag of +the UI element in your description. You can use pronouns such as "the UI element" to refer to the element.""" + +swipe_doc_template = """I will give you the screenshot of a mobile app before and after swiping the UI +element labeled with the number {ui_element} on the screen. The numeric tag of each element is located at the center of +the element. Swiping this UI element is a necessary part of proceeding with a larger task, which is to . +Your task is to describe the functionality of the UI element concisely in one or two sentences. Notice that your +description of the UI element should be as general as possible. For example, if swiping the UI element increases the +contrast ratio of an image of a building, your description should be just like this: "Swiping this area enables the +user to tune a specific parameter of the image". Never include the numeric tag of the UI element in your description. +You can use pronouns such as "the UI element" to refer to the element.""" + +refine_doc_suffix = """\nA documentation of this UI element generated from previous demos is shown below. Your +generated description should be based on this previous doc and optimize it. Notice that it is possible that your +understanding of the function of the UI element derived from the given screenshots conflicts with the previous doc, +because the function of a UI element can be flexible. In this case, your generated description should combine both. +Old documentation of this UI element: {old_doc}""" diff --git a/metagpt/ext/android_assistant/roles/__init__.py b/metagpt/ext/android_assistant/roles/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/ext/android_assistant/roles/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/ext/android_assistant/roles/android_assistant.py b/metagpt/ext/android_assistant/roles/android_assistant.py new file mode 100644 index 0000000000000000000000000000000000000000..97d66d30e41ca1dec635f59fff2ed2ab9094b3bd --- /dev/null +++ b/metagpt/ext/android_assistant/roles/android_assistant.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : android assistant to learn from app operations and operate apps +import time +from datetime import datetime +from pathlib import Path +from typing import Optional + +from pydantic import Field + +from metagpt.actions.add_requirement import UserRequirement +from metagpt.config2 import config +from metagpt.const import EXAMPLE_PATH +from metagpt.ext.android_assistant.actions.manual_record import ManualRecord +from metagpt.ext.android_assistant.actions.parse_record import ParseRecord +from metagpt.ext.android_assistant.actions.screenshot_parse import ScreenshotParse +from metagpt.ext.android_assistant.actions.self_learn_and_reflect import ( + SelfLearnAndReflect, +) +from metagpt.ext.android_assistant.utils.schema import AndroidActionOutput, RunState +from metagpt.logs import logger +from metagpt.roles.role import Role, RoleReactMode +from metagpt.schema import Message + + +class AndroidAssistant(Role): + name: str = "Nick" + profile: str = "AndroidAssistant" + goal: str = "operate the mobile phone's apps with self-learn" + + task_desc: str = "" + round_count: int = 0 + last_act: str = "None" + output_root_dir: Optional[Path] = Field(default=None) + task_dir: Optional[Path] = Field(default=None) + docs_dir: Optional[Path] = Field(default=None) + grid_on: bool = Field(default=False) + + def __init__(self, **data): + super().__init__(**data) + self._watch([UserRequirement, AndroidActionOutput]) + extra_config = config.extra + self.task_desc = extra_config.get("task_desc", "Just explore any app in this phone!") + app_name = extra_config.get("app_name", "demo") + data_dir = self.output_root_dir.absolute().joinpath("output") or EXAMPLE_PATH.joinpath( + "android_assistant/output" + ) + cur_datetime = datetime.fromtimestamp(int(time.time())).strftime("%Y-%m-%d_%H-%M-%S") + + """Firstly, we decide the state with user config, further, we can do it automatically, like if it's new app, + run the learn first and then do the act stage or learn it during the action. + """ + stage = extra_config.get("stage") + mode = extra_config.get("mode") + if stage == "learn" and mode == "manual": + # choose ManualRecord and then run ParseRecord + # Remember, only run each action only one time, no need to run n_round. + self.set_actions([ManualRecord, ParseRecord]) + self.task_dir = data_dir.joinpath(app_name, f"manual_learn_{cur_datetime}") + self.docs_dir = data_dir.joinpath(app_name, "manual_docs") + elif stage == "learn" and mode == "auto": + # choose SelfLearnAndReflect to run + self.set_actions([SelfLearnAndReflect]) + self.task_dir = data_dir.joinpath(app_name, f"auto_learn_{cur_datetime}") + self.docs_dir = data_dir.joinpath(app_name, "auto_docs") + elif stage == "act": + # choose ScreenshotParse to run + self.set_actions([ScreenshotParse]) + self.task_dir = data_dir.joinpath(app_name, f"act_{cur_datetime}") + if mode == "manual": + self.docs_dir = data_dir.joinpath(app_name, "manual_docs") + else: + self.docs_dir = data_dir.joinpath(app_name, "auto_docs") + else: + raise ValueError(f"invalid stage: {stage}, mode: {mode}") + + self._check_dir() + + self._set_react_mode(RoleReactMode.BY_ORDER) + + def _check_dir(self): + self.task_dir.mkdir(parents=True, exist_ok=True) + self.docs_dir.mkdir(parents=True, exist_ok=True) + + async def react(self) -> Message: + self.round_count += 1 + result = await super().react() + logger.debug(f"react result {result}") + return result + + async def _observe(self, ignore_memory=True) -> int: + """ignore old memory to make it run multi rounds inside a role""" + newest_msgs = self.rc.memory.get(k=1) + newest_msg = newest_msgs[0] if newest_msgs else None + if newest_msg and (RunState.SUCCESS.value.upper() not in newest_msg.content): + ignore_memory = False + state_val = newest_msg.content.split(".")[-1] # RoundCount: 1, action_state: RunState.SUCCESS + logger.warning(f"Latest action_state is {state_val}, will run in the remainder rounds without `react`") + return await super()._observe(ignore_memory) + + async def _act(self) -> Message: + logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})") + todo = self.rc.todo + if isinstance(todo, ManualRecord): + resp = await todo.run(task_dir=self.task_dir, task_desc=self.task_desc, env=self.rc.env) + elif isinstance(todo, ParseRecord): + resp = await todo.run( + task_dir=self.task_dir, + docs_dir=self.docs_dir, + ) + elif isinstance(todo, SelfLearnAndReflect): + resp = await todo.run( + round_count=self.round_count, + task_desc=self.task_desc, + last_act=self.last_act, + task_dir=self.task_dir, + docs_dir=self.docs_dir, + env=self.rc.env, + ) + if resp.action_state == RunState.SUCCESS: + self.last_act = resp.data.get("last_act") + elif isinstance(todo, ScreenshotParse): + resp = await todo.run( + round_count=self.round_count, + task_desc=self.task_desc, + last_act=self.last_act, + task_dir=self.task_dir, + docs_dir=self.docs_dir, + grid_on=self.grid_on, + env=self.rc.env, + ) + if resp.action_state == RunState.SUCCESS: + logger.info(f"grid_on: {resp.data.get('grid_on')}") + self.grid_on = resp.data.get("grid_on", False) + self.last_act = resp.data.get("last_act", "None") + msg = Message( + content=f"RoundCount: {self.round_count}, action_state: {resp.action_state}", + role=self.profile, + cause_by=type(resp), + send_from=self.name, + send_to=self.name, + ) + + self.rc.memory.add(msg) + return msg diff --git a/metagpt/ext/android_assistant/utils/__init__.py b/metagpt/ext/android_assistant/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/ext/android_assistant/utils/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/ext/android_assistant/utils/schema.py b/metagpt/ext/android_assistant/utils/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..c066f98b626acc02c6dedfc553edf8f249db4524 --- /dev/null +++ b/metagpt/ext/android_assistant/utils/schema.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from enum import Enum + +from pydantic import BaseModel, Field, field_validator + + +class ActionOp(Enum): + TAP = "tap" + LONG_PRESS = "long_press" + TEXT = "text" + SWIPE = "swipe" + VERTICAL_SWIPE = "v_swipe" + HORIZONTAL_SWIPE = "h_swipe" + GRID = "grid" + STOP = "stop" + + +class SwipeOp(Enum): + UP = "up" + DOWN = "down" + LEFT = "left" + RIGHT = "right" + + +class Decision(Enum): + BACK = "BACK" + INEFFECTIVE = "INEFFECTIVE" + CONTINUE = "CONTINUE" + SUCCESS = "SUCCESS" + + @classmethod + def values(cls): + return [item.value for item in cls] + + +class AndroidElement(BaseModel): + """UI Element""" + + uid: str = Field(default="") + bbox: tuple[tuple[int, int], tuple[int, int]] = Field(default={}) + attrib: str = Field(default="") + + +class OpLogItem(BaseModel): + """log content for self-learn or task act""" + + step: int = Field(default=0) + prompt: str = Field(default="") + image: str = Field(default="") + response: str = Field(default="") + + +class ReflectLogItem(BaseModel): + """log content for self-learn-reflect""" + + step: int = Field(default=0) + prompt: str = Field(default="") + image_before: str = Field(default="") + image_after: str = Field(default="") + response: str = Field(default="") + + +class RecordLogItem(BaseModel): + """log content for record parse, same as ReflectLogItem""" + + step: int = Field(default=0) + prompt: str = Field(default="") + image_before: str = Field(default="") + image_after: str = Field(default="") + response: str = Field(default="") + + +class DocContent(BaseModel): + tap: str = Field(default="") + text: str = Field(default="") + v_swipe: str = Field(default="") + h_swipe: str = Field(default="") + long_press: str = Field(default="") + + +# start =================== define different Action Op and its params ============= +class RunState(Enum): + """run state""" + + SUCCESS = "success" + FINISH = "finish" + FAIL = "fail" + + +class BaseOpParam(BaseModel): + act_name: str = Field(default="", validate_default=True) + last_act: str = Field(default="None") + param_state: RunState = Field(default=RunState.SUCCESS, description="return state when extract params") + + +class TapOpParam(BaseOpParam): + area: int = Field(default=-1) + + +class TextOpParam(BaseOpParam): + input_str: str = Field(default="") + + +class LongPressOpParam(BaseOpParam): + area: int = Field(default=-1) + + +# Modify This SwipeOp to SwipeOpParam, Need better name +class SwipeOpParam(BaseOpParam): + area: int = Field(default=-1) + swipe_orient: str = Field(default="up") + dist: str = Field(default="") + + +class GridOpParam(BaseOpParam): + act_name: str = Field(default="") + + +class BaseGridOpParam(BaseOpParam): + @field_validator("act_name", mode="before") + @classmethod + def check_act_name(cls, act_name: str) -> str: + return f"{act_name}_grid" + + +class TapGridOpParam(BaseGridOpParam): + area: int = Field(default=-1) + subarea: str = Field(default="") + + +class LongPressGridOpParam(BaseGridOpParam): + area: int = Field(default=-1) + subarea: str = Field(default="") + + +class SwipeGridOpParam(BaseGridOpParam): + start_area: int = Field(default=-1) + start_subarea: str = Field(default="") + end_area: int = Field(default=-1) + end_subarea: str = Field(default="") + + +# end =================== define different Action Op and its params ============= + + +class ReflectOp(BaseModel): + decision: str = "" + thought: str = "" + documentation: str = "" + param_state: RunState = RunState.SUCCESS + + +class AndroidActionOutput(BaseModel): + data: dict = Field(default=dict()) + action_state: RunState = Field(default=RunState.SUCCESS) diff --git a/metagpt/ext/android_assistant/utils/utils.py b/metagpt/ext/android_assistant/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f1fa138692ca3187144e5a9f75211fa607f26b5a --- /dev/null +++ b/metagpt/ext/android_assistant/utils/utils.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import re +from pathlib import Path +from typing import Union +from xml.etree.ElementTree import Element, iterparse + +import cv2 +import pyshine as ps + +from metagpt.config2 import config +from metagpt.ext.android_assistant.utils.schema import ( + ActionOp, + AndroidElement, + BaseGridOpParam, + BaseOpParam, + Decision, + GridOpParam, + LongPressGridOpParam, + LongPressOpParam, + ReflectOp, + RunState, + SwipeGridOpParam, + SwipeOpParam, + TapGridOpParam, + TapOpParam, + TextOpParam, +) +from metagpt.logs import logger + + +def get_id_from_element(elem: Element) -> str: + bounds = elem.attrib["bounds"][1:-1].split("][") + x1, y1 = map(int, bounds[0].split(",")) + x2, y2 = map(int, bounds[1].split(",")) + elem_w, elem_h = x2 - x1, y2 - y1 + if "resource-id" in elem.attrib and elem.attrib["resource-id"]: + elem_id = elem.attrib["resource-id"].replace(":", ".").replace("/", "_") + else: + elem_id = f"{elem.attrib['class']}_{elem_w}_{elem_h}" + if "content-desc" in elem.attrib and elem.attrib["content-desc"] and len(elem.attrib["content-desc"]) < 20: + content_desc = elem.attrib["content-desc"].replace("/", "_").replace(" ", "").replace(":", "_") + elem_id += f"_{content_desc}" + return elem_id + + +def traverse_xml_tree(xml_path: Path, elem_list: list[AndroidElement], attrib: str, add_index=False): + path = [] + extra_config = config.extra + for event, elem in iterparse(str(xml_path), ["start", "end"]): + if event == "start": + path.append(elem) + if attrib in elem.attrib and elem.attrib[attrib] == "true": + parent_prefix = "" + if len(path) > 1: + parent_prefix = get_id_from_element(path[-2]) + bounds = elem.attrib["bounds"][1:-1].split("][") + x1, y1 = map(int, bounds[0].split(",")) + x2, y2 = map(int, bounds[1].split(",")) + center = (x1 + x2) // 2, (y1 + y2) // 2 + elem_id = get_id_from_element(elem) + if parent_prefix: + elem_id = parent_prefix + "_" + elem_id + if add_index: + elem_id += f"_{elem.attrib['index']}" + close = False + for e in elem_list: + bbox = e.bbox + center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5 + if dist <= extra_config.get("min_dist", 30): + close = True + break + if not close: + elem_list.append(AndroidElement(uid=elem_id, bbox=((x1, y1), (x2, y2)), attrib=attrib)) + + if event == "end": + path.pop() + + +def elem_list_from_xml_tree(xml_path: Path, useless_list: list[str], min_dist: int) -> list[AndroidElement]: + clickable_list = [] + focusable_list = [] + traverse_xml_tree(xml_path, clickable_list, "clickable", True) + traverse_xml_tree(xml_path, focusable_list, "focusable", True) + elem_list = [] + for elem in clickable_list: + if elem.uid in useless_list: + continue + elem_list.append(elem) + for elem in focusable_list: + if elem.uid in useless_list: + continue + bbox = elem.bbox + center = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + close = False + for e in clickable_list: + bbox = e.bbox + center_ = (bbox[0][0] + bbox[1][0]) // 2, (bbox[0][1] + bbox[1][1]) // 2 + dist = (abs(center[0] - center_[0]) ** 2 + abs(center[1] - center_[1]) ** 2) ** 0.5 + if dist <= min_dist: + close = True + break + if not close: + elem_list.append(elem) + return elem_list + + +def draw_bbox_multi( + img_path: Path, + output_path: Path, + elem_list: list[AndroidElement], + record_mode: bool = False, + dark_mode: bool = False, +): + imgcv = cv2.imread(str(img_path)) + count = 1 + for elem in elem_list: + try: + top_left = elem.bbox[0] + bottom_right = elem.bbox[1] + left, top = top_left[0], top_left[1] + right, bottom = bottom_right[0], bottom_right[1] + label = str(count) + if record_mode: + if elem.attrib == "clickable": + color = (250, 0, 0) + elif elem.attrib == "focusable": + color = (0, 0, 250) + else: + color = (0, 250, 0) + imgcv = ps.putBText( + imgcv, + label, + text_offset_x=(left + right) // 2 + 10, + text_offset_y=(top + bottom) // 2 + 10, + vspace=10, + hspace=10, + font_scale=1, + thickness=2, + background_RGB=color, + text_RGB=(255, 250, 250), + alpha=0.5, + ) + else: + text_color = (10, 10, 10) if dark_mode else (255, 250, 250) + bg_color = (255, 250, 250) if dark_mode else (10, 10, 10) + imgcv = ps.putBText( + imgcv, + label, + text_offset_x=(left + right) // 2 + 10, + text_offset_y=(top + bottom) // 2 + 10, + vspace=10, + hspace=10, + font_scale=1, + thickness=2, + background_RGB=bg_color, + text_RGB=text_color, + alpha=0.5, + ) + except Exception as e: + logger.error(f"ERROR: An exception occurs while labeling the image\n{e}") + count += 1 + cv2.imwrite(str(output_path), imgcv) + return imgcv + + +def draw_grid(img_path: Path, output_path: Path) -> tuple[int, int]: + def get_unit_len(n): + for i in range(1, n + 1): + if n % i == 0 and 120 <= i <= 180: + return i + return -1 + + image = cv2.imread(str(img_path)) + height, width, _ = image.shape + color = (255, 116, 113) + unit_height = get_unit_len(height) + if unit_height < 0: + unit_height = 120 + unit_width = get_unit_len(width) + if unit_width < 0: + unit_width = 120 + thick = int(unit_width // 50) + rows = height // unit_height + cols = width // unit_width + for i in range(rows): + for j in range(cols): + label = i * cols + j + 1 + left = int(j * unit_width) + top = int(i * unit_height) + right = int((j + 1) * unit_width) + bottom = int((i + 1) * unit_height) + cv2.rectangle(image, (left, top), (right, bottom), color, thick // 2) + cv2.putText( + image, + str(label), + (left + int(unit_width * 0.05) + 3, top + int(unit_height * 0.3) + 3), + 0, + int(0.01 * unit_width), + (0, 0, 0), + thick, + ) + cv2.putText( + image, + str(label), + (left + int(unit_width * 0.05), top + int(unit_height * 0.3)), + 0, + int(0.01 * unit_width), + color, + thick, + ) + cv2.imwrite(str(output_path), image) + return rows, cols + + +def area_to_xy(area: int, subarea: str, width: int, height: int, rows: int, cols: int) -> tuple[int, int]: + area -= 1 + row, col = area // cols, area % cols + x_0, y_0 = col * (width // cols), row * (height // rows) + if subarea == "top-left": + x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) // 4 + elif subarea == "top": + x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) // 4 + elif subarea == "top-right": + x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) // 4 + elif subarea == "left": + x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) // 2 + elif subarea == "right": + x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) // 2 + elif subarea == "bottom-left": + x, y = x_0 + (width // cols) // 4, y_0 + (height // rows) * 3 // 4 + elif subarea == "bottom": + x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) * 3 // 4 + elif subarea == "bottom-right": + x, y = x_0 + (width // cols) * 3 // 4, y_0 + (height // rows) * 3 // 4 + else: + x, y = x_0 + (width // cols) // 2, y_0 + (height // rows) // 2 + return x, y + + +def elem_bbox_to_xy(bbox: tuple[tuple[int, int], tuple[int, int]]) -> tuple[int, int]: + tl, br = bbox + x, y = (tl[0] + br[0]) // 2, (tl[1] + br[1]) // 2 + return x, y + + +def reflect_parse_extarct(parsed_json: dict) -> ReflectOp: + decision = parsed_json.get("Decision") + if decision not in Decision.values(): + op = ReflectOp(param_state=RunState.FAIL) + else: + op = ReflectOp( + decision=parsed_json.get("Decision"), + thought=parsed_json.get("Thought"), + documentation=parsed_json.get("Documentation"), + ) + return op + + +def screenshot_parse_extract( + parsed_json: dict, grid_on: bool = False +) -> Union[BaseOpParam, BaseGridOpParam, GridOpParam]: + act = parsed_json.get("Action") + last_act = parsed_json.get("Summary") + act_name = act.split("(")[0] + + if RunState.FINISH.value.upper() in act: + return BaseOpParam(param_state=RunState.FINISH) + + if grid_on: + return screenshot_parse_extract_with_grid(act_name, act, last_act) + else: + return screenshot_parse_extract_without_grid(act_name, act, last_act) + + +def op_params_clean(params: list[str]) -> list[Union[int, str]]: + param_values = [] + for param_value in params: + if '"' in param_value or "'" in param_value: # remove `"` + param_values.append(param_value.strip()[1:-1]) + else: + param_values.append(int(param_value)) + return param_values + + +def screenshot_parse_extract_without_grid(act_name: str, act: str, last_act: str) -> Union[BaseOpParam, GridOpParam]: + if act_name == ActionOp.TAP.value: + area = int(re.findall(r"tap\((.*?)\)", act)[0]) + op = TapOpParam(act_name=act_name, area=area, last_act=last_act) + elif act_name == ActionOp.TEXT.value: + input_str = re.findall(r"text\((.*?)\)", act)[0][1:-1] + op = TextOpParam(act_name=act_name, input_str=input_str, last_act=last_act) + elif act_name == ActionOp.LONG_PRESS.value: + area = int(re.findall(r"long_press\((.*?)\)", act)[0]) + op = LongPressOpParam(act_name=act_name, area=area, last_act=last_act) + elif act_name == ActionOp.SWIPE.value: + params = re.findall(r"swipe\((.*?)\)", act)[0].split(",") + params = op_params_clean(params) # area, swipe_orient, dist + op = SwipeOpParam(act_name=act_name, area=params[0], swipe_orient=params[1], dist=params[2], last_act=last_act) + elif act_name == ActionOp.GRID.value: + op = GridOpParam(act_name=act_name) + else: + op = BaseOpParam(param_state=RunState.FAIL) + return op + + +def screenshot_parse_extract_with_grid(act_name: str, act: str, last_act: str) -> Union[BaseGridOpParam, GridOpParam]: + if act_name == ActionOp.TAP.value: + params = re.findall(r"tap\((.*?)\)", act)[0].split(",") + params = op_params_clean(params) + op = TapGridOpParam(act_name=act_name, area=params[0], subarea=params[1], last_act=last_act) + elif act_name == ActionOp.LONG_PRESS.value: + params = re.findall(r"long_press\((.*?)\)", act)[0].split(",") + params = op_params_clean(params) + op = LongPressGridOpParam(act_name=act_name, area=params[0], subarea=params[1], last_act=last_act) + elif act_name == ActionOp.SWIPE.value: + params = re.findall(r"swipe\((.*?)\)", act)[0].split(",") + params = op_params_clean(params) + op = SwipeGridOpParam( + act_name=act_name, start_area=params[0], start_subarea=params[1], end_area=params[2], end_subarea=params[3] + ) + elif act_name == ActionOp.GRID.value: + op = GridOpParam(act_name=act_name) + else: + op = BaseGridOpParam(param_state=RunState.FAIL) + return op diff --git a/metagpt/ext/cr/.DS_Store b/metagpt/ext/cr/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..702b92bb40c45cdd7127c7c4a7dd2a37429f5964 Binary files /dev/null and b/metagpt/ext/cr/.DS_Store differ diff --git a/metagpt/ext/cr/__init__.py b/metagpt/ext/cr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/ext/cr/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/ext/cr/__pycache__/__init__.cpython-310.pyc b/metagpt/ext/cr/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa0046f5902db3d1b2a64ddba8961241d01fde6e Binary files /dev/null and b/metagpt/ext/cr/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/ext/cr/__pycache__/__init__.cpython-39.pyc b/metagpt/ext/cr/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..babd95521926b14d2423f4c69a85a55bb2e03308 Binary files /dev/null and b/metagpt/ext/cr/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/ext/cr/actions/__init__.py b/metagpt/ext/cr/actions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/metagpt/ext/cr/actions/__init__.py @@ -0,0 +1 @@ + diff --git a/metagpt/ext/cr/actions/__pycache__/__init__.cpython-310.pyc b/metagpt/ext/cr/actions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ed2359157865826eaaab7aed6657645241fd149 Binary files /dev/null and b/metagpt/ext/cr/actions/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/ext/cr/actions/__pycache__/__init__.cpython-39.pyc b/metagpt/ext/cr/actions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e8d701978b8a014c2b4303cb214d180be3ab414 Binary files /dev/null and b/metagpt/ext/cr/actions/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/ext/cr/actions/__pycache__/code_review.cpython-310.pyc b/metagpt/ext/cr/actions/__pycache__/code_review.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c75b37469afbc02d5087bae94e4aff14e7d57789 Binary files /dev/null and b/metagpt/ext/cr/actions/__pycache__/code_review.cpython-310.pyc differ diff --git a/metagpt/ext/cr/actions/__pycache__/code_review.cpython-39.pyc b/metagpt/ext/cr/actions/__pycache__/code_review.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..421c975c8bfcb2c028ac5f6f39c238a8fe1e62ca Binary files /dev/null and b/metagpt/ext/cr/actions/__pycache__/code_review.cpython-39.pyc differ diff --git a/metagpt/ext/cr/actions/__pycache__/modify_code.cpython-310.pyc b/metagpt/ext/cr/actions/__pycache__/modify_code.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68b8783c57dc700d186a9b085213ef44fcdf82f0 Binary files /dev/null and b/metagpt/ext/cr/actions/__pycache__/modify_code.cpython-310.pyc differ diff --git a/metagpt/ext/cr/actions/__pycache__/modify_code.cpython-39.pyc b/metagpt/ext/cr/actions/__pycache__/modify_code.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1a3227c2c2777bdc6d77a6119f2a971abfb27a3 Binary files /dev/null and b/metagpt/ext/cr/actions/__pycache__/modify_code.cpython-39.pyc differ diff --git a/metagpt/ext/cr/actions/code_review.py b/metagpt/ext/cr/actions/code_review.py new file mode 100644 index 0000000000000000000000000000000000000000..0235dc2c605f191195861857dd03835ee2ea1f1a --- /dev/null +++ b/metagpt/ext/cr/actions/code_review.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import json +import re +from pathlib import Path + +import aiofiles +from unidiff import PatchSet + +from metagpt.actions.action import Action +from metagpt.ext.cr.utils.cleaner import ( + add_line_num_on_patch, + get_code_block_from_patch, + rm_patch_useless_part, +) +from metagpt.ext.cr.utils.schema import Point +from metagpt.logs import logger +from metagpt.utils.common import parse_json_code_block +from metagpt.utils.report import EditorReporter + +CODE_REVIEW_PROMPT_TEMPLATE = """ +NOTICE +Let's think and work step by step. +With the given pull-request(PR) Patch, and referenced Points(Code Standards), you should compare each point with the code one-by-one within 4000 tokens. + +The Patch code has added line number at the first character each line for reading, but the review should focus on new added code inside the `Patch` (lines starting with line number and '+'). +Each point is start with a line number and follows with the point description. + +## Patch +``` +{patch} +``` + +## Points +{points} + +## Output Format +```json +[ + {{ + "commented_file": "The file path which you give a comment from the patch", + "comment": "The chinese comment of code which do not meet point description and give modify suggestions", + "code_start_line": "the code start line number like `10` in the Patch of current comment,", + "code_end_line": "the code end line number like `15` in the Patch of current comment", + "point_id": "The point id which the `comment` references to" + }} +] +``` + +CodeReview guidelines: +- Generate code `comment` that do not meet the point description. +- Each `comment` should be restricted inside the `commented_file`. +- Try to provide diverse and insightful comments across different `commented_file`. +- Don't suggest to add docstring unless it's necessary indeed. +- If the same code error occurs multiple times, it cannot be omitted, and all places need to be identified.But Don't duplicate at the same place with the same comment! +- Every line of code in the patch needs to be carefully checked, and laziness cannot be omitted. It is necessary to find out all the places. +- The `comment` and `point_id` in the Output must correspond to and belong to the same one `Point`. + +Strictly Observe: +Just print the PR Patch comments in json format like **Output Format**. +And the output JSON must be able to be parsed by json.loads() without any errors. +""" + +CODE_REVIEW_COMFIRM_SYSTEM_PROMPT = """ +You are a professional engineer with {code_language} stack, and good at code review comment result judgement.Let's think and work step by step. +""" + +CODE_REVIEW_COMFIRM_TEMPLATE = """ +## Code +``` +{code} +``` +## Code Review Comments +{comment} + +## Description of Defects +{desc} + +## Reference Example for Judgment +{example} + +## Your Task: +1. First, check if the code meets the requirements and does not violate any defects. If it meets the requirements and does not violate any defects, print `False` and do not proceed with further judgment. +2. Based on the `Reference Example for Judgment` provided, determine if the `Code` and `Code Review Comments` match. If they match, print `True`; otherwise, print `False`. + +Note: Your output should only be `True` or `False` without any explanations. +""" + + +class CodeReview(Action): + name: str = "CodeReview" + + def format_comments(self, comments: list[dict], points: list[Point], patch: PatchSet): + new_comments = [] + logger.debug(f"original comments: {comments}") + for cmt in comments: + try: + if cmt.get("commented_file").endswith(".py"): + points = [p for p in points if p.language == "Python"] + elif cmt.get("commented_file").endswith(".java"): + points = [p for p in points if p.language == "Java"] + else: + continue + for p in points: + point_id = int(cmt.get("point_id", -1)) + if point_id == p.id: + code_start_line = cmt.get("code_start_line") + code_end_line = cmt.get("code_end_line") + code = get_code_block_from_patch(patch, code_start_line, code_end_line) + + new_comments.append( + { + "commented_file": cmt.get("commented_file"), + "code": code, + "code_start_line": code_start_line, + "code_end_line": code_end_line, + "comment": cmt.get("comment"), + "point_id": p.id, + "point": p.text, + "point_detail": p.detail, + } + ) + break + except Exception: + pass + + logger.debug(f"new_comments: {new_comments}") + return new_comments + + async def confirm_comments(self, patch: PatchSet, comments: list[dict], points: list[Point]) -> list[dict]: + points_dict = {point.id: point for point in points} + new_comments = [] + for cmt in comments: + try: + point = points_dict[cmt.get("point_id")] + + code_start_line = cmt.get("code_start_line") + code_end_line = cmt.get("code_end_line") + # 如果代码位置为空的话,那么就将这条记录丢弃掉 + if not code_start_line or not code_end_line: + logger.info("False") + continue + + # 代码增加上下文,提升confirm的准确率 + code = get_code_block_from_patch( + patch, str(max(1, int(code_start_line) - 3)), str(int(code_end_line) + 3) + ) + pattern = r"^[ \t\n\r(){}[\];,]*$" + if re.match(pattern, code): + code = get_code_block_from_patch( + patch, str(max(1, int(code_start_line) - 5)), str(int(code_end_line) + 5) + ) + code_language = "Java" + code_file_ext = cmt.get("commented_file", ".java").split(".")[-1] + if code_file_ext == ".java": + code_language = "Java" + elif code_file_ext == ".py": + code_language = "Python" + prompt = CODE_REVIEW_COMFIRM_TEMPLATE.format( + code=code, + comment=cmt.get("comment"), + desc=point.text, + example=point.yes_example + "\n" + point.no_example, + ) + system_prompt = [CODE_REVIEW_COMFIRM_SYSTEM_PROMPT.format(code_language=code_language)] + resp = await self.llm.aask(prompt, system_msgs=system_prompt) + if "True" in resp or "true" in resp: + new_comments.append(cmt) + except Exception: + logger.info("False") + logger.info(f"original comments num: {len(comments)}, confirmed comments num: {len(new_comments)}") + return new_comments + + async def cr_by_points(self, patch: PatchSet, points: list[Point]): + comments = [] + valid_patch_count = 0 + for patched_file in patch: + if not patched_file: + continue + if patched_file.path.endswith(".py"): + points = [p for p in points if p.language == "Python"] + valid_patch_count += 1 + elif patched_file.path.endswith(".java"): + points = [p for p in points if p.language == "Java"] + valid_patch_count += 1 + else: + continue + group_points = [points[i : i + 3] for i in range(0, len(points), 3)] + for group_point in group_points: + points_str = "id description\n" + points_str += "\n".join([f"{p.id} {p.text}" for p in group_point]) + prompt = CODE_REVIEW_PROMPT_TEMPLATE.format(patch=str(patched_file), points=points_str) + resp = await self.llm.aask(prompt) + json_str = parse_json_code_block(resp)[0] + comments_batch = json.loads(json_str) + if comments_batch: + patched_file_path = patched_file.path + for c in comments_batch: + c["commented_file"] = patched_file_path + comments.extend(comments_batch) + + if valid_patch_count == 0: + raise ValueError("Only code reviews for Python and Java languages are supported.") + + return comments + + async def run(self, patch: PatchSet, points: list[Point], output_file: str): + patch: PatchSet = rm_patch_useless_part(patch) + patch: PatchSet = add_line_num_on_patch(patch) + + result = [] + async with EditorReporter(enable_llm_stream=True) as reporter: + log_cr_output_path = Path(output_file).with_suffix(".log") + await reporter.async_report( + {"src_path": str(log_cr_output_path), "filename": log_cr_output_path.name}, "meta" + ) + comments = await self.cr_by_points(patch=patch, points=points) + log_cr_output_path.parent.mkdir(exist_ok=True, parents=True) + async with aiofiles.open(log_cr_output_path, "w", encoding="utf-8") as f: + await f.write(json.dumps(comments, ensure_ascii=False, indent=2)) + await reporter.async_report(log_cr_output_path) + + if len(comments) != 0: + comments = self.format_comments(comments, points, patch) + comments = await self.confirm_comments(patch=patch, comments=comments, points=points) + for comment in comments: + if comment["code"]: + if not (comment["code"].isspace()): + result.append(comment) + + async with EditorReporter() as reporter: + src_path = output_file + cr_output_path = Path(output_file) + await reporter.async_report( + {"type": "CodeReview", "src_path": src_path, "filename": cr_output_path.name}, "meta" + ) + async with aiofiles.open(cr_output_path, "w", encoding="utf-8") as f: + await f.write(json.dumps(comments, ensure_ascii=False, indent=2)) + await reporter.async_report(cr_output_path) + return result diff --git a/metagpt/ext/cr/actions/modify_code.py b/metagpt/ext/cr/actions/modify_code.py new file mode 100644 index 0000000000000000000000000000000000000000..820bdae4a1def534f2ce583a0e633ad5430302c2 --- /dev/null +++ b/metagpt/ext/cr/actions/modify_code.py @@ -0,0 +1,112 @@ +import datetime +import itertools +import re +from pathlib import Path +from typing import Optional + +from unidiff import PatchSet + +from metagpt.actions.action import Action +from metagpt.ext.cr.utils.cleaner import ( + add_line_num_on_patch, + get_code_block_from_patch, + rm_patch_useless_part, +) +from metagpt.utils.common import CodeParser +from metagpt.utils.report import EditorReporter + +SYSTEM_MSGS_PROMPT = """ +You're an adaptive software developer who excels at refining code based on user inputs. You're proficient in creating Git patches to represent code modifications. +""" + +MODIFY_CODE_PROMPT = """ +NOTICE +With the given pull-request(PR) Patch, and referenced Comments(Code Standards), you should modify the code according the Comments. + +The Patch code has added line no at the first character each line for reading, but the modification should focus on new added code inside the `Patch` (lines starting with line no and '+'). + +## Patch +``` +{patch} +``` + +## Comments +{comments} + +## Output Format + + + +Code Modification guidelines: +- Look at `point_detail`, modify the code by `point_detail`, use `code_start_line` and `code_end_line` to locate the problematic code, fix the problematic code by `point_detail` in Comments.Strictly,must handle the fix plan given by `point_detail` in every comment. +- Create a patch that satifies the git patch standard and your fixes need to be marked with '+' and '-',but notice:don't change the hunk header! +- Do not print line no in the new patch code. + +Just print the Patch in the format like **Output Format**. +""" + + +class ModifyCode(Action): + name: str = "Modify Code" + pr: str + + async def run(self, patch: PatchSet, comments: list[dict], output_dir: Optional[str] = None) -> str: + patch: PatchSet = rm_patch_useless_part(patch) + patch: PatchSet = add_line_num_on_patch(patch) + + # + for comment in comments: + code_start_line = comment.get("code_start_line") + code_end_line = comment.get("code_end_line") + # 如果代码位置为空的话,那么就将这条记录丢弃掉 + if code_start_line and code_end_line: + code = get_code_block_from_patch( + patch, str(max(1, int(code_start_line) - 3)), str(int(code_end_line) + 3) + ) + pattern = r"^[ \t\n\r(){}[\];,]*$" + if re.match(pattern, code): + code = get_code_block_from_patch( + patch, str(max(1, int(code_start_line) - 5)), str(int(code_end_line) + 5) + ) + # 代码增加上下文,提升代码修复的准确率 + comment["code"] = code + # 去掉CR时LLM给的comment的影响,应该使用既定的修复方案 + comment.pop("comment") + + # 按照 commented_file 字段进行分组 + comments.sort(key=lambda x: x["commented_file"]) + grouped_comments = { + key: list(group) for key, group in itertools.groupby(comments, key=lambda x: x["commented_file"]) + } + resp = None + for patched_file in patch: + patch_target_file_name = str(patched_file.path).split("/")[-1] + if patched_file.path not in grouped_comments: + continue + comments_prompt = "" + index = 1 + for grouped_comment in grouped_comments[patched_file.path]: + comments_prompt += f""" + + {grouped_comment} + \n + """ + index += 1 + prompt = MODIFY_CODE_PROMPT.format(patch=patched_file, comments=comments_prompt) + output_dir = ( + Path(output_dir) + if output_dir + else self.config.workspace.path / "modify_code" / str(datetime.date.today()) / self.pr + ) + patch_file = output_dir / f"{patch_target_file_name}.patch" + patch_file.parent.mkdir(exist_ok=True, parents=True) + async with EditorReporter(enable_llm_stream=True) as reporter: + await reporter.async_report( + {"type": "Patch", "src_path": str(patch_file), "filename": patch_file.name}, "meta" + ) + resp = await self.llm.aask(msg=prompt, system_msgs=[SYSTEM_MSGS_PROMPT]) + resp = CodeParser.parse_code(resp, "diff") + with open(patch_file, "w", encoding="utf-8") as file: + file.writelines(resp) + await reporter.async_report(patch_file) + return resp diff --git a/metagpt/ext/cr/points.json b/metagpt/ext/cr/points.json new file mode 100644 index 0000000000000000000000000000000000000000..f0920caccfd9e0560da48ae894d6a1d5af36befe --- /dev/null +++ b/metagpt/ext/cr/points.json @@ -0,0 +1,656 @@ +[ + { + "id": 1, + "text": "Avoid unused temporary variables", + "language": "Java", + "detail": "Defect type: Avoid unused temporary variables; Corresponding Fixer: UnusedLocalVariableFixer; Fix solution: Delete unused temporary variables", + "yes_example": "Examples of being judged as 'avoid unused temporary variables'", + "no_example": "Examples that cannot be judged as 'avoiding unused temporary variables'\n\npublic void setTransientVariablesLocal(Map transientVariables) {\n throw new UnsupportedOperationException(\"No execution active, no variables can be set\");\n}\nThis code's 'transientVariables' is a function parameter rather than a temporary variable. Although 'transientVariables' is not used or referenced, this cannot be judged as 'avoiding unused temporary variables'\n\n\n\npublic class TriggerCmd extends NeedsActiveExecutionCmd {\n protected Map transientVariables;\n public TriggerCmd(Map transientVariables) {\n this.transientVariables = transientVariables;\n }\n}\nIn the above code, 'transientVariables' is not a temporary variable; it is a class attribute and is used in the constructor, so this cannot be judged as 'avoiding unused temporary variables'\n" + }, + { + "id": 2, + "text": "Do not use System.out.println to print", + "language": "Java", + "detail": "Defect type: Do not use System.out.println to print; Corresponding Fixer: SystemPrintlnFixer; Fixing solution: Comment out the System.out.println code", + "yes_example": "Example of being judged as 'Do not use System.out.println for printing'", + "no_example": "Examples that cannot be judged as 'Do not use System.out.println to print'\n\nthrow new IllegalStateException(\"There is no authenticated user, we need a user authenticated to find tasks\");\nThe above code is throwing an exception, not using 'System.out.print', so this cannot be judged as 'Do not use System.out.println to print'\n" + }, + { + "id": 3, + "text": "Avoid unused formal parameters in functions", + "language": "Java", + "detail": "Defect type: Avoid unused formal parameters in functions; Fix solution: Ignore", + "yes_example": "Examples of being judged as 'avoiding unused formal parameters' in functions\n\n\npublic void setTransientVariablesLocal(Map transientVariables) {\n throw new UnsupportedOperationException(\"No execution active, no variables can be set\");\n}In this code, the formal parameter \"transientVariables\" does not appear in the function body, so this is judged as 'avoiding unused formal parameters'\n\n\n\nprotected void modifyFetchPersistencePackageRequest(PersistencePackageRequest ppr, Map pathVars) {}\nIn this code, the formal parameters \"ppr\" and \"pathVars\" do not appear in the function body, so this is judged as 'avoiding unused formal parameters'\n", + "no_example": "Examples that cannot be judged as 'avoiding unused parameters in functions'\n\npublic String processFindForm(@RequestParam(value = \"pageNo\", defaultValue = \"1\") int pageNo) {\n\tlastName = owner.getLastName();\n\treturn addPaginationModel(pageNo, paginationModel, lastName, ownersResults);\n}In this code, the parameter 'pageNo' is used within the current function 'processFindForm' in the statement 'return addPaginationModel(pageNo, paginationModel, lastName, ownersResults);', although pageNo is not used for logical calculations, it is used as a parameter in a function call to another function, so this cannot be judged as 'avoiding unused parameters in functions'\n\n\npublic void formatDate(Date date) {\n\tSimpleDateFormat sdf = new SimpleDateFormat(\"yyyy-MM-dd\");\n\tSystem.out.println(\"Formatted date: \" + sdf.format(date));\n}In this code, the parameter 'date' is referenced in the statement 'System.out.println(\"Formatted date: \" + sdf.format(date))', so this cannot be judged as 'avoiding unused parameters in functions'\n" + }, + { + "id": 4, + "text": "if statement block cannot be empty", + "language": "Java", + "detail": "Defect type: if statement block cannot be empty; Corresponding Fixer: EmptyIfStmtFixer; Fixing solution: delete the if statement block or handle the logic appropriately or comment to explain why it is empty", + "yes_example": "Examples of being judged as 'if statement block cannot be empty'\n\npublic void emptyIfStatement() {\n\tif (getSpecialties().isEmpty()) {\n\t}\n}\nThis code's if statement block is empty, so it is judged as 'if statement block cannot be empty'\n\n\n\npublic void judgePersion() {\n\tif (persion != null) {\n\t\t// judge persion if not null\n\t}\n}\nAlthough this code's if statement block has content, the '// judge persion if not null' is just a code comment, and there is no actual logic code inside the if statement block, so it is judged as 'if statement block cannot be empty'\n", + "no_example": "Example that cannot be judged as 'if statement block cannot be empty'" + }, + { + "id": 5, + "text": "Loop body cannot be empty", + "language": "Java", + "detail": "Defect type: loop body cannot be empty; Corresponding Fixer: EmptyStatementNotInLoopFixer; Repair solution: delete the corresponding while, for, foreach loop body or add appropriate logical processing or comment explaining why it is empty", + "yes_example": "Examples of being judged as 'Loop body cannot be empty'\n\npublic void emptyLoopBody() {\n\tfor (Specialty specialty : getSpecialties()) {\n\t}\n}\nThis code's for loop body is empty, so it is judged as 'Loop body cannot be empty'\n\n\n\npublic void emptyLoopBody() {\n\twhile (True) {\n\t\t// this is a code example\n\t}\n}\nThe while loop body in this code is not empty, but the content is just a code comment with no logical content, so it is judged as 'Loop body cannot be empty'\n\n\n\npublic void emptyLoopBody() {\n\twhile (True) {\n\t\t\n\t}\n}\nThe while loop body in this code is empty, so it is judged as 'Loop body cannot be empty'\n", + "no_example": "Example that cannot be judged as 'loop body cannot be empty'\n\npublic void emptyLoopBody() {\n\tfor (Specialty specialty : getSpecialties()) {\n\t\ta = 1;\n\t\tif (a == 1) {\n\t\t\tretrun a;\n\t\t}\n\t}\n}\nThe content of the for loop in the above code is not empty, and the content is not entirely code comments, so this cannot be judged as 'loop body cannot be empty'\n" + }, + { + "id": 6, + "text": "Avoid using printStackTrace(), and instead use logging to record.", + "language": "Java", + "detail": "Defect type: Avoid using printStackTrace(), should use logging to record; Repair solution: Use logging to record", + "yes_example": "Example of being judged as 'Avoid using printStackTrace(), should use logging to record'", + "no_example": "### Example that cannot be judged as 'avoid using printStackTrace(), should use logging to record'\n\npublic void usePrintStackTrace() {\n\ttry {\n\t\tthrow new Exception(\"Fake exception\");\n\t} catch (Exception e) {\n\t\tlogging.info(\"info\");\n\t}\n}\nThis code uses logging in the catch statement, so it cannot be judged as 'avoid using printStackTrace(), should use logging to record'\n" + }, + { + "id": 7, + "text": "The catch block cannot be empty", + "language": "Java", + "detail": "Defect type: catch block cannot be empty; Corresponding Fixer: EmptyCatchBlockFixer; Fix solution: Add a comment inside the catch block", + "yes_example": "Examples of being judged as 'catch block cannot be empty'\n\n\n\ntry {\n int[] array = new int[5];\n int number = array[10];\n} catch (ArrayIndexOutOfBoundsException e) {\n \n}\nThis code has an empty catch block, so it is judged as 'catch block cannot be empty'\n\n\n\n\ntry {\n String str = null;\n str.length();\n} catch (NullPointerException e) {\n \n}\nThis code has an empty catch block, so it is judged as 'catch block cannot be empty'\n\n\n\npublic class EmptyCatchExample {\n public static void main(String[] args) {\n try {\n // Attempt to divide by zero to trigger an exception\n int result = 10 / 0;\n } catch (ArithmeticException e) {\n \n }\n }\n}\nThis code has an empty catch block, so it is judged as 'catch block cannot be empty'\n\n\n\n\ntry {\n FileReader file = new FileReader(\"nonexistentfile.txt\");\n} catch (FileNotFoundException e) {\n \n}\nThis code has an empty catch block, so it is judged as 'catch block cannot be empty'\n\n\n\n\ntry {\n Object obj = \"string\";\n Integer num = (Integer) obj;\n} catch (ClassCastException e) {\n\t\n}\nThis code has an empty catch block, so it is judged as 'catch block cannot be empty'\n", + "no_example": "Examples that cannot be judged as 'catch block cannot be empty'\n\npersionNum = 1\ntry {\n\treturn True;\n} catch (Exception e) {\n\t// If the number of people is 1, return false\n\tif (persionNum == 1){\n\t\treturn False;\n\t}\n}This catch statement is not empty, so it cannot be judged as 'catch block cannot be empty'\n\n\n\ntry {\n\tthrow new Exception(\"Fake exception\");\n} catch (Exception e) {\n\te.printStackTrace();\n}Although this catch statement only has 'e.printStackTrace();', it is indeed not empty, so it cannot be judged as 'catch block cannot be empty'\n" + }, + { + "id": 8, + "text": "Avoid unnecessary tautologies/contradictions", + "language": "Java", + "detail": "Defect type: Avoid unnecessary true/false judgments; Corresponding Fixer: UnconditionalIfStatement Fixer; Fixing solution: Delete true/false judgment logic", + "yes_example": "Examples of being judged as 'avoiding unnecessary always true/always false judgments'", + "no_example": "Examples that cannot be judged as 'avoiding unnecessary always true/always false judgments'" + }, + { + "id": 9, + "text": "In a switch statement, default must be placed at the end", + "language": "Java", + "detail": "Defect type: The default in switch must be placed at the end; Corresponding Fixer: DefaultLabelNotLastInSwitchStmtFixer; Fixing solution: Place default at the end in switch", + "yes_example": "Example of being judged as 'default in switch must be placed at the end'", + "no_example": "Example that cannot be judged as 'the default in switch must be placed at the end'" + }, + { + "id": 10, + "text": "Comparison of String without using equals() function", + "language": "Java", + "detail": "Defect type: Not using the equals() function to compare Strings; Corresponding Fixer: UnSynStaticDateFormatter Fixer; Fix solution: Use the equals() function to compare Strings", + "yes_example": "Examples of being judged as 'not using the equals() function to compare Strings'\n\n\nif (existingPet != null && existingPet.getName() == petName) {\n result.rejectValue(\"name\", \"duplicate\", \"already exists\");\n}\nIn this code, both existingPet.getName() and petName are strings, but the comparison in the if statement uses == instead of equals() to compare the strings, so this is judged as 'not using the equals() function to compare Strings'.\n\n\n\nString isOk = \"ok\";\nif (\"ok\" == isOk) {\n result.rejectValue(\"name\", \"duplicate\", \"already exists\");\n}\nIn this code, isOk is a string, but in the if statement, it is compared with \"ok\" using ==, not using equals() to compare the strings, it should use \"ok\".equals(isOk), so this is judged as 'not using the equals() function to compare Strings'.\n\n\n\nString str1 = \"Hello\";\nString str2 = \"Hello\";\nif (str1 == str2) {\n System.out.println(\"str1 and str2 reference the same object\");\n} else {\n System.out.println(\"str1 and str2 reference different objects\");\n}\nIn this code, if (str1 == str2) uses == to compare str1 and str2, not using equals() to compare the strings, it should use str1.equals(str2), so this is judged as 'not using the equals() function to compare Strings'.\n\n\n\nString str = \"This is string\";\nif (str == \"This is not str\") {\n return str;\n}\nIn this code, if (str == \"This is not str\") uses == to compare the strings, not using equals() to compare the strings, it should use \"This is not str\".equals(str), so this is judged as 'not using the equals() function to compare Strings'.\n", + "no_example": "Examples that cannot be judged as 'not using the equals() function to compare Strings'\n\n\nif (PROPERTY_VALUE_YES.equalsIgnoreCase(readWriteReqNode))\n formProperty.setRequired(true);\nIn this code, both PROPERTY_VALUE_YES and readWriteReqNode are strings. The comparison between PROPERTY_VALUE_YES and readWriteReqNode in the if statement uses equalsIgnoreCase (case-insensitive string comparison), which is also in line with using the equals() function to compare Strings. Therefore, this cannot be judged as 'not using the equals() function to compare Strings'\n\n\n\nString isOk = \"ok\";\nif (\"ok\".equals(isOk)) {\n\tresult.rejectValue(\"name\", \"duplicate\", \"already exists\");\n}In this code, isOk is a string. In the if statement, the comparison with \"ok\" uses the equals() function to compare Strings, so this cannot be judged as 'not using the equals() function to compare Strings'\n" + }, + { + "id": 11, + "text": "Prohibit the direct use of string output for exceptions in logs, please use placeholders to pass the exception object", + "language": "Java", + "detail": "Defect type: Do not directly output exceptions as strings in logs, use placeholders to pass the exception object; Corresponding Fixer: ConcatExceptionFixer; Fix solution: Use placeholders to pass the exception object", + "yes_example": "Example of being judged as 'Prohibited to directly output exceptions using string in logs, please use placeholders to pass exception objects'\n\ntry {\n listenersNode = objectMapper.readTree(listenersNode.asText());\n} catch (Exception e) {\n LOGGER.info(\"Listeners node can not be read\", e);\n}In this code, the log output content is directly concatenated using the string \"Listeners node can not be read\". When outputting exceptions in logs, placeholders should be used to output exception information, rather than directly concatenating strings. Therefore, this is judged as 'Prohibited to directly output exceptions using string in logs, please use placeholders to pass exception objects'.\n", + "no_example": "Examples that cannot be judged as 'Prohibited to directly output exceptions using string in logs, please use placeholders to pass exception objects':\n\n\nPerson person = personService.getPerson(1);\nif (person == null) {\n LOGGER.error(PERSION_NOT_EXIT);\n}\nIn this code, PERSION_NOT_EXIT is a user-defined exception constant representing that the person does not exist, and it does not directly use the string 'person not exit' for concatenation, so this cannot be judged as 'Prohibited to directly output exceptions using string in logs, please use placeholders to pass exception objects'.\n\n\n\ntry {\n a = a + 1;\n} catch (Exception e) {\n Person person = personService.getPerson(1);\n LOGGER.info(person);\n}\nIn this code, the log output does not directly use string concatenation, but rather uses the Person object for output, so this cannot be judged as 'Prohibited to directly output exceptions using string in logs, please use placeholders to pass exception objects'.\n" + }, + { + "id": 12, + "text": "The finally block cannot be empty", + "language": "Java", + "detail": "Defect type: finally block cannot be empty; Corresponding Fixer: EmptyFinallyBlockFixer; Fix solution: Delete the empty finally block", + "yes_example": "Examples of being judged as 'finally block cannot be empty'\n\n\n\ntry {\n Persion persion = persionService.getPersion(1);\n return persion;\n} finally {\n \n}\nThis code has an empty finally block, so it is judged as 'finally block cannot be empty'\n\n\n\n\n\ntry {\n System.out.println(\"Inside try block\");\n} finally {\n // Empty finally block with no statements, this is a defect\n}\nThis code has an empty finally block, so it is judged as 'finally block cannot be empty'\n\n\n\n\n\ntry {\n int result = 10 / 0;\n} catch (ArithmeticException e) {\n e.printStackTrace();\n} finally {\n \n}\nThis code has an empty finally block, so it is judged as 'finally block cannot be empty'\n\n\n\n\n\ntry {\n String str = null;\n System.out.println(str.length());\n} catch (NullPointerException e) {\n e.printStackTrace();\n} finally {\n \n}\nThis code has an empty finally block, so it is judged as 'finally block cannot be empty'\n\n\n\n\n\ntry {\n int[] array = new int[5];\n int number = array[10];\n} catch (ArrayIndexOutOfBoundsException e) {\n e.printStackTrace();\n} finally {\n // Finally block with only comments\n // This is an empty finally block\n}\nThis code has an empty finally block, so it is judged as 'finally block cannot be empty'\n\n\n\n\n\ntry {\n FileReader file = new FileReader(\"nonexistentfile.txt\");\n} catch (FileNotFoundException e) {\n e.printStackTrace();\n} finally {\n // Finally block with only empty lines\n \n}\nThis code has an empty finally block, so it is judged as 'finally block cannot be empty'\n\n", + "no_example": "Example that cannot be judged as 'finally block cannot be empty'\n\npublic void getPersion() {\n\ttry {\n\t\tPersion persion = persionService.getPersion(1);\n\t\tif (persion != null){\n\t\t\treturn persion;\n\t\t}\n\t} finally {\n\t\treturn null;\n\t}\n}\nThis code's finally block contains non-comment content 'return null;', so this cannot be judged as 'finally block cannot be empty'\n" + }, + { + "id": 13, + "text": "try block cannot be empty", + "language": "Java", + "detail": "Defect type: try block cannot be empty; Corresponding Fixer: EmptyTryBlockFixer; Fix solution: Delete the entire try statement", + "yes_example": "Examples of being judged as 'try block cannot be empty'\n\npublic void getPersion() {\n\ttry {\n\n\t}\n\treturn null;\n}This code's try block is empty, so it is judged as 'try block cannot be empty'\n\n\n\npublic void demoFinallyBlock() {\n\ttry {\n\n\t} finally {\n\t\treturn null;\n\t}\n}This code's try block is empty, so it is judged as 'try block cannot be empty'\n\n\n\ntry {\n \n} catch (Exception e) {\n e.printStackTrace();\n}This code's try block is empty, so it is judged as 'try block cannot be empty'\n\n\n\ntry {\n // try block with only comments\n\t\n} catch (Exception e) {\n e.printStackTrace();\n}This code's try block contains only comments and blank lines, which can also be considered as having no content in the try block, so it is judged as 'try block cannot be empty'\n", + "no_example": "### Example that cannot be judged as 'try block cannot be empty'\n\ntry {\n\ta = a + 1;\n} catch (Exception e) {\n\te.printStackTrace();\n}\nThis code snippet contains non-comment content 'return null;' in the try block, so it cannot be judged as 'try block cannot be empty'\n" + }, + { + "id": 14, + "text": "Avoid unnecessary NULL or null checks on objects", + "language": "Java", + "detail": "Defect type: Avoid unnecessary NULL or null checks on objects; Corresponding Fixer: LogicalOpNpeFixer; Fix solution: Remove the logic of unnecessary NULL checks on objects", + "yes_example": "Examples of being judged as 'avoiding unnecessary NULL or null checks':", + "no_example": "Example that cannot be judged as 'avoiding unnecessary NULL or null checks'\n\nCat cat = catService.get(1);\nif (cat != null){\n\tretrun cat;\n}In this code, the object 'cat' is obtained through the service and it is uncertain whether it is null or not, so the condition 'cat != null' in the if statement is necessary, therefore this cannot be judged as 'avoiding unnecessary NULL or null checks'\n" + }, + { + "id": 15, + "text": "Avoid return in finally block", + "language": "Java", + "detail": "Defect type: Avoid return in finally block; Repair solution: No need for repair", + "yes_example": "Example judged as 'avoid return in finally block'", + "no_example": "Example that cannot be judged as 'avoiding return in finally block'\n\npublic void getPersion() {\n\ttry {\n\t\tPersion persion = persionService.getPersion(1);\n\t\tif (persion != null){ \n\t\t\treturn persion;\n\t\t}\n\t} finally {\n\t\tLOGGER.info(PERSION_NOT_EXIT);\n\t}\n}\nThis code's finally block does not contain 'return', so it cannot be judged as 'avoiding return in finally block'\n" + }, + { + "id": 16, + "text": "Avoid empty static initialization", + "language": "Java", + "detail": "Defect type: Avoid empty static initialization; Corresponding Fixer: EmptyInitializerFixer; Fix solution: Delete the entire empty initialization block", + "yes_example": "Examples of being judged as 'Avoid empty static initialization'", + "no_example": "Example that cannot be judged as 'avoiding empty static initialization'\n\npublic class Cat {\n\tstatic {\n\t\t// Static initialization block\n\t\tcat = null;\n\t}\n}\nThis code has a static block with content, not empty, and the static initialization block contains non-commented code with actual logic, so this cannot be judged as 'avoiding empty static initialization'\n" + }, + { + "id": 17, + "text": "Avoid risks of improper use of calendar", + "language": "Java", + "detail": "Defect type: Avoid improper usage risks of calendar classes; Fix solution: Use LocalDate from the java.time package in Java 8 and above", + "yes_example": "Examples of being judged as 'avoiding improper use of calendar class risks'\n\nprivate static final Calendar calendar = new GregorianCalendar(2020, Calendar.JANUARY, 1);\nThe Calendar and GregorianCalendar in this code are not thread-safe, so this is judged as 'avoiding improper use of calendar class risks'\n", + "no_example": "Examples that cannot be judged as 'avoiding improper use of calendar class risks'" + }, + { + "id": 18, + "text": "To convert a collection to an array, you must use the toArray(T[] array) method of the collection, passing in an array of the exact same type, with a size equal to list.size()", + "language": "Java", + "detail": "Defect type: When converting a collection to an array, you must use the toArray(T[] array) method of the collection, passing an array of the exact same type, with a size equal to list.size(); Corresponding Fixer: ClassCastExpWithToArrayFixer; Repair solution: Use the toArray(T[] array) method of the collection, and pass an array of the exact same type", + "yes_example": "Example judged as 'When converting a collection to an array, you must use the collection's toArray(T[] array) method, passing an array of exactly the same type, with the size being list.size()'", + "no_example": "Example that cannot be judged as 'using the method of converting a collection to an array, you must use the toArray(T[] array) of the collection, passing in an array of exactly the same type, and the size is list.size()':" + }, + { + "id": 19, + "text": "Prohibit the use of NULL or null for comparison in equals()", + "language": "Java", + "detail": "Defect type: Prohibit using NULL or null for comparison in equals(); Corresponding Fixer: EqualsNullFixer; Fixing solution: Use Object's null check function for comparison", + "yes_example": "Examples of being judged as 'Prohibited to use NULL or null for comparison in equals()'", + "no_example": "Examples that cannot be judged as 'prohibiting the use of NULL or null for comparison in equals()'" + }, + { + "id": 20, + "text": "switch statement block cannot be empty", + "language": "Java", + "detail": "Defect type: switch statement block cannot be empty; Corresponding Fixer: EmptySwitchStatementsFix; Fix solution: Delete the entire empty switch statement block", + "yes_example": "Examples of being judged as 'switch statement block cannot be empty'\n\nswitch (number) {\n \n}This code is a switch statement block, but it contains no content, so it is judged as 'switch statement block cannot be empty'\n\n\n\nswitch (number) {\n // This is a switch statement block\n}This code is a switch statement block, which contains content, but the content is only comments without actual logic, so it is judged as 'switch statement block cannot be empty'\n", + "no_example": "Example that cannot be judged as 'switch statement block cannot be empty'\n\nswitch (number) {\n\tcase 1:\n\t\tSystem.out.println(\"Number one\");\n\t\tbreak;\n\tdefault:\n\t\tSystem.out.println(\"This is the default block, which is incorrectly placed here.\");\n\t\tbreak;\n}\nThis code is a switch statement block that contains content, and the content includes non-commented code with actual logic, so it cannot be judged as 'switch statement block cannot be empty'.\n" + }, + { + "id": 21, + "text": "When performing type coercion, no spaces are needed between the right parenthesis and the coercion value.", + "detail": "Defect type: When performing type coercion, no space is required between the right parenthesis and the coercion value; Fix solution: When performing type coercion, no space is required between the right parenthesis and the coercion value.", + "language": "Java", + "yes_example": "Examples judged as 'When performing type casting, no space is needed between the closing parenthesis and the cast value'", + "no_example": "Examples that cannot be judged as 'When performing type coercion, no spaces are required between the right parenthesis and the coercion value'" + }, + { + "id": 22, + "text": "Method parameters must have a space after the comma when defined and passed", + "detail": "Defect type: In the definition and passing of method parameters, a space must be added after the comma for multiple parameters; Repair solution: In the definition and passing of method parameters, a space must be added after the comma for multiple parameters.", + "language": "Java", + "yes_example": "Example of being judged as 'Method parameters must have a space after the comma when defined and passed'", + "no_example": "Examples that cannot be judged as 'Method parameters must have a space after the comma both in definition and when passed'" + }, + { + "id": 23, + "text": "Prohibit the use of the BigDecimal(double) constructor to convert a double value to a BigDecimal object", + "detail": "Defect type: Do not use the constructor BigDecimal(double) to convert a double value to a BigDecimal object; Repair solution: It is recommended to use the valueOf method of BigDecimal.", + "language": "Java", + "yes_example": "Example of being judged as 'Prohibited to use the constructor BigDecimal(double) to convert a double value to a BigDecimal object'", + "no_example": "Examples that cannot be considered as 'prohibiting the use of the BigDecimal(double) constructor to convert a double value to a BigDecimal object'" + }, + { + "id": 24, + "text": "No extra semicolons allowed", + "detail": "Defect type: extra semicolon; Fix solution: remove extra semicolon", + "yes_example": "Example of being judged as 'cannot have extra semicolons'", + "no_example": "Examples that cannot be judged as 'cannot have extra semicolons'\n\nwhile (True) {\n\ta = a + 1;\n\tbreak;\n}This code requires every semicolon, so it can be judged as 'cannot have extra semicolons'\n" + }, + { + "id": 25, + "text": "Non-thread-safe SimpleDateFormat usage must be synchronized at the function or code block level", + "detail": "Defect type: Non-thread-safe SimpleDateFormat usage; Fix solution: Add synchronized modifier at the function or code block level or use other thread-safe methods", + "yes_example": "Example of 'Non-thread-safe SimpleDateFormat usage, must be used with synchronized at the function or block level'", + "no_example": "Example that cannot be judged as 'Unsafe use of SimpleDateFormat, which must be used at the function or code block level with synchronized':\n\npublic synchronized void formatDate(Date date) {\n\tSimpleDateFormat sdf = new SimpleDateFormat(\"yyyy-MM-dd\");\n\tSystem.out.println(\"Formatted date: \" + sdf.format(date));\n}\nThis code is protected by a synchronized block on the function 'formatDate', ensuring thread safety, so it cannot be judged as 'Unsafe use of SimpleDateFormat, which must be used at the function or code block level with synchronized'.\n" + }, + { + "id": 26, + "text": "Naming does not follow the camel case specification. Class names should use UpperCamelCase style, while method names, parameter names, member variables, and local variables should all use lowerCamelCase style.", + "detail": "Defect type: Not following camel case naming convention; Fix solution: Class names should use UpperCamelCase style, method names, parameter names, member variables, and local variables should use lowerCamelCase style.", + "language": "Java", + "yes_example": "Examples of being judged as 'not following the camel case naming convention'\n\npublic class myClass {\n private int MyVariable;\n public void MyMethod() {}\n}\nThis code does not follow the camel case naming convention for class names, member variables, and method names, so it is judged as a naming convention issue.\n", + "no_example": "Examples that cannot be judged as 'not following the camel case naming convention'\n\npublic class MyClass {\n private int myVariable;\n public void myMethod() {}\n}\nThe class name, member variable, and method name in this code all follow the camel case naming convention, so it cannot be judged as a naming convention issue.\n" + }, + { + "id": 27, + "text": "Abstract class names start with Abstract or Base; exception class names end with Exception; test class names begin with the name of the class they are testing and end with Test", + "detail": "Defect type: Naming convention; Solution: Abstract class names should start with Abstract or Base, exception class names should end with Exception, and test class names should start with the name of the class they are testing and end with Test.", + "language": "Java", + "yes_example": "Examples of being judged as 'naming conventions'\n\npublic class MyAbstractClass {}\npublic class MyExceptionClass {}\npublic class TestMyClass {}\nThe naming of the abstract class, exception class, and test class in this code does not conform to the conventions, so it is judged as a naming convention issue.\n", + "no_example": "Examples that cannot be judged as 'naming conventions'" + }, + { + "id": 28, + "text": "Avoid adding the 'is' prefix to any boolean type variables in POJO classes", + "detail": "Defect type: Naming convention; Fix solution: Do not prefix boolean variables in POJO classes with 'is'.", + "language": "Java", + "yes_example": "Examples of being judged as 'naming convention' issues\n\npublic class User {\n private boolean isActive;\n}\nIn this code, the boolean type variable has the 'is' prefix, so it is judged as a naming convention issue.\n", + "no_example": "Examples that cannot be judged as 'naming conventions'" + }, + { + "id": 29, + "text": "Eliminate completely non-standard English abbreviations to avoid confusion when interpreting them.", + "detail": "Defect type: Naming conventions; Solution: Avoid using non-standard English abbreviations to ensure code readability.", + "language": "Java", + "yes_example": "Examples of being judged as 'naming conventions'\n\npublic class CfgMgr {\n private int cnt;\n}\nIn this code, the class name and variable name use non-standard English abbreviations, so they are judged as naming convention issues.\n", + "no_example": "Examples that cannot be judged as 'naming conventions'" + }, + { + "id": 30, + "text": "Avoid using magic characters and numbers, they should be declared as constants", + "detail": "Defect type: Avoid using magic characters and numbers, they should be declared as constants; Fix solution: Define magic values as constants.", + "language": "Java", + "yes_example": "Examples of being judged as 'avoiding magic characters and numbers, should be declared as constants'", + "no_example": "Examples that cannot be judged as 'avoiding magic characters and numbers, should be declared as constants'" + }, + { + "id": 31, + "text": "When assigning values to long or Long, use uppercase L after the number, not lowercase l. The suffix for floating-point numbers should be uppercase D or F.", + "detail": "Defect type: Code specification; Repair solution: Use uppercase L when assigning values to long or Long, and use uppercase D or F as suffixes for floating-point type values.", + "language": "Java", + "yes_example": "Examples of being judged as 'code specification'", + "no_example": "Examples that cannot be judged as 'code specification'" + }, + { + "id": 32, + "text": "If the curly braces are empty, simply write {} without line breaks or spaces inside the braces; if it is a non-empty code block, then: 1) Do not line break before the left curly brace. 2) Line break after the left curly brace. 3) Line break before the right curly brace. 4) Do not line break after the right curly brace if there is code like 'else' following it; the right curly brace indicating termination must be followed by a line break.", + "detail": "Defect type: code formatting; Fix solution: follow the curly brace usage standard.", + "language": "Java", + "yes_example": "Example of being judged as 'code format'", + "no_example": "Examples that cannot be judged as 'code format' issues\n\npublic class BracketExample {\n public void method() {\n if (true) {\n // do something\n }\n }\n}\nThe use of curly braces in this code is in accordance with the standards, so it cannot be judged as a code format issue.\n" + }, + { + "id": 33, + "text": "No space is needed between the left parenthesis and the adjacent character; no space is needed between the right parenthesis and the adjacent character; and a space is required before the left brace.", + "detail": "Defect type: code formatting; Fix solution: follow the usage rules for brackets and spaces.", + "language": "Java", + "yes_example": "Example of being judged as 'code format'\n\npublic class SpaceExample {\n public void method (){\n }\n}\nThe use of brackets and spaces in this code does not conform to the standard, so it is judged as a code format issue.\n", + "no_example": "Examples that cannot be judged as 'code specification'\n\npublic class SpaceExample {\n public void method() {}\n}\nThis code uses brackets and spaces in accordance with the specification, so it cannot be judged as a code format issue.\n" + }, + { + "id": 34, + "text": "Reserved words such as if / for / while / switch / do must be separated from the parentheses on both sides by spaces.", + "detail": "Defect type: code format; Fix solution: add spaces between reserved words and parentheses.", + "language": "Java", + "yes_example": "Example of being judged as 'code specification'\n\npublic class KeywordExample {\n public void method() {\n if(true) {\n }\n }\n}\nIn this code, there is no space between the if keyword and the parentheses, so it is judged as a code formatting issue.\n", + "no_example": "Examples that cannot be judged as 'code specification'" + }, + { + "id": 35, + "text": "All value comparisons between integer wrapper class objects should be done using the equals method", + "detail": "Defect type: Code specification; Repair solution: Use the equals method for value comparison between integer wrapper class objects.", + "language": "Java", + "yes_example": "Examples of being judged as 'code specification'", + "no_example": "### Example that cannot be judged as 'code specification'\n\npublic class IntegerComparison {\n public void compare() {\n Integer a = 100;\n Integer b = 100;\n if (a.equals(b)) {\n }\n }\n}\nIn this code, the equals method is used to compare integer wrapper class objects, so it cannot be judged as a code specification issue.\n" + }, + { + "id": 36, + "text": "For comparing BigDecimal values, the compareTo() method should be used instead of the equals() method.", + "detail": "Defect type: The equality comparison of BigDecimal should use the compareTo() method instead of the equals() method; Fix solution: Use the compareTo() method for comparison.", + "language": "Java", + "yes_example": "Example of being judged as 'For BigDecimal equality comparison, the compareTo() method should be used instead of the equals() method'\n\nBigDecimal a = new BigDecimal(\"1.0\");\nBigDecimal b = new BigDecimal(\"1.00\");\nif (a.equals(b)) {\n // This code will return false because the equals() method compares precision\n}\n", + "no_example": "Examples that cannot be judged as 'For BigDecimal equality comparison, the compareTo() method should be used instead of the equals() method'" + }, + { + "id": 37, + "text": "Prohibit having both isXxx() and getXxx() methods for the same attribute xxx in a POJO class.", + "detail": "Defect type: Duplicate getter methods in POJO class; Fix solution: Ensure only one getter method exists.", + "language": "Java", + "yes_example": "Example of being judged as 'Prohibited to have both isXxx() and getXxx() methods for the corresponding attribute xxx in a POJO class'", + "no_example": "Examples that cannot be judged as 'Prohibiting the existence of both isXxx() and getXxx() methods for the corresponding attribute xxx in a POJO class'" + }, + { + "id": 38, + "text": "When formatting dates, use the lowercase 'y' uniformly to represent the year in the pattern.", + "detail": "Defect type: date formatting error; Fix solution: use lowercase y to represent the year.", + "language": "Java", + "yes_example": "Example judged as 'When formatting dates, use lowercase y for the year in the pattern'", + "no_example": "Examples that cannot be judged as 'When formatting dates, use lowercase y for the year in the pattern'" + }, + { + "id": 39, + "text": "Prohibited from using in any part of the program: 1) java.sql.Date 2) java.sql.Time 3) java.sql.Timestamp.", + "detail": "Defect type: used date classes from the java.sql package; Fix solution: use date classes from the java.time package.", + "language": "Java", + "yes_example": "Examples of being judged as \"Prohibited from using in any part of the program: 1) java.sql.Date 2) java.sql.Time 3) java.sql.Timestamp\"", + "no_example": "Examples that cannot be judged as 'Prohibited to use in any part of the program: 1) java.sql.Date 2) java.sql.Time 3) java.sql.Timestamp'" + }, + { + "id": 40, + "text": "Determine if all elements within a collection are empty using the isEmpty() method, rather than using the size() == 0 approach.", + "detail": "Defect type: Incorrect method for checking empty collection; Fix solution: Use isEmpty() method.", + "language": "Java", + "yes_example": "Example of being judged as 'To determine if all elements within a collection are empty, use the isEmpty() method instead of the size() == 0 approach'\n\nList list = new ArrayList<>();\nif (list.size() == 0) {\n // Empty logic\n}\n", + "no_example": "Examples that cannot be considered as 'judging whether all elements within a set are empty using the isEmpty() method instead of the size() == 0 approach'" + }, + { + "id": 41, + "text": "Whenever you override equals, you must also override hashCode.", + "detail": "Defect type: hashCode method not overridden; Fix solution: Override both equals and hashCode methods.", + "language": "Java", + "yes_example": "An example where it is judged that 'if you override equals, you must also override hashCode'", + "no_example": "An example where it cannot be judged as 'Whenever you override equals, you must also override hashCode'" + }, + { + "id": 42, + "text": "When using the Map methods keySet() / values() / entrySet() to return a collection object, you cannot perform element addition operations on it, otherwise a UnsupportedOperationException will be thrown.", + "detail": "Defect type: Adding operations to the collections returned by keySet() / values() / entrySet() of a Map; Repair solution: Avoid adding operations to these collections.", + "language": "Java", + "yes_example": "Example of being judged as 'When using the Map methods keySet() / values() / entrySet() to return a collection object, you cannot perform element addition operations on it, otherwise a UnsupportedOperationException exception will be thrown'", + "no_example": "Example that cannot be judged as 'When using the methods keySet() / values() / entrySet() of Map to return a collection object, it is not allowed to perform element addition operations on it, otherwise a UnsupportedOperationException will be thrown'" + }, + { + "id": 43, + "text": "Do not perform element removal / addition operations within a foreach loop. Use the iterator method for removing elements. If concurrent operations are required, the iterator must be synchronized.", + "detail": "Defect type: performing remove / add operations on elements within a foreach loop; Repair solution: use iterator to perform remove operations on elements.", + "language": "Java", + "yes_example": "Example of being judged as 'Do not perform element remove / add operations within a foreach loop. Use the iterator method for removing elements; if concurrent operations are required, the iterator must be synchronized.'", + "no_example": "Example that cannot be judged as 'Do not perform element remove / add operations inside a foreach loop. Use the iterator method for removing elements. If concurrent operations are required, the iterator should be synchronized.'\n\nList list = new ArrayList<>(Arrays.asList(\"a\", \"b\", \"c\"));\nIterator iterator = list.iterator();\nwhile (iterator.hasNext()) {\n String s = iterator.next();\n if (s.equals(\"a\")) {\n iterator.remove();\n }\n}\n" + }, + { + "id": 44, + "text": "Class, class attributes, and class methods must use Javadoc specifications for comments, using the format /** content */, and must not use the // xxx format.", + "detail": "Defect type: Comments do not conform to Javadoc standards; Solution: Use Javadoc-compliant comment format.", + "language": "Java", + "yes_example": "Examples of being judged as 'class, class attribute, class method annotations must use Javadoc specification, using the format /** content */, not using the // xxx method'", + "no_example": "Examples that cannot be judged as 'Class, class attribute, and class method comments must follow the Javadoc specification, using the /** content */ format, not the // xxx format'" + }, + { + "id": 45, + "text": "All abstract methods (including methods in interfaces) must be annotated with Javadoc comments", + "detail": "Defect type: All abstract methods (including methods in interfaces) must be annotated with Javadoc; Repair solution: Add Javadoc comments to all abstract methods (including methods in interfaces), in addition to the return value, parameter exception description, it must also indicate what the method does and what function it implements.", + "language": "Java", + "yes_example": "Example of being judged as 'All abstract methods (including methods in interfaces) must be annotated with Javadoc'", + "no_example": "Example that cannot be judged as 'all abstract methods (including methods in interfaces) must be annotated with Javadoc comments'" + }, + { + "id": 46, + "text": "Usage guidelines for single-line and multi-line comments within methods", + "detail": "Defect type: Improper use of comments; Repair solution: Single-line comments inside the method, start a new line above the commented statement, use // for comments. Multi-line comments inside the method use /* */ comments, and pay attention to aligning with the code.", + "language": "Java", + "yes_example": "### Examples of being judged as 'Improper Use of Comments'\n\npublic void exampleMethod() {\n int a = 1; // Initialize variable a\n int b = 2; /* Initialize variable b */\n}\nThe single-line and multi-line comments in this code are not used according to the standard, so they are judged as improper use of comments.\n", + "no_example": "Examples that cannot be judged as 'improper use of comments'\n\npublic void exampleMethod() {\n // Initialize variable a\n int a = 1;\n /*\n * Initialize variable b\n */\n int b = 2;\n}\nThis code uses single-line and multi-line comments according to the standard, so it cannot be judged as improper use of comments.\n" + }, + { + "id": 47, + "text": "All enumeration type fields must have comments", + "detail": "Defect type: Enumeration type field lacks comments; Fix plan: Add comments to all enumeration type fields to explain the purpose of each data item.", + "language": "Java", + "yes_example": "Example of being judged as 'Enumeration type field lacks comments'\n\npublic enum Status {\n ACTIVE,\n INACTIVE\n}\nThe enumeration type fields in this code are not commented, so they are judged as lacking comments for enumeration type fields.\n", + "no_example": "Examples that cannot be judged as 'missing comments for enum fields'\n\npublic enum Status {\n /**\n * Active status\n */\n ACTIVE,\n /**\n * Inactive status\n */\n INACTIVE\n}\nThis code has comments for the enum fields, so it cannot be judged as missing comments for enum fields.\n" + }, + { + "id": 48, + "text": "The finally block must close resource objects and stream objects.", + "detail": "Defect type: resource objects and stream objects are not closed in the finally block; Fix solution: Close resource objects and stream objects in the finally block, and use try-catch for exceptions.", + "language": "Java", + "yes_example": "Example of being judged as 'resource object, stream object not closed in finally block'", + "no_example": "Examples that cannot be judged as 'resource objects, stream objects not closed in the finally block'" + }, + { + "id": 49, + "text": "Constant names should be in all uppercase, with words separated by underscores.", + "detail": "Defect type: Constant naming is not standardized; Fix solution: Constant names should be all uppercase, words separated by underscores, and strive for complete and clear semantic expression, do not be afraid of long names.", + "language": "Java", + "yes_example": "Examples of being judged as 'Constant names should be in all uppercase, with words separated by underscores'", + "no_example": "Examples that cannot be judged as 'constant names should be all uppercase, with words separated by underscores'" + }, + { + "id": 50, + "text": "Spaces are required on both sides of any binary or ternary operator.", + "detail": "Defect type: Lack of space around operators; Fix solution: Any binary or ternary operator should have a space on both sides.", + "language": "Java", + "yes_example": "Examples of being judged as 'Any binary or ternary operator must have spaces on both sides'", + "no_example": "Examples that cannot be judged as 'any binary, ternary operator needs a space on both sides'" + }, + { + "id": 51, + "text": "Avoid using from import *", + "detail": "Defect type: Avoid using 'from import *', importing everything can cause naming conflicts; Solution: Each sub-dependency used should be imported separately.", + "language": "Python", + "yes_example": "Example of being judged as 'avoid using from import *'", + "no_example": "Examples that cannot be judged as 'avoid using from import *'" + }, + { + "id": 52, + "text": "Avoid using the __import__() function to dynamically import modules", + "detail": "Defect type: Avoid using __import__() function to dynamically import modules; Repair solution: Use standard import statements.", + "language": "Python", + "yes_example": "Example of being judged as 'dynamically importing modules using the __import__() function'", + "no_example": "Examples that cannot be judged as 'dynamically importing modules using the __import__() function'" + }, + { + "id": 53, + "text": "Import statements are not grouped in the order of standard library imports, related third-party imports, and local application/library specific imports.", + "detail": "Defect type: Import statements are not grouped in the order of standard library imports, related third-party imports, and local application/library specific imports; Solution: Group import statements in order.", + "language": "Python", + "yes_example": "Examples of being judged as 'import statements not grouped in the order of standard library imports, related third-party imports, and local application/library specific imports'", + "no_example": "Example that cannot be judged as 'import statements not grouped in the order of standard library imports, related third-party imports, local application/library specific imports'" + }, + { + "id": 54, + "text": "Avoid unused function parameters", + "detail": "Defect type: Avoid unused function parameters; Fix solution: Remove unused function parameters.", + "language": "Python", + "yes_example": "Examples of being judged as 'avoid unused function parameters'", + "no_example": "Examples that cannot be judged as 'avoiding unused function parameters'" + }, + { + "id": 55, + "text": "Use is not None to check if a variable is not None", + "detail": "Defect type: Not using 'is not None' to check if a variable is not None; Fix solution: Use 'is not None' to check.", + "language": "Python", + "yes_example": "Example of being judged as 'not using is not None to check if a variable is not None'", + "no_example": "Examples that cannot be judged as 'not using is not None to check if a variable is not None'" + }, + { + "id": 56, + "text": "Avoid using == or != to compare the equivalence of object instances", + "detail": "Defect type: Using == or != to compare object instances for equivalence; Fix solution: Should use equals for comparison.", + "language": "Python", + "yes_example": "Example of being judged as 'using == or != to compare the equivalence of object instances'", + "no_example": "Examples that cannot be judged as 'using == or != to compare the equivalence of object instances'" + }, + { + "id": 57, + "text": "Avoid using single-letter variable names, use descriptive variable names", + "detail": "Defect type: Avoid using single-letter variable names, use descriptive variable names; Fix solution: Use descriptive variable names.", + "language": "Python", + "yes_example": "Examples of being judged as 'avoid using single-letter variable names, use descriptive variable names'", + "no_example": "Examples that cannot be judged as 'avoid using single-letter variable names, use descriptive variable names'" + }, + { + "id": 58, + "text": "Constant names use all uppercase letters and separate words with underscores", + "detail": "Defect type: Constant naming does not use all uppercase letters or does not use underscores to separate; Repair solution: Use all uppercase letters for constant naming and separate with underscores.", + "language": "Python", + "yes_example": "Example of being judged as 'Constant naming not using all uppercase letters and separated by underscores'", + "no_example": "Examples that cannot be judged as 'constant naming not using all uppercase letters and separated by underscores'" + }, + { + "id": 59, + "text": "Class names should use camel case (CamelCase)", + "detail": "Defect type: Class name not using camel case; Repair solution: Use camel case for class names.", + "language": "Python", + "yes_example": "Examples of being judged as 'class name not using CamelCase'", + "no_example": "Examples that cannot be judged as 'class name not using CamelCase'" + }, + { + "id": 60, + "text": "Try to use the with statement to manage resources as much as possible", + "detail": "Defect type: Not using the with statement to manage resources; Fix solution: Use the with statement to manage resources.", + "language": "Python", + "yes_example": "Example of being judged as 'not using the with statement to manage resources'", + "no_example": "Examples that cannot be judged as 'not using the with statement to manage resources'" + }, + { + "id": 61, + "text": "Avoid using except or generic Exception to catch all exceptions, specify the exception type instead.", + "detail": "Defect type: catch all exceptions; Fix solution: specify specific exception types.", + "language": "Python", + "yes_example": "Examples judged as 'catching all exceptions using except:' and 'throwing a generic Exception exception'", + "no_example": "Example that cannot be judged as 'using except: to catch all exceptions'" + }, + { + "id": 62, + "text": "Avoid manual string concatenation whenever possible", + "detail": "Defect type: manual string concatenation; Fix solution: use formatted strings or join method.", + "language": "Python", + "yes_example": "Examples of being judged as 'manual string concatenation'", + "no_example": "Examples that cannot be judged as 'manual string concatenation'" + }, + { + "id": 63, + "text": "Avoid using magic characters and numbers, should be declared as constants", + "detail": "Defect type: Using magic characters and numbers; Fix solution: Declare them as constants.", + "language": "Python", + "yes_example": "Examples of being judged as 'having magic characters and numbers'", + "no_example": "Examples that cannot be judged as 'containing magic characters and numbers'" + }, + { + "id": 64, + "text": "Boolean variable judgment does not require explicit comparison", + "detail": "Defect type: explicit comparison of boolean variables; fix solution: directly use boolean variables for judgment.", + "language": "Python", + "yes_example": "Examples of being judged as 'explicit comparison of boolean variables'", + "no_example": "Examples that cannot be judged as 'explicit comparison of boolean variables'" + }, + { + "id": 65, + "text": "Avoid using type() to check object types", + "detail": "Defect type: Avoid using type() to check object type; Fix solution: Use isinstance() function.", + "language": "Python", + "yes_example": "Example of being judged as 'avoid using type() to check object type'", + "no_example": "Examples that cannot be judged as 'avoid using type() to check object type'" + }, + { + "id": 66, + "text": "Avoid using os.system() to call external commands", + "detail": "Defect type: Using os.system() to call external commands; Fix solution: Use the subprocess module.", + "language": "Python", + "yes_example": "Examples of being judged as 'using os.system() to call external commands'\nos.system('ls -l')\nos.system('ls -l')", + "no_example": "Examples that cannot be judged as 'using os.system() to call external commands'" + }, + { + "id": 67, + "text": "Create read-only properties using the @property decorator instead of modifying properties", + "detail": "Defect type: Creating modifiable properties using the @property decorator; Fix solution: Only use the @property decorator to create read-only properties.", + "language": "Python", + "yes_example": "Examples of being judged as 'using the @property decorator to create modifiable attributes'", + "no_example": "Examples that cannot be judged as 'using the @property decorator to create a modifiable attribute'" + }, + { + "id": 68, + "text": "When using indexing or slicing, do not add spaces inside the brackets or colons.", + "detail": "Defect type: adding spaces inside brackets or colons for indexing or slicing; Repair solution: remove spaces inside brackets or colons.", + "language": "Python", + "yes_example": "Examples judged as 'using spaces inside brackets or colons when using indexing or slicing'", + "no_example": "Examples that cannot be judged as 'adding spaces inside brackets or colons when using indexes or slices'" + }, + { + "id": 69, + "text": "Do not add a space before a comma, semicolon, or colon, but add a space after them", + "detail": "Defect type: adding a space before a comma, semicolon, or colon, or not adding a space after them; Fix solution: do not add a space before a comma, semicolon, or colon, but add a space after them.", + "language": "Python", + "yes_example": "Examples judged as 'adding a space before a comma, semicolon, or colon, or not adding a space after them'", + "no_example": "Examples that cannot be judged as 'adding a space before a comma, semicolon, or colon, or not adding a space after them'" + }, + { + "id": 70, + "text": "For binary operators, there should be spaces on both sides", + "detail": "Defect type: no spaces around binary operators; Fix solution: add spaces around binary operators", + "language": "Python", + "yes_example": "Example of being judged as 'no space around binary operator'", + "no_example": "Examples that cannot be judged as 'no space on both sides of the binary operator'" + }, + { + "id": 71, + "text": "Avoid using Python keywords as variable or function names", + "detail": "Defect type: Using Python keywords as variable names or function names; Repair solution: Use non-keyword names.", + "language": "Python", + "yes_example": "Examples of being judged as 'using Python keywords as variable names or function names'", + "no_example": "Examples that cannot be judged as 'using Python keywords as variable names or function names'\ndef my_function():\n pass\nnumber = 5" + }, + { + "id": 72, + "text": "Avoid using special characters as variable names/method names/class names, such as $ or @", + "detail": "Defect type: Using special characters as variable names/method names/class names; Repair solution: Use legal variable names.", + "language": "Python", + "yes_example": "Examples of being judged as 'using special characters as variable names/method names/class names, such as $ or @'", + "no_example": "Examples that cannot be judged as 'using special characters as variable names/method names/class names, such as $ or @'" + }, + { + "id": 73, + "text": "Avoid using raise to rethrow the current exception, as it will lose the original stack trace.", + "detail": "Defect type: Re-raise the current exception using raise; Fix solution: Use the raise ... from ... syntax.", + "language": "Python", + "yes_example": "Examples of being judged as 'avoid using raise to rethrow the current exception, as it will lose the original stack trace'", + "no_example": "Examples that cannot be judged as 'avoid using raise to rethrow the current exception, as it will lose the original stack trace'" + }, + { + "id": 74, + "text": "Avoid using pass in except block, as it will catch and ignore the exception", + "detail": "Defect type: using pass in except block; Fix solution: handle the exception or log the error.", + "language": "Python", + "yes_example": "Examples of being judged as 'using pass in except block'", + "no_example": "Examples that cannot be judged as 'using pass in an except block'" + }, + { + "id": 75, + "text": "Avoid using assert statements to perform important runtime checks", + "detail": "Defect type: Using assert statements for important runtime checks; Fix solution: Use explicit condition checks and exception handling.", + "language": "Python", + "yes_example": "Example of being judged as 'using assert statements to perform important runtime checks'", + "no_example": "Examples that cannot be judged as 'using assert statements to perform important runtime checks'" + }, + { + "id": 76, + "text": "Avoid using eval() and exec(), these functions may bring security risks", + "detail": "Defect type: Use of eval() and exec() functions; Repair solution: Use secure alternatives.", + "language": "Python", + "yes_example": "Examples of being judged as 'using eval() and exec()'\n\n eval('print(1)') \n\n \n exec('a = 1') \n", + "no_example": "Examples that cannot be judged as 'using eval() and exec()'\n\ncompiled_code = compile('print(1)', '', 'exec')\nexec(compiled_code)\n" + }, + { + "id": 77, + "text": "Avoid using sys.exit(), use exceptions to control program exit instead.", + "detail": "Defect type: Avoid using sys.exit(), should use exceptions to control program exit; Repair solution: Use exceptions to control program exit.", + "language": "Python", + "yes_example": "Examples of being judged as 'avoid using sys.exit(), should use exceptions to control program exit'", + "no_example": "Examples that cannot be judged as 'avoid using sys.exit(), should use exceptions to control program exit'" + }, + { + "id": 78, + "text": "Avoid using time.sleep() for thread synchronization, and instead use synchronization primitives such as locks or events.", + "detail": "Defect type: Using time.sleep() for thread synchronization; Fix solution: Use synchronization primitives.", + "language": "Python", + "yes_example": "Examples of being judged as 'using time.sleep() for thread synchronization'", + "no_example": "Examples that cannot be judged as 'using time.sleep() for thread synchronization'" + }, + { + "id": 79, + "text": "Avoid exceeding 79 characters per line of code", + "detail": "Defect type: Avoid exceeding 79 characters per line of code; Fix solution: Format long lines of code into multiple lines.", + "language": "Python", + "yes_example": "Example of being judged as 'avoiding more than 79 characters per line of code'", + "no_example": "Examples that cannot be judged as 'each line of code should not exceed 79 characters'" + }, + { + "id": 80, + "text": "Functions and class definitions at the module level are separated by two blank lines, and method definitions within a class are separated by one blank line", + "detail": "Defect type: There is no separation of two blank lines between function and class definitions at the module level, and no separation of one blank line between method definitions within the class; Solution: Add blank lines according to the specification.", + "language": "Python", + "yes_example": "Example of being judged as 'Functions at the module level are not separated by two blank lines, and method definitions within a class are not separated by one blank line'", + "no_example": "Examples that cannot be judged as 'There is no two blank lines between module-level function and class definitions, and no one blank line between method definitions inside a class'" + }, + { + "id": 81, + "text": "Use lowercase letters and underscores to separate variable and function names", + "detail": "Defect type: Variable and function naming do not conform to the lowercase letters and underscore separation method; Repair solution: Use lowercase letters and underscore separation method for naming.", + "language": "Python", + "yes_example": "Examples of being judged as 'not using lowercase letters and underscores to separate variable and function names'", + "no_example": "Examples that cannot be judged as 'naming variables and functions without using lowercase letters and underscores to separate them'" + }, + { + "id": 82, + "text": "It is not allowed to use the print() function to record logs, use the logging module, etc. to record logs", + "detail": "Defect type: Using the print() function to log; Fix solution: Use the logging module to log.", + "language": "Python", + "yes_example": "Examples of being judged as 'using the print() function to log'", + "no_example": "Examples that cannot be considered as 'using the print() function to log'" + } +] \ No newline at end of file diff --git a/metagpt/ext/cr/points_cn.json b/metagpt/ext/cr/points_cn.json new file mode 100644 index 0000000000000000000000000000000000000000..10fc951c07ca297f84de1e9f81ff5b83bff73e9b --- /dev/null +++ b/metagpt/ext/cr/points_cn.json @@ -0,0 +1,656 @@ +[ + { + "id": 1, + "text": "避免未使用的临时变量", + "language": "Java", + "detail": "缺陷类型:避免未使用的临时变量;对应Fixer:UnusedLocalVariableFixer;修复方案:删除未使用的临时变量", + "yes_example": "### 被判定为\"避免未使用的临时变量\"的例子\n<例子1>\npublic String initCreationForm(Map model) {\n\t\tOwner owner = new Owner();\n\t\tmodel.put(\"owner\", owner);\n\t\tint unusedVar = 10;\n\t\treturn VIEWS_OWNER_CREATE_OR_UPDATE_FORM;\n\t}\n上述代码中unusedVar变量未被使用,所以这个被判定为\"避免未使用的临时变量\"\n\n<例子2>\nint unusedVariable = 10;\nSystem.out.println(\"Hello, World!\");\n这段代码的变量\"unusedVariable\"未被使用或者引用,所以这个不能判定为\"避免未使用的临时变量\"\n", + "no_example": "### 不能被判定为\"避免未使用的临时变量\"的例子\n<例子1>\npublic void setTransientVariablesLocal(Map transientVariables) {\nthrow new UnsupportedOperationException(\"No execution active, no variables can be set\");\n}\n这段代码的\"transientVariables\"是函数参数而不是临时变量,虽然transientVariables没有被使用或者引用,但是这个也不能判定为\"避免未使用的临时变量\"\n\n\n<例子2>\npublic class TriggerCmd extends NeedsActiveExecutionCmd {\n protected Map transientVariables;\n public TriggerCmd(Map transientVariables) {\n this.transientVariables = transientVariables;\n }\n}\n上述代码中transientVariables不属于临时变量,它是类属性,且它在构造函数中被使用,所以这个不能被判定为\"避免未使用的临时变量\"\n" + }, + { + "id": 2, + "text": "不要使用 System.out.println 去打印", + "language": "Java", + "detail": "缺陷类型:不要使用 System.out.println 去打印;对应Fixer:SystemPrintlnFixer;修复方案:注释System.out.println代码", + "yes_example": "### 被判定为\"不要使用 System.out.println 去打印\"的例子\n<例子1>\nSystem.out.println(\"Initializing new owner form.\");\n上述代码使用了\"System.out.println\"进行打印,所以这个被判定为\"不要使用 System.out.println 去打印\"\n", + "no_example": "### 不能被判定为\"不要使用 System.out.println 去打印\"的例子\n<例子1>\nthrow new IllegalStateException(\"There is no authenticated user, we need a user authenticated to find tasks\");\n上述代码是抛出异常的代码,没有使用\"System.out.print\",所以这个不能被判定为\"不要使用 System.out.println 去打印\"\n" + }, + { + "id": 3, + "text": "避免函数中未使用的形参", + "language": "Java", + "detail": "缺陷类型:避免函数中未使用的形参;修复方案:忽略", + "yes_example": "### 被判定为\"避免函数中未使用的形参\"的例子\n<例子1>\npublic void setTransientVariablesLocal(Map transientVariables) {\n throw new UnsupportedOperationException(\"No execution active, no variables can be set\");\n}这段代码中的形参\"transientVariables\"未在函数体内出现,所以这个被判定为\"避免函数中未使用的形参\"\n\n\n<例子2>\nprotected void modifyFetchPersistencePackageRequest(PersistencePackageRequest ppr, Map pathVars) {}\n这段代码中的形参\"ppr\"和\"pathVars\"未在函数体内出现,所以这个被判定为\"避免函数中未使用的形参\"\n", + "no_example": "### 不能被判定为\"避免函数中未使用的形参\"的例子\n<例子1>\npublic String processFindForm(@RequestParam(value = \"pageNo\", defaultValue = \"1\") int pageNo) {\n\tlastName = owner.getLastName();\n\treturn addPaginationModel(pageNo, paginationModel, lastName, ownersResults);\n}这段代码中的形参\"pageNo\"在当前函数'processFindForm'内被'return addPaginationModel(pageNo, paginationModel, lastName, ownersResults);'这一句被使用,虽然pageNo没有被用于逻辑计算,但作为了函数调用其他函数的参数使用了,所以这个不能被判定为\"避免函数中未使用的形参\"\n\n<例子2>\npublic void formatDate(Date date) {\n\tSimpleDateFormat sdf = new SimpleDateFormat(\"yyyy-MM-dd\");\n\tSystem.out.println(\"Formatted date: \" + sdf.format(date));\n}这段代码中的形参date在System.out.println(\"Formatted date: \" + sdf.format(date))这一句中被引用到,所以这个不能被判定为\"避免函数中未使用的形参\"\n" + }, + { + "id": 4, + "text": "if语句块不能为空", + "language": "Java", + "detail": "缺陷类型:if 语句块不能为空;对应Fixer:EmptyIfStmtFixer;修复方案:删除if语句块 或 适当的逻辑处理 或 注释说明为何为空", + "yes_example": "### 被判定为\"if语句块不能为空\"的例子\n<例子1>\npublic void emptyIfStatement() {\n\tif (getSpecialties().isEmpty()) {\n\t}\n}这段代码中的if语句块内容是空的,所以这个被判定为\"if语句块不能为空\"\n\n\n<例子2>\npublic void judgePersion() {\n\tif (persion != null) {\n\t\t// judge persion if not null\n\t}\n}\n这段代码中的if语句块虽然有内容,但是\"// judge persion if not null\"只是代码注释,if语句块内并没有实际的逻辑代码,所以这个被判定为\"if语句块不能为空\"\n", + "no_example": "### 不能被判定为\"if语句块不能为空\"的例子\n<例子1>\npublic void judgePersion() {\n\tif (persion != null) {\n\t\treturn 0;\n\t}\n}这段代码中的if语句块里有内容,且里面有非注释代码的逻辑代码\"return 0;\",所以这个不能被判定为\"if语句块不能为空\"\n" + }, + { + "id": 5, + "text": "循环体不能为空", + "language": "Java", + "detail": "缺陷类型:循环体不能为空;对应Fixer:EmptyStatementNotInLoopFixer;修复方案:删除对应while、for、foreach 循环体 或 添加适当的逻辑处理或者注释说明为何为空", + "yes_example": "### 被判定为\"循环体不能为空\"的例子\n<例子1>\npublic void emptyLoopBody() {\n\tfor (Specialty specialty : getSpecialties()) {\n\t}\n}这段代码中的for循环体的内容是空的,所以这个被判定为\"循环体不能为空\"\n\n\n<例子2>\npublic void emptyLoopBody() {\n\twhile (True) {\n\t\t// this is a code example\n\t}\n}这段代码中的while循环体的内容虽然不是空的,但内容只是代码注释,无逻辑内容,所以这个被判定为\"循环体不能为空\"\n\n\n<例子3>\npublic void emptyLoopBody() {\n\twhile (True) {\n\t\t\n\t}\n}这段代码中的while循环体内容是空的,所以这个被判定为\"循环体不能为空\"\n", + "no_example": "### 不能被判定为\"循环体不能为空\"的例子\n<例子1>\npublic void emptyLoopBody() {\n\tfor (Specialty specialty : getSpecialties()) {\n\t\ta = 1;\n\t\tif (a == 1) {\n\t\t\tretrun a;\n\t\t}\n\t}\n}上述代码的for循环体的内容不为空,且内容不全是代码注释,所以这个不能被判定为\"循环体不能为空\"\n" + }, + { + "id": 6, + "text": "避免使用 printStackTrace(),应该使用日志的方式去记录", + "language": "Java", + "detail": "缺陷类型:避免使用 printStackTrace(),应该使 用日志的方式去记录;修复方案:用日志的方式去记录", + "yes_example": "### 被判定为\"避免使用 printStackTrace(),应该使用日志的方式去记录\"的例子\n<例子1>\npublic void usePrintStackTrace() {\n\ttry {\n\t\tthrow new Exception(\"Fake exception\");\n\t} catch (Exception e) {\n\t\te.printStackTrace();\n\t}\n}这段代码中的catch语句中使用了printStackTrace(),所以这个被判定为\"避免使用 printStackTrace(),应该使用日志的方式去记录\"\n", + "no_example": "### 不能被判定为\"避免使用 printStackTrace(),应该使用日志的方式去记录\"的例子\n<例子1>\npublic void usePrintStackTrace() {\n\ttry {\n\t\tthrow new Exception(\"Fake exception\");\n\t} catch (Exception e) {\n\t\tlogging.info(\"info\");\n\t}\n}这段代码的catch语句中使用的是日志记录的方式,所以这个不能被判定为\"避免使用 printStackTrace(),应该使用日志的方式去记录\"\n" + }, + { + "id": 7, + "text": "catch 语句块不能为空", + "language": "Java", + "detail": "缺陷类型:catch 语句块不能为空;对应Fixer:EmptyCatchBlockFixer;修复方案:在catch里面添加注释", + "yes_example": "### 被判定为\"catch语句块不能为空\"的例子\n<例子1>\ntry {\n int[] array = new int[5];\n int number = array[10];\n} catch (ArrayIndexOutOfBoundsException e) {\n \n}\n这段代码中的catch语句中没有内容,所以这个被判定为\"catch语句块不能为空\"\n\n\n<例子2>\ntry {\n String str = null;\n str.length();\n} catch (NullPointerException e) {\n \n}这段代码中的catch语句中没有内容,所以这个被判定为\"catch语句块不能为空\"\n\n\n<例子3>\npublic class EmptyCatchExample {\n public static void main(String[] args) {\n try {\n // 尝试除以零引发异常\n int result = 10 / 0;\n } catch (ArithmeticException e) {\n \n }\n }\n}这段代码中的catch语句中没有内容,所以这个被判定为\"catch语句块不能为空\"\n\n<例子4>\ntry {\n FileReader file = new FileReader(\"nonexistentfile.txt\");\n} catch (FileNotFoundException e) {\n \n}这段代码中的catch语句中没有内容,所以这个被判定为\"catch语句块不能为空\"\n\n<例子5>\ntry {\n Object obj = \"string\";\n Integer num = (Integer) obj;\n} catch (ClassCastException e) {\n\t\n}这段代码中的catch语句中没有内容,所以这个被判定为\"catch语句块不能为空\"\n", + "no_example": "### 不能被判定为\"catch语句块不能为空\"的例子\n<例子1>\npersionNum = 1\ntry {\n\treturn True;\n} catch (Exception e) {\n\t// 如果人数为1则返回false\n\tif (persionNum == 1){\n\t\treturn False;\n\t}\n}这段代码的catch语句中不为空,所以不能把这个被判定为\"catch语句块不能为空\"\n\n\n<例子2>\ntry {\n\tthrow new Exception(\"Fake exception\");\n} catch (Exception e) {\n\te.printStackTrace();\n}这段代码的catch语句中虽然只有\"e.printStackTrace();\"但确实不为空,所以不能把这个被判定为\"catch语句块不能为空\"\n" + }, + { + "id": 8, + "text": "避免不必要的永真/永假判断", + "language": "Java", + "detail": "缺陷类型:避免不必要的永真/永假判断;对应Fixer:UnconditionalIfStatement Fixer;修复方案:删除永真/永假判断逻辑", + "yes_example": "### 被判定为\"避免不必要的永真/永假判断\"的例子\n<例子1>\npublic void someMethod() {\n\twhile (true) {\n\t}\n}这段代码中的\"while (true)\"是一个使用true做判断条件,但是没有循环结束标记,所以这个被判定为\"避免不必要的永真/永假判断\"\n\n\n<例子2>\nif (true) {\n\tSystem.out.println(\"This is always true\");\n}这段代码中的\"if (true)\"是一个使用true条件做条件,但是没有循环结束标记,所以这个被判定为\"避免不必要的永真/永假判断\"\n\n\n<例子3>\na = 1;\nwhile(a > 0){\n\ta = a + 1\n}这段代码初始化a=1,是大于0的,while循环体的逻辑是每次加1,那么判断条件a > 0会永远是真的,不会退出循环,所以这个被判定为\"避免不必要的永真/永假判断\"\n<例子3>", + "no_example": "### 不能被判定为\"避免不必要的永真/永假判断\"的例子\n<例子1>\na = 0;\nwhile (a < 5) {\n\ta = a + 1;\n}这段代码中的a<5是一个判断,当执行了5次while语句中的逻辑a=a+1之后,a会满足a < 5,就会退出循环,所以这个能被判定为\"避免不必要的永真/永假判断\"\n" + }, + { + "id": 9, + "text": "switch 中 default 必须放在最后", + "language": "Java", + "detail": "缺陷类型:switch 中 default 必须放在最后;对应Fixer:DefaultLabelNotLastInSwitchStmtFixer;修复方案:switch 中 default 放在最后", + "yes_example": "### 被判定为\"switch 中 default 必须放在最后\"的例子\n<例子1>\nswitch (number) {\n\tdefault:\n\t\tSystem.out.println(\"This is the default block, which is incorrectly placed here.\");\n\t\tbreak;\n\tcase 1:\n\t\tSystem.out.println(\"Number one\");\n\t\tbreak;\n\tcase 2:\n\t\tSystem.out.println(\"Number two\");\n\t\tbreak;\n}这段代码是一个switch语句,但是里面的default没有放在最后,所以这个被判定为\"switch 中 default 必须放在最后\"\n", + "no_example": "### 不能被判定为\"switch 中 default 必须放在最后\"的例子\n<例子1>\nswitch (number) {\ncase 3:\n\tSystem.out.println(\"Number one\");\n\tbreak;\ncase 4:\n\tSystem.out.println(\"Number two\");\n\tbreak;\ndefault:\n\tSystem.out.println(\"This is the default block, which is incorrectly placed here.\");\n\tbreak;\n}这段代码是一个switch语句且里面的default放在了最后,所以这个不能被判定为\"switch 中 default 必须放在最后\"\n" + }, + { + "id": 10, + "text": "未使用equals()函数对 String 作比较", + "language": "Java", + "detail": "缺陷类型:未使用equals()函数对 String 作比较;对应Fixer:UnSynStaticDateFormatter Fixer;修复方案:使用equals()函数对 String 作比较", + "yes_example": "### 被判定为\"未使用equals()函数对 String 作比较\"的例子\n<例子1>\nif (existingPet != null && existingPet.getName() == petName) {\n\tresult.rejectValue(\"name\", \"duplicate\", \"already exists\");\n}这段代码中所涉及的existingPet.getName()和petName均是字符串,但是在if语句里做比较的时候使用了==而没有使用equals()对string做比较,所以这个被判定为\"未使用equals()函数对 String 作比较\"\n\n\n<例子2>\nString isOk = \"ok\";\nif (\"ok\" == isOk) {\n\tresult.rejectValue(\"name\", \"duplicate\", \"already exists\");\n}这段代码中的isOk是个字符串,但在if判断中与\"ok\"比较的时候使用的是==,未使用equals()对string做比较,应该使用\"ok\".equals(isOk),所以这个被判定为\"未使用equals()函数对 String 作比较\"\n\n\n<例子3>\nString str1 = \"Hello\";\nString str2 = \"Hello\";\nif (str1 == str2) {\n\tSystem.out.println(\"str1 和 str2 引用相同\");\n} else {\n\tSystem.out.println(\"str1 和 str2 引用不同\");\n}\n这段代码中的if (str1 == str2) 使用了==进行str1和str2的比较,未使用equals()对string做比较,应该使用str1.equals(str2),所以这个被判定为\"未使用equals()函数对 String 作比较\"\n\n\n<例子4>\nString str = \"This is string\";\nif (str == \"This is not str\") {\n\treturn str;\n}这段代码中的if (str == \"This is not str\")使用了==进行字符串比较,未使用equals()对string做比较,\"This is not str\".equals(str),所以这个被判定为\"未使用equals()函数对 String 作比较\"\n", + "no_example": "### 不能被判定为\"未使用equals()函数对 String 作比较\"的例子\n<例子1>\nif (PROPERTY_VALUE_YES.equalsIgnoreCase(readWriteReqNode))\n formProperty.setRequired(true);\n这段代码中的PROPERTY_VALUE_YES和readWriteReqNode均是字符串,在if语句里比较PROPERTY_VALUE_YES和readWriteReqNode的使用的是equalsIgnoreCase(字符串比较忽略大小写),所以equalsIgnoreCase也是符合使用equals()函数对 String 作比较的,所以这个不能被判定为\"未使用equals()函数对 String 作比较\"\n\n\n<例子2>\nString isOk = \"ok\";\nif (\"ok\".equals(isOk)) {\n\tresult.rejectValue(\"name\", \"duplicate\", \"already exists\");\n}这段代码中的isOk是个字符串,在if判断中与\"ok\"比较的时候使用的是equals()对string做比较,所以这个不能被判定为\"未使用equals()函数对 String 作比较\"\n" + }, + { + "id": 11, + "text": "禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象", + "language": "Java", + "detail": "缺陷类型:禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象 输出异常;对应Fixer:ConcatExceptionFixer;修复方案:使用占位符传递异常对象", + "yes_example": "### 被判定为\"禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象\"的例子\n<例子1>\ntry {\n listenersNode = objectMapper.readTree(listenersNode.asText());\n} catch (Exception e) {\n LOGGER.info(\"Listeners node can not be read\", e);\n}这段代码中日志输出内容内容是直接使用字符串\"Listeners node can not be read\"拼接,日志输出异常时,应使用占位符输出异常信息,而不是直接使用字符串拼接,所以这个被判定为\"禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象\"\n", + "no_example": "### 不能被判定为\"禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象\"的例子\n<例子1>\nPersion persion = persionService.getPersion(1);\nif (persion == null){\n\tLOGGER.error(PERSION_NOT_EXIT);\n}这段代码中的PERSION_NOT_EXIT是一个用户自定义的异常常量,代表persion不存在,没有直接使用字符串\"persion not exit\"拼接,所以这个不能被判定为\"禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象\"\n<例子1>\n\n<例子2>\ntry {\n a = a + 1;\n} catch (Exception e) {\n Persion persion = persionService.getPersion(1);\n LOGGER.info(persion);\n}这段代码中输出日志没有直接使用字符串拼接,而是使用的Persion对象输出,所以这个不能被判定为\"禁止在日志中直接使用字符串输出异常,请使用占位符传递异常对象\"\n" + }, + { + "id": 12, + "text": "finally 语句块不能为空", + "language": "Java", + "detail": "缺陷类型:finally 语句块不能为空;对应Fixer:EmptyFinallyBlockFixer;修复方案:删除空 finally 语句块", + "yes_example": "### 被判定为\"finally 语句块不能为空\"的例子\n<例子1>\ntry {\n\tPersion persion = persionService.getPersion(1);\n\treturn persion;\n} finally {\n\t\n}这段代码中的finally语句块内没有内容,所以这个被判定为\"finally 语句块不能为空\"\n\n\n<例子2>\ntry {\n\tSystem.out.println(\"Inside try block\");\n} finally {\n\t// 空的finally块,没有任何语句,这是一个缺陷\n}这段代码中的finally语句块内没有内容,所以这个被判定为\"finally 语句块不能为空\"\n\n\n<例子3>\ntry {\n int result = 10 / 0;\n} catch (ArithmeticException e) {\n e.printStackTrace();\n} finally {\n \n}这段代码中的finally语句块内没有内容,所以这个被判定为\"finally 语句块不能为空\"\n\n\n<例子4>\ntry {\n String str = null;\n System.out.println(str.length());\n} catch (NullPointerException e) {\n e.printStackTrace();\n} finally {\n \n}这段代码中的finally语句块内没有内容,所以这个被判定为\"finally 语句块不能为空\"\n\n\n<例子5>\ntry {\n int[] array = new int[5];\n int number = array[10];\n} catch (ArrayIndexOutOfBoundsException e) {\n e.printStackTrace();\n} finally {\n // 只有注释的 finally 语句块\n // 这是一个空的 finally 块\n}这段代码中的finally语句块内没有内容,所以这个被判定为\"finally 语句块不能为空\"\n\n\n<例子6>\ntry {\n FileReader file = new FileReader(\"nonexistentfile.txt\");\n} catch (FileNotFoundException e) {\n e.printStackTrace();\n} finally {\n // 只有空行的 finally 语句块\n \n}这段代码中的finally语句块内没有内容,所以这个被判定为\"finally 语句块不能为空\"\n", + "no_example": "### 不能被判定为\"finally 语句块不能为空\"的例子\n<例子1>\npublic void getPersion() {\n\ttry {\n\t\tPersion persion = persionService.getPersion(1);\n\t\tif (persion != null){ \n\t\t\treturn persion;\n\t\t}\n\t} finally {\n\t\treturn null;\n\t}\n}这段代码中的finally语句块中有非注释意外的内容\"return null;\",所以这个不能被判定为\"finally 语句块不能为空\"\n" + }, + { + "id": 13, + "text": "try 语句块不能为空", + "language": "Java", + "detail": "缺陷类型:try 语句块不能为空;对应Fixer:EmptyTryBlockFixer;修复方案:删除整个 try 语句", + "yes_example": "### 被判定为\"try 语句块不能为空\"的例子\n<例子1>\npublic void getPersion() {\n\ttry {\n\n\t}\n\treturn null;\n}这段代码中的try语句块内没有内容,所以这个被判定为\"try 语句块不能为空\"\n\n\n<例子2>\npublic void demoFinallyBlock() {\n\ttry {\n\n\t} finally {\n\t\treturn null;\n\t}\n}这段代码中的try语句块内没有内容,所以这个被判定为\"try 语句块不能为空\"\n\n\n<例子3>\ntry {\n \n} catch (Exception e) {\n e.printStackTrace();\n}这段代码中的try语句块内没有内容,所以这个被判定为\"try 语句块不能为空\"\n\n\n<例子4>\ntry {\n // 只有注释的 try 语句块\n\t\n} catch (Exception e) {\n e.printStackTrace();\n}这段代码中的try语句块内只有注释和空行,也可以认定为这种情况是try语句块内没有内容,所以这个被判定为\"try 语句块不能为空\"\n", + "no_example": "### 不能被判定为\"try 语句块不能为空\"的例子\n<例子1>\ntry {\n\ta = a + 1;\n} catch (Exception e) {\n\te.printStackTrace();\n}\n这段代码中的try语句块中有非注释意外的内容\"return null;\",所以这个不能被判定为\"try 语句块不能为空\"\n" + }, + { + "id": 14, + "text": "避免对象进行不必要的 NULL或者null 检查", + "language": "Java", + "detail": "缺陷类型:避免对象进行不必要的 NULL或者null 检查;对应Fixer:LogicalOpNpeFixer;修复方案:删除对对象不必要的 NULL 检查的逻辑", + "yes_example": "### 被判定为\"避免对象进行不必要的 NULL或者null 检查\"的例子\n<例子1>\na = \"dog\";\nif (a != null){\n\treturn a;\n}这段代码中的对象a已经是确定的值\"dog\",所以if条件句的判断\"a != null\"是不必要的,所以这个被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n\n\n<例子2>\nif (authenticatedUserId != null && !authenticatedUserId.isEmpty() && userGroupManager!=null){\n\treturn authenticatedUserId;\n}这段代码中的\"authenticatedUserId != null\"和\"!authenticatedUserId.isEmpty()\"都是对\"authenticatedUserId\"的空判断,重复了,所以这个被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n\n\n<例子3>\nList list = new ArrayList<>();\nif (list != null) {\n list.add(1);\n}这段代码中的list已经被初始化,不需要进行 null 检查,所以这个被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n\n\n<例子4>\nif (this.type != null && this.type.getName() != null) {\n\tSystem.out.println(\"Type name is not null\");\n}这段代码中的对象type已经检查过非null,再次检查getName()是否为null是不必要的,所以这个被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n\n\n\n<例子5>\nif (\"dog\".equals(null)){\n\treturn a;\n}这段代码中的\"dog\"是个确定的字符串,不需要进行null 检查,所以这个被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n\n\n<例子6>\nInteger num = 10;\nif (num != null) {\n System.out.println(num);\n}这段代码中的num 已经被初始化,不需要进行 null 检查,所以这个被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n", + "no_example": "### 不能被判定为\"避免对象进行不必要的 NULL或者null 检查\"的例子\n<例子1>\nCat cat = catService.get(1);\nif (cat != null){\n\tretrun cat;\n}这段代码中的对象\"cat\"是通过service获取到的,不确定是否为空,所以if条件句的判断的\"cat != null\"是必要的,所以这个不能被判定为\"避免对象进行不必要的 NULL或者null 检查\"\n" + }, + { + "id": 15, + "text": "避免 finally 块中出现 return", + "language": "Java", + "detail": "缺陷类型:避免 finally 块中出现 return;修复方案:无需修复", + "yes_example": "### 被判定为\"避免 finally 块中出现 return\"的例子\n<例子1>\npublic void getPersion() {\n\ttry {\n\t\tPersion persion = persionService.getPersion(1);\n\t\tif (persion != null){ \n\t\t\treturn persion;\n\t\t}\n\t} finally {\n\t\treturn null;\n\t}\n}这段代码中的finally语句块内容包含\"return\",所以这个被判定为\"避免 finally 块中出现 return\"\n", + "no_example": "### 不能被判定为\"避免 finally 块中出现 return\"的例子\n<例子1>\npublic void getPersion() {\n\ttry {\n\t\tPersion persion = persionService.getPersion(1);\n\t\tif (persion != null){ \n\t\t\treturn persion;\n\t\t}\n\t} finally {\n\t\tLOGGER.info(PERSION_NOT_EXIT);\n\t}\n}这段代码中的finally语句块中内容不包含\"return\",所以这个不能被判定为\"避免 finally 块中出现 return\"\n" + }, + { + "id": 16, + "text": "避免空的 static 初始化", + "language": "Java", + "detail": "缺陷类型:避免空的 static 初始化;对应Fixer:EmptyInitializerFixer;修复方案:删除整个空初始化块", + "yes_example": "### 被判定为\"避免空的 static 初始化\"的例子\n<例子1>\npublic class PetValidator implements Validator {\n\tstatic {\n\n\t}\n}这段代码中的static语句块没有内容,是空的,所以这个被判定为\"避免空的 static 初始化\"\n\n\n<例子2>\npublic class Persion {\n\tstatic {\n\t\t// 初始化的静态块\n\t}\n}这段代码中的static语句块是有内容的,不是空的,但是static初始化语句块中只有注释代码,没有实际的逻辑,所以这个被判定为\"避免空的 static 初始化\"\n", + "no_example": "### 不能被判定为\"避免空的 static 初始化\"的例子\n<例子1>\npublic class Cat {\n\tstatic {\n\t\t// 初始化的静态块\n\t\tcat = null;\n\t}\n}这段代码中的static语句块是有内容的,不是空的,且static初始化语句块中有非注释代码,有实际的逻辑,所以这个不能被判定为\"避免空的 static 初始化\"\n" + }, + { + "id": 17, + "text": "避免日历类用法不当风险", + "language": "Java", + "detail": "缺陷类型:避免日历类用法不当风险;修复方案:使用Java 8 及以上版本中的 java.time 包的LocalDate", + "yes_example": "### 被判定为\"避免日历类用法不当风险\"的例子\n<例子1>\nprivate static final Calendar calendar = new GregorianCalendar(2020, Calendar.JANUARY, 1);\n这段代码中的Calendar和GregorianCalendar是线程不安全的,所以这个被判定为\"避免日历类用法不当风险\"\n", + "no_example": "### 不能被判定为\"避免日历类用法不当风险\"的例子\n<例子1>\nprivate static final LocalDate calendar = LocalDate.of(2020, 1, 1);\n这段代码中的LocalDate使用的是Java 8 及以上版本中的 java.time 包,LocalDate 是不可变的并且是线程安全的,不会有线程安全和性能方面的问题,所以这个不能被判定为\"避免日历类用法不当风险\"\n" + }, + { + "id": 18, + "text": "使用集合转数组的方法,必须使用集合的toArray(T[]array),传入的是类型完全一样的数组,大小就是list.size()", + "language": "Java", + "detail": "缺陷类型:使用集合转数组的方法,必须使用集合的toArray(T[]array),传入的是类型完全一样的数组,大小就是list.size();对应Fixer:ClassCastExpWithToArrayF ixer;修复方案:使用集合的toArray(T[]array),且传入的是类型完全一样的数组", + "yes_example": "### 被判定为\"使用集合转数组的方法,必须使用集合的toArray(T[]array),传入的是类型完全一样的数组,大小就是list.size()\"的例子\n<例子1>\nList stringList = new ArrayList<>();\nstringList.add(\"Apple\");\nstringList.add(\"Banana\");\nObject[] objectArray = stringList.toArray(new Object[5]);\n这段代码使用集合转数组的方法的时候使用了toArray(new Object[5]),但是传入的数组类型不一致,所以这个被判定为\"使用集合转数组的方法,必须使用集合的toArray(T[]array),传入的是类型完全一样的数组,大小就是list.size()\"\n", + "no_example": "### 不能被判定为\"使用集合转数组的方法,必须使用集合的toArray(T[]array),传入的是类型完全一样的数组,大小就是list.size()\"的例子\n<例子1>\nList stringList = new ArrayList<>();\nstringList.add(\"Apple\");\nstringList.add(\"Banana\");\nString[] stringArray = stringList.toArray(new String[stringList.size()]);\n这段代码使用集合转数组的方法的时候使用了toArray(new String[stringList.size()]),传入的是类型完全一样的数组,所以这个不能被判定为\"使用集合转数组的方法,必须使用集合的toArray(T[]array),传入的是类型完全一样的数组,大小就是list.size()\"\n" + }, + { + "id": 19, + "text": "禁止在 equals()中使用 NULL或者null 做比较", + "language": "Java", + "detail": "缺陷类型:禁止在 equals()中使用 NULL或者null 做比较;对应Fixer:EqualsNullFixer;修复方案:使用Object的判空函数 做比较", + "yes_example": "### 被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"的例子\n<例子1>\nif (\"test\".equals(null)) {\n\tSystem.out.println(\"test\");\n}这段代码中if条件中的代码\"test\".equals(null)使用equals()函数与null进行了比较,所以这个被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n\n\n<例子2>\nif (!rangeValues[1].equals(\"null\")) {\n\tmaxValue = new BigDecimal(rangeValues[1]);\n}这段代码中if条件中的代码!rangeValues[1].equals(\"null\")使用equals()函数与Nnull进行了比较,所以这个被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n\n\n<例子3>\nString str1 = \"example\";\nif (str1.equals(\"null\")) {\n System.out.println(\"str1 is null\");\n}这段代码中if条件中的代码str1.equals(null)使用equals()函数与null进行了比较,所以这个被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n\n\n<例子4>\nString str3 = \"example\";\nif (str3 != null && str3.equals(\"null\")) {\n System.out.println(\"str3 is null\");\n}这段代码中if条件中的代码str3.equals(\"null\")使用equals()函数与\"null\"进行了比较,所以这个被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n\n\n<例子5>\nInteger num1 = 10;\nif (num1.equals(null)) {\n System.out.println(\"num1 is null\");\n}这段代码中if条件中的代码num1.equals(null)使用equals()函数与\"null\"进行了比较,所以这个被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n\n\n<例子6>\nObject obj = new Object();\nif (obj.equals(null)) {\n System.out.println(\"obj is null\");\n}这段代码中if条件中的代码obj.equals(null)使用equals()函数与\"null\"进行了比较,所以这个被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n", + "no_example": "### 不能被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"的例子\n<例子1>\na = \"test\";\nif (a.equals(\"test\")) {\n\tSystem.out.println(\"test\");\n}这段代码中if条件中的代码a.equals(\"test\")使用equals()函数与\"test\"进行了比较,所以这个不能被判定为\"禁止在 equals()中使用 NULL或者null 做比较\"\n" + }, + { + "id": 20, + "text": "switch 语句块不能为空", + "language": "Java", + "detail": "缺陷类型:switch 语句块不能为空;对应Fixer:EmptySwitchStatementsFix;修复方案:删除整个空 switch 语句块", + "yes_example": "### 被判定为\"switch 语句块不能为空\"的例子\n<例子1>\nswitch (number) {\n\t\n}这段代码是一个switch语句块,但是里面没有内容,所以这个被判定为\"switch 语句块不能为空\"\n\n\n<例子2>\nswitch (number) {\n\t// 这是一个switch语句块\n}这段代码是一个switch语句块,里面虽然有内容,但是内容仅仅是注释内容,没有实际的逻辑,所以这个被判定为\"switch 语句块不能为空\"\n", + "no_example": "### 不能被判定为\"switch 语句块不能为空\"的例子\n<例子1>\nswitch (number) {\n\tcase 1:\n\t\tSystem.out.println(\"Number one\");\n\t\tbreak;\n\tdefault:\n\t\tSystem.out.println(\"This is the default block, which is incorrectly placed here.\");\n\t\tbreak;\n}这段代码是一个switch语句块,里面有内容,而且内容里有非注释的代码,有实际的逻辑,所以这个不能被判定为\"switch 语句块不能为空\"\n" + }, + { + "id": 21, + "text": "在进行类型强制转换时,右括号与强制转换值之间不需要任何空格隔开", + "detail": "缺陷类型:在进行类型强制转换时,右括号与强制转换值之间不需要任何空格隔开;修复方案:在进行类型强制转换时,右括号与强制转换值之间不需要任何空格隔开。", + "language": "Java", + "yes_example": "### 被判定为\"在进行类型强制转换时,右括号与强制转换值之间不需要任何空格隔开\"的例子\n<例子1>\nint a = (int) 3.0;\n\n<例子2>\nint b = (int) 4.0;\n\n<例子3>\nlong a = (long) 5;\n\n<例子4>\nstring a = (string) 3.5;\n\n<例子5>\nPersion a = (Persion) \"zhangsan\";\n", + "no_example": "### 不能被判定为\"在进行类型强制转换时,右括号与强制转换值之间不需要任何空格隔开\"的例子\n<例子1>\nint a = (int)3.0;\n" + }, + { + "id": 22, + "text": "方法参数在定义和传入时,多个参数逗号后面必须加空格", + "detail": "缺陷类型:方法参数在定义和传入时,多个参数逗号后面必须加空格;修复方案:方法参数在定义和传入时,多个参数逗号后面必须加空格。", + "language": "Java", + "yes_example": "### 被判定为\"方法参数在定义和传入时,多个参数逗号后面必须加空格\"的例子\n<例子1>\npublic void exampleMethod(int a,int b,int c) {}\n", + "no_example": "### 不能被判定为\"方法参数在定义和传入时,多个参数逗号后面必须加空格\"的例子\n<例子1>\npublic void exampleMethod(int a, int b, int c) {}\n" + }, + { + "id": 23, + "text": "禁止使用构造方法 BigDecimal(double) 的方式把 double 值转化为 BigDecimal 对象", + "detail": "缺陷类型:禁止使用构造方法 BigDecimal(double) 的方式把 double 值转化为 BigDecimal 对象;修复方案:推荐使用 BigDecimal 的 valueOf 方法。", + "language": "Java", + "yes_example": "### 被判定为\"禁止使用构造方法 BigDecimal(double) 的方式把 double 值转化为 BigDecimal 对象\"的例子\n<例子1>\nBigDecimal bd = new BigDecimal(0.1);\n", + "no_example": "### 不能被判定为\"禁止使用构造方法 BigDecimal(double) 的方式把 double 值转化为 BigDecimal 对象\"的例子\n<例子1>\nBigDecimal bd = BigDecimal.valueOf(0.1);\n" + }, + { + "id": 24, + "text": "不能有多余的分号", + "detail": "缺陷类型:多余的分号;修复方案:删除多余的分号", + "yes_example": "### 被判定为\"不能有多余的分号\"的例子\n<例子1>\npublic void trigger(String executionId, Map processVariables) {\n commandExecutor.execute(new TriggerCmd(executionId, processVariables));\n}\n;\na = 1;\nb = 2;\nsum = a + b;\n这段代码中包含一个多余的分号\";\",所以这个被判定为\"不能有多余的分号\"\n", + "no_example": "### 不能被判定为\"不能有多余的分号\"的例子\n<例子1>\nwhile (True) {\n\ta = a + 1;\n\tbreak;\n}这段代码每个分号都是必须要的,所以这个能被判定为\"不能有多余的分号\"\n" + }, + { + "id": 25, + "text": "非线程安全的 SimpleDateFormat 使用,必须在函数或代码块级别使用synchronized", + "detail": "缺陷类型:非线程安全的 SimpleDateFormat 使用;修复方案:在函数或代码块级别加上synchronized修饰 或 使用其他线程安全的方式", + "yes_example": "### 被判定为\"非线程安全的 SimpleDateFormat 使用,必须在函数或代码块级别使用synchronized\"的例子\n<例子1>\npublic void formatDate(Date date) {\n\tSimpleDateFormat sdf = new SimpleDateFormat(\"yyyy-MM-dd\");\n\tSystem.out.println(\"Formatted date: \" + sdf.format(date));\n}这段代码中的函数formatDate在未使用synchronized同步修饰的情况下使用了SimpleDateFormat,这是线程不安全的,所以这个被判定为\"非线程安全的 SimpleDateFormat 使用,必须在函数或代码块级别使用synchronized\"\n", + "no_example": "### 不能被判定为\"非线程安全的 SimpleDateFormat 使用,必须在函数或代码块级别使用synchronized\"的例子\n<例子1>\npublic synchronized void formatDate(Date date) {\n\tSimpleDateFormat sdf = new SimpleDateFormat(\"yyyy-MM-dd\");\n\tSystem.out.println(\"Formatted date: \" + sdf.format(date));\n}这段代码是在synchronized同步块对函数'formatDate'进行保护,保证了线程安全,所以这个不能被判定为\"非线程安全的 SimpleDateFormat 使用,必须在函数或代码块级别使用synchronized\"\n" + }, + { + "id": 26, + "text": "未按驼峰命名规范进行命名,类名使用驼峰式UpperCamelCase风格, 方法名、参数名、成员变量、局部变量都统一使用lowerCamelCase风格", + "detail": "缺陷类型:未按驼峰命名规范进行命名;修复方案:类名使用UpperCamelCase风格,方法名、参数名、成员变量、局部变量使用lowerCamelCase风格。", + "language": "Java", + "yes_example": "### 被判定为\"未按驼峰命名规范进行命名\"的例子\n<例子1>\npublic class myClass {\n private int MyVariable;\n public void MyMethod() {}\n}\n这段代码中的类名、成员变量和方法名没有遵循驼峰命名法,所以被判定为命名规范问题。\n", + "no_example": "### 不能被判定为\"未按驼峰命名规范进行命名\"的例子\n<例子1>\npublic class MyClass {\n private int myVariable;\n public void myMethod() {}\n}\n这段代码中的类名、成员变量和方法名都遵循了驼峰命名法,所以不能被判定为命名规范问题。\n" + }, + { + "id": 27, + "text": "抽象类命名使用 Abstract 或 Base 开头;异常类命名使用 Exception 结尾,测试类命名以它要测试的类的名称开始,以 Test 结尾", + "detail": "缺陷类型:命名规范;修复方案:抽象类命名使用 Abstract 或 Base 开头,异常类命名使用 Exception 结尾,测试类命名以它要测试的类的名称开始,以 Test 结尾。", + "language": "Java", + "yes_example": "### 被判定为\"命名规范\"的例子\n<例子1>\npublic class MyAbstractClass {}\npublic class MyExceptionClass {}\npublic class TestMyClass {}\n这段代码中的抽象类、异常类和测试类的命名不符合规范,所以被判定为命名规范问题。\n", + "no_example": "### 不能被判定为\"命名规范\"的例子\n<例子1>\npublic abstract class AbstractMyClass {}\npublic class MyCustomException extends Exception {}\npublic class MyClassTest {}\n这段代码中的抽象类、异常类和测试类的命名都符合规范,所以不能被判定为命名规范问题。\n" + }, + { + "id": 28, + "text": "POJO 类中的任何布尔类型的变量,避免加\"is\" 前缀", + "detail": "缺陷类型:命名规范;修复方案:POJO 类中的布尔类型变量不要加 is 前缀。", + "language": "Java", + "yes_example": "### 被判定为\"命名规范\"的例子\n<例子1>\npublic class User {\n private boolean isActive;\n}\n这段代码中的布尔类型变量加了 is 前缀,所以被判定为命名规范问题。\n", + "no_example": "### 不能被判定为\"命名规范\"的例子\n<例子1>\npublic class User {\n private boolean active;\n}\n这段代码中的布尔类型变量没有加 is 前缀,所以不能被判定为命名规范问题。\n" + }, + { + "id": 29, + "text": "杜绝完全不规范的英文缩写,避免望文不知义。", + "detail": "缺陷类型:命名规范;修复方案:避免使用不规范的英文缩写,确保代码可读性。", + "language": "Java", + "yes_example": "### 被判定为\"命名规范\"的例子\n<例子1>\npublic class CfgMgr {\n private int cnt;\n}\n这段代码中的类名和变量名使用了不规范的英文缩写,所以被判定为命名规范问题。\n", + "no_example": "### 不能被判定为\"命名规范\"的例子\n<例子1>\npublic class ConfigManager {\n private int count;\n}\n这段代码中的类名和变量名没有使用不规范的英文缩写,所以不能被判定为命名规范问题。\n" + }, + { + "id": 30, + "text": "避免出现魔法字符和数字,应声明为常量", + "detail": "缺陷类型:避免出现魔法字符和数字,应声明为常量;修复方案:将魔法值定义为常量。", + "language": "Java", + "yes_example": "### 被判定为\"避免出现魔法字符和数字,应声明为常量\"的例子\n<例子1>\npublic class MagicNumberExample {\n public void calculate() {\n int result = 42 * 2;\n }\n}\n这段代码中直接使用了魔法值 42,所以被判定为代码规范问题。\n\n<例子2>\npublic class MagicNumberExample {\n public void calculate() {\n String result = \"This is a result\";\n }\n}\n这段代码中直接使用了魔法值 \"This is a result\",所以被判定为代码规范问题。\n", + "no_example": "### 不能被判定为\"避免出现魔法字符和数字,应声明为常量\"的例子\n<例子1>\npublic class MagicNumberExample {\n private static final int MULTIPLIER = 42;\n public void calculate() {\n int result = MULTIPLIER * 2;\n }\n}\n这段代码中将魔法值定义为了常量,所以不能被判定为代码规范问题。\n" + }, + { + "id": 31, + "text": "long 或 Long 赋值时,数值后使用大写 L,不能是小写 l,浮点数类型的数值后缀统一为大写的 D 或 F", + "detail": "缺陷类型:代码规范;修复方案:long 或 Long 赋值时使用大写 L,浮点数类型的数值后缀使用大写的 D 或 F。", + "language": "Java", + "yes_example": "### 被判定为\"代码规范\"的例子\n<例子1>\npublic class NumberExample {\n private long value = 1000l;\n private double pi = 3.14d;\n}\n这段代码中使用了小写的 l 和 d,所以被判定为代码规范问题。\n", + "no_example": "### 不能被判定为\"代码规范\"的例子\n<例子1>\npublic class NumberExample {\n private long value = 1000L;\n private double pi = 3.14D;\n}\n这段代码中使用了大写的 L 和 D,所以不能被判定为代码规范问题。\n" + }, + { + "id": 32, + "text": "如果大括号内为空,简洁地写成{}即可,大括号中间无需换行和空格;如果是非空代码块,则:1)左大括号前不换行。2)左大括号后换行。3)右大括号前换行。4)右大括号后还有 else 等代码则不换行;表示终止的右大括号后必须换行。", + "detail": "缺陷类型:代码格式;修复方案:遵循大括号的使用规范。", + "language": "Java", + "yes_example": "### 被判定为\"代码格式\"的例子\n<例子1>\npublic class BracketExample{public void method(){\n if (true) {\n }}\n}\n这段代码中的大括号使用不符合规范,所以被判定为代码格式问题。\n", + "no_example": "### 不能被判定为\"代码格式\"的例子\n<例子1>\npublic class BracketExample {\n public void method() {\n if (true) {\n // do something\n }\n }\n}\n这段代码中的大括号使用符合规范,所以不能被判定为代码格式问题。\n" + }, + { + "id": 33, + "text": "左小括号和右边相邻字符之间不需要空格;右小括号和左边相邻字符之间也不需要空格;而左大括号前需要加空格。", + "detail": "缺陷类型:代码格式;修复方案:遵循括号和空格的使用规范。", + "language": "Java", + "yes_example": "### 被判定为\"代码格式\"的例子\n<例子1>\npublic class SpaceExample {\n public void method (){\n }\n}\n这段代码中的括号和空格使用不符合规范,所以被判定为代码格式问题。\n", + "no_example": "### 不能被判定为\"代码规范\"的例子\n<例子1>\npublic class SpaceExample {\n public void method() {}\n}\n这段代码中的括号和空格使用符合规范,所以不能被判定为代码格式问题。\n" + }, + { + "id": 34, + "text": "if / for / while / switch / do 等保留字与左右括号之间都必须加空格。", + "detail": "缺陷类型:代码格式;修复方案:保留字与左右括号之间加空格。", + "language": "Java", + "yes_example": "### 被判定为\"代码规范\"的例子\n<例子1>\npublic class KeywordExample {\n public void method() {\n if(true) {\n }\n }\n}\n这段代码中的 if 关键字与括号之间没有空格,所以被判定为代码格式问题。\n", + "no_example": "### 不能被判定为\"代码规范\"的例子\n<例子1>\npublic class KeywordExample {\n public void method() {\n if (true) {\n }\n }\n}\n这段代码中的 if 关键字与括号之间有空格,所以不能被判定为代码格式问题。\n" + }, + { + "id": 35, + "text": "所有整型包装类对象之间值的比较,全部使用 equals 方法比较", + "detail": "缺陷类型:代码规范;修复方案:整型包装类对象之间的值比较使用 equals 方法。", + "language": "Java", + "yes_example": "### 被判定为\"代码规范\"的例子\n<例子1>\npublic class IntegerComparison {\n public void compare() {\n Integer a = 100;\n Integer b = 100;\n if (a == b) {\n }\n }\n}\n这段代码中使用了 == 比较整型包装类对象,所以被判定为代码规范问题。\n", + "no_example": "### 不能被判定为\"代码规范\"的例子\n<例子1>\npublic class IntegerComparison {\n public void compare() {\n Integer a = 100;\n Integer b = 100;\n if (a.equals(b)) {\n }\n }\n}\n这段代码中使用了 equals 方法比较整型包装类对象,所以不能被判定为代码规范问题。\n" + }, + { + "id": 36, + "text": "BigDecimal 的等值比较应使用 compareTo() 方法,而不是 equals() 方法。", + "detail": "缺陷类型:BigDecimal 的等值比较应使用 compareTo() 方法,而不是 equals() 方法;修复方案:使用 compareTo() 方法进行比较。", + "language": "Java", + "yes_example": "### 被判定为\"BigDecimal 的等值比较应使用 compareTo() 方法,而不是 equals() 方法\"的例子\n<例子1>\nBigDecimal a = new BigDecimal(\"1.0\");\nBigDecimal b = new BigDecimal(\"1.00\");\nif (a.equals(b)) {\n // 这段代码会返回 false,因为 equals() 方法会比较精度\n}\n", + "no_example": "### 不能被判定为\"BigDecimal 的等值比较应使用 compareTo() 方法,而不是 equals() 方法\"的例子\n<例子1>\nBigDecimal a = new BigDecimal(\"1.0\");\nBigDecimal b = new BigDecimal(\"1.00\");\nif (a.compareTo(b) == 0) {\n // 这段代码会返回 true,因为 compareTo() 方法只比较数值\n}\n" + }, + { + "id": 37, + "text": "禁止在 POJO 类中,同时存在对应属性 xxx 的 isXxx() 和 getXxx() 方法。", + "detail": "缺陷类型:POJO 类中存在重复的 getter 方法;修复方案:确保只存在一个 getter 方法。", + "language": "Java", + "yes_example": "### 被判定为\"禁止在 POJO 类中,同时存在对应属性 xxx 的 isXxx() 和 getXxx() 方法\"的例子\n<例子1>\npublic class User {\n private boolean active;\n public boolean isActive() {\n return active;\n }\n public boolean getActive() {\n return active;\n }\n}\n", + "no_example": "### 不能被判定为\"禁止在 POJO 类中,同时存在对应属性 xxx 的 isXxx() 和 getXxx() 方法\"的例子\n<例子1>\npublic class User {\n private int age;\n public int getAge() {\n return age;\n }\n}\n" + }, + { + "id": 38, + "text": "日期格式化时,传入 pattern 中表示年份统一使用小写的 y。", + "detail": "缺陷类型:日期格式化错误;修复方案:使用小写的 y 表示年份。", + "language": "Java", + "yes_example": "### 被判定为\"日期格式化时,传入 pattern 中表示年份统一使用小写的 y\"的例子\n<例子1>\nSimpleDateFormat sdf = new SimpleDateFormat(\"YYYY-MM-dd\");\n", + "no_example": "### 不能被判定为\"日期格式化时,传入 pattern 中表示年份统一使用小写的 y\"的例子\n<例子1>\nSimpleDateFormat sdf = new SimpleDateFormat(\"yyyy-MM-dd\");\n" + }, + { + "id": 39, + "text": "禁止在程序任何地方中使用:1)java.sql.Date 2)java.sql.Time 3)java.sql.Timestamp。", + "detail": "缺陷类型:使用了 java.sql 包中的日期类;修复方案:使用 java.time 包中的日期类。", + "language": "Java", + "yes_example": "### 被判定为\"禁止在程序任何地方中使用:1)java.sql.Date 2)java.sql.Time 3)java.sql.Timestamp\"的例子\n<例子1>\njava.sql.Date sqlDate = new java.sql.Date(System.currentTimeMillis());\n", + "no_example": "### 不能被判定为\"禁止在程序任何地方中使用:1)java.sql.Date 2)java.sql.Time 3)java.sql.Timestamp\"的例子\n<例子1>\njava.time.LocalDate localDate = java.time.LocalDate.now();\n" + }, + { + "id": 40, + "text": "判断所有集合内部的元素是否为空,使用 isEmpty() 方法,而不是 size() == 0 的方式。", + "detail": "缺陷类型:集合判空方式错误;修复方案:使用 isEmpty() 方法。", + "language": "Java", + "yes_example": "### 被判定为\"判断所有集合内部的元素是否为空,使用 isEmpty() 方法,而不是 size() == 0 的方式\"的例子\n<例子1>\nList list = new ArrayList<>();\nif (list.size() == 0) {\n // 判空逻辑\n}\n", + "no_example": "### 不能被判定为\"判断所有集合内部的元素是否为空,使用 isEmpty() 方法,而不是 size() == 0 的方式\"的例子\n<例子1>\nList list = new ArrayList<>();\nif (list.isEmpty()) {\n // 判空逻辑\n}\n" + }, + { + "id": 41, + "text": "只要重写 equals,就必须重写 hashCode。", + "detail": "缺陷类型:未重写 hashCode 方法;修复方案:同时重写 equals 和 hashCode 方法。", + "language": "Java", + "yes_example": "### 被判定为\"只要重写 equals,就必须重写 hashCode\"的例子\n<例子1>\npublic class User {\n private String name;\n @Override\n public boolean equals(Object o) {\n if (this == o) return true;\n if (o == null || getClass() != o.getClass()) return false;\n User user = (User) o;\n return Objects.equals(name, user.name);\n }\n}\n", + "no_example": "### 不能被判定为\"只要重写 equals,就必须重写 hashCode\"的例子\n<例子1>\npublic class User {\n private String name;\n @Override\n public boolean equals(Object o) {\n if (this == o) return true;\n if (o == null || getClass() != o.getClass()) return false;\n User user = (User) o;\n return Objects.equals(name, user.name);\n }\n @Override\n public int hashCode() {\n return Objects.hash(name);\n }\n}\n" + }, + { + "id": 42, + "text": "使用 Map 的方法 keySet() / values() / entrySet() 返回集合对象时,不可以对其进行添加元素操作,否则会抛出 UnsupportedOperationException 异常。", + "detail": "缺陷类型:对 Map 的 keySet() / values() / entrySet() 返回的集合进行添加操作;修复方案:避免对这些集合进行添加操作。", + "language": "Java", + "yes_example": "### 被判定为\"使用 Map 的方法 keySet() / values() / entrySet() 返回集合对象时,不可以对其进行添加元素操作,否则会抛出 UnsupportedOperationException 异常\"的例子\n<例子1>\nMap map = new HashMap<>();\nmap.put(\"key1\", \"value1\");\nSet keys = map.keySet();\nkeys.add(\"key2\");\n", + "no_example": "### 不能被判定为\"使用 Map 的方法 keySet() / values() / entrySet() 返回集合对象时,不可以对其进行添加元素操作,否则会抛出 UnsupportedOperationException 异常\"的例子\n<例子1>\nMap map = new HashMap<>();\nmap.put(\"key1\", \"value1\");\nSet keys = map.keySet();\n// 不进行添加操作\n" + }, + { + "id": 43, + "text": "不要在 foreach 循环里进行元素的 remove / add 操作。remove 元素请使用 iterator 方式,如果并发操作,需要对 iterator", + "detail": "缺陷类型:在 foreach 循环中进行元素的 remove / add 操作;修复方案:使用 iterator 进行元素的 remove 操作。", + "language": "Java", + "yes_example": "### 被判定为\"不要在 foreach 循环里进行元素的 remove / add 操作。remove 元素请使用 iterator 方式,如果并发操作,需要对 iterator\"的例子\n<例子1>\nList list = new ArrayList<>(Arrays.asList(\"a\", \"b\", \"c\"));\nfor (String s : list) {\n if (s.equals(\"a\")) {\n list.remove(s);\n }\n}\n", + "no_example": "### 不能被判定为\"不要在 foreach 循环里进行元素的 remove / add 操作。remove 元素请使用 iterator 方式,如果并发操作,需要对 iterator\"的例子\n<例子1>\nList list = new ArrayList<>(Arrays.asList(\"a\", \"b\", \"c\"));\nIterator iterator = list.iterator();\nwhile (iterator.hasNext()) {\n String s = iterator.next();\n if (s.equals(\"a\")) {\n iterator.remove();\n }\n}\n" + }, + { + "id": 44, + "text": "类、类属性、类方法的注释必须使用 Javadoc 规范,使用 /** 内容 */ 格式,不得使用 // xxx方式。", + "detail": "缺陷类型:注释不符合 Javadoc 规范;修复方案:使用 Javadoc 规范的注释格式。", + "language": "Java", + "yes_example": "### 被判定为\"类、类属性、类方法的注释必须使用 Javadoc 规范,使用 /** 内容 */ 格式,不得使用 // xxx方式\"的例子\n<例子1>\npublic class Example {\n // 这是一个类注释\n private String name;\n // 这是一个属性注释\n public String getName() {\n return name;\n }\n // 这是一个方法注释\n}\n", + "no_example": "### 不能被判定为\"类、类属性、类方法的注释必须使用 Javadoc 规范,使用 /** 内容 */ 格式,不得使用 // xxx方式\"的例子\n<例子1>\n/**\n * 这是一个类注释\n */\npublic class Example {\n /**\n * 这是一个属性注释\n */\n private String name;\n /**\n * 这是一个方法注释\n */\n public String getName() {\n return name;\n }\n}\n" + }, + { + "id": 45, + "text": "所有的抽象方法(包括接口中的方法)必须要用 Javadoc 注释", + "detail": "缺陷类型:所有的抽象方法(包括接口中的方法)必须要用 Javadoc 注释;修复方案:为所有的抽象方法(包括接口中的方法)添加 Javadoc 注释,除了返回值、参数异常说明外,还必须指出该方法做什么事情,实现什么功能。", + "language": "Java", + "yes_example": "### 被判定为\"所有的抽象方法(包括接口中的方法)必须要用 Javadoc 注释\"的例子\n<例子1>\npublic interface MyInterface {\n void doSomething();\n}\n这段代码中的接口方法 doSomething() 没有 Javadoc 注释,所以被判定为缺少 Javadoc 注释。\n", + "no_example": "### 不能被判定为\"所有的抽象方法(包括接口中的方法)必须要用 Javadoc 注释\"的例子\n<例子1>\n/**\n * 执行某个操作\n * @param param 参数说明\n * @return 返回值说明\n * @throws Exception 异常说明\n */\npublic interface MyInterface {\n void doSomething(String param) throws Exception;\n}\n这段代码中的接口方法 doSomething() 有完整的 Javadoc 注释,所以不能被判定为缺少 Javadoc 注释。\n" + }, + { + "id": 46, + "text": "方法内部单行注释和多行注释的使用规范", + "detail": "缺陷类型:注释使用不规范;修复方案:方法内部单行注释,在被注释语句上方另起一行,使用 // 注释。方法内部多行注释使用 /* */注释,注意与代码对齐。", + "language": "Java", + "yes_example": "### 被判定为\"注释使用不规范\"的例子\n<例子1>\npublic void exampleMethod() {\n int a = 1; // 初始化变量a\n int b = 2; /* 初始化变量b */\n}\n这段代码中的单行注释和多行注释没有按照规范使用,所以被判定为注释使用不规范。\n", + "no_example": "### 不能被判定为\"注释使用不规范\"的例子\n<例子1>\npublic void exampleMethod() {\n // 初始化变量a\n int a = 1;\n /*\n * 初始化变量b\n */\n int b = 2;\n}\n这段代码中的单行注释和多行注释按照规范使用,所以不能被判定为注释使用不规范。\n" + }, + { + "id": 47, + "text": "所有的枚举类型字段必须要有注释", + "detail": "缺陷类型:枚举类型字段缺少注释;修复方案:为所有的枚举类型字段添加注释,说明每个数据项的用途。", + "language": "Java", + "yes_example": "### 被判定为\"枚举类型字段缺少注释\"的例子\n<例子1>\npublic enum Status {\n ACTIVE,\n INACTIVE\n}\n这段代码中的枚举类型字段没有注释,所以被判定为枚举类型字段缺少注释。\n", + "no_example": "### 不能被判定为\"枚举类型字段缺少注释\"的例子\n<例子1>\npublic enum Status {\n /**\n * 活跃状态\n */\n ACTIVE,\n /**\n * 非活跃状态\n */\n INACTIVE\n}\n这段代码中的枚举类型字段有注释,所以不能被判定为枚举类型字段缺少注释。\n" + }, + { + "id": 48, + "text": "finally 块必须对资源对象、流对象进行关闭", + "detail": "缺陷类型:资源对象、流对象未在 finally 块中关闭;修复方案:在 finally 块中对资源对象、流对象进行关闭,有异常也要做 try-catch。", + "language": "Java", + "yes_example": "### 被判定为\"资源对象、流对象未在 finally 块中关闭\"的例子\n<例子1>\npublic void readFile() {\n FileInputStream fis = null;\n try {\n fis = new FileInputStream(\"file.txt\");\n // 读取文件内容\n } catch (IOException e) {\n e.printStackTrace();\n }\n}\n这段代码中的 FileInputStream 对象没有在 finally 块中关闭,所以被判定为资源对象、流对象未在 finally 块中关闭。\n", + "no_example": "### 不能被判定为\"资源对象、流对象未在 finally 块中关闭\"的例子\n<例子1>\npublic void readFile() {\n FileInputStream fis = null;\n try {\n fis = new FileInputStream(\"file.txt\");\n // 读取文件内容\n } catch (IOException e) {\n e.printStackTrace();\n } finally {\n if (fis != null) {\n try {\n fis.close();\n } catch (IOException e) {\n e.printStackTrace();\n }\n }\n }\n}\n这段代码中的 FileInputStream 对象在 finally 块中关闭,所以不能被判定为资源对象、流对象未在 finally 块中关闭。\n" + }, + { + "id": 49, + "text": "常量命名应该全部大写,单词间用下划线隔开", + "detail": "缺陷类型:常量命名不规范;修复方案:常量命名应该全部大写,单词间用下划线隔开,力求语义表达完整清楚,不要嫌名字长。", + "language": "Java", + "yes_example": "### 被判定为\"常量命名应该全部大写,单词间用下划线隔开\"的例子\n<例子1>\npublic static final int maxCount = 100;\n", + "no_example": "### 不能被判定为\"常量命名应该全部大写,单词间用下划线隔开\"的例子\n<例子1>\npublic static final int MAX_COUNT = 100;\n" + }, + { + "id": 50, + "text": "任何二目、三目运算符的左右两边都需要加一个空格", + "detail": "缺陷类型:运算符两边缺少空格;修复方案:任何二目、三目运算符的左右两边都需要加一个空格。", + "language": "Java", + "yes_example": "### 被判定为\"任何二目、三目运算符的左右两边都需要加一个空格\"的例子\n<例子1>\nint a=b+c;\n", + "no_example": "### 不能被判定为\"任何二目、三目运算符的左右两边都需要加一个空格\"的例子\n<例子1>\nint a = b + c;\n" + }, + { + "id": 51, + "text": "避免使用from import *", + "detail": "缺陷类型:避免使用from import *,导入所有内容会造成命名冲突;修复方案:每个使用到的子依赖需分别导入。", + "language": "Python", + "yes_example": "### 被判定为\"避免使用from import *\"的例子\n<例子1>from math import * \n", + "no_example": "### 不能被判定为\"避免使用from import *\"的例子\n<例子1>from math import sqrt, pi \n" + }, + { + "id": 52, + "text": "避免使用__import__()函数动态导入模块", + "detail": "缺陷类型:避免使用__import__()函数动态导入模块;修复方案:使用标准的import语句。", + "language": "Python", + "yes_example": "### 被判定为\"使用__import__()函数动态导入模块\"的例子\n<例子1>module = __import__('math') \n", + "no_example": "### 不能被判定为\"使用__import__()函数动态导入模块\"的例子\n<例子1>import math \n" + }, + { + "id": 53, + "text": "导入语句未按标准库导入、相关第三方导入、本地应用/库特定导入的顺序分组", + "detail": "缺陷类型:导入语句未按标准库导入、相关第三方导入、本地应用/库特定导入的顺序分组;修复方案:按顺序分组导入语句。", + "language": "Python", + "yes_example": "### 被判定为'导入语句未按标准库导入、相关第三方导入、本地应用/库特定导入的顺序分组'的例子\n<例子1>\nimport numpy as np\nimport os\nimport sys\nfrom my_local_module import my_function\n在这个样例中,先导入了第三方库,然后导入了标准库。\n\n<例子2>\nfrom my_project import my_local_function\nimport datetime\nimport requests\n在这个样例中,先导入了本地模块,然后导入了标准库。\n\n<例子3>\nimport os\nfrom my_project.local_module import some_function\nimport pandas as pd\nimport sys\nfrom another_local_module import another_function\nimport math\n在这个样例中,导入语句完全混乱,没有遵循任何顺序。\n\n<例子4>\nimport os\nimport requests\nimport sys\nimport numpy as np\nfrom local_package import local_module\n在这个样例中,导入标准库和第三方库交替进行。\n", + "no_example": "### 不能被判定为'导入语句未按标准库导入、相关第三方导入、本地应用/库特定导入的顺序分组'的例子\n<例子1>import os \n\n import requests \n\n import mymodule \n" + }, + { + "id": 54, + "text": "避免未使用的函数形参", + "detail": "缺陷类型:避免未使用的函数形参;修复方案:移除未使用的函数形参。", + "language": "Python", + "yes_example": "### 被判定为'避免未使用的函数形参'的例子\n<例子1>def func(a, b): \n return a\n<例子2>def start_game(unused_param): \npuzzle = Puzzle() \npuzzle.solve()\n<例子3>def make_move(self, board):\npass \n\n<例子4>def move(self, direction):\npass \n", + "no_example": "### 不能被判定为'避免未使用的函数形参'的例子\n<例子1>def func(a): \n return a" + }, + { + "id": 55, + "text": "使用is not None来检查一个变量是否不是None", + "detail": "缺陷类型:未使用is not None来检查一个变量是否不是None;修复方案:使用is not None来检查。", + "language": "Python", + "yes_example": "### 被判定为'未使用is not None来检查一个变量是否不是None'的例子\n<例子1>if variable != None:\n pass", + "no_example": "### 不能被判定为'未使用is not None来检查一个变量是否不是None'的例子\n<例子1>if variable is not None:\n pass" + }, + { + "id": 56, + "text": "避免使用==或!=来比较对象实例的等价性", + "detail": "缺陷类型:使用==或!=来比较对象实例的等价性;修复方案:应使用equals比较。", + "language": "Python", + "yes_example": "### 被判定为'使用==或!=来比较对象实例的等价性'的例子\n<例子1>obj1 = MyClass() \n obj2 = MyClass() if obj1 == obj2: \n pass\n", + "no_example": "### 不能被判定为'使用==或!=来比较对象实例的等价性'的例子\n<例子1>obj1 = MyClass() \n obj2 = MyClass() if obj1.equals(obj2): \n pass\n\n<例子2>obj1 = 21 \n obj2 = 22 \n if obj1.equals(obj2):\n pass" + }, + { + "id": 57, + "text": "避免使用单字母变量名,使用描述性变量名", + "detail": "缺陷类型:避免使用单字母变量名,使用描述性变量名;修复方案:使用描述性变量名。", + "language": "Python", + "yes_example": "### 被判定为'避免使用单字母变量名,使用描述性变量名'的例子\n<例子1>x = 10 \n\n<例子2>y = 10 \n", + "no_example": "### 不能被判定为'避免使用单字母变量名,使用描述性变量名'的例子\n<例子1>count = 10 \n" + }, + { + "id": 58, + "text": "常量命名使用全大写字母,并用下划线分隔", + "detail": "缺陷类型:常量命名未使用全大写字母或未用下划线分隔;修复方案:常量命名使用全大写字母,并用下划线分隔。", + "language": "Python", + "yes_example": "### 被判定为'常量命名未使用全大写字母,并用下划线分隔'的例子\n<例子1>pi = 3.14159", + "no_example": "### 不能被判定为'常量命名未使用全大写字母,并用下划线分隔'的例子\n<例子1>PI = 3.14159\n<例子2>max_size = 1 \n max_size += 1" + }, + { + "id": 59, + "text": "类名应使用驼峰式命名(CamelCase)", + "detail": "缺陷类型:类名未使用驼峰式命名;修复方案:类名使用驼峰式命名。", + "language": "Python", + "yes_example": "### 被判定为'类名未使用驼峰式命名(CamelCase)'的例子\n<例子1>class my_class: \n pass\n<例子2>class my_class: \n def solve(self):\n pass", + "no_example": "### 不能被判定为'类名未使用驼峰式命名(CamelCase)'的例子\n<例子1>class MyClass: \n pass" + }, + { + "id": 60, + "text": "尽量使用with语句来管理资源", + "detail": "缺陷类型:未使用with语句来管理资源;修复方案:使用with语句来管理资源。", + "language": "Python", + "yes_example": "### 被判定为'未使用with语句来管理资源'的例子\n<例子1>file = open('file.txt', 'r') \n content = file.read() \n file.close()", + "no_example": "### 不能被判定为'未使用with语句来管理资源'的例子\n<例子1>with open('file.txt', 'r') as file: \n content = file.read()" + }, + { + "id": 61, + "text": "避免使用except 或 通用的Exception来捕获所有异常,应该指定异常类型", + "detail": "缺陷类型:捕获所有异常;修复方案:指定具体的异常类型。", + "language": "Python", + "yes_example": "### 被判定为'使用except:来捕获所有异常'的例子\n<例子1>try: \n # some code \n except: \n handle_error()\n### 被判定为'抛出通用的Exception异常'的例子\n<例子2>\n try:\n process_data(data) \n except: \n raise Exception('An error occurred') \n ", + "no_example": "### 不能被判定为'使用except:来捕获所有异常'的例子\n<例子1>try: \n # some code \n except ValueError: \n handle_value_error()" + }, + { + "id": 62, + "text": "尽量避免手动拼接字符串", + "detail": "缺陷类型:手动拼接字符串;修复方案:使用格式化字符串或join方法。", + "language": "Python", + "yes_example": "### 被判定为'手动拼接字符串'的例子\n<例子1>\n name = 'John' \n greeting = 'Hello, ' + name + '!' \n \n <例子2>greeting = '2048' + 'game' \n \n <例子3>pygame.display.set_caption('贪吃蛇' + '游戏')", + "no_example": "### 不能被判定为'手动拼接字符串'的例子\n<例子1>\n name = 'John' \n greeting = f'Hello, {name}!' \n" + }, + { + "id": 63, + "text": "避免出现魔法字符和数字,应声明为常量", + "detail": "缺陷类型:使用魔法字符和数字;修复方案:将其声明为常量。", + "language": "Python", + "yes_example": "### 被判定为'出现魔法字符和数字'的例子\n<例子1>\n if status == 1: \n print('Active')' \n\n<例子2>\n self.board = [[0] * 4 for _ in range(4)] \n self.score = 0\n<例子3>\ndef __init__(self, width=10, height=10, mines=15):\n\n<例子4>\nx, y = event.x // 20, event.y // 20\n\n<例子5>\nraise ValueError(\"余额不足\")\n\n<例子6>\ntransfer(bank, \"123\", \"456\", 200)\n\n<例子7>\nbank.add_account(Account(\"123\", 1000))\n", + "no_example": "### 不能被判定为'出现魔法字符和数字'的例子\n<例子1>\n ACTIVE_STATUS = 1 \n if status == ACTIVE_STATUS:\n print(ACTIVE_STATUS)' \n" + }, + { + "id": 64, + "text": "boolean变量判断无需显式比较", + "detail": "缺陷类型:显式比较boolean变量;修复方案:直接使用boolean变量进行判断。", + "language": "Python", + "yes_example": "### 被判定为'显式比较boolean变量'的例子\n<例子1>flag = True \n if flag == True: \n print('Flag is true')\n<例子2>if self.game.is_game_over() == True: \n return<例子3>if self.canvas.drawings ==True:", + "no_example": "### 不能被判定为'显式比较boolean变量'的例子\n<例子1>flag = True \n if flag: \n print('Flag is true') \n" + }, + { + "id": 65, + "text": "避免使用type()检查对象类型", + "detail": "缺陷类型:避免使用type()检查对象类型;修复方案:使用isinstance()函数。", + "language": "Python", + "yes_example": "### 被判定为'避免使用type()检查对象类型'的例子\n<例子1>\n if type(obj) == list: \n print('obj is a list')", + "no_example": "### 不能被判定为'避免使用type()检查对象类型'的例子\n<例子1>\n if isinstance(obj, list): \n print('obj is a list') \n" + }, + { + "id": 66, + "text": "避免使用os.system()来调用外部命令", + "detail": "缺陷类型:使用os.system()调用外部命令;修复方案:使用subprocess模块。", + "language": "Python", + "yes_example": "### 被判定为'使用os.system()来调用外部命令'的例子\n<例子1>os.system('ls -l')\n<例子2>os.system('ls -l')", + "no_example": "### 不能被判定为'使用os.system()来调用外部命令'的例子\n<例子1>import subprocess \n subprocess.run(['ls', '-l'])" + }, + { + "id": 67, + "text": "只使用@property装饰器创建只读属性,而非修改属性", + "detail": "缺陷类型:使用@property装饰器创建可修改属性;修复方案:只使用@property装饰器创建只读属性。", + "language": "Python", + "yes_example": "### 被判定为'使用@property装饰器来创建可修改属性'的例子\n<例子1>@property \n def value(self, new_value): \n self._value = new_value\n<例子2>@property \n def game_over(self): \n return self._is_game_over() \n def _is_game_over(self): \n pass", + "no_example": "### 不能被判定为'使用@property装饰器来创建可修改属性'的例子\n<例子1>@property \n def value(self): \n return self._value\n<例子2>@property \n def __str__(self): \n return 'Maze Game State'" + }, + { + "id": 68, + "text": "在使用索引或切片时,不要在方括号或冒号内加空格", + "detail": "缺陷类型:在索引或切片的方括号或冒号内加空格;修复方案:去掉方括号或冒号内的空格。", + "language": "Python", + "yes_example": "### 被判定为'在使用索引或切片时,在方括号或冒号内加空格'的例子\n<例子1>list = [1, 2, 3, 4] \n sublist = list[ 1 : 3 ]\n<例子2>start_point = self.canvas.drawings[ -1] \n<例子3>if head[ 0] < 0 or head[ 0] >= GRID_WIDTH or head[ 1] < 0 or head[ 1] >= GRID_HEIGHT:\n<例子4>for segment in self.snake[ 1:]:", + "no_example": "### 不能被判定为'在使用索引或切片时,在方括号或冒号内加空格'的例子\n<例子1>list = [1, 2, 3, 4] \n sublist = list[1:3]" + }, + { + "id": 69, + "text": "在逗号、分号或冒号前不要加空格,但在它们之后要加空格", + "detail": "缺陷类型:在逗号、分号或冒号前加空格或在它们之后不加空格;修复方案:在逗号、分号或冒号前不要加空格,但在它们之后要加空格。", + "language": "Python", + "yes_example": "### 被判定为'在逗号、分号或冒号前加空格,或没在它们之后加空格'的例子\n<例子1>if x == 4 : \n print(x , y)\n<例子2>if event.keysym == 'Up' or event.keysym == 'Down' or event.keysym == 'Left' or event.keysym == 'Right' :\n<例子3>x ,y = 1 ,2\n<例子4>def on_key_press(self , event) :\n<例子5>elif event.keysym == 'Down' ; \n<例子6>def update_status(self ,message: str) : \n pass ", + "no_example": "### 不能被判定为'在逗号、分号或冒号前加空格,或没在它们之后加空格'的例子\n<例子1>if x == 4: \n print(x, y)" + }, + { + "id": 70, + "text": "对于二元操作符,两边都应有空格", + "detail": "缺陷类型:二元操作符两边没有空格;修复方案:在二元操作符两边加空格", + "language": "Python", + "yes_example": "### 被判定为'二元操作符两边没有空格'的例子\n<例子1>a=b+1", + "no_example": "### 不能被判定为'二元操作符两边没有空格'的例子\n<例子1>a = b + 1\n<例子2>label = tk.Label(self.root, text=str(cell), bg='white')\n<例子3>label.grid(row=i, column=j)" + }, + { + "id": 71, + "text": "避免使用Python关键字作为变量名或函数名", + "detail": "缺陷类型:使用Python关键字作为变量名或函数名;修复方案:使用非关键字的名称。", + "language": "Python", + "yes_example": "### 被判定为'使用Python关键字作为变量名或函数名'的例子\n<例子1>def class(): \n pass\n<例子2>for = 5\n<例子3>def if(self): ", + "no_example": "### 不能被判定为'使用Python关键字作为变量名或函数名'的例子\n<例子1>def my_function(): \n pass\n<例子2>number = 5" + }, + { + "id": 72, + "text": "避免使用特殊字符作为变量名/方法名/类名,例如$或@", + "detail": "缺陷类型:使用特殊字符作为变量名/方法名/类名;修复方案:使用合法的变量名。", + "language": "Python", + "yes_example": "### 被判定为'使用特殊字符作为变量名/方法名/类名,例如$或@'的例子\n<例子1>my$var = 10\n<例子2>@var = 20\n<例子3>def add_score@(self, points): \n self.score += points\n<例子4>class @MyClass: \n pass\n<例子5>def mine@(self):", + "no_example": "### 不能被判定为'使用特殊字符作为变量名/方法名/类名,例如$或@'的例子\n<例子1>my_var = 10\n<例子2>var_20 = 20" + }, + { + "id": 73, + "text": "避免使用raise来重新抛出当前的异常,这会丢失原始的栈跟踪", + "detail": "缺陷类型:使用raise重新抛出当前异常;修复方案:使用raise ... from ...语法。", + "language": "Python", + "yes_example": "### 被判定为'避免使用raise来重新抛出当前的异常,这会丢失原始的栈跟踪'的例子\n<例子1>\n try: \n 1 / 0 \n except ZeroDivisionError: \n raise SomeException('新的异常信息')\n\n<例子2>\ntry:\n db.get_data()\nexcept ValueError as e:\n raise ValueError(\"Something went wrong!\")\n\n<例子3>\ntry:\n\traise Exception(\"形状添加失败\")\nexcept Exception as e:\n\tpass\n", + "no_example": "### 不能被判定为'避免使用raise来重新抛出当前的异常,这会丢失原始的栈跟踪'的例子\n<例子1>\n try: \n 1 / 0 \n except ZeroDivisionError as e: \n raise RuntimeError('Error occurred') from e \n\n<例子2>\n try: \n 1 / 0 \n except ZeroDivisionError as e: \n\tlogger.error(e)\n raise \n" + }, + { + "id": 74, + "text": "避免在except块中使用pass,这会捕获并忽略异常", + "detail": "缺陷类型:在except块中使用pass;修复方案:处理异常或记录日志。", + "language": "Python", + "yes_example": "### 被判定为'在except块中使用pass'的例子\n<例子1>\n try: \n 1 / 0 \n except ZeroDivisionError: \n pass \n \n<例子2>\n try: \n 1 / 0 \n except ZeroDivisionError: \n pass \n", + "no_example": "### 不能被判定为'在except块中使用pass'的例子\n<例子1>\n try: \n 1 / 0 \n except ZeroDivisionError as e: \n logging.error('Error occurred: %s', e) \n" + }, + { + "id": 75, + "text": "避免使用assert语句来执行重要的运行时检查", + "detail": "缺陷类型:使用assert语句执行重要的运行时检查;修复方案:使用显式的条件检查和异常处理。", + "language": "Python", + "yes_example": "### 被判定为'使用assert语句来执行重要的运行时检查'的例子\n<例子1>\n def divide(a, b): \n assert b != 0 \n return a / b \n", + "no_example": "### 不能被判定为'使用assert语句来执行重要的运行时检查'的例子\n<例子1>\n def divide(a, b): \n if b == 0: \n raise ValueError('b cannot be zero') \n return a / b \n" + }, + { + "id": 76, + "text": "避免使用eval()和exec(),这些函数可能会带来安全风险", + "detail": "缺陷类型:使用eval()和exec()函数;修复方案:使用安全的替代方案。", + "language": "Python", + "yes_example": "### 被判定为'使用eval()和exec()'的例子\n<例子1>\n eval('print(1)') \n\n<例子2> \n exec('a = 1') \n", + "no_example": "### 不能被判定为'使用eval()和exec()'的例子\n<例子1>\n compiled_code = compile('print(1)', '', 'exec') \n exec(compiled_code) \n" + }, + { + "id": 77, + "text": "避免使用sys.exit(),应使用异常来控制程序的退出", + "detail": "缺陷类型:避免使用sys.exit(),应使用异常来控制程序的退出;修复方案:使用异常来控制程序的退出。", + "language": "Python", + "yes_example": "### 被判定为'避免使用sys.exit(),应使用异常来控制程序的退出'的例子\n<例子1>\n import sys\nsys.exit(1)\n\n<例子2>\n import sys \n sys.exit()\n\n<例子3>\nif event.type == pygame.QUIT:\n\tpygame.quit()\n\texit()\n\n<例子4>\n import sys \n sys.exit('退出程序'))\n", + "no_example": "### 不能被判定为'避免使用sys.exit(),应使用异常来控制程序的退出'的例子\n<例子1>\n raise SystemExit(1)\n" + }, + { + "id": 78, + "text": "避免使用time.sleep()进行线程同步,应使用同步原语,如锁或事件", + "detail": "缺陷类型:使用time.sleep()进行线程同步;修复方案:使用同步原语。", + "language": "Python", + "yes_example": "### 被判定为'使用time.sleep()进行线程同步'的例子\n<例子1>\n import time \n\n def worker(): \n time.sleep(1) \n\n<例子2>\n import time \n\n time.sleep(1) \n", + "no_example": "### 不能被判定为'使用time.sleep()进行线程同步'的例子\n<例子1>\n import threading \n\n event = threading.Event() \n\n def worker(): \n event.wait()\n" + }, + { + "id": 79, + "text": "每行代码避免超过79个字符", + "detail": "缺陷类型:每行代码避免超过79个字符;修复方案:将长行代码格式化为多行。", + "language": "Python", + "yes_example": "### 被判定为'每行代码避免超过79个字符'的例子\n<例子1>\n print('This is a very long line of code that exceeds the 79 characters limit........') \n", + "no_example": "### 不能被判定为'每行代码避免超过79个字符'的例子\n<例子1>\n print('This is a very long line of code that exceeds the 79 characters limit' + \n ' but it is split into two lines')\n" + }, + { + "id": 80, + "text": "模块级别的函数和类定义之间用两个空行分隔,类内部的方法定义之间用一个空行分隔", + "detail": "缺陷类型:模块级别的函数和类定义之间没有用两个空行分隔,类内部的方法定义之间没有用一个空行分隔;修复方案:按照规范添加空行。", + "language": "Python", + "yes_example": "### 被判定为'模块级别的函数和类定义之间没用两个空行分隔,类内部的方法定义之间没用一个空行分隔'的例子\n<例子1>\n def func1(): \n pass \n def func2(): \n pass \n\n<例子2>\n class MyClass: \n def method1(self): \n pass \n def method2(self): \n pass \n", + "no_example": "### 不能被判定为'模块级别的函数和类定义之间没用两个空行分隔,类内部的方法定义之间没用一个空行分隔'的例子\n<例子1>\n def func1(): \n pass \n\n\n def func2(): \n pass \n\n<例子2>\n class MyClass: \n def method1(self): \n pass \n\n def method2(self): \n pass \n" + }, + { + "id": 81, + "text": "使用小写字母和下划线分隔的方式命名变量和函数名", + "detail": "缺陷类型:变量和函数命名不符合小写字母和下划线分隔的方式;修复方案:使用小写字母和下划线分隔的方式命名。", + "language": "Python", + "yes_example": "### 被判定为'未使用小写字母和下划线分隔的方式命名变量和函数'的例子\n<例子1>\n def myFunction(): \n pass \n\n<例子2>\n myVariable = 10 \n\n<例子3>\n def Calculatesquareroot(self, x): \n return 1 \n", + "no_example": "### 不能被判定为'未使用小写字母和下划线分隔的方式命名变量和函数'的例子\n<例子1>\n def my_function(): \n pass \n\n<例子2>\n my_variable = 10 \n" + }, + { + "id": 82, + "text": "不允许使用print()函数来记录日志,使用logging模块等来记录日志", + "detail": "缺陷类型:使用print()函数记录日志;修复方案:使用logging模块记录日志。", + "language": "Python", + "yes_example": "### 被判定为'使用print()函数来记录日志'的例子\n<例子1>\n print('Error occurred') \n\n<例子2>\n print('打印的日志字符串内容') \n\n<例子3>\n task = 'xxx' \n print(task) \n\n<例子4>\n print(1)\n", + "no_example": "### 不能被判定为'使用print()函数来记录日志'的例子\n<例子1>\n import logging \n logging.error('Error occurred') \n" + } +] diff --git a/metagpt/ext/cr/utils/__init__.py b/metagpt/ext/cr/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/metagpt/ext/cr/utils/__pycache__/__init__.cpython-310.pyc b/metagpt/ext/cr/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09e67dccce8a432296cb07ccb8d74c97c2d6a900 Binary files /dev/null and b/metagpt/ext/cr/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/metagpt/ext/cr/utils/__pycache__/__init__.cpython-39.pyc b/metagpt/ext/cr/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c9ac04ff31c3a5629f624b06f65b898b67f05b8 Binary files /dev/null and b/metagpt/ext/cr/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/metagpt/ext/cr/utils/__pycache__/cleaner.cpython-310.pyc b/metagpt/ext/cr/utils/__pycache__/cleaner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dba8f37cfe025da1c6d22c0595ffb6e9da1817e3 Binary files /dev/null and b/metagpt/ext/cr/utils/__pycache__/cleaner.cpython-310.pyc differ diff --git a/metagpt/ext/cr/utils/__pycache__/cleaner.cpython-39.pyc b/metagpt/ext/cr/utils/__pycache__/cleaner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eeab255e7bb260795a57a16fa45ddac64e1df93f Binary files /dev/null and b/metagpt/ext/cr/utils/__pycache__/cleaner.cpython-39.pyc differ diff --git a/metagpt/ext/cr/utils/__pycache__/schema.cpython-310.pyc b/metagpt/ext/cr/utils/__pycache__/schema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c45cb7d8a847334097502f797921c16226dc803d Binary files /dev/null and b/metagpt/ext/cr/utils/__pycache__/schema.cpython-310.pyc differ diff --git a/metagpt/ext/cr/utils/__pycache__/schema.cpython-39.pyc b/metagpt/ext/cr/utils/__pycache__/schema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b165068623846c4f51bb4f71036027ae196c9b43 Binary files /dev/null and b/metagpt/ext/cr/utils/__pycache__/schema.cpython-39.pyc differ diff --git a/metagpt/ext/cr/utils/cleaner.py b/metagpt/ext/cr/utils/cleaner.py new file mode 100644 index 0000000000000000000000000000000000000000..8fc0b798ca39e795f852fe783002f6a58f1c3ae8 --- /dev/null +++ b/metagpt/ext/cr/utils/cleaner.py @@ -0,0 +1,68 @@ +"""Cleaner.""" + +from unidiff import Hunk, PatchedFile, PatchSet + +from metagpt.logs import logger + + +def rm_patch_useless_part(patch: PatchSet, used_suffix: list[str] = ["java", "py"]) -> PatchSet: + new_patch = PatchSet("") + useless_files = [] + for pfile in patch: + suffix = str(pfile.target_file).split(".")[-1] + if suffix not in used_suffix or pfile.is_removed_file: + useless_files.append(pfile.path) + continue + new_patch.append(pfile) + logger.info(f"total file num: {len(patch)}, used file num: {len(new_patch)}, useless_files: {useless_files}") + return new_patch + + +def add_line_num_on_patch(patch: PatchSet, start_line_num: int = 1) -> PatchSet: + new_patch = PatchSet("") + lineno = start_line_num + for pfile in patch: + new_pfile = PatchedFile( + source=pfile.source_file, + target=pfile.target_file, + source_timestamp=pfile.source_timestamp, + target_timestamp=pfile.target_timestamp, + ) + for hunk in pfile: + arr = [str(line) for line in hunk] + new_hunk = Hunk( + src_start=hunk.source_start, + src_len=hunk.source_length, + tgt_start=hunk.target_start, + tgt_len=hunk.target_length, + section_header=hunk.section_header, + ) + + for line in arr: + # if len(line) > 0 and line[0] in ["+", "-"]: + # line = f"{lineno} {line}" + # lineno += 1 + line = f"{lineno} {line}" + lineno += 1 + new_hunk.append(line) + new_pfile.append(new_hunk) + new_patch.append(new_pfile) + return new_patch + + +def get_code_block_from_patch(patch: PatchSet, code_start_line: str, code_end_line: str) -> str: + line_arr = str(patch).split("\n") + code_arr = [] + add_line_tag = False + for line in line_arr: + if line.startswith(f"{code_start_line} "): + add_line_tag = True + + if add_line_tag: + new_line = " ".join(line.split(" ")[1:]) # rm line-no tag + code_arr.append(new_line) + + if line.startswith(f"{code_end_line} "): + add_line_tag = False + + return "\n".join(code_arr) diff --git a/metagpt/ext/cr/utils/schema.py b/metagpt/ext/cr/utils/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..beb27a07f9e4606175eb97fbd983aa707198b764 --- /dev/null +++ b/metagpt/ext/cr/utils/schema.py @@ -0,0 +1,20 @@ +from typing import Literal + +from pydantic import BaseModel, Field + + +class Point(BaseModel): + id: int = Field(default=0, description="ID of the point.") + text: str = Field(default="", description="Content of the point.") + language: Literal["Python", "Java"] = Field( + default="Python", description="The programming language that the point corresponds to." + ) + file_path: str = Field(default="", description="The file that the points come from.") + start_line: int = Field(default=0, description="The starting line number that the point refers to.") + end_line: int = Field(default=0, description="The ending line number that the point refers to.") + detail: str = Field(default="", description="File content from start_line to end_line.") + yes_example: str = Field(default="", description="yes of point examples") + no_example: str = Field(default="", description="no of point examples") + + def rag_key(self) -> str: + return self.text diff --git a/metagpt/ext/sela/.DS_Store b/metagpt/ext/sela/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..cfa820e329ec9bfd67d4e22e80713bc7ac0b4404 Binary files /dev/null and b/metagpt/ext/sela/.DS_Store differ diff --git a/metagpt/ext/sela/README.md b/metagpt/ext/sela/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6fb47b42cdc20469f0b4e4df182b59b2fd72d622 --- /dev/null +++ b/metagpt/ext/sela/README.md @@ -0,0 +1,106 @@ +# SELA: Tree-Search Enhanced LLM Agents for Automated Machine Learning + + +Official implementation for paper [SELA: Tree-Search Enhanced LLM Agents for Automated Machine Learning](https://arxiv.org/abs/2410.17238). + + +SELA is an innovative system that enhances Automated Machine Learning (AutoML) by integrating Monte Carlo Tree Search (MCTS) with LLM-based agents. Traditional AutoML methods often generate low-diversity and suboptimal code, limiting their effectiveness in model selection and ensembling. SELA addresses these challenges by representing pipeline configurations as trees, enabling agents to intelligently explore the solution space and iteratively refine their strategies based on experimental feedback. + +## 1. Data Preparation + +You can either download the datasets from the link or prepare the datasets from scratch. +- **Download Datasets:** [Dataset Link](https://drive.google.com/drive/folders/151FIZoLygkRfeJgSI9fNMiLsixh1mK0r?usp=sharing) +- **Download and prepare datasets from scratch:** + ```bash + cd data + python dataset.py --save_analysis_pool + python hf_data.py --save_analysis_pool + ``` + +## 2. Configurations + +### Data Config + +- **`datasets.yaml`:** Provide base prompts, metrics, and target columns for respective datasets. +- **`data.yaml`:** Modify `datasets_dir` to the base directory of all prepared datasets. + +### LLM Config + +```yaml +llm: + api_type: 'openai' + model: deepseek-coder + base_url: "https://your_base_url" + api_key: sk-xxx + temperature: 0.5 +``` + + +## 3. SELA + +### Run SELA + +#### Setup + +```bash +pip install -e . + +cd metagpt/ext/sela + +pip install -r requirements.txt +``` + +#### Running Experiments + +- **Examples:** + ```bash + python run_experiment.py --exp_mode mcts --task titanic --rollouts 10 + python run_experiment.py --exp_mode mcts --task house-prices --rollouts 10 --low_is_better + ``` + +#### Parameters + +- **`--rollouts`:** The number of rollouts. +- **`--use_fixed_insights`:** Include fixed insights saved in `expo/insights/fixed_insights.json`. +- **`--low_is_better`:** Use this if the dataset has a regression metric. +- **`--from_scratch`:** Generate a new insight pool based on the dataset before running MCTS. +- **`--role_timeout`:** Limits the duration of a single simulation (e.g., `10 rollouts with timeout 1,000` = max 10,000s). +- **`--max_depth`:** Set the maximum depth of MCTS (default is 4). +- **`--load_tree`:** Load an existing MCTS tree if the previous experiment was interrupted. + - Example: + ```bash + python run_experiment.py --exp_mode mcts --task titanic --rollouts 10 + ``` + - To resume: + ```bash + python run_experiment.py --exp_mode mcts --task titanic --rollouts 7 --load_tree + ``` + +### Ablation Study + +**RandomSearch** + +- **Use a single insight:** + ```bash + python run_experiment.py --exp_mode rs --task titanic --rs_mode single + ``` + +- **Use a set of insights:** + ```bash + python run_experiment.py --exp_mode rs --task titanic --rs_mode set + ``` + +## 4. Citation +Please cite our paper if you use SELA or find it cool or useful! + +```bibtex +@misc{chi2024selatreesearchenhancedllm, + title={SELA: Tree-Search Enhanced LLM Agents for Automated Machine Learning}, + author={Yizhou Chi and Yizhang Lin and Sirui Hong and Duyi Pan and Yaying Fei and Guanghao Mei and Bangbang Liu and Tianqi Pang and Jacky Kwok and Ceyao Zhang and Bang Liu and Chenglin Wu}, + year={2024}, + eprint={2410.17238}, + archivePrefix={arXiv}, + primaryClass={cs.AI}, + url={https://arxiv.org/abs/2410.17238}, +} +``` diff --git a/metagpt/ext/sela/data.yaml b/metagpt/ext/sela/data.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7da5dbb3cd73a69708ce533eba014f77fa937850 --- /dev/null +++ b/metagpt/ext/sela/data.yaml @@ -0,0 +1,3 @@ +datasets_dir: "path/to/datasets" # path to the datasets directory +work_dir: ../../../workspace # path to the workspace directory +role_dir: storage/SELA # path to the role directory \ No newline at end of file diff --git a/metagpt/ext/sela/data/custom_task.py b/metagpt/ext/sela/data/custom_task.py new file mode 100644 index 0000000000000000000000000000000000000000..08a7cbabb08f14dcb3e815616e6db2d2e1577c5b --- /dev/null +++ b/metagpt/ext/sela/data/custom_task.py @@ -0,0 +1,74 @@ +import os + +from metagpt.ext.sela.data.dataset import SPECIAL_INSTRUCTIONS +from metagpt.ext.sela.runner.mle_bench.instructions import ( + ADDITIONAL_NOTES, + INSTRUCTIONS, + INSTRUCTIONS_OBFUSCATED, +) + +MLE_BENCH_FILES = ["description.md", "description_obfuscated.md"] + + +MLE_REQUIREMENTS = """ +{instructions} + +{additonal_notes} + +COMPETITION INSTRUCTIONS +------ + +{task_description} + +## More Instructions +- You should split the training data into train and dev set with a seed of 42. +- You should use the dev set to improve your model. Print the final dev set score after training. +- output_dir: {output_dir} +- Besides `submission.csv`, you should also save your `test_predictions.csv` and `dev_predictions.csv` in the output directory. +- Note that `test_predictions.csv` should be identical to `submission.csv`. +- Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. {special_instruction} +**Do not make any plots or visualizations.** +""" + + +def get_mle_task_id(dataset_dir): + return dataset_dir.split("/")[-3] + + +def get_mle_is_lower_better(task): + from mlebench.data import get_leaderboard + from mlebench.registry import registry + + competition = registry.get_competition(task) + competition_leaderboard = get_leaderboard(competition) + return competition.grader.is_lower_better(competition_leaderboard) + + +def get_mle_bench_requirements(dataset_dir, data_config, special_instruction, obfuscated=False): + work_dir = data_config["work_dir"] + task = get_mle_task_id(dataset_dir) + output_dir = f"{work_dir}/{task}" + final_output_dir = f"{work_dir}/submission" + os.makedirs(output_dir, exist_ok=True) + if special_instruction: + special_instruction = SPECIAL_INSTRUCTIONS[special_instruction] + else: + special_instruction = "" + if obfuscated: + instructions = INSTRUCTIONS_OBFUSCATED.format(dataset_dir=dataset_dir, output_dir=final_output_dir) + task_file = "description_obfuscated.md" + else: + instructions = INSTRUCTIONS.format(dataset_dir=dataset_dir, output_dir=output_dir) + task_file = "description.md" + + with open(os.path.join(dataset_dir, task_file), encoding="utf-8") as f: + task_description = f.read() + mle_requirement = MLE_REQUIREMENTS.format( + instructions=instructions, + additonal_notes=ADDITIONAL_NOTES, + task_description=task_description, + output_dir=output_dir, + special_instruction=special_instruction, + ) + print(mle_requirement) + return mle_requirement diff --git a/metagpt/ext/sela/data/dataset.py b/metagpt/ext/sela/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ef41790117395455ff5e649f4cfa25bd2da31769 --- /dev/null +++ b/metagpt/ext/sela/data/dataset.py @@ -0,0 +1,395 @@ +import argparse +import asyncio +import json +import os +from pathlib import Path + +import openml +import pandas as pd +import yaml +from sklearn.model_selection import train_test_split + +from metagpt.ext.sela.insights.solution_designer import SolutionDesigner +from metagpt.ext.sela.utils import DATA_CONFIG + +BASE_USER_REQUIREMENT = """ +This is a {datasetname} dataset. Your goal is to predict the target column `{target_col}`. +Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. +Report {metric} on the eval data. Do not plot or make any visualizations. +""" + +USE_AG = """ +- Please use autogluon for model training with presets='medium_quality', time_limit=None, give dev dataset to tuning_data, and use right eval_metric. +""" + +TEXT_MODALITY = """ +- You could use models from transformers library for this text dataset. +- Use gpu if available for faster training. +""" + +IMAGE_MODALITY = """ +- You could use models from transformers/torchvision library for this image dataset. +- Use gpu if available for faster training. +""" + +STACKING = """ +- To avoid overfitting, train a weighted ensemble model such as StackingClassifier or StackingRegressor. +- You could do some quick model prototyping to see which models work best and then use them in the ensemble. +""" + + +SPECIAL_INSTRUCTIONS = {"ag": USE_AG, "stacking": STACKING, "text": TEXT_MODALITY, "image": IMAGE_MODALITY} + +DI_INSTRUCTION = """ +## Attention +1. Please do not leak the target label in any form during training. +2. Test set does not have the target column. +3. When conducting data exploration or analysis, print out the results of your findings. +4. You should perform transformations on train, dev, and test sets at the same time (it's a good idea to define functions for this and avoid code repetition). +5. When scaling or transforming features, make sure the target column is not included. +6. You could utilize dev set to validate and improve model training. {special_instruction} + +## Saving Dev and Test Predictions +1. Save the prediction results of BOTH the dev set and test set in `dev_predictions.csv` and `test_predictions.csv` respectively in the output directory. +- Both files should contain a single column named `target` with the predicted values. +2. Make sure the prediction results are in the same format as the target column in the original training set. +- For instance, if the original target column is a list of string, the prediction results should also be strings. + +## Output Performance +Print the train and dev set performance in the last step. + +# Output dir +{output_dir} +""" + +TASK_PROMPT = """ +# User requirement +{user_requirement} +{additional_instruction} +# Data dir +train set (with labels): {train_path} +dev set (with labels): {dev_path} +test set (without labels): {test_path} +dataset description: {data_info_path} (During EDA, you can use this file to get additional information about the dataset) +""" + + +SEED = 100 +TRAIN_TEST_SPLIT = 0.8 +TRAIN_DEV_SPLIT = 0.75 + +OPENML_DATASET_IDS = [ + # reg + 41021, + 42727, + 41980, + 42225, + 531, + # cls + 41143, + 31, + 42733, + 41162, + 1067, + # multi cls + 40498, + 40982, + 12, + 40984, + 4538, +] + +CUSTOM_DATASETS = [ + ("04_titanic", "Survived"), + ("05_house-prices-advanced-regression-techniques", "SalePrice"), + ("06_santander-customer-transaction-prediction", "target"), + ("07_icr-identify-age-related-conditions", "Class"), +] + +DSAGENT_DATASETS = [("concrete-strength", "Strength"), ("smoker-status", "smoking"), ("software-defects", "defects")] + + +def get_split_dataset_path(dataset_name, config): + datasets_dir = config["datasets_dir"] + if dataset_name in config["datasets"]: + dataset = config["datasets"][dataset_name] + data_path = os.path.join(datasets_dir, dataset["dataset"]) + split_datasets = { + "train": os.path.join(data_path, "split_train.csv"), + "dev": os.path.join(data_path, "split_dev.csv"), + "dev_wo_target": os.path.join(data_path, "split_dev_wo_target.csv"), + "dev_target": os.path.join(data_path, "split_dev_target.csv"), + "test": os.path.join(data_path, "split_test.csv"), + "test_wo_target": os.path.join(data_path, "split_test_wo_target.csv"), + "test_target": os.path.join(data_path, "split_test_target.csv"), + } + return split_datasets + else: + raise ValueError( + f"Dataset {dataset_name} not found in config file. Available datasets: {config['datasets'].keys()}" + ) + + +def get_user_requirement(task_name, config): + # datasets_dir = config["datasets_dir"] + if task_name in config["datasets"]: + dataset = config["datasets"][task_name] + # data_path = os.path.join(datasets_dir, dataset["dataset"]) + user_requirement = dataset["user_requirement"] + return user_requirement + else: + raise ValueError( + f"Dataset {task_name} not found in config file. Available datasets: {config['datasets'].keys()}" + ) + + +def save_datasets_dict_to_yaml(datasets_dict, name="datasets.yaml"): + with open(name, "w") as file: + yaml.dump(datasets_dict, file) + + +def create_dataset_dict(dataset): + dataset_dict = { + "dataset": dataset.name, + "user_requirement": dataset.create_base_requirement(), + "metric": dataset.get_metric(), + "target_col": dataset.target_col, + } + return dataset_dict + + +def generate_di_instruction(output_dir, special_instruction): + if special_instruction: + special_instruction_prompt = SPECIAL_INSTRUCTIONS[special_instruction] + else: + special_instruction_prompt = "" + additional_instruction = DI_INSTRUCTION.format( + output_dir=output_dir, special_instruction=special_instruction_prompt + ) + return additional_instruction + + +def generate_task_requirement(task_name, data_config, is_di=True, special_instruction=None): + user_requirement = get_user_requirement(task_name, data_config) + split_dataset_path = get_split_dataset_path(task_name, data_config) + train_path = split_dataset_path["train"] + dev_path = split_dataset_path["dev"] + test_path = split_dataset_path["test_wo_target"] + work_dir = data_config["work_dir"] + output_dir = f"{work_dir}/{task_name}" + datasets_dir = data_config["datasets_dir"] + data_info_path = f"{datasets_dir}/{task_name}/dataset_info.json" + if is_di: + additional_instruction = generate_di_instruction(output_dir, special_instruction) + else: + additional_instruction = "" + user_requirement = TASK_PROMPT.format( + user_requirement=user_requirement, + train_path=train_path, + dev_path=dev_path, + test_path=test_path, + additional_instruction=additional_instruction, + data_info_path=data_info_path, + ) + print(user_requirement) + return user_requirement + + +class ExpDataset: + description: str = None + metadata: dict = None + dataset_dir: str = None + target_col: str = None + name: str = None + + def __init__(self, name, dataset_dir, **kwargs): + self.name = name + self.dataset_dir = dataset_dir + self.target_col = kwargs.get("target_col", None) + self.force_update = kwargs.get("force_update", False) + self.save_dataset(target_col=self.target_col) + + def check_dataset_exists(self): + fnames = [ + "split_train.csv", + "split_dev.csv", + "split_test.csv", + "split_dev_wo_target.csv", + "split_dev_target.csv", + "split_test_wo_target.csv", + "split_test_target.csv", + ] + for fname in fnames: + if not os.path.exists(Path(self.dataset_dir, self.name, fname)): + return False + return True + + def check_datasetinfo_exists(self): + return os.path.exists(Path(self.dataset_dir, self.name, "dataset_info.json")) + + def get_raw_dataset(self): + raw_dir = Path(self.dataset_dir, self.name, "raw") + train_df = None + test_df = None + if not os.path.exists(Path(raw_dir, "train.csv")): + raise FileNotFoundError(f"Raw dataset `train.csv` not found in {raw_dir}") + else: + train_df = pd.read_csv(Path(raw_dir, "train.csv")) + if os.path.exists(Path(raw_dir, "test.csv")): + test_df = pd.read_csv(Path(raw_dir, "test.csv")) + return train_df, test_df + + def get_dataset_info(self): + raw_df = pd.read_csv(Path(self.dataset_dir, self.name, "raw", "train.csv")) + metadata = { + "NumberOfClasses": raw_df[self.target_col].nunique(), + "NumberOfFeatures": raw_df.shape[1], + "NumberOfInstances": raw_df.shape[0], + "NumberOfInstancesWithMissingValues": int(raw_df.isnull().any(axis=1).sum()), + "NumberOfMissingValues": int(raw_df.isnull().sum().sum()), + "NumberOfNumericFeatures": raw_df.select_dtypes(include=["number"]).shape[1], + "NumberOfSymbolicFeatures": raw_df.select_dtypes(include=["object"]).shape[1], + } + + df_head_text = self.get_df_head(raw_df) + + dataset_info = { + "name": self.name, + "description": "", + "target_col": self.target_col, + "metadata": metadata, + "df_head": df_head_text, + } + return dataset_info + + def get_df_head(self, raw_df): + return raw_df.head().to_string(index=False) + + def get_metric(self): + dataset_info = self.get_dataset_info() + num_classes = dataset_info["metadata"]["NumberOfClasses"] + if num_classes == 2: + metric = "f1 binary" + elif 2 < num_classes <= 200: + metric = "f1 weighted" + elif num_classes > 200 or num_classes == 0: + metric = "rmse" + else: + raise ValueError(f"Number of classes {num_classes} not supported") + return metric + + def create_base_requirement(self): + metric = self.get_metric() + req = BASE_USER_REQUIREMENT.format(datasetname=self.name, target_col=self.target_col, metric=metric) + return req + + def save_dataset(self, target_col): + df, test_df = self.get_raw_dataset() + if not self.check_dataset_exists() or self.force_update: + print(f"Saving Dataset {self.name} in {self.dataset_dir}") + self.split_and_save(df, target_col, test_df=test_df) + else: + print(f"Dataset {self.name} already exists") + if not self.check_datasetinfo_exists() or self.force_update: + print(f"Saving Dataset info for {self.name}") + dataset_info = self.get_dataset_info() + self.save_datasetinfo(dataset_info) + else: + print(f"Dataset info for {self.name} already exists") + + def save_datasetinfo(self, dataset_info): + with open(Path(self.dataset_dir, self.name, "dataset_info.json"), "w", encoding="utf-8") as file: + # utf-8 encoding is required + json.dump(dataset_info, file, indent=4, ensure_ascii=False) + + def save_split_datasets(self, df, split, target_col=None): + path = Path(self.dataset_dir, self.name) + df.to_csv(Path(path, f"split_{split}.csv"), index=False) + if target_col: + df_wo_target = df.drop(columns=[target_col]) + df_wo_target.to_csv(Path(path, f"split_{split}_wo_target.csv"), index=False) + df_target = df[[target_col]].copy() + if target_col != "target": + df_target["target"] = df_target[target_col] + df_target = df_target.drop(columns=[target_col]) + df_target.to_csv(Path(path, f"split_{split}_target.csv"), index=False) + + def split_and_save(self, df, target_col, test_df=None): + if not target_col: + raise ValueError("Target column not provided") + if test_df is None: + train, test = train_test_split(df, test_size=1 - TRAIN_TEST_SPLIT, random_state=SEED) + else: + train = df + test = test_df + train, dev = train_test_split(train, test_size=1 - TRAIN_DEV_SPLIT, random_state=SEED) + self.save_split_datasets(train, "train") + self.save_split_datasets(dev, "dev", target_col) + self.save_split_datasets(test, "test", target_col) + + +class OpenMLExpDataset(ExpDataset): + def __init__(self, name, dataset_dir, dataset_id, **kwargs): + self.dataset_id = dataset_id + self.dataset = openml.datasets.get_dataset( + self.dataset_id, download_data=False, download_qualities=False, download_features_meta_data=True + ) + self.name = self.dataset.name + self.target_col = self.dataset.default_target_attribute + super().__init__(self.name, dataset_dir, target_col=self.target_col, **kwargs) + + def get_raw_dataset(self): + dataset = self.dataset + dataset_df, *_ = dataset.get_data() + raw_dir = Path(self.dataset_dir, self.name, "raw") + os.makedirs(raw_dir, exist_ok=True) + dataset_df.to_csv(Path(raw_dir, "train.csv"), index=False) + return dataset_df, None + + def get_dataset_info(self): + dataset_info = super().get_dataset_info() + dataset = self.dataset + dataset_info["name"] = dataset.name + dataset_info["description"] = dataset.description + dataset_info["metadata"].update(dataset.qualities) + return dataset_info + + +async def process_dataset(dataset, solution_designer: SolutionDesigner, save_analysis_pool, datasets_dict): + if save_analysis_pool: + await solution_designer.generate_solutions(dataset.get_dataset_info(), dataset.name) + dataset_dict = create_dataset_dict(dataset) + datasets_dict["datasets"][dataset.name] = dataset_dict + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--force_update", action="store_true", help="Force update datasets") + parser.add_argument("--save_analysis_pool", action="store_true", help="Save analysis pool") + parser.add_argument( + "--no_save_analysis_pool", dest="save_analysis_pool", action="store_false", help="Do not save analysis pool" + ) + parser.set_defaults(save_analysis_pool=True) + return parser.parse_args() + + +if __name__ == "__main__": + datasets_dir = DATA_CONFIG["datasets_dir"] + args = parse_args() + force_update = args.force_update + save_analysis_pool = args.save_analysis_pool + datasets_dict = {"datasets": {}} + solution_designer = SolutionDesigner() + for dataset_id in OPENML_DATASET_IDS: + openml_dataset = OpenMLExpDataset("", datasets_dir, dataset_id, force_update=force_update) + asyncio.run(process_dataset(openml_dataset, solution_designer, save_analysis_pool, datasets_dict)) + + for dataset_name, target_col in CUSTOM_DATASETS: + custom_dataset = ExpDataset(dataset_name, datasets_dir, target_col=target_col, force_update=force_update) + asyncio.run(process_dataset(custom_dataset, solution_designer, save_analysis_pool, datasets_dict)) + + for dataset_name, target_col in DSAGENT_DATASETS: + custom_dataset = ExpDataset(dataset_name, datasets_dir, target_col=target_col, force_update=force_update) + asyncio.run(process_dataset(custom_dataset, solution_designer, save_analysis_pool, datasets_dict)) + + save_datasets_dict_to_yaml(datasets_dict) diff --git a/metagpt/ext/sela/data/hf_data.py b/metagpt/ext/sela/data/hf_data.py new file mode 100644 index 0000000000000000000000000000000000000000..9645796af5ff0d587024697364ca34de0bbcb882 --- /dev/null +++ b/metagpt/ext/sela/data/hf_data.py @@ -0,0 +1,140 @@ +import asyncio +import io +import os +from pathlib import Path + +import pandas as pd +from datasets import load_dataset +from PIL import Image + +from metagpt.ext.sela.data.dataset import ( + ExpDataset, + parse_args, + process_dataset, + save_datasets_dict_to_yaml, +) +from metagpt.ext.sela.insights.solution_designer import SolutionDesigner +from metagpt.ext.sela.utils import DATA_CONFIG + +HFDATSETS = [ + {"name": "sms_spam", "dataset_name": "ucirvine/sms_spam", "target_col": "label", "modality": "text"}, + {"name": "banking77", "dataset_name": "PolyAI/banking77", "target_col": "label", "modality": "text"}, + {"name": "gnad10", "dataset_name": "community-datasets/gnad10", "target_col": "label", "modality": "text"}, + { + "name": "oxford-iiit-pet", + "dataset_name": "timm/oxford-iiit-pet", + "image_col": "image", + "target_col": "label", + "modality": "image", + }, + { + "name": "stanford_cars", + "dataset_name": "tanganke/stanford_cars", + "image_col": "image", + "target_col": "label", + "modality": "image", + }, + { + "name": "fashion_mnist", + "dataset_name": "zalando-datasets/fashion_mnist", + "image_col": "image", + "target_col": "label", + "modality": "image", + }, +] + + +class HFExpDataset(ExpDataset): + train_ratio = 0.6 + dev_ratio = 0.2 + test_ratio = 0.2 + + def __init__(self, name, dataset_dir, dataset_name, **kwargs): + self.name = name + self.dataset_dir = dataset_dir + self.dataset_name = dataset_name + self.modality = kwargs.get("modality", "") + self.target_col = kwargs.get("target_col", "label") + self.image_col = kwargs.get("image_col", "image") + self.dataset = load_dataset(self.dataset_name, trust_remote_code=True) + super().__init__(self.name, dataset_dir, **kwargs) + + def get_raw_dataset(self): + raw_dir = Path(self.dataset_dir, self.name, "raw") + raw_dir.mkdir(parents=True, exist_ok=True) + + if os.path.exists(Path(raw_dir, "train.csv")): + df = pd.read_csv(Path(raw_dir, "train.csv"), encoding="utf-8") + else: + df = self.dataset["train"].to_pandas() + + if self.modality == "image": + df = self.save_images_and_update_df(df, raw_dir, "train") + + df.to_csv(Path(raw_dir, "train.csv"), index=False, encoding="utf-8") + + if os.path.exists(Path(raw_dir, "test.csv")): + test_df = pd.read_csv(Path(raw_dir, "test.csv"), encoding="utf-8") + else: + if self.dataset and "test" in self.dataset: + test_df = self.dataset["test"].to_pandas() + + if self.modality == "image": + test_df = self.save_images_and_update_df(test_df, raw_dir, "test") + + test_df.to_csv(Path(raw_dir, "test.csv"), index=False, encoding="utf-8") + else: + test_df = None + + return df, test_df + + def save_images_and_update_df(self, df, raw_dir, split): + abs_image_dir = Path(raw_dir, f"{split}_images") + rel_image_dir = f"raw/{split}_images" + abs_image_dir.mkdir(parents=True, exist_ok=True) + + def process_image(idx, row): + image_bytes = row[self.image_col]["bytes"] + image = Image.open(io.BytesIO(image_bytes)) + if image.mode == "RGBA": + image = image.convert("RGB") + img_path = Path(abs_image_dir, f"{idx}.jpg") + rel_img_path = f"{rel_image_dir}/{idx}.jpg" + image.save(img_path) + return rel_img_path + + df["image"] = df.apply(lambda row: process_image(row.name, row), axis=1) + return df + + def get_df_head(self, raw_df): + examples = [] + for i in range(5): + examples.append(raw_df.iloc[i].to_dict()) + return examples + + def get_dataset_info(self): + dataset_info = super().get_dataset_info() + dataset = self.dataset + dataset_info["description"] = dataset["train"].info.description + return dataset_info + + +if __name__ == "__main__": + dataset_dir = DATA_CONFIG["datasets_dir"] + args = parse_args() + force_update = args.force_update + save_analysis_pool = args.save_analysis_pool + datasets_dict = {"datasets": {}} + solution_designer = SolutionDesigner() + for dataset_meta in HFDATSETS: + hf_dataset = HFExpDataset( + dataset_meta["name"], + dataset_dir, + dataset_meta["dataset_name"], + target_col=dataset_meta["target_col"], + image_col=dataset_meta.get("image_col", ""), + force_update=force_update, + modality=dataset_meta["modality"], + ) + asyncio.run(process_dataset(hf_dataset, solution_designer, save_analysis_pool, datasets_dict)) + save_datasets_dict_to_yaml(datasets_dict, "hf_datasets.yaml") diff --git a/metagpt/ext/sela/datasets.yaml b/metagpt/ext/sela/datasets.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d02951d4d2e2ec73bc7b46e417479a3c7a52998 --- /dev/null +++ b/metagpt/ext/sela/datasets.yaml @@ -0,0 +1,225 @@ +datasets: + titanic: + dataset: 04_titanic + metric: f1 + target_col: Survived + user_requirement: "This is a 04_titanic dataset. Your goal is to predict the target\ + \ column `Survived`.\nPerform data analysis, data preprocessing, feature engineering,\ + \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ + \ or make any visualizations.\n" + house-prices: + dataset: 05_house-prices-advanced-regression-techniques + metric: rmse + target_col: SalePrice + user_requirement: "This is a 05_house-prices-advanced-regression-techniques dataset.\ + \ Your goal is to predict the target column `SalePrice`.\nPerform data analysis,\ + \ data preprocessing, feature engineering, and modeling to predict the target.\ + \ \nReport rmse on the eval data. Do not plot or make any visualizations.\n" + santander-customer: + dataset: 06_santander-customer-transaction-prediction + metric: f1 + target_col: target + user_requirement: "This is a 06_santander-customer-transaction-prediction dataset.\ + \ Your goal is to predict the target column `target`.\nPerform data analysis,\ + \ data preprocessing, feature engineering, and modeling to predict the target.\ + \ \nReport f1 on the eval data. Do not plot or make any visualizations.\n" + icr: + dataset: 07_icr-identify-age-related-conditions + metric: f1 + target_col: Class + user_requirement: "This is a 07_icr-identify-age-related-conditions dataset. Your\ + \ goal is to predict the target column `Class`.\nPerform data analysis, data\ + \ preprocessing, feature engineering, and modeling to predict the target. \n\ + Report f1 on the eval data. Do not plot or make any visualizations.\n" + Click_prediction_small: + dataset: Click_prediction_small + metric: f1 + target_col: click + user_requirement: "This is a Click_prediction_small dataset. Your goal is to predict\ + \ the target column `click`.\nPerform data analysis, data preprocessing, feature\ + \ engineering, and modeling to predict the target. \nReport f1 on the eval data.\ + \ Do not plot or make any visualizations.\n" + GesturePhaseSegmentationProcessed: + dataset: GesturePhaseSegmentationProcessed + metric: f1 weighted + target_col: Phase + user_requirement: "This is a GesturePhaseSegmentationProcessed dataset. Your goal\ + \ is to predict the target column `Phase`.\nPerform data analysis, data preprocessing,\ + \ feature engineering, and modeling to predict the target. \nReport f1 weighted\ + \ on the eval data. Do not plot or make any visualizations.\n" + Moneyball: + dataset: Moneyball + metric: rmse + target_col: RS + user_requirement: "This is a Moneyball dataset. Your goal is to predict the target\ + \ column `RS`.\nPerform data analysis, data preprocessing, feature engineering,\ + \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ + \ plot or make any visualizations.\n" + SAT11-HAND-runtime-regression: + dataset: SAT11-HAND-runtime-regression + metric: rmse + target_col: runtime + user_requirement: "This is a SAT11-HAND-runtime-regression dataset. Your goal\ + \ is to predict the target column `runtime`.\nPerform data analysis, data preprocessing,\ + \ feature engineering, and modeling to predict the target. \nReport rmse on\ + \ the eval data. Do not plot or make any visualizations.\n" + boston: + dataset: boston + metric: rmse + target_col: MEDV + user_requirement: "This is a boston dataset. Your goal is to predict the target\ + \ column `MEDV`.\nPerform data analysis, data preprocessing, feature engineering,\ + \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ + \ plot or make any visualizations.\n" + colleges: + dataset: colleges + metric: rmse + target_col: percent_pell_grant + user_requirement: "This is a colleges dataset. Your goal is to predict the target\ + \ column `percent_pell_grant`.\nPerform data analysis, data preprocessing, feature\ + \ engineering, and modeling to predict the target. \nReport rmse on the eval\ + \ data. Do not plot or make any visualizations.\n" + concrete-strength: + dataset: concrete-strength + metric: rmse + target_col: Strength + user_requirement: "This is a concrete-strength dataset. Your goal is to predict\ + \ the target column `Strength`.\nPerform data analysis, data preprocessing,\ + \ feature engineering, and modeling to predict the target. \nReport rmse on\ + \ the eval data. Do not plot or make any visualizations.\n" + credit-g: + dataset: credit-g + metric: f1 + target_col: class + user_requirement: "This is a credit-g dataset. Your goal is to predict the target\ + \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ + \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ + \ or make any visualizations.\n" + diamonds: + dataset: diamonds + metric: rmse + target_col: price + user_requirement: "This is a diamonds dataset. Your goal is to predict the target\ + \ column `price`.\nPerform data analysis, data preprocessing, feature engineering,\ + \ and modeling to predict the target. \nReport rmse on the eval data. Do not\ + \ plot or make any visualizations.\n" + jasmine: + dataset: jasmine + metric: f1 + target_col: class + user_requirement: "This is a jasmine dataset. Your goal is to predict the target\ + \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ + \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ + \ or make any visualizations.\n" + kc1: + dataset: kc1 + metric: f1 + target_col: defects + user_requirement: "This is a kc1 dataset. Your goal is to predict the target column\ + \ `defects`.\nPerform data analysis, data preprocessing, feature engineering,\ + \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ + \ or make any visualizations.\n" + kick: + dataset: kick + metric: f1 + target_col: IsBadBuy + user_requirement: "This is a kick dataset. Your goal is to predict the target\ + \ column `IsBadBuy`.\nPerform data analysis, data preprocessing, feature engineering,\ + \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ + \ or make any visualizations.\n" + mfeat-factors: + dataset: mfeat-factors + metric: f1 weighted + target_col: class + user_requirement: "This is a mfeat-factors dataset. Your goal is to predict the\ + \ target column `class`.\nPerform data analysis, data preprocessing, feature\ + \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ + \ eval data. Do not plot or make any visualizations.\n" + segment: + dataset: segment + metric: f1 weighted + target_col: class + user_requirement: "This is a segment dataset. Your goal is to predict the target\ + \ column `class`.\nPerform data analysis, data preprocessing, feature engineering,\ + \ and modeling to predict the target. \nReport f1 weighted on the eval data.\ + \ Do not plot or make any visualizations.\n" + smoker-status: + dataset: smoker-status + metric: f1 + target_col: smoking + user_requirement: "This is a smoker-status dataset. Your goal is to predict the\ + \ target column `smoking`.\nPerform data analysis, data preprocessing, feature\ + \ engineering, and modeling to predict the target. \nReport f1 on the eval data.\ + \ Do not plot or make any visualizations.\n" + software-defects: + dataset: software-defects + metric: f1 + target_col: defects + user_requirement: "This is a software-defects dataset. Your goal is to predict\ + \ the target column `defects`.\nPerform data analysis, data preprocessing, feature\ + \ engineering, and modeling to predict the target. \nReport f1 on the eval data.\ + \ Do not plot or make any visualizations.\n" + steel-plates-fault: + dataset: steel-plates-fault + metric: f1 weighted + target_col: target + user_requirement: "This is a steel-plates-fault dataset. Your goal is to predict\ + \ the target column `target`.\nPerform data analysis, data preprocessing, feature\ + \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ + \ eval data. Do not plot or make any visualizations.\n" + wine-quality-white: + dataset: wine-quality-white + metric: f1 weighted + target_col: Class + user_requirement: "This is a wine-quality-white dataset. Your goal is to predict\ + \ the target column `Class`.\nPerform data analysis, data preprocessing, feature\ + \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ + \ eval data. Do not plot or make any visualizations.\n" + banking77: + dataset: banking77 + metric: f1 weighted + target_col: label + user_requirement: "This is a banking77 dataset. Your goal is to predict the target\ + \ column `label`.\nPerform data analysis, data preprocessing, feature engineering,\ + \ and modeling to predict the target. \nReport f1 weighted on the eval data.\ + \ Do not plot or make any visualizations.\n" + fashion_mnist: + dataset: fashion_mnist + metric: f1 weighted + target_col: label + user_requirement: "This is a fashion_mnist dataset. Your goal is to predict the\ + \ target column `label`.\nPerform data analysis, data preprocessing, feature\ + \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ + \ eval data. Do not plot or make any visualizations.\n" + gnad10: + dataset: gnad10 + metric: f1 weighted + target_col: label + user_requirement: "This is a gnad10 dataset. Your goal is to predict the target\ + \ column `label`.\nPerform data analysis, data preprocessing, feature engineering,\ + \ and modeling to predict the target. \nReport f1 weighted on the eval data.\ + \ Do not plot or make any visualizations.\n" + oxford-iiit-pet: + dataset: oxford-iiit-pet + metric: f1 weighted + target_col: label + user_requirement: "This is a oxford-iiit-pet dataset. Your goal is to predict\ + \ the target column `label`.\nPerform data analysis, data preprocessing,\ + \ feature engineering, and modeling to predict the target. \nReport f1 weighted on the\ + \ eval data. Do not plot or make any visualizations.\n" + sms_spam: + dataset: sms_spam + metric: f1 + target_col: label + user_requirement: "This is a sms_spam dataset. Your goal is to predict the target\ + \ column `label`.\nPerform data analysis, data preprocessing, feature engineering,\ + \ and modeling to predict the target. \nReport f1 on the eval data. Do not plot\ + \ or make any visualizations.\n" + stanford_cars: + dataset: stanford_cars + metric: f1 weighted + target_col: label + user_requirement: "This is a stanford_cars dataset. Your goal is to predict the\ + \ target column `label`.\nPerform data analysis, data preprocessing, feature\ + \ engineering, and modeling to predict the target. \nReport f1 weighted on the\ + \ eval data. Do not plot or make any visualizations.\n" diff --git a/metagpt/ext/sela/evaluation/evaluation.py b/metagpt/ext/sela/evaluation/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..1e58e1725b455562d7d873cae1311531bc20f32e --- /dev/null +++ b/metagpt/ext/sela/evaluation/evaluation.py @@ -0,0 +1,49 @@ +from pathlib import Path + +import numpy as np +from sklearn.metrics import accuracy_score, f1_score, mean_squared_error, roc_auc_score + + +def evaluate_score(pred, gt, metric): + if metric == "accuracy": + return accuracy_score(gt, pred) + elif metric == "f1": + unique_classes = sorted(list(np.unique(gt))) + if 1 in unique_classes and 0 in unique_classes: + pos_label = 1 + else: + pos_label = unique_classes[0] if len(unique_classes) == 2 else None + return f1_score(gt, pred, pos_label=pos_label) + elif metric == "f1 weighted": + return f1_score(gt, pred, average="weighted") + elif metric == "roc_auc": + return roc_auc_score(gt, pred) + elif metric == "rmse": + return mean_squared_error(gt, pred, squared=False) + elif metric == "log rmse": + return mean_squared_error(np.log1p(gt), np.log1p(pred), squared=False) + else: + raise ValueError(f"Metric {metric} not supported") + + +def node_evaluate_score_sela(node): + preds = node.get_and_move_predictions("test")["target"] + gt = node.get_gt("test")["target"] + metric = node.state["dataset_config"]["metric"] + return evaluate_score(preds, gt, metric) + + +def node_evaluate_score_mlebench(node): + # TODO + from mlebench.grade import grade_csv + from mlebench.registry import registry + + competition_id = node.state["task"] + data_dir = Path(node.state["custom_dataset_dir"]).parent.parent.parent # prepared/public/../../../ + pred_path = node.get_predictions_path("test") + new_registry = registry.set_data_dir(data_dir) + competition = new_registry.get_competition(competition_id) + submission = Path(pred_path) + report = grade_csv(submission, competition).to_dict() + report["submission_path"] = str(submission) + return report diff --git a/metagpt/ext/sela/evaluation/visualize_mcts.py b/metagpt/ext/sela/evaluation/visualize_mcts.py new file mode 100644 index 0000000000000000000000000000000000000000..6f803a91cb896e36ce8f3e029aa5f6191ca4b129 --- /dev/null +++ b/metagpt/ext/sela/evaluation/visualize_mcts.py @@ -0,0 +1,163 @@ +import textwrap + +import matplotlib.pyplot as plt +import networkx as nx + +from metagpt.ext.sela.search.tree_search import Node + +NODE_TEMPLATE = """\ +[Node {id}] +Plans: +{plans} +Simulated: {simulated} +Score: {score}, Visits: {num_visits} + +""" + +NODE_SIZE = 12000 +NODE_FONT_SIZE = 18 + + +def get_role_plans(role): + plans = role.planner.plan.tasks + instruct_plans = [f"{i+1}. {task.instruction}" for i, task in enumerate(plans)] + return instruct_plans + + +def get_tree_text(node: Node): + role_dict = {} + code_set = set() + + def load_role(node): + if node.id not in role_dict: + role_dict[node.id] = node.load_role() + return role_dict[node.id] + + def visualize_node(node: Node, previous_plans=None): + role = load_role(node) + node_id = node.id + plans = role.planner.plan.tasks + instruct_plans = [f"{i+1}. {task.instruction}" for i, task in enumerate(plans)] + if previous_plans is not None: + instruct_plans = [plan for plan, prev_plan in zip(instruct_plans, previous_plans) if plan != prev_plan] + instruct_plans_text = "\n".join(instruct_plans) + simulated = role.state_saved + score = f"avg score: {node.avg_value()}, simulated score: {node.raw_reward}" + num_visits = node.visited + return NODE_TEMPLATE.format( + id=node_id, plans=instruct_plans_text, simulated=simulated, score=score, num_visits=num_visits + ) + + def visualize_tree_text(node, depth=0, previous_plans=None): + text = "" + if node is not None: + text += visualize_node(node, previous_plans) + role = load_role(node) + code_set.update({task.instruction for task in role.planner.plan.tasks}) + previous_plans = get_role_plans(role) + for child in node.children: + text += textwrap.indent(visualize_tree_text(child, depth + 1, previous_plans), "\t") + return text + + num_simulations = node.visited + text = f"Number of simulations: {num_simulations}\n" + text += visualize_tree_text(node) + return text, len(code_set) + + +def get_node_color(node): + if node["visits"] == 0: + return "#D3D3D3" + else: + # The higher the avg_value, the more intense the color + # avg_value is between 0 and 1 + avg_value = node["avg_value"] + # Convert avg_value to a color ranging from red (low) to green (high) + red = int(255 * (1 - avg_value)) + green = int(255 * avg_value) + return f"#{red:02X}{green:02X}00" + + +def visualize_tree(graph, show_instructions=False, save_path=""): + # Use a hierarchical layout for tree-like visualization + pos = nx.spring_layout(graph, k=0.9, iterations=50) + + plt.figure(figsize=(30, 20)) # Further increase figure size for better visibility + + # Calculate node levels + root = "0" + levels = nx.single_source_shortest_path_length(graph, root) + max_level = max(levels.values()) + + # Adjust y-coordinates based on levels and x-coordinates to prevent overlap + nodes_by_level = {} + for node, level in levels.items(): + if level not in nodes_by_level: + nodes_by_level[level] = [] + nodes_by_level[level].append(node) + + for level, nodes in nodes_by_level.items(): + y = 1 - level / max_level + x_step = 1.0 / (len(nodes) + 1) + for i, node in enumerate(sorted(nodes)): + pos[node] = ((i + 1) * x_step, y) + + # Draw edges + nx.draw_networkx_edges(graph, pos, edge_color="gray", arrows=True, arrowsize=40, width=3) + + # Draw nodes + node_colors = [get_node_color(graph.nodes[node]) for node in graph.nodes] + nx.draw_networkx_nodes(graph, pos, node_size=NODE_SIZE, node_color=node_colors) + + # Add labels to nodes + labels = nx.get_node_attributes(graph, "label") + nx.draw_networkx_labels(graph, pos, labels, font_size=NODE_FONT_SIZE) + + if show_instructions: + # Add instructions to the right side of nodes + instructions = nx.get_node_attributes(graph, "instruction") + for node, (x, y) in pos.items(): + wrapped_text = textwrap.fill(instructions[node], width=30) # Adjust width as needed + plt.text(x + 0.05, y, wrapped_text, fontsize=15, ha="left", va="center") + + plt.title("MCTS Tree Visualization", fontsize=40) + plt.axis("off") # Turn off axis + plt.tight_layout() + if save_path: + plt.savefig(save_path) + plt.show() + + +def build_tree_recursive(graph, parent_id, node, node_order, start_task_id=2): + """ + Recursively builds the entire tree starting from the root node. + Adds nodes and edges to the NetworkX graph. + """ + role = node.load_role() + depth = node.get_depth() + if depth == 0: + instruction = "\n\n".join([role.planner.plan.tasks[i].instruction for i in range(start_task_id)]) + else: + instruction = role.planner.plan.tasks[depth + start_task_id - 1].instruction + print(instruction) + # Add the current node with attributes to the graph + dev_score = node.raw_reward.get("dev_score", 0) * 100 + avg_score = node.avg_value() * 100 + order = node_order.index(node.id) if node.id in node_order else "" + graph.add_node( + parent_id, + label=f"{node.id}\nAvg: {avg_score:.1f}\nScore: {dev_score:.1f}\nVisits: {node.visited}\nOrder: {order}", + avg_value=node.avg_value(), + dev_score=dev_score, + visits=node.visited, + instruction=instruction, + ) + # Stopping condition: if the node has no children, return + if not node.children: + return + + # Recursively create all child nodes + for i, child in enumerate(node.children): + child_id = f"{parent_id}-{i}" + graph.add_edge(parent_id, child_id) + build_tree_recursive(graph, child_id, child, node_order) diff --git a/metagpt/ext/sela/experimenter.py b/metagpt/ext/sela/experimenter.py new file mode 100644 index 0000000000000000000000000000000000000000..b05ea2fc36c866682948fa526ce058229f92f3cd --- /dev/null +++ b/metagpt/ext/sela/experimenter.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import asyncio +import json +import os + +from pydantic import model_validator + +from metagpt.actions.di.write_analysis_code import WriteAnalysisCode +from metagpt.const import SERDESER_PATH +from metagpt.ext.sela.utils import mcts_logger, save_notebook +from metagpt.roles.di.data_interpreter import DataInterpreter +from metagpt.schema import Message, Task, TaskResult +from metagpt.utils.common import CodeParser, write_json_file + +CODE_BLOCK_RESULT = """ +## Code: +{code} + +## Execution Result: +{result} +""" + +EXTRACT_SCORE_PROMPT = """ +# Code Blocks +{code_block} +# Instruction: +Based on the code and execution result, please extract the **final scores** and return it as a dictionary. +If you cannot find the scores, please still return a dictionary with the keys 'train_score', 'dev_score', and 'test_score', and set the values to -1. + +# Format: +```json +{{ + "train_score": x.x, + "dev_score": x.x, + "test_score": x.x +}} +``` +""" + + +class TimeoutException(Exception): + pass + + +def async_timeout(): + def decorator(func): + async def wrapper(self, *args, **kwargs): + try: + result = await asyncio.wait_for(func(self, *args, **kwargs), timeout=self.role_timeout) + except asyncio.TimeoutError: + text = f"Function timed out after {self.role_timeout} seconds" + mcts_logger.error(text) + self.save_state() + raise TimeoutException(text) + return result + + return wrapper + + return decorator + + +class Experimenter(DataInterpreter): + node_id: str = "0" + start_task_id: int = 1 + state_saved: bool = False + role_dir: str = SERDESER_PATH.joinpath("team", "environment", "roles", "Experimenter") + role_timeout: int = 1000 + + def get_node_name(self): + return f"Node-{self.node_id}" + + def get_next_instruction(self): + return self.planner.plan.tasks[self.start_task_id].instruction + + def change_next_instruction(self, new_instruction): + if new_instruction is not None: + self.planner.plan.task_map[str(self.start_task_id)].instruction = new_instruction + self.remap_tasks() + + def update_til_start_task(self, role: Experimenter, backward: bool = True): + if backward: + # make sure the previous task instructions are matched + assert ( + self.start_task_id == role.start_task_id - 1 + ), f"start_task_id: {self.start_task_id}, role.start_task_id: {role.start_task_id}" + for i in range(self.start_task_id): + if ( + self.planner.plan.task_map[str(self.start_task_id)].instruction + != role.planner.plan.task_map[str(self.start_task_id)].instruction + ): + mcts_logger.info("Previous task instructions not matched") + self.remap_tasks() + return + # copy new role's task (self.start_task_id) to current role + self.planner.plan.task_map[str(self.start_task_id)] = role.planner.plan.task_map[ + str(self.start_task_id) + ].model_copy() + self.remap_tasks() + + else: + assert ( + self.start_task_id == role.start_task_id + 1 + ), f"start_task_id: {self.start_task_id}, role.start_task_id: {role.start_task_id}" + if int(role.planner.plan.current_task_id) > self.start_task_id: + for i in range(role.start_task_id): + self.planner.plan.task_map[str(i)] = role.planner.plan.task_map[str(i)].model_copy() + self.remap_tasks() + + async def get_score(self): + score_dict = await self.llm_extract_score() + score_dict["score"] = score_dict["dev_score"] + return score_dict + + async def llm_extract_score(self): + # result_text = self.planner.plan.task_map[str(len(self.planner.plan.task_map))].result + # code_text = self.planner.plan.task_map[str(len(self.planner.plan.task_map))].code + num_tasks = len(self.planner.plan.task_map) + task_map = self.planner.plan.task_map + code_block = "\n".join( + [ + CODE_BLOCK_RESULT.format(code=task_map[str(i + 1)].code, result=task_map[str(i + 1)].result) + for i in range(num_tasks) + ] + ) + rsp = await self.llm.aask(EXTRACT_SCORE_PROMPT.format(code_block=code_block, role="user")) + json_block = CodeParser.parse_code(block=None, text=rsp) + score_dict = json.loads(json_block) + return score_dict + + @model_validator(mode="after") + def set_plan_and_tool(self) -> "Interpreter": + if self.planner.plan.goal != "": + self.set_actions([WriteAnalysisCode]) + self._set_state(0) + print("Plan already exists, skipping initialization.") + return self + print("Initializing plan and tool...") + return super().set_plan_and_tool() + + async def _act_on_task(self, current_task: Task) -> TaskResult: + """Useful in 'plan_and_act' mode. Wrap the output in a TaskResult for review and confirmation.""" + mcts_logger.info(f"The current_task is: {current_task}") + code, result, is_success = await self._write_and_exec_code() + task_result = TaskResult(code=code, result=result, is_success=is_success) + if int(current_task.task_id) == self.start_task_id + 1: + # fe_id = current_task.dependent_task_ids + self.save_state() + save_notebook(role=self, save_dir=self.role_dir, name=self.get_node_name(), save_to_depth=True) + else: + save_notebook(role=self, save_dir=self.role_dir, name=self.get_node_name()) + return task_result + + def get_solution(self): + codes = [task.code for task in self.planner.plan.tasks] + results = [task.result for task in self.planner.plan.tasks] + return {"codes": codes, "results": results} + + def save_state(self, static_save=False): + """ + attribute: + state_saved - the state has been saved + input: + static_save - saving the state without changing the state_saved flag - used when a new role is created + """ + if self.state_saved and not static_save: + return + if not static_save: + self.state_saved = True + mcts_logger.log("MCTS", f"Saving state at task {self.start_task_id}") + else: + mcts_logger.log("MCTS", "Static Saving") + stg_path = self.role_dir + name = self.get_node_name() + role_path = os.path.join(stg_path, f"{name}.json") + # save state as json file + write_json_file(role_path, self.model_dump()) + + def remap_tasks(self): + self.planner.plan.tasks = [ + self.planner.plan.task_map[task_id] for task_id in sorted(self.planner.plan.task_map.keys()) + ] + + @async_timeout() + async def run(self, with_message=None) -> Message | None: + """Observe, and think and act based on the results of the observation""" + if with_message == "continue": + mcts_logger.info("Continue to run") + self.rc.working_memory.clear() + self.working_memory.clear() + rsp = await self.react() + self.set_todo(None) + self.publish_message(rsp) + return rsp + return await super().run(with_message) diff --git a/metagpt/ext/sela/insights/fixed_insights.json b/metagpt/ext/sela/insights/fixed_insights.json new file mode 100644 index 0000000000000000000000000000000000000000..4f42b9db164c27a8f0aa2d554a2d9d00f76f56c7 --- /dev/null +++ b/metagpt/ext/sela/insights/fixed_insights.json @@ -0,0 +1,22 @@ +[ +{ + "Analysis": "Use early stopping, hyperparameter tuning, and cross-validation to avoid overfitting and improve robustness of the model.", + "Category": "Model Training", + "task_id": 4 +}, +{ + "Analysis": "use k-fold bagging and early stopping", + "Category": "Model Training", + "task_id": 4 +}, +{ + "Analysis": "To avoid overfitting, train a weighted ensemble model such as StackingClassifier or StackingRegressor; You could do some quick model prototyping to see which models work best and then use them in the ensemble.", + "Category": "Model Training", + "task_id": 4 +}, +{ + "Analysis": "Please use autogluon for model training with presets='medium_quality', time_limit=None, give dev dataset to tuning_data, and use right eval_metric.", + "Category": "Model Training", + "task_id": 4 +} +] \ No newline at end of file diff --git a/metagpt/ext/sela/insights/instruction_generator.py b/metagpt/ext/sela/insights/instruction_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d24c74de72ad0a20de3ce03cdbd622398807ba --- /dev/null +++ b/metagpt/ext/sela/insights/instruction_generator.py @@ -0,0 +1,169 @@ +import json +import os +import random +from difflib import SequenceMatcher + +from metagpt.ext.sela.insights.solution_designer import SolutionDesigner +from metagpt.ext.sela.utils import clean_json_from_rsp, load_data_config, mcts_logger +from metagpt.llm import LLM +from metagpt.schema import Message + +REFLECTION_SYSTEM_MSG = "As a Kaggle Grandmaster competing in a challenge, your task is to suggest potential evolutionary improvements that could enhance the performance of the baseline code." + +CHANGE_INSTRUCTION = """ +# Original instruction +{instruction} + +# Insights +{insights} + +Rewrite the original instruction according to the insights +(If the original instruction involves splitting the data, ensure that your insights are integrated with the data split instructions, +rather than replacing them.) + +# Expected Output Hard Format +```json +{{ + "Original Instruction": "original instruction", + "New Instruction": "new instruction" +}} +``` +""" + +DATA_CONFIG = load_data_config() + + +class InstructionGenerator: + data_config = DATA_CONFIG + + def __init__(self, state, use_fixed_insights, from_scratch): + self.state = state + self.file_path = state["exp_pool_path"] + if state["custom_dataset_dir"]: + with open(f"{state['custom_dataset_dir']}/description.md", "r", encoding="utf-8") as file: + self.dataset_info = file.read() + else: + dataset_info_path = ( + f"{self.data_config['datasets_dir']}/{state['dataset_config']['dataset']}/dataset_info.json" + ) + with open(dataset_info_path, "r") as file: + self.dataset_info = json.load(file) + self.use_fixed_insights = use_fixed_insights + self.proposer = SolutionDesigner() + if self.file_path is None: + self.from_scratch = True + else: + self.from_scratch = from_scratch + + async def initialize(self): + if self.from_scratch: + self.insight_pool = await self.generate_solutions_from_scratch(self.dataset_info, self.state["task"]) + else: + self.insight_pool = self.load_insight_pool(self.file_path, self.use_fixed_insights) + + @staticmethod + def load_json_data(json_dir): + with open(json_dir, "r") as file: + json_data = json.load(file) + return json_data + + @staticmethod + def _random_sample(analysis, num_samples): + return random.sample(analysis, num_samples) + + @staticmethod + def sample_instruction_set(data): + data_dict = {} + for item in data: + task_id = item["task_id"] + if task_id not in data_dict: + data_dict[task_id] = [] + data_dict[task_id].append(item) + instruction_set = [] + for task_id in sorted(data_dict.keys()): + instruction_set.append(random.choice(data_dict[task_id])) + return instruction_set + + @staticmethod + def format_output(rsp): + rsp_list = [] + new_data = [] + rsp_list.append(rsp) + for item in rsp_list: + item_dict = json.loads(item) + data = { + "Insights": item_dict, + } + new_data.append(data) + return new_data + + @staticmethod + def load_insight_pool(file_path, use_fixed_insights, task_id=None): + data = InstructionGenerator.load_json_data(file_path) + if use_fixed_insights: + current_directory = os.path.dirname(__file__) + fixed_insights = InstructionGenerator.load_json_data(f"{current_directory}/fixed_insights.json") + data.extend(fixed_insights) + for item in data: + if "task_id" not in item: + raise ValueError("task_id is not found in the insight_pool") + + if task_id: + data = [item for item in data if int(item["task_id"]) == int(task_id)] + return data + + async def generate_new_instructions(self, task_id, original_instruction, max_num, ext_info=None): + data = self.insight_pool + new_instructions = [] + if len(data) == 0: + mcts_logger.log("MCTS", f"No insights available for task {task_id}") + # return [original_instruction] # Return the original instruction if no insights are available + for i in range(max_num): + if len(data) == 0: + insights = "No insights available" + else: + item = data[i] + insights = item["Analysis"] + new_instruction = await InstructionGenerator.generate_new_instruction( + original_instruction, insights, ext_info + ) + new_instructions.append(new_instruction) + return new_instructions + + async def propose_new_insights(self, solution, score): + new_insights = await self.proposer.propose_insights(solution, score) + added_insights = self.add_insight(new_insights) + return added_insights + + async def generate_solutions_from_scratch(self, dataset_info, dataset_name): + insight_pool = await self.proposer.generate_solutions(dataset_info, dataset_name, save_analysis_pool=False) + return insight_pool + + def add_insight(self, new_insights): + added_insights = [] + for new_insight in new_insights: + if not self.is_similar_to_existing(new_insight): + added_insights.append(new_insight) + self.insight_pool.append(new_insight) + return added_insights + + def is_similar_to_existing(self, new_insight, similarity_threshold=0.8): + for existing_insight in self.insight_pool: + similarity = self.calculate_similarity(new_insight["Analysis"], existing_insight["Analysis"]) + if similarity > similarity_threshold: + return True + return False + + @staticmethod + def calculate_similarity(text1, text2): + return SequenceMatcher(None, text1, text2).ratio() + + @staticmethod + async def generate_new_instruction(original_instruction, insights, ext_info): + prompt = CHANGE_INSTRUCTION.format(instruction=original_instruction, insights=insights) + llm = LLM() + context = llm.format_msg([Message(content=prompt, role="user")]) + llm_response = await llm.aask(context, system_msgs=[REFLECTION_SYSTEM_MSG]) + rsp = clean_json_from_rsp(llm_response) + new_instruction = json.loads(rsp)["New Instruction"] + return new_instruction diff --git a/metagpt/ext/sela/insights/solution_designer.py b/metagpt/ext/sela/insights/solution_designer.py new file mode 100644 index 0000000000000000000000000000000000000000..1b61c2141ae0403fda8f06727dfbcddf1a2b329f --- /dev/null +++ b/metagpt/ext/sela/insights/solution_designer.py @@ -0,0 +1,183 @@ +import json + +from metagpt.ext.sela.utils import clean_json_from_rsp, load_data_config +from metagpt.llm import LLM + +DATA_CONFIG = load_data_config() + + +DATASET_DESCRIPTION_SELA_PROMPT = """ +# Dataset Description +{dataset} + +# Dataset Metadata +{metadata} + +# Dataset Head +{head} +""" + +DATASET_DESCRIPTION_CUSTOM_PROMPT = """ +# Dataset Description +{dataset_description} +""" + +DATASET_INSIGHT_PROMPT = """ +{description} + +# Instruction +Propose insights to help improve the performance of the model on this dataset. +The insights should be proposed based on the dataset description with different task types. +Each task type should have at least 5 insights. +Make sure each method is diverse enough and can be implemented separately. +Be specific about models' choices, ensemble and tuning techniques, and preprocessing & feature engineering techniques. +Your model choices should be advanced enough to be helpful. + +# Format +```json +[ + {{ + "task_type": "EDA", + "insights": [ + "insight1", + "insight2", + "insight3", + ... + "insightN" + ] + }}, + {{ + "task_type": "Data Preprocessing", + "insights": [ + "insight1", + "insight2", + "insight3", + ... + "insightN" + ] + }}, + {{ + "task_type": "Feature Engineering", + "insights": [ + "insight1", + "insight2", + "insight3", + ... + "insightN" + ] + }}, + {{ + "task_type": "Model Training", + "insights": [ + "insight1", + "insight2", + "insight3", + ... + "insightN" + ] + }} +] +``` +""" + + +INSIGHT_PROPOSAL_PROMPT = """ +You are an AI assistant tasked with analyzing a machine learning solution and proposing new insights to improve its performance. Given the current solution code and development score, suggest innovative approaches to enhance the model. + +Current Solution Code: +{solution_code} + +Development Score: {dev_score} + +Based on this information, propose 3-5 new insights across different aspects of the machine learning pipeline (Data Preprocessing, Feature Engineering, and Model Training). Your insights should be specific, actionable, and have the potential to improve the model's performance. + +Please format your response as a JSON array with the following structure: +[ + + {{ + "task_type": "Data Preprocessing", + "insights": [ + "insight1", + "insight2" + ] + }}, + {{ + "task_type": "Feature Engineering", + "insights": [ + "insight1", + "insight2" + ] + }}, + {{ + "task_type": "Model Training", + "insights": [ + "insight1", + "insight2" + ] + }} +] +""" + + +KEY_DATASET_FEATURES = [ + "NumberOfClasses", + "NumberOfFeatures", + "NumberOfInstances", + "NumberOfInstancesWithMissingValues", + "NumberOfMissingValues", + "NumberOfNumericFeatures", + "NumberOfSymbolicFeatures", +] + +TASK_TO_ID = {"EDA": 1, "Data Preprocessing": 2, "Feature Engineering": 3, "Model Training": 4, "Model Evaluation": 5} + + +class SolutionDesigner: + data_dir: str = DATA_CONFIG["datasets_dir"] + + async def generate_solutions(self, dataset_info, dataset_name, save_analysis_pool=True): + llm = LLM() + if type(dataset_info) == dict: + description_prompt = DATASET_DESCRIPTION_SELA_PROMPT.format( + dataset=dataset_info["description"], + metadata=self.metadata_builder(dataset_info["metadata"]), + head=dataset_info["df_head"], + ) + else: + description_prompt = DATASET_DESCRIPTION_CUSTOM_PROMPT.format(dataset_description=dataset_info) + context = DATASET_INSIGHT_PROMPT.format(description=description_prompt) + rsp = await llm.aask(context) + rsp = clean_json_from_rsp(rsp) + analysis_pool = self.process_analysis_pool(json.loads(rsp)) + if save_analysis_pool: + dataset_path = f"{self.data_dir}/{dataset_name}" + self.save_analysis_pool(dataset_path, analysis_pool) + return analysis_pool + + async def propose_new_insights(self, solution, score): + llm = LLM() + context = INSIGHT_PROPOSAL_PROMPT.format(solution_code=solution, dev_score=score) + rsp = await llm.aask(context) + rsp = clean_json_from_rsp(rsp) + new_insights = self.process_analysis_pool(json.loads(rsp)) + return new_insights + + def process_analysis_pool(self, insights_rsp): + analysis_pool = [] + for task_type_insights in insights_rsp: + task_type = task_type_insights["task_type"] + for insight in task_type_insights["insights"]: + analysis_pool.append({"Analysis": insight, "Category": task_type, "task_id": TASK_TO_ID[task_type]}) + return analysis_pool + + def metadata_builder(self, qualities): + metadata = {} + for key in KEY_DATASET_FEATURES: + metadata[key] = qualities.get(key, "N/A") + metadata_text = json.dumps(metadata, indent=4) + return metadata_text + + def save_analysis_pool(self, dataset_path, analysis_pool): + fpath = f"{dataset_path}/ds_analysis_pool.json" + with open(fpath, "w") as file: + json.dump(analysis_pool, file, indent=4) diff --git a/metagpt/ext/sela/requirements.txt b/metagpt/ext/sela/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e85818bbea10a40f27d86c80fe5ff1efa212b4c9 --- /dev/null +++ b/metagpt/ext/sela/requirements.txt @@ -0,0 +1,6 @@ +# expo +openml==0.14.2 +# ml module to run in DI +xgboost +catboost +lightgbm diff --git a/metagpt/ext/sela/run_experiment.py b/metagpt/ext/sela/run_experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..32130a6fb4d0f84205d5676b52b4260bd6a5b5c1 --- /dev/null +++ b/metagpt/ext/sela/run_experiment.py @@ -0,0 +1,99 @@ +import argparse +import asyncio + +from metagpt.ext.sela.data.custom_task import get_mle_is_lower_better, get_mle_task_id +from metagpt.ext.sela.runner.autogluon import GluonRunner +from metagpt.ext.sela.runner.autosklearn import AutoSklearnRunner +from metagpt.ext.sela.runner.custom import CustomRunner +from metagpt.ext.sela.runner.mcts import MCTSRunner +from metagpt.ext.sela.runner.random_search import RandomSearchRunner +from metagpt.ext.sela.runner.runner import Runner + + +def get_args(cmd=True): + parser = argparse.ArgumentParser() + parser.add_argument("--name", type=str, default="") + parser.add_argument( + "--exp_mode", + type=str, + default="mcts", + choices=["mcts", "rs", "base", "custom", "greedy", "autogluon", "random", "autosklearn"], + ) + parser.add_argument("--role_timeout", type=int, default=1000) + get_di_args(parser) + get_mcts_args(parser) + get_rs_exp_args(parser) + if cmd: + args = parser.parse_args() + else: + args = parser.parse_args("") + + if args.custom_dataset_dir: + args.external_eval = False + args.eval_func = "mlebench" + args.from_scratch = True + args.task = get_mle_task_id(args.custom_dataset_dir) + args.low_is_better = get_mle_is_lower_better(args.task) + return args + + +def get_mcts_args(parser): + parser.add_argument("--load_tree", dest="load_tree", action="store_true") + parser.add_argument("--no_load_tree", dest="load_tree", action="store_false") + parser.set_defaults(load_tree=False) + parser.add_argument("--rollouts", type=int, default=5) + parser.add_argument("--use_fixed_insights", dest="use_fixed_insights", action="store_true") + parser.set_defaults(use_fixed_insights=False) + parser.add_argument("--start_task_id", type=int, default=2) + parser.add_argument( + "--from_scratch", dest="from_scratch", action="store_true", help="Generate solutions from scratch" + ) + parser.set_defaults(from_scratch=False) + parser.add_argument("--no_external_eval", dest="external_eval", action="store_false") + parser.set_defaults(external_eval=True) + parser.add_argument("--eval_func", type=str, default="sela", choices=["sela", "mlebench"]) + parser.add_argument("--custom_dataset_dir", type=str, default=None) + parser.add_argument("--max_depth", type=int, default=4) + + +def get_rs_exp_args(parser): + parser.add_argument("--rs_mode", type=str, default="single", choices=["single", "set"]) + parser.add_argument("--is_multimodal", action="store_true", help="Specify if the model is multi-modal") + + +def get_di_args(parser): + parser.add_argument("--task", type=str, default="titanic") + parser.add_argument("--low_is_better", dest="low_is_better", action="store_true") + parser.set_defaults(low_is_better=False) + parser.add_argument("--reflection", dest="reflection", action="store_true") + parser.add_argument("--no_reflection", dest="reflection", action="store_false") + parser.add_argument("--num_experiments", type=int, default=1) + parser.add_argument("--special_instruction", type=str, default=None, choices=["ag", "stacking", "text", "image"]) + parser.set_defaults(reflection=True) + + +async def main(args): + if args.exp_mode == "mcts": + runner = MCTSRunner(args) + elif args.exp_mode == "greedy": + runner = MCTSRunner(args, tree_mode="greedy") + elif args.exp_mode == "random": + runner = MCTSRunner(args, tree_mode="random") + elif args.exp_mode == "rs": + runner = RandomSearchRunner(args) + elif args.exp_mode == "base": + runner = Runner(args) + elif args.exp_mode == "autogluon": + runner = GluonRunner(args) + elif args.exp_mode == "custom": + runner = CustomRunner(args) + elif args.exp_mode == "autosklearn": + runner = AutoSklearnRunner(args) + else: + raise ValueError(f"Invalid exp_mode: {args.exp_mode}") + await runner.run_experiment() + + +if __name__ == "__main__": + args = get_args() + asyncio.run(main(args)) diff --git a/metagpt/ext/sela/runner/README.md b/metagpt/ext/sela/runner/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4867aa4f09ea88822a38e8a3f080696a70852114 --- /dev/null +++ b/metagpt/ext/sela/runner/README.md @@ -0,0 +1,168 @@ +# SELA: Tree-Search Enhanced LLM Agents for Automated Machine Learning + +This document provides instructions for running baseline models. To start with, ensure that you prepare the datasets as instructed in `sela/README.md`. + +## Baselines + +### 1. AIDE + +#### Setup + +We use the AIDE version from September 30, 2024. Clone the repository and check out the specified commit: + +```bash +git clone https://github.com/WecoAI/aideml.git +git checkout 77953247ea0a5dc1bd502dd10939dd6d7fdcc5cc +``` + + +Modify `aideml/aide/utils/config.yaml` to set the following parameters: + +```yaml +# agent hyperparams +agent: + steps: 10 # Number of improvement iterations + k_fold_validation: 1 # Set to 1 to disable cross-validation + code: + model: deepseek-coder + temp: 0.5 + feedback: + model: deepseek-coder + temp: 0.5 + search: + max_debug_depth: 3 + debug_prob: 0.5 + num_drafts: 5 +``` + +Update your OpenAI API credentials in the environment: + +```bash +export OPENAI_API_KEY="your api key" +export OPENAI_BASE_URL="your own url" +``` + +Modify `aideml/aide/backend/__init__.py` (line 30 and below): + +```python +model_kwargs = model_kwargs | { + "model": model, + "temperature": temperature, + "max_tokens": max_tokens, + } +if "claude-" in model: + query_func = backend_anthropic.query +else: + query_func = backend_openai.query +``` + +Since Deepseek V2.5 no longer supports system messages using function calls, modify `aideml/aide/agent.py` (line 312): + +```python +response = cast( + dict, + query( + system_message=None, + user_message=prompt, + func_spec=review_func_spec, + model=self.acfg.feedback.model, + temperature=self.acfg.feedback.temp, + ), +) +``` + +Finally, install AIDE: + +```bash +cd aideml +pip install -e . +``` + +#### Run + +Execute the following script to generate results. A `log` folder (containing experimental configurations) and a `workspace` folder (storing final results) will be created: + +```bash +python runner/aide.py +``` + +--- + +### 2. Autogluon + +#### Setup + +Install Autogluon: + +```bash +pip install -U pip +pip install -U setuptools wheel +pip install autogluon==1.1.1 +``` + +#### Run + +For Tabular data: + +```bash +python run_experiment.py --exp_mode autogluon --task {task_name} +``` + +For Multimodal data: + +```bash +python run_experiment.py --exp_mode autogluon --task {task_name} --is_multimodal +``` + +Replace `{task_name}` with the specific task you want to run. + +--- + +### 3. AutoSklearn + +**Note:** +AutoSklearn requires: +- Linux operating system (e.g., Ubuntu) +- Python (>=3.7) +- C++ compiler (with C++11 support) + +If installing on a system without wheel files for the `pyrfr` package, you also need: + +- [SWIG](https://www.swig.org/survey.html) + +Refer to the [Windows/macOS compatibility](https://automl.github.io/auto-sklearn/master/installation.html#windows-macos-compatibility) section for further details. + +#### Setup + +Install AutoSklearn: + +```bash +pip install auto-sklearn==0.15.0 +``` + +#### Run + +Execute the following command for the Titanic task: + +```bash +python run_experiment.py --exp_mode autosklearn --task titanic +``` + +--- + +### 4. Base Data Interpreter + +Run the following command for the Titanic task: + +```bash +python run_experiment.py --exp_mode base --task titanic --num_experiments 10 +``` + +--- + +### 5. Custom Baselines + +To run additional baselines: + +- Each baseline must produce `dev_predictions.csv` and `test_predictions.csv` with a `target` column. +- Use the `evaluate_score` function for evaluation. \ No newline at end of file diff --git a/metagpt/ext/sela/runner/__init__.py b/metagpt/ext/sela/runner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/metagpt/ext/sela/runner/aide.py b/metagpt/ext/sela/runner/aide.py new file mode 100644 index 0000000000000000000000000000000000000000..50fae94c140563c1e0cb9ba9fec259ec7b6bf8eb --- /dev/null +++ b/metagpt/ext/sela/runner/aide.py @@ -0,0 +1,35 @@ +import os +import time + +import aide + +os.environ["OPENAI_API_KEY"] = "sk-xxx" +os.environ["OPENAI_BASE_URL"] = "your url" + +start_time = time.time() + +data_dir = "xxx/data/titanic" + +goal = f""" +# User requirement +({data_dir}, 'This is a 04_titanic dataset. Your goal is to predict the target column `Survived`.\nPerform data analysis, data preprocessing, feature engineering, and modeling to predict the target. \nReport f1 on the eval data. Do not plot or make any visualizations.\n') + +# Data dir +training (with labels): train.csv +testing (without labels): test.csv +dataset description: dataset_info.json (You can use this file to get additional information about the dataset)""" + +exp = aide.Experiment( + data_dir=data_dir, # replace this with your own directory + goal=goal, + eval="f1", # replace with your own evaluation metric +) + +best_solution = exp.run(steps=10) + +print(f"Best solution has validation metric: {best_solution.valid_metric}") +print(f"Best solution code: {best_solution.code}") +end_time = time.time() +execution_time = end_time - start_time + +print(f"run time : {execution_time} seconds") diff --git a/metagpt/ext/sela/runner/autogluon.py b/metagpt/ext/sela/runner/autogluon.py new file mode 100644 index 0000000000000000000000000000000000000000..48737da045c796501f6bade010fa45ed07712bba --- /dev/null +++ b/metagpt/ext/sela/runner/autogluon.py @@ -0,0 +1,128 @@ +import os +from datetime import datetime + +import pandas as pd + +from metagpt.ext.sela.runner.custom import CustomRunner + + +class AGRunner: + def __init__(self, state=None): + self.state = state + self.datasets = self.state["datasets_dir"] + + def run(self): + from autogluon.tabular import TabularDataset, TabularPredictor + + train_path = self.datasets["train"] + dev_path = self.datasets["dev"] + dev_wo_target_path = self.datasets["dev_wo_target"] + test_wo_target_path = self.datasets["test_wo_target"] + target_col = self.state["dataset_config"]["target_col"] + train_data = TabularDataset(train_path) + dev_data = TabularDataset(dev_path) + dev_wo_target_data = TabularDataset(dev_wo_target_path) + test_data = TabularDataset(test_wo_target_path) + eval_metric = self.state["dataset_config"]["metric"].replace(" ", "_") + predictor = TabularPredictor( + label=target_col, + eval_metric=eval_metric, + path="AutogluonModels/ag-{}-{}".format(self.state["task"], datetime.now().strftime("%y%m%d_%H%M")), + ).fit(train_data=train_data, tuning_data=dev_data, num_gpus=1) + dev_preds = predictor.predict(dev_wo_target_data) + test_preds = predictor.predict(test_data) + return {"test_preds": test_preds, "dev_preds": dev_preds} + + def run_multimodal(self): + from autogluon.multimodal import MultiModalPredictor + + target_col = self.state["dataset_config"]["target_col"] + train_path = self.datasets["train"] + dev_path = self.datasets["dev"] + dev_wo_target_path = self.datasets["dev_wo_target"] # Updated variable name + test_wo_target_path = self.datasets["test_wo_target"] + eval_metric = self.state["dataset_config"]["metric"].replace(" ", "_") + + # Load the datasets + train_data, dev_data, dev_wo_target_data, test_data = self.load_split_dataset( + train_path, dev_path, dev_wo_target_path, test_wo_target_path + ) + + # Create and fit the predictor + predictor = MultiModalPredictor( + label=target_col, + eval_metric=eval_metric, + path="AutogluonModels/ag-{}-{}".format(self.state["task"], datetime.now().strftime("%y%m%d_%H%M")), + ).fit(train_data=train_data, tuning_data=dev_data) + + # Make predictions on dev and test datasets + dev_preds = predictor.predict(dev_wo_target_data) + test_preds = predictor.predict(test_data) + + # Return predictions for dev and test datasets + return {"dev_preds": dev_preds, "test_preds": test_preds} + + def load_split_dataset(self, train_path, dev_path, dev_wo_target_path, test_wo_target_path): + """ + Loads training, dev, and test datasets from given file paths + + Args: + train_path (str): Path to the training dataset. + dev_path (str): Path to the dev dataset with target labels. + dev_wo_target_path (str): Path to the dev dataset without target labels. + test_wo_target_path (str): Path to the test dataset without target labels. + + Returns: + train_data (pd.DataFrame): Loaded training dataset with updated image paths. + dev_data (pd.DataFrame): Loaded dev dataset with updated image paths. + dev_wo_target_data (pd.DataFrame): Loaded dev dataset without target labels and updated image paths. + test_data (pd.DataFrame): Loaded test dataset with updated image paths. + """ + + # Define the root path to append + root_folder = os.path.join("F:/Download/Dataset/", self.state["task"]) + + # Load the datasets + train_data = pd.read_csv(train_path) + dev_data = pd.read_csv(dev_path) # Load dev dataset with target labels + dev_wo_target_data = pd.read_csv(dev_wo_target_path) # Load dev dataset without target labels + test_data = pd.read_csv(test_wo_target_path) + + # Get the name of the first column (assuming it's the image path column) + image_column = train_data.columns[0] + + # Append root folder path to the image column in each dataset + train_data[image_column] = train_data[image_column].apply(lambda x: os.path.join(root_folder, x)) + dev_data[image_column] = dev_data[image_column].apply(lambda x: os.path.join(root_folder, x)) + dev_wo_target_data[image_column] = dev_wo_target_data[image_column].apply( + lambda x: os.path.join(root_folder, x) + ) + test_data[image_column] = test_data[image_column].apply(lambda x: os.path.join(root_folder, x)) + + return train_data, dev_data, dev_wo_target_data, test_data + + +class GluonRunner(CustomRunner): + result_path: str = "results/autogluon" + + def __init__(self, args, **kwargs): + super().__init__(args, **kwargs) + self.framework = AGRunner(self.state) + self.is_multimodal = args.is_multimodal if hasattr(args, "is_multimodal") else False + + async def run_experiment(self): + if not self.is_multimodal: + result = self.framework.run() + else: + result = self.framework.run_multimodal() + + assert result is not None + user_requirement = self.state["requirement"] + dev_preds = result["dev_preds"] + test_preds = result["test_preds"] + score_dict = { + "dev_score": self.evaluate_predictions(dev_preds, "dev"), + "test_score": self.evaluate_predictions(test_preds, "test"), + } + results = [0, {"score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)}] + self.save_result(results) diff --git a/metagpt/ext/sela/runner/autosklearn.py b/metagpt/ext/sela/runner/autosklearn.py new file mode 100644 index 0000000000000000000000000000000000000000..7d0eb364e55a2e17710d6a84e07d710eaeaa3c2c --- /dev/null +++ b/metagpt/ext/sela/runner/autosklearn.py @@ -0,0 +1,96 @@ +from datetime import datetime +from functools import partial + +import pandas as pd + +from metagpt.ext.sela.evaluation.evaluation import evaluate_score +from metagpt.ext.sela.runner.custom import CustomRunner + + +def custom_scorer(y_true, y_pred, metric_name): + return evaluate_score(y_pred, y_true, metric_name) + + +class ASRunner: + time_limit = 600 + + def __init__(self, state=None): + self.state = state + self.datasets = self.state["datasets_dir"] + + def create_autosklearn_scorer(self, metric_name): + from autosklearn.metrics import make_scorer + + return make_scorer(name=metric_name, score_func=partial(custom_scorer, metric_name=metric_name)) + + def run(self): + import autosklearn.classification + import autosklearn.regression + + train_path = self.datasets["train"] + dev_wo_target_path = self.datasets["dev_wo_target"] + test_wo_target_path = self.datasets["test_wo_target"] + target_col = self.state["dataset_config"]["target_col"] + + train_data = pd.read_csv(train_path) + dev_data = pd.read_csv(dev_wo_target_path) + test_data = pd.read_csv(test_wo_target_path) + eval_metric = self.state["dataset_config"]["metric"] + X_train = train_data.drop(columns=[target_col]) + y_train = train_data[target_col] + + if eval_metric == "rmse": + automl = autosklearn.regression.AutoSklearnRegressor( + time_left_for_this_task=self.time_limit, + metric=self.create_autosklearn_scorer(eval_metric), + memory_limit=8192, + tmp_folder="AutosklearnModels/as-{}-{}".format( + self.state["task"], datetime.now().strftime("%y%m%d_%H%M") + ), + n_jobs=-1, + ) + elif eval_metric in ["f1", "f1 weighted"]: + automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=self.time_limit, + metric=self.create_autosklearn_scorer(eval_metric), + memory_limit=8192, + tmp_folder="AutosklearnModels/as-{}-{}".format( + self.state["task"], datetime.now().strftime("%y%m%d_%H%M") + ), + n_jobs=-1, + ) + else: + raise ValueError(f"Unsupported metric: {eval_metric}") + automl.fit(X_train, y_train) + + dev_preds = automl.predict(dev_data) + test_preds = automl.predict(test_data) + + return {"test_preds": test_preds, "dev_preds": dev_preds} + + +class AutoSklearnRunner(CustomRunner): + result_path: str = "results/autosklearn" + + def __init__(self, args, **kwargs): + super().__init__(args, **kwargs) + self.framework = ASRunner(self.state) + + async def run_experiment(self): + result = self.framework.run() + user_requirement = self.state["requirement"] + dev_preds = result["dev_preds"] + test_preds = result["test_preds"] + score_dict = { + "dev_score": self.evaluate_predictions(dev_preds, "dev"), + "test_score": self.evaluate_predictions(test_preds, "test"), + } + results = [ + 0, + { + "score_dict": score_dict, + "user_requirement": user_requirement, + "args": vars(self.args), + }, + ] + self.save_result(results) diff --git a/metagpt/ext/sela/runner/custom.py b/metagpt/ext/sela/runner/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..e9a8ee276f72524f248c66a84bca87c4e09ed690 --- /dev/null +++ b/metagpt/ext/sela/runner/custom.py @@ -0,0 +1,62 @@ +import os + +import pandas as pd + +from metagpt.ext.sela.evaluation.evaluation import evaluate_score +from metagpt.ext.sela.runner.runner import Runner +from metagpt.ext.sela.search.tree_search import create_initial_state + + +class CustomRunner(Runner): + result_path: str = "results/custom" + + def __init__(self, args, **kwargs): + super().__init__(args, **kwargs) + self.framework = kwargs.get("framework", None) # todo + self.task = kwargs.get("task", self.args.task) + self.low_is_better = kwargs.get("low_is_better", self.args.low_is_better) + self.name = kwargs.get("name", "") + self.result_path = f"results/custom_{self.name}" + self.state = create_initial_state( + self.task, + start_task_id=1, + data_config=self.data_config, + args=self.args, + ) + + def run_experiment(self): + user_requirement = self.state["requirement"] + preds = self.framework.run(user_requirement) + test_preds = preds["test_preds"] + dev_preds = preds["dev_preds"] + score_dict = { + "dev_score": self.evaluate_predictions(dev_preds, "dev"), + "test_score": self.evaluate_predictions(test_preds, "test"), + } + results = {"score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)} + self.save_result(results) + + def evaluate_pred_files(self, dev_pred_path, test_pred_path): + dev_preds = pd.read_csv(dev_pred_path)["target"] + test_preds = pd.read_csv(test_pred_path)["target"] + score_dict = { + "dev_score": self.evaluate_score(dev_preds, "dev"), + "test_score": self.evaluate_score(test_preds, "test"), + } + return score_dict + + def evaluate_predictions(self, preds, split): + metric = self.state["dataset_config"]["metric"] + gt_path = os.path.join(self.state["datasets_dir"][f"{split}_target"]) + gt = pd.read_csv(gt_path)["target"] + score = evaluate_score(preds, gt, metric) + return score + + def load_datasets(self): + train_path = self.state["datasets_dir"]["train"] + dev_path = self.state["datasets_dir"]["dev"] + test_path = self.state["datasets_dir"]["test"] + train = pd.read_csv(train_path) + dev = pd.read_csv(dev_path) + test = pd.read_csv(test_path) + return train, dev, test diff --git a/metagpt/ext/sela/runner/mcts.py b/metagpt/ext/sela/runner/mcts.py new file mode 100644 index 0000000000000000000000000000000000000000..8b6c1410025956d10fd3d5e5c589a3f5e57fb1db --- /dev/null +++ b/metagpt/ext/sela/runner/mcts.py @@ -0,0 +1,80 @@ +import shutil + +from metagpt.ext.sela.evaluation.evaluation import ( + node_evaluate_score_mlebench, + node_evaluate_score_sela, +) +from metagpt.ext.sela.evaluation.visualize_mcts import get_tree_text +from metagpt.ext.sela.runner.runner import Runner +from metagpt.ext.sela.search.search_algorithm import MCTS, Greedy, Random + + +class MCTSRunner(Runner): + result_path: str = "results/mcts" + + def __init__(self, args, tree_mode=None, **kwargs): + if args.special_instruction == "image": + self.start_task_id = 1 # start from datapreprocessing if it is image task + else: + self.start_task_id = args.start_task_id + + if args.eval_func == "sela": + self.eval_func = node_evaluate_score_sela + elif args.eval_func == "mlebench": + self.eval_func = node_evaluate_score_mlebench + + super().__init__(args, **kwargs) + self.tree_mode = tree_mode + + async def run_experiment(self): + use_fixed_insights = self.args.use_fixed_insights + depth = self.args.max_depth + if self.tree_mode == "greedy": + mcts = Greedy(root_node=None, max_depth=depth, use_fixed_insights=use_fixed_insights) + elif self.tree_mode == "random": + mcts = Random(root_node=None, max_depth=depth, use_fixed_insights=use_fixed_insights) + else: + mcts = MCTS(root_node=None, max_depth=depth, use_fixed_insights=use_fixed_insights) + best_nodes = await mcts.search(state=self.state, args=self.args) + best_node = best_nodes["global_best"] + dev_best_node = best_nodes["dev_best"] + score_dict = best_nodes["scores"] + additional_scores = {"grader": self.eval_func(dev_best_node)} + + text, num_generated_codes = get_tree_text(mcts.root_node) + text += f"Generated {num_generated_codes} unique codes.\n" + text += f"Best node: {best_node.id}, score: {best_node.raw_reward}\n" + text += f"Dev best node: {dev_best_node.id}, score: {dev_best_node.raw_reward}\n" + text += f"Grader score: {additional_scores['grader']}\n" + print(text) + results = [ + { + "best_node": best_node.id, + "best_node_score": best_node.raw_reward, + "dev_best_node": dev_best_node.id, + "dev_best_node_score": dev_best_node.raw_reward, + "num_generated_codes": num_generated_codes, + "user_requirement": best_node.state["requirement"], + "tree_text": text, + "args": vars(self.args), + "scores": score_dict, + "additional_scores": additional_scores, + } + ] + self.save_result(results) + self.copy_notebook(best_node, "best") + self.copy_notebook(dev_best_node, "dev_best") + self.save_tree(text) + + def copy_notebook(self, node, name): + node_dir = node.get_node_dir() + node_nb_dir = f"{node_dir}/Node-{node.id}.ipynb" + save_name = self.get_save_name() + copy_nb_dir = f"{self.result_path}/{save_name}_{name}.ipynb" + shutil.copy(node_nb_dir, copy_nb_dir) + + def save_tree(self, tree_text): + save_name = self.get_save_name() + fpath = f"{self.result_path}/{save_name}_tree.txt" + with open(fpath, "w") as f: + f.write(tree_text) diff --git a/metagpt/ext/sela/runner/mle_bench/instructions.py b/metagpt/ext/sela/runner/mle_bench/instructions.py new file mode 100644 index 0000000000000000000000000000000000000000..136726bcfe688556ef47ec28284aac302d7f3f7e --- /dev/null +++ b/metagpt/ext/sela/runner/mle_bench/instructions.py @@ -0,0 +1,48 @@ +SUBMISSION_INSTURCTION = """ +- **Submission checker**: We provide a tool for checking the validity of your submission file, in the form of a server endpoint at `http://localhost:5000/validate`. If you post a submission file to the server (e.g. `curl -X POST -F "file=@${SUBMISSION_FILE}" ${SERVER_URL}`), you will receive a response message telling you if your submission is valid or not (note that this does NOT give you a score). For your convenience, we provide a script `/home/validate_submission.sh` which can help you make this submission to the validation server. +""" +INSTRUCTIONS = """ +You are participating in *MLE-bench*, an offline version of Kaggle competitions adapted to test the coding ability of agents like yourself. + +You will be given a machine learning task. You must solve the task by training a model and running the model on the test set to produce a submission file. You may not hand-label the data to produce a submission file. + +BENCHMARK INSTRUCTIONS +------ +First, please read the instructions below which are specific to our benchmark setup. Note, these instructions can be found again in `runner/mle_bench/instructions.py`. + +- **Competition-specific instructions**: The instructions for the competition are available in `{dataset_dir}/description.md`. Read this file to understand the problem setup for the competition you are currently solving. +- **Dataset**: The dataset files for the competition are available in `{dataset_dir}`. +- **Submission**: We have simplified all competition types to use a simple CSV submission format. Regardless of what the competition instructions say, you MUST produce a submission file at `{output_dir}/submission.csv`. We will only evaluate submissions found at this path; failure to submit a `submission.csv` at this path will result in a failed attempt. +- **Test set**: In general, since we do not have access to the original test labels, we have constructed our own train/test split for each competition. As such, the dataset you see in `{dataset_dir}` may differ slightly from what is described in the competition instructions. Nevertheless the files available to you should otherwise be similar to the original and you should proceed with the competition task on the train/test splits available to you. +- **Cheating**: You are not allowed to view others' solutions to this competition. If we detect that your submission is plagiarized from others, the submission will be disqualified. + +Your submission will be graded based on the metric in the competition-specific instructions. Your goal is to achieve the best score possible on the competition task. + +Wherever there is a conflict between the points above and the competition-specific instructions in `{dataset_dir}/description.md`, the above points should always take priority since these are always true for our offline benchmark setup. +""" + +INSTRUCTIONS_OBFUSCATED = """ +You are participating in *MLE-bench*, a benchmark designed to test the coding ability of agents like yourself. + +You will be given a machine learning task. You must solve the task by training a model and running the model on the test set to produce a submission file. You may not hand-label the data to produce a submission file. + +BENCHMARK INSTRUCTIONS +------ +First, please read the instructions below which are specific to our benchmark setup. Note, these instructions can be found again in `runner/mle_bench/instructions.py`. + +- **Task-specific instructions**: The instructions for the task are available in `{dataset_dir}/description.md`. Read this file to understand the problem setup for the task you are currently solving. +- **Dataset**: The dataset files for the task are available in `{dataset_dir}/`. +- **Submission**: Regardless of what the task instructions say, you MUST produce a submission file at `{output_dir}/submission.csv`. We will only evaluate submissions found at this path; failure to submit a `submission.csv` at this path will result in a failed attempt. +- **Test set**: In general, since we do not have access to the original test labels, we have constructed our own train/test split for each task. As such, the dataset you see in `{dataset_dir}` may differ slightly from what is described in the task instructions. Nevertheless the files available to you should otherwise be similar to the original and you should proceed with the task on the train/test splits available to you. +- **Cheating**: You are not allowed to view others' solutions to this task. If we detect that your submission is plagiarized from others, the submission will be disqualified. + +Your submission will be graded based on the metric in the task-specific instructions. Your goal is to achieve the best score possible on the task. + +Wherever there is a conflict between the points above and the task-specific instructions in `{dataset_dir}/description.md`, the above points should always take priority since these are always true for our offline benchmark setup. +""" + +ADDITIONAL_NOTES = """ +ADDITIONAL NOTES +------ +- **Compute**: You have access to gpu with the appropriate drivers installed. +""" diff --git a/metagpt/ext/sela/runner/random_search.py b/metagpt/ext/sela/runner/random_search.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f43ac0c9016b8726c5b4efceae9becf3313e55 --- /dev/null +++ b/metagpt/ext/sela/runner/random_search.py @@ -0,0 +1,53 @@ +from metagpt.ext.sela.experimenter import Experimenter +from metagpt.ext.sela.insights.instruction_generator import InstructionGenerator +from metagpt.ext.sela.runner.runner import Runner +from metagpt.ext.sela.utils import get_exp_pool_path + +EXPS_PROMPT = """ +When doing the tasks, you can refer to the insights below: +{experience} + +""" + + +class RandomSearchRunner(Runner): + result_path: str = "results/random_search" + + async def run_experiment(self): + # state = create_initial_state(self.args.task, start_task_id=1, data_config=self.data_config, low_is_better=self.args.low_is_better, name="") + user_requirement = self.state["requirement"] + exp_pool_path = get_exp_pool_path(self.args.task, self.data_config, pool_name="ds_analysis_pool") + exp_pool = InstructionGenerator.load_insight_pool( + exp_pool_path, use_fixed_insights=self.args.use_fixed_insights + ) + if self.args.rs_mode == "single": + exps = InstructionGenerator._random_sample(exp_pool, self.args.num_experiments) + exps = [exp["Analysis"] for exp in exps] + elif self.args.rs_mode == "set": + exps = [] + for i in range(self.args.num_experiments): + exp_set = InstructionGenerator.sample_instruction_set(exp_pool) + exp_set_text = "\n".join([f"{exp['task_id']}: {exp['Analysis']}" for exp in exp_set]) + exps.append(exp_set_text) + else: + raise ValueError(f"Invalid mode: {self.args.rs_mode}") + + results = [] + for i in range(self.args.num_experiments): + di = Experimenter(node_id=str(i), use_reflection=self.args.reflection, role_timeout=self.args.role_timeout) + di.role_dir = f"{di.role_dir}_{self.args.task}" + requirement = user_requirement + EXPS_PROMPT.format(experience=exps[i]) + print(requirement) + score_dict = await self.run_di(di, requirement, run_idx=i) + results.append( + { + "idx": i, + "score_dict": score_dict, + "rs_mode": self.args.rs_mode, + "insights": exps[i], + "user_requirement": requirement, + "args": vars(self.args), + } + ) + results = self.summarize_results(results) + self.save_result(results) diff --git a/metagpt/ext/sela/runner/runner.py b/metagpt/ext/sela/runner/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..4b5504e096a194869235b2cb97bd33986edd5e64 --- /dev/null +++ b/metagpt/ext/sela/runner/runner.py @@ -0,0 +1,133 @@ +import datetime +import json +import os + +import numpy as np +import pandas as pd + +from metagpt.ext.sela.evaluation.evaluation import evaluate_score +from metagpt.ext.sela.experimenter import Experimenter +from metagpt.ext.sela.search.tree_search import create_initial_state +from metagpt.ext.sela.utils import DATA_CONFIG, save_notebook + + +class Runner: + result_path: str = "results/base" + data_config = DATA_CONFIG + start_task_id = 1 + + def __init__(self, args, **kwargs): + self.args = args + self.start_time_raw = datetime.datetime.now() + self.start_time = self.start_time_raw.strftime("%Y%m%d%H%M") + self.state = create_initial_state( + self.args.task, + start_task_id=self.start_task_id, + data_config=self.data_config, + args=self.args, + ) + + async def run_di(self, di, user_requirement, run_idx): + max_retries = 3 + num_runs = 1 + run_finished = False + while num_runs <= max_retries and not run_finished: + try: + await di.run(user_requirement) + score_dict = await di.get_score() + score_dict = self.evaluate(score_dict, self.state) + run_finished = True + except Exception as e: + print(f"Error: {e}") + num_runs += 1 + # save_notebook(role=di, save_dir=self.result_path, name=f"{self.args.task}_{self.start_time}_{run_idx}") + save_name = self.get_save_name() + save_notebook(role=di, save_dir=self.result_path, name=f"{save_name}_{run_idx}") + + if not run_finished: + score_dict = {"train_score": -1, "dev_score": -1, "test_score": -1, "score": -1} + return score_dict + + def summarize_results(self, results): + dev_scores = [result["score_dict"]["dev_score"] for result in results] + best_dev_score = ( + max(dev_scores) + if not self.args.low_is_better + else min([score for score in dev_scores if score != -1] + [np.inf]) + ) + best_score_idx = dev_scores.index(best_dev_score) + + test_scores = [result["score_dict"]["test_score"] for result in results] + avg_score = sum(test_scores) / len(test_scores) + global_best_score = ( + max(test_scores) + if not self.args.low_is_better + else min([score for i, score in enumerate(test_scores) if dev_scores[i] != -1] + [np.inf]) + ) + + results.insert( + 0, + { + "best_dev_score": best_dev_score, + "best_dev_score_idx": best_score_idx, + "best_dev_test_score": test_scores[best_score_idx], + "avg_test_score": avg_score, + "global_best_test_score": global_best_score, + }, + ) + return results + + async def run_experiment(self): + state = self.state + user_requirement = state["requirement"] + results = [] + + for i in range(self.args.num_experiments): + di = Experimenter(node_id="0", use_reflection=self.args.reflection, role_timeout=self.args.role_timeout) + score_dict = await self.run_di(di, user_requirement, run_idx=i) + results.append( + {"idx": i, "score_dict": score_dict, "user_requirement": user_requirement, "args": vars(self.args)} + ) + self.save_result(results) # save intermediate results + results = self.summarize_results(results) + + self.save_result(results) + + def evaluate_prediction(self, split, state): + pred_path = os.path.join(state["work_dir"], state["task"], f"{split}_predictions.csv") + os.makedirs(state["node_dir"], exist_ok=True) + pred_node_path = os.path.join(state["node_dir"], f"{self.start_time}-{split}_predictions.csv") + gt_path = os.path.join(state["datasets_dir"][f"{split}_target"]) + preds = pd.read_csv(pred_path) + preds = preds[preds.columns.tolist()[-1]] + preds.to_csv(pred_node_path, index=False) + gt = pd.read_csv(gt_path)["target"] + metric = state["dataset_config"]["metric"] + os.remove(pred_path) + return evaluate_score(preds, gt, metric) + + def evaluate(self, score_dict, state): + scores = { + "dev_score": self.evaluate_prediction("dev", state), + "test_score": self.evaluate_prediction("test", state), + } + score_dict.update(scores) + return score_dict + + def get_save_name(self): + return f"{self.args.exp_mode}-{self.args.task}_{self.start_time}" + + def save_result(self, result): + end_time_raw = datetime.datetime.now() + end_time = end_time_raw.strftime("%Y%m%d%H%M") + time_info = { + "start_time": self.start_time, + "end_time": end_time, + "duration (seconds)": (end_time_raw - self.start_time_raw).seconds, + } + result = result.copy() + result.insert(0, time_info) + save_name = self.get_save_name() + os.makedirs(self.result_path, exist_ok=True) + with open(f"{self.result_path}/{save_name}.json", "w") as f: + json.dump(result, f, indent=4) diff --git a/metagpt/ext/sela/scripts/run_cls.sh b/metagpt/ext/sela/scripts/run_cls.sh new file mode 100644 index 0000000000000000000000000000000000000000..f0ee5ddcf1daf9f03bcf25ad405b4c94e33eafc3 --- /dev/null +++ b/metagpt/ext/sela/scripts/run_cls.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +tasks=("smoker-status" "software-defects" "jasmine" "credit-g" "Click_prediction_small" "kick" "kc1" "titanic" "icr" "wine-quality-white" "mfeat-factors" "segment" "GesturePhaseSegmentationProcessed") + + +for i in {1..3} +do + for task in "${tasks[@]}"; do + echo "Running experiment for task: $task" + python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10 --special_instruction stacking + echo "Experiment for task $task completed." + done +done + +echo "All experiments completed." diff --git a/metagpt/ext/sela/scripts/run_cls_mod.sh b/metagpt/ext/sela/scripts/run_cls_mod.sh new file mode 100644 index 0000000000000000000000000000000000000000..ae3622b7a91de93374694aebe32f230fc27fab7e --- /dev/null +++ b/metagpt/ext/sela/scripts/run_cls_mod.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +tasks=("banking77" "gnad10" "sms_spam" "oxford-iiit-pet" "stanford_cars" "fashion_mnist" ) + +for i in {1..3} +do + for task in "${tasks[@]}"; do + echo "Running experiment for task: $task" + python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10 + echo "Experiment for task $task completed." + done +done +echo "All experiments completed." diff --git a/metagpt/ext/sela/scripts/run_reg.sh b/metagpt/ext/sela/scripts/run_reg.sh new file mode 100644 index 0000000000000000000000000000000000000000..f8a7428864eb657d0b67a92eda8903f83d61b82e --- /dev/null +++ b/metagpt/ext/sela/scripts/run_reg.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +tasks=("concrete-strength" "Moneyball" "colleges" "SAT11-HAND-runtime-regression" "diamonds" "boston" "house-prices") + +for i in {1..3} +do + for task in "${tasks[@]}"; do + echo "Running experiment for task: $task" + python run_experiment.py --exp_mode mcts --task "$task" --rollouts 10 --low_is_better --special_instruction stacking + echo "Experiment for task $task completed." + done +done + +echo "All experiments completed." diff --git a/metagpt/ext/sela/scripts/visualize_experiment.py b/metagpt/ext/sela/scripts/visualize_experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d980d11830a4b50e86c02ef49c4c25c6bc09ff --- /dev/null +++ b/metagpt/ext/sela/scripts/visualize_experiment.py @@ -0,0 +1,28 @@ +import networkx as nx + +from metagpt.ext.sela.evaluation.visualize_mcts import ( + build_tree_recursive, + visualize_tree, +) +from metagpt.ext.sela.MCTS import MCTS, create_initial_state, initialize_di_root_node +from metagpt.ext.sela.run_experiment import get_args +from metagpt.ext.sela.utils import DATA_CONFIG + +if __name__ == "__main__": + args = get_args() + data_config = DATA_CONFIG + state = create_initial_state(args.task, 0, data_config, args=args) + role, node = initialize_di_root_node(state) + mcts = MCTS( + root_node=node, + max_depth=5, + use_fixed_insights=False, + ) + + mcts.load_tree() + mcts.load_node_order() + root = mcts.root_node + node_order = mcts.node_order + G = nx.DiGraph() + build_tree_recursive(G, "0", root, node_order) + visualize_tree(G, save_path=f"results/{args.task}-tree.png") diff --git a/metagpt/ext/sela/search/search_algorithm.py b/metagpt/ext/sela/search/search_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..ca47d8cf6c51e23730318443dcef07236ce1ec3d --- /dev/null +++ b/metagpt/ext/sela/search/search_algorithm.py @@ -0,0 +1,32 @@ +import numpy as np + +from metagpt.ext.sela.search.tree_search import BaseTreeSearch, Node + + +class Greedy(BaseTreeSearch): + def best_child(self): + if len(self.children) == 0: + return self.root_node + all_children = [child for children in self.children.values() for child in children] + return max(all_children, key=lambda x: x.normalized_reward.get("dev_score", 0)) + + +class Random(BaseTreeSearch): + def best_child(self): + if len(self.children) == 0: + return self.root_node + all_children = [child for children in self.children.values() for child in children] + return np.random.choice(all_children) + + +class MCTS(BaseTreeSearch): + def best_child(self): + def uct(node: Node): + n_visits = node.visited if node.visited else self.c_unvisited + avg_value = node.avg_value() if node.visited else node.value / self.c_unvisited + return avg_value + self.c_explore * np.sqrt(np.log(node.parent.visited) / n_visits) + + if len(self.children) == 0: + return self.root_node + all_children = [child for children in self.children.values() for child in children] + return max(all_children, key=uct) diff --git a/metagpt/ext/sela/search/tree_search.py b/metagpt/ext/sela/search/tree_search.py new file mode 100644 index 0000000000000000000000000000000000000000..eac26c86ca7af852bc256023c7de89c34bdb8395 --- /dev/null +++ b/metagpt/ext/sela/search/tree_search.py @@ -0,0 +1,492 @@ +import json +import os +import pickle +import shutil + +import numpy as np +import pandas as pd + +from metagpt.ext.sela.data.custom_task import ( + get_mle_bench_requirements, + get_mle_task_id, +) +from metagpt.ext.sela.data.dataset import ( + generate_task_requirement, + get_split_dataset_path, +) +from metagpt.ext.sela.evaluation.evaluation import evaluate_score +from metagpt.ext.sela.experimenter import Experimenter, TimeoutException +from metagpt.ext.sela.insights.instruction_generator import InstructionGenerator +from metagpt.ext.sela.utils import get_exp_pool_path, load_execute_notebook, mcts_logger +from metagpt.tools.tool_recommend import ToolRecommender +from metagpt.utils.common import read_json_file + + +def initialize_di_root_node(state: dict, reflection: bool = True): + """ + Initialize the root node of the decision tree. + + Args: + state (dict): The initial state of the tree, containing: + - task (str): The task to be performed (e.g., "titanic"). + - work_dir (str): The working directory. + - node_dir (str): The directory for the node. + - dataset_config (dict): The configuration of the dataset. + - datasets_dir (str): The directory of the datasets. + - exp_pool_path (str): The path to the experiment pool. + - requirement (str): The requirement for the task. + - has_run (bool): Whether the task has run. + - start_task_id (int): The ID of the starting task. + - low_is_better (bool): Whether a lower score is better. + - role_timeout (int): The timeout for the role. + - external_eval (bool): Whether to use external evaluation. + - custom_dataset_dir (str): The directory of the custom dataset. + reflection (bool, optional): Whether to use reflection. Defaults to True. + + Returns: + tuple: A tuple containing the Experimenter role and the root Node. + """ + role = Experimenter( + node_id="0", + start_task_id=state["start_task_id"], + use_reflection=reflection, + role_dir=state["node_dir"], + role_timeout=state["role_timeout"], + ) + return role, Node(parent=None, state=state, action=None, value=0) + + +def create_initial_state(task: str, start_task_id: int, data_config: dict, args): + """ + Create the initial state of the tree. + + Args: + task (str): The task to be performed. + start_task_id (int): The ID of the starting task. + data_config (dict): The configuration of the data. + Expected keys: 'datasets', 'work_dir', 'role_dir'. + args (Namespace): The arguments passed to the program. + Expected attributes: 'external_eval', 'custom_dataset_dir', 'special_instruction', 'name', 'low_is_better', 'role_timeout'. + + Returns: + dict: The initial state of the tree. + """ + external_eval = args.external_eval + + if args.custom_dataset_dir: + dataset_config = None + datasets_dir = args.custom_dataset_dir + requirement = get_mle_bench_requirements( + args.custom_dataset_dir, data_config, special_instruction=args.special_instruction + ) + exp_pool_path = None + # external_eval = False # make sure external eval is false if custom dataset is used + task = get_mle_task_id(args.custom_dataset_dir) + else: + dataset_config = data_config["datasets"][task] + if dataset_config["metric"] == "rmse": + args.low_is_better = True + datasets_dir = get_split_dataset_path(task, data_config) + requirement = generate_task_requirement( + task, data_config, is_di=True, special_instruction=args.special_instruction + ) + exp_pool_path = get_exp_pool_path(task, data_config, pool_name="ds_analysis_pool") + + initial_state = { + "task": task, + "work_dir": data_config["work_dir"], + "node_dir": os.path.join(data_config["work_dir"], data_config["role_dir"], f"{task}{args.name}"), + "dataset_config": dataset_config, + "datasets_dir": datasets_dir, # won't be used if external eval is used + "exp_pool_path": exp_pool_path, + "requirement": requirement, + "has_run": False, + "start_task_id": start_task_id, + "low_is_better": args.low_is_better, + "role_timeout": args.role_timeout, + "external_eval": external_eval, + "custom_dataset_dir": args.custom_dataset_dir, + } + os.makedirs(initial_state["node_dir"], exist_ok=True) + return initial_state + + +class Node: + state: dict = {} + action: str = None + value: float = 0 + visited: int = 0 + children: list = [] + normalized_reward: dict = {"train_score": 0, "dev_score": 0, "test_score": 0} + parent = None + + def __init__( + self, parent=None, state: dict = None, action: str = None, value: float = 0, max_depth: int = 4, **kwargs + ): + self.state = state + self.action = action + self.value = value + self.raw_value = 0 + self.raw_reward = dict() + self.parent = parent + self.children = [] + self.max_depth = max_depth + self.depth = self.generate_depth() + self.id = self.generate_id() + if self.parent is not None: + self.save_node() + + def avg_value(self): + if self.visited == 0: + return 0 + return self.value / self.visited + + def __hash__(self): + return hash(self.id) + + def save_node(self): + os.makedirs(self.state["node_dir"], exist_ok=True) + with open(os.path.join(self.state["node_dir"], f"Node-{self.id}.pkl"), "wb") as f: + pickle.dump(self, f) + + def load_node(self): + with open(os.path.join(self.state["node_dir"], f"Node-{self.id}.pkl"), "rb") as f: + return pickle.load(f) + + def get_depth(self): + return self.depth + + def get_node_dir(self): + return self.state["node_dir"] + + def generate_depth(self): + if self.parent is None: + return 0 + else: + return self.parent.depth + 1 + + def generate_id(self): + if self.parent is None: + return "0" + else: + num_sibling = len(self.parent.children) + return f"{self.parent.id}-{num_sibling}" + + def is_terminal(self): + return int(self.state["start_task_id"]) == self.max_depth + 1 # TODO: Check if this is correct or +1 + + def is_fully_expanded(self): + return len(self.children) > 0 + + def add_child(self, child_node): + self.children.append(child_node) + + def update(self, reward: dict, child_node=None): + if child_node is not None: + child_role = child_node.load_role() + role = self.load_role() + role.update_til_start_task(child_role) + role.save_state() + else: + self.raw_value = reward["test_score"] + self.value += reward["score"] + self.visited += 1 + self.save_node() + + def get_role_path(self): + fname = f"Node-{self.id}.json" + role_path = os.path.join(self.state["node_dir"], fname) + return role_path + + def load_role(self): + role_dict = read_json_file(self.get_role_path()) + if role_dict.get("tool_recommender") is None: + role_dict["tool_recommender"] = ToolRecommender() + elif isinstance(role_dict.get("tool_recommender", {}).get("tools"), dict): + role_dict["tool_recommender"]["tools"] = list(role_dict["tool_recommender"]["tools"].keys()) + role = Experimenter(**role_dict) + if self.parent is not None: # TODO: Check this + parent_role = self.parent.load_role() + role.update_til_start_task(parent_role, backward=False) + role.remap_tasks() + return role + + def save_new_role(self, role: Experimenter): + role.node_id = self.id + role.start_task_id = self.state["start_task_id"] + role.state_saved = False + role.change_next_instruction(self.action) + mcts_logger.log("MCTS", f"Saving new role: {role.node_id}") + role = role.model_copy() + role.save_state(static_save=True) + + async def expand(self, max_children: int, instruction_generator: InstructionGenerator): + if self.is_fully_expanded(): + return + role = self.load_role() + original_instruction = role.get_next_instruction() + insights = await instruction_generator.generate_new_instructions( + task_id=role.start_task_id + 1, + original_instruction=original_instruction, + max_num=max_children, + ) + new_state = self.state.copy() + new_state["start_task_id"] += 1 + for insight in insights: + new_role = role.model_copy() + node = Node(parent=self, state=new_state, action=insight, value=0) + node.save_new_role(new_role) + self.add_child(node) + + def get_predictions_path(self, split): + return os.path.join(self.state["node_dir"], f"Node-{self.id}-{split}_predictions.csv") + + def get_and_move_predictions(self, split): + if not os.path.exists(self.get_predictions_path(split)): + pred_path = os.path.join(self.state["work_dir"], self.state["task"], f"{split}_predictions.csv") + shutil.copy(pred_path, self.get_predictions_path(split)) + os.remove(pred_path) + return pd.read_csv(self.get_predictions_path(split)) + + def get_gt(self, split): + gt_path = os.path.join(self.state["datasets_dir"][f"{split}_target"]) + return pd.read_csv(gt_path) + + def evaluate_prediction(self, split): + preds = self.get_and_move_predictions(split)["target"] + gt = self.get_gt(split)["target"] + metric = self.state["dataset_config"]["metric"] + return evaluate_score(preds, gt, metric) + + def evaluate_simulation(self, score_dict): + if self.state["external_eval"]: # use external evaluation + scores = {"dev_score": self.evaluate_prediction("dev"), "test_score": self.evaluate_prediction("test")} + scores["score"] = scores["dev_score"] + score_dict.update(scores) + else: + self.get_and_move_predictions("dev") + self.get_and_move_predictions("test") + return score_dict + + async def run_node(self, role: Experimenter = None): + if self.is_terminal() and role is not None: + if role.state_saved: + return self.raw_reward + + max_retries = 3 + num_runs = 1 + run_finished = False + while num_runs <= max_retries and not run_finished: + try: + if not role: + role = self.load_role() + await load_execute_notebook(role) # execute previous notebook's code + await role.run(with_message="continue") + else: + await role.run(with_message=self.state["requirement"]) + score_dict = await role.get_score() + score_dict = self.evaluate_simulation(score_dict) + self.raw_reward = score_dict + run_finished = True + except TimeoutException as e: + mcts_logger.log("MCTS", f"Role-level timeout: {e}") + break + except Exception as e: + mcts_logger.log("MCTS", f"Error in running the role: {e}") + num_runs += 1 + + if not run_finished: + mcts_logger.log("MCTS", f"Role {role.node_id} failed to run") + if self.state["low_is_better"]: + score_dict = {"test_score": np.inf, "dev_score": np.inf, "score": np.inf} + else: + score_dict = {"test_score": 0, "dev_score": 0, "score": 0} + self.raw_reward = score_dict + if self.state["low_is_better"]: + # normalized the score to be between 0 and 1, and higher is better + def normalize_score(score): + if score == -1: + return 0 + return 1 / (1 + score) + + score_dict = {k: normalize_score(v) for k, v in score_dict.items()} + self.normalized_reward = score_dict + result_dict = role.get_solution() + return score_dict, result_dict + + +class BaseTreeSearch: + # data_path + root_node: Node = None + children: dict = {} + max_depth: int = None + c_explore: float = 1.4 + c_unvisited: float = 0.8 + node_order: list = [] + # insight generator + instruction_generator: InstructionGenerator = None + + def __init__(self, root_node: Node, max_depth: int, use_fixed_insights: bool): + self.root_node = root_node + self.max_depth = max_depth + self.use_fixed_insights = use_fixed_insights + + def select(self, node: Node): + node = self.best_child() + mcts_logger.log("MCTS", f"Selected node id: {node.id}") + return node + + def best_child(self): + raise NotImplementedError + + async def expand(self, node: Node, max_children=5): + await node.expand(max_children, self.instruction_generator) + if node not in self.children or not self.children[node]: + self.children[node] = node.children + return node.children + + async def simulate(self, node: Node, role=None): + "Returns the reward for a random simulation (to completion) of `node`" + mcts_logger.log("MCTS", f"Start simulating node {node.id}:") + while node.children: + node = np.random.choice(node.children) + reward, result_dict = await node.run_node(role) + mcts_logger.log("MCTS", f"Simulated node's reward: {reward}") + # TODO: add new insights + return reward + + def backpropagate(self, node: Node, reward: dict): + child_node = node + node.update(reward) + node = node.parent + while node is not None: + node.update(reward, child_node) + node, child_node = node.parent, node + + def best_path(self, root: Node): + best_child = root + global_best_score = root.normalized_reward["test_score"] + dev_best_score = root.normalized_reward["dev_score"] + + def bfs(node: Node, best_score: float, best_child: Node, split: str): + assert split in ["test_score", "dev_score"] + if node not in self.children: + return best_score, best_child + for child in self.children[node]: + score = child.normalized_reward[split] + print(child.id, split, score) + if score > best_score: + best_score = score + best_child = child + best_score, best_child = bfs(child, best_score, best_child, split) + return best_score, best_child + + _, global_best_child = bfs(root, global_best_score, best_child, "test_score") + _, dev_best_child = bfs(root, dev_best_score, best_child, "dev_score") + + return {"dev_best": dev_best_child, "global_best": global_best_child, "scores": self.get_score_order_dict()} + + def get_num_simulations(self): + return self.root_node.visited + + def save_node_order(self, node_id: str): + self.node_order.append(node_id) + with open(os.path.join(self.root_node.state["node_dir"], "node_order.json"), "w") as f: + json.dump(self.node_order, f) + + def load_node_order(self): + with open(os.path.join(self.root_node.state["node_dir"], "node_order.json"), "r") as f: + self.node_order = json.load(f) + + def get_score_order_dict(self): + scores = {"dev": [], "test": [], "dev_raw": [], "test_raw": []} + for node_id in self.node_order: + node = Node(parent=None, state=self.root_node.state, action=None, value=0) + node.id = node_id + node = node.load_node() + scores["dev"].append(node.normalized_reward["dev_score"]) + scores["test"].append(node.normalized_reward["test_score"]) + scores["dev_raw"].append(node.raw_reward["dev_score"]) + scores["test_raw"].append(node.raw_reward["test_score"]) + return scores + + async def search(self, state: dict, args): + reflection = args.reflection + load_tree = args.load_tree + rollouts = args.rollouts + from_scratch = args.from_scratch + role, root = initialize_di_root_node(state, reflection=reflection) + self.root_node = root + self.instruction_generator = InstructionGenerator( + state=state, use_fixed_insights=self.use_fixed_insights, from_scratch=from_scratch + ) + await self.instruction_generator.initialize() + + tree_loaded = False + if load_tree: + tree_loaded = self.load_tree() + mcts_logger.log("MCTS", f"Number of simulations: {self.get_num_simulations()}") + mcts_logger.log("MCTS", f"Tree loaded: {tree_loaded}") + + if not tree_loaded: + rollouts -= 2 # 2 rollouts for the initial tree + if rollouts < 0: + raise ValueError("Rollouts must be greater than 2 if there is no tree to load") + self.children[root] = [] + reward = await self.simulate(root, role) + self.backpropagate(root, reward) + node, reward = await self.expand_and_simulate(root) + # self.backpropagate(node, reward) + self.save_node_order(root.id) + self.save_node_order(node.id) + else: + root = self.root_node + self.load_node_order() + + for _ in range(rollouts): # number of rollouts + mcts_logger.log("MCTS", f"Start the next rollout {_+1}") + node = self.select(root) + if node.is_terminal(): + if node.raw_value == 0: + reward = await self.simulate(node) + else: + reward = {"test_score": node.raw_value, "score": node.raw_reward["score"]} + mcts_logger.log("MCTS", f"Terminal node's reward: {reward}") + self.backpropagate(node, reward) + else: + node, reward = await self.expand_and_simulate(node) + # self.backpropagate(node, reward) + self.save_node_order(node.id) + return self.best_path(root) + + async def expand_and_simulate(self, node: Node): + # Expand and randomly select a child node, then simulate it + if node.visited > 0: + children = await self.expand(node) + node = np.random.choice(children) + reward = await self.simulate(node) + self.backpropagate(node, reward) + return node, reward + + def load_tree(self): + def load_children_node(node: Node): + mcts_logger.log("MCTS", f"Load node {node.id}'s child: {node.children}") + if node.is_terminal() or not node.children: + return + for child in node.children: + child.load_node() + self.children[child] = child.children + load_children_node(child) + + # Load all pkl files in the node_dir + all_pkl_files = os.listdir(self.root_node.state["node_dir"]) + all_pkl_files = [f for f in all_pkl_files if f.endswith(".pkl")] + if os.path.exists(os.path.join(self.root_node.state["node_dir"], "Node-0.pkl")): + with open(os.path.join(self.root_node.state["node_dir"], "Node-0.pkl"), "rb") as f: + self.root_node = pickle.load(f) + self.children[self.root_node] = self.root_node.children + load_children_node(self.root_node) + + if self.children: + return True + return False diff --git a/metagpt/ext/sela/utils.py b/metagpt/ext/sela/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..21b311e7f399cf53157dffe1e43810af1f066e81 --- /dev/null +++ b/metagpt/ext/sela/utils.py @@ -0,0 +1,130 @@ +import os +import re +from datetime import datetime +from pathlib import Path + +import nbformat +import yaml +from loguru import logger as _logger +from nbclient import NotebookClient +from nbformat.notebooknode import NotebookNode + +from metagpt.roles.role import Role + + +def load_data_config(file_path="data.yaml"): + with open(file_path, "r") as stream: + data_config = yaml.safe_load(stream) + return data_config + + +DATASET_CONFIG = load_data_config("datasets.yaml") +DATA_CONFIG = load_data_config() +DATA_CONFIG["datasets"] = DATASET_CONFIG["datasets"] + + +def get_mcts_logger(): + logfile_level = "DEBUG" + name: str = None + current_date = datetime.now() + formatted_date = current_date.strftime("%Y%m%d") + log_name = f"{name}_{formatted_date}" if name else formatted_date # name a log with prefix name + + # _logger.remove() + _logger.level("MCTS", color="", no=25) + # _logger.add(sys.stderr, level=print_level) + _logger.add(Path(DATA_CONFIG["work_dir"]) / DATA_CONFIG["role_dir"] / f"{log_name}.txt", level=logfile_level) + _logger.propagate = False + return _logger + + +mcts_logger = get_mcts_logger() + + +def get_exp_pool_path(task_name, data_config, pool_name="analysis_pool"): + datasets_dir = data_config["datasets_dir"] + if task_name in data_config["datasets"]: + dataset = data_config["datasets"][task_name] + data_path = os.path.join(datasets_dir, dataset["dataset"]) + else: + raise ValueError( + f"Dataset {task_name} not found in config file. Available datasets: {data_config['datasets'].keys()}" + ) + exp_pool_path = os.path.join(data_path, f"{pool_name}.json") + if not os.path.exists(exp_pool_path): + return None + return exp_pool_path + + +def change_plan(role, plan): + print(f"Change next plan to: {plan}") + tasks = role.planner.plan.tasks + finished = True + for i, task in enumerate(tasks): + if not task.code: + finished = False + break + if not finished: + tasks[i].plan = plan + return finished + + +def is_cell_to_delete(cell: NotebookNode) -> bool: + if "outputs" in cell: + for output in cell["outputs"]: + if output and "traceback" in output: + return True + return False + + +def process_cells(nb: NotebookNode) -> NotebookNode: + new_cells = [] + i = 1 + for cell in nb["cells"]: + if cell["cell_type"] == "code" and not is_cell_to_delete(cell): + cell["execution_count"] = i + new_cells.append(cell) + i = i + 1 + nb["cells"] = new_cells + return nb + + +def save_notebook(role: Role, save_dir: str = "", name: str = "", save_to_depth=False): + save_dir = Path(save_dir) + tasks = role.planner.plan.tasks + nb = process_cells(role.execute_code.nb) + os.makedirs(save_dir, exist_ok=True) + file_path = save_dir / f"{name}.ipynb" + nbformat.write(nb, file_path) + + if save_to_depth: + clean_file_path = save_dir / f"{name}_clean.ipynb" + codes = [task.code for task in tasks if task.code] + clean_nb = nbformat.v4.new_notebook() + for code in codes: + clean_nb.cells.append(nbformat.v4.new_code_cell(code)) + nbformat.write(clean_nb, clean_file_path) + + +async def load_execute_notebook(role): + tasks = role.planner.plan.tasks + codes = [task.code for task in tasks if task.code] + executor = role.execute_code + executor.nb = nbformat.v4.new_notebook() + executor.nb_client = NotebookClient(executor.nb, timeout=role.role_timeout) + # await executor.build() + for code in codes: + outputs, success = await executor.run(code) + print(f"Execution success: {success}, Output: {outputs}") + print("Finish executing the loaded notebook") + return executor + + +def clean_json_from_rsp(text): + pattern = r"```json(.*?)```" + matches = re.findall(pattern, text, re.DOTALL) + if matches: + json_str = "\n".join(matches) + return json_str + else: + return "" diff --git a/metagpt/ext/spo/.DS_Store b/metagpt/ext/spo/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..cd9418154c07550f8a545ee660ad133ef35f156d Binary files /dev/null and b/metagpt/ext/spo/.DS_Store differ diff --git a/metagpt/ext/spo/__init__.py b/metagpt/ext/spo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/metagpt/ext/spo/app.py b/metagpt/ext/spo/app.py new file mode 100644 index 0000000000000000000000000000000000000000..20895a420aa0046221cc845c62993d5755917bc9 --- /dev/null +++ b/metagpt/ext/spo/app.py @@ -0,0 +1,301 @@ +import asyncio +from pathlib import Path +from typing import Dict + +import streamlit as st +import yaml +from loguru import logger as _logger + +from metagpt.const import METAGPT_ROOT +from metagpt.ext.spo.components.optimizer import PromptOptimizer +from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType + + +def load_yaml_template(template_path: Path) -> Dict: + if template_path.exists(): + with open(template_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + return {"prompt": "", "requirements": "", "count": None, "qa": [{"question": "", "answer": ""}]} + + +def save_yaml_template(template_path: Path, data: Dict) -> None: + template_format = { + "prompt": str(data.get("prompt", "")), + "requirements": str(data.get("requirements", "")), + "count": data.get("count"), + "qa": [ + {"question": str(qa.get("question", "")).strip(), "answer": str(qa.get("answer", "")).strip()} + for qa in data.get("qa", []) + ], + } + + template_path.parent.mkdir(parents=True, exist_ok=True) + + with open(template_path, "w", encoding="utf-8") as f: + yaml.dump(template_format, f, allow_unicode=True, sort_keys=False, default_flow_style=False, indent=2) + + +def display_optimization_results(result_data): + for result in result_data: + round_num = result["round"] + success = result["succeed"] + prompt = result["prompt"] + + with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"): + st.markdown("**Prompt:**") + st.code(prompt, language="text") + st.markdown("
", unsafe_allow_html=True) + + col1, col2 = st.columns(2) + with col1: + st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}") + with col2: + st.markdown(f"**Tokens:** {result['tokens']}") + + st.markdown("**Answers:**") + for idx, answer in enumerate(result["answers"]): + st.markdown(f"**Question {idx + 1}:**") + st.text(answer["question"]) + st.markdown("**Answer:**") + st.text(answer["answer"]) + st.markdown("---") + + # Summary + success_count = sum(1 for r in result_data if r["succeed"]) + total_rounds = len(result_data) + + st.markdown("### Summary") + col1, col2 = st.columns(2) + with col1: + st.metric("Total Rounds", total_rounds) + with col2: + st.metric("Successful Rounds", success_count) + + +def main(): + if "optimization_results" not in st.session_state: + st.session_state.optimization_results = [] + + st.markdown( + """ +
+
+

SPO | Self-Supervised Prompt Optimization 🤖

+
+
+ + Paper + + + GitHub + + A framework for self-supervised prompt optimization +
+
+ """, + unsafe_allow_html=True, + ) + + # Sidebar for configurations + with st.sidebar: + st.header("Configuration") + + # Template Selection/Creation + settings_path = Path("metagpt/ext/spo/settings") + existing_templates = [f.stem for f in settings_path.glob("*.yaml")] + + template_mode = st.radio("Template Mode", ["Use Existing", "Create New"]) + + if template_mode == "Use Existing": + template_name = st.selectbox("Select Template", existing_templates) + else: + template_name = st.text_input("New Template Name") + if template_name and not template_name.endswith(".yaml"): + template_name = f"{template_name}" + + # LLM Settings + st.subheader("LLM Settings") + opt_model = st.selectbox( + "Optimization Model", ["claude-3-5-sonnet-20240620", "gpt-4o", "gpt-4o-mini", "deepseek-chat"], index=0 + ) + opt_temp = st.slider("Optimization Temperature", 0.0, 1.0, 0.7) + + eval_model = st.selectbox( + "Evaluation Model", ["gpt-4o-mini", "claude-3-5-sonnet-20240620", "gpt-4o", "deepseek-chat"], index=0 + ) + eval_temp = st.slider("Evaluation Temperature", 0.0, 1.0, 0.3) + + exec_model = st.selectbox( + "Execution Model", ["gpt-4o-mini", "claude-3-5-sonnet-20240620", "gpt-4o", "deepseek-chat"], index=0 + ) + exec_temp = st.slider("Execution Temperature", 0.0, 1.0, 0.0) + + # Optimizer Settings + st.subheader("Optimizer Settings") + initial_round = st.number_input("Initial Round", 1, 100, 1) + max_rounds = st.number_input("Maximum Rounds", 1, 100, 10) + + # Main content area + st.header("Template Configuration") + + if template_name: + template_path = settings_path / f"{template_name}.yaml" + template_data = load_yaml_template(template_path) + + if "current_template" not in st.session_state or st.session_state.current_template != template_name: + st.session_state.current_template = template_name + st.session_state.qas = template_data.get("qa", []) + + # Edit template sections + prompt = st.text_area("Prompt", template_data.get("prompt", ""), height=100) + requirements = st.text_area("Requirements", template_data.get("requirements", ""), height=100) + + # qa section + st.subheader("Q&A Examples") + + # Add new qa button + if st.button("Add New Q&A"): + st.session_state.qas.append({"question": "", "answer": ""}) + + # Edit qas + new_qas = [] + for i in range(len(st.session_state.qas)): + st.markdown(f"**QA #{i + 1}**") + col1, col2, col3 = st.columns([45, 45, 10]) + + with col1: + question = st.text_area( + f"Question {i + 1}", st.session_state.qas[i].get("question", ""), key=f"q_{i}", height=100 + ) + with col2: + answer = st.text_area( + f"Answer {i + 1}", st.session_state.qas[i].get("answer", ""), key=f"a_{i}", height=100 + ) + with col3: + if st.button("🗑️", key=f"delete_{i}"): + st.session_state.qas.pop(i) + st.rerun() + + new_qas.append({"question": question, "answer": answer}) + + # Save template button + if st.button("Save Template"): + new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas} + + save_yaml_template(template_path, new_template_data) + + st.session_state.qas = new_qas + st.success(f"Template saved to {template_path}") + + st.subheader("Current Template Preview") + preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt} + st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml") + + st.subheader("Optimization Logs") + log_container = st.empty() + + class StreamlitSink: + def write(self, message): + current_logs = st.session_state.get("logs", []) + current_logs.append(message.strip()) + st.session_state.logs = current_logs + + log_container.code("\n".join(current_logs), language="plaintext") + + streamlit_sink = StreamlitSink() + _logger.remove() + + def prompt_optimizer_filter(record): + return "optimizer" in record["name"].lower() + + _logger.add( + streamlit_sink.write, + format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}", + filter=prompt_optimizer_filter, + ) + _logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG") + + # Start optimization button + if st.button("Start Optimization"): + try: + # Initialize LLM + SPO_LLM.initialize( + optimize_kwargs={"model": opt_model, "temperature": opt_temp}, + evaluate_kwargs={"model": eval_model, "temperature": eval_temp}, + execute_kwargs={"model": exec_model, "temperature": exec_temp}, + ) + + # Create optimizer instance + optimizer = PromptOptimizer( + optimized_path="workspace", + initial_round=initial_round, + max_rounds=max_rounds, + template=f"{template_name}.yaml", + name=template_name, + ) + + # Run optimization with progress bar + with st.spinner("Optimizing prompts..."): + optimizer.optimize() + + st.success("Optimization completed!") + + st.header("Optimization Results") + + prompt_path = optimizer.root_path / "prompts" + result_data = optimizer.data_utils.load_results(prompt_path) + + st.session_state.optimization_results = result_data + + except Exception as e: + st.error(f"An error occurred: {str(e)}") + _logger.error(f"Error during optimization: {str(e)}") + + if st.session_state.optimization_results: + st.header("Optimization Results") + display_optimization_results(st.session_state.optimization_results) + + st.markdown("---") + st.subheader("Test Optimized Prompt") + col1, col2 = st.columns(2) + + with col1: + test_prompt = st.text_area("Optimized Prompt", value="", height=200, key="test_prompt") + + with col2: + test_question = st.text_area("Your Question", value="", height=200, key="test_question") + + if st.button("Test Prompt"): + if test_prompt and test_question: + try: + with st.spinner("Generating response..."): + SPO_LLM.initialize( + optimize_kwargs={"model": opt_model, "temperature": opt_temp}, + evaluate_kwargs={"model": eval_model, "temperature": eval_temp}, + execute_kwargs={"model": exec_model, "temperature": exec_temp}, + ) + + llm = SPO_LLM.get_instance() + messages = [{"role": "user", "content": f"{test_prompt}\n\n{test_question}"}] + + async def get_response(): + return await llm.responser(request_type=RequestType.EXECUTE, messages=messages) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + response = loop.run_until_complete(get_response()) + finally: + loop.close() + + st.subheader("Response:") + st.markdown(response) + + except Exception as e: + st.error(f"Error generating response: {str(e)}") + else: + st.warning("Please enter both prompt and question.") + + +if __name__ == "__main__": + main() diff --git a/metagpt/ext/spo/components/__init__.py b/metagpt/ext/spo/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/metagpt/ext/spo/components/evaluator.py b/metagpt/ext/spo/components/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..952ef211bad0b1f4b7ba1ab80823f00cf754a2d9 --- /dev/null +++ b/metagpt/ext/spo/components/evaluator.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# @Date : 8/23/2024 10:00 AM +# @Author : all +# @Desc : Evaluation for different datasets +import asyncio +import random +from typing import Any, Dict + +from metagpt.ext.spo.prompts.evaluate_prompt import EVALUATE_PROMPT +from metagpt.ext.spo.utils import load +from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType, extract_content +from metagpt.logs import logger + + +class QuickExecute: + """ + Execute Prompt + """ + + def __init__(self, prompt: str): + self.prompt = prompt + self.llm = SPO_LLM.get_instance() + + async def prompt_execute(self) -> tuple[Any]: + _, _, qa, _ = load.load_meta_data() + answers = [] + + async def fetch_answer(q: str) -> Dict[str, Any]: + messages = [{"role": "user", "content": f"{self.prompt}\n\n{q}"}] + try: + answer = await self.llm.responser(request_type=RequestType.EXECUTE, messages=messages) + return {"question": q, "answer": answer} + except Exception as e: + return {"question": q, "answer": str(e)} + + tasks = [fetch_answer(item["question"]) for item in qa] + answers = await asyncio.gather(*tasks) + + return answers + + +class QuickEvaluate: + """ + Complete the evaluation for different answers here. + """ + + def __init__(self): + self.llm = SPO_LLM.get_instance() + + async def prompt_evaluate(self, samples: dict, new_samples: dict) -> bool: + _, requirement, qa, _ = load.load_meta_data() + + if random.random() < 0.5: + samples, new_samples = new_samples, samples + is_swapped = True + else: + is_swapped = False + + messages = [ + { + "role": "user", + "content": EVALUATE_PROMPT.format( + requirement=requirement, sample=samples, new_sample=new_samples, answers=str(qa) + ), + } + ] + + try: + response = await self.llm.responser(request_type=RequestType.EVALUATE, messages=messages) + choose = extract_content(response, "choose") + return choose == "A" if is_swapped else choose == "B" + + except Exception as e: + logger.error(e) + return False diff --git a/metagpt/ext/spo/components/optimizer.py b/metagpt/ext/spo/components/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce588f44b7911ae0e4ebe1497d962e6cb2b129d --- /dev/null +++ b/metagpt/ext/spo/components/optimizer.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +# @Date : 8/12/2024 22:00 PM +# @Author : issac +# @Desc : optimizer for prompt + +import asyncio +from pathlib import Path +from typing import List + +from metagpt.ext.spo.prompts.optimize_prompt import PROMPT_OPTIMIZE_PROMPT +from metagpt.ext.spo.utils import load +from metagpt.ext.spo.utils.data_utils import DataUtils +from metagpt.ext.spo.utils.evaluation_utils import EvaluationUtils +from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType, extract_content +from metagpt.ext.spo.utils.prompt_utils import PromptUtils +from metagpt.logs import logger + + +class PromptOptimizer: + def __init__( + self, + optimized_path: str = None, + initial_round: int = 1, + max_rounds: int = 10, + name: str = "", + template: str = "", + ) -> None: + self.name = name + self.root_path = Path(optimized_path) / self.name + self.top_scores = [] + self.round = initial_round + self.max_rounds = max_rounds + self.template = template + + self.prompt_utils = PromptUtils(self.root_path) + self.data_utils = DataUtils(self.root_path) + self.evaluation_utils = EvaluationUtils(self.root_path) + self.llm = SPO_LLM.get_instance() + + def optimize(self): + for opt_round in range(self.max_rounds): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._optimize_prompt()) + self.round += 1 + + self.show_final_result() + + def show_final_result(self): + best_round = self.data_utils.get_best_round() + + logger.info("\n" + "=" * 50) + logger.info("\n🏆 OPTIMIZATION COMPLETED - FINAL RESULTS 🏆\n") + logger.info(f"\n📌 Best Performing Round: {best_round['round']}") + logger.info(f"\n🎯 Final Optimized Prompt:\n{best_round['prompt']}") + logger.info("\n" + "=" * 50 + "\n") + + async def _optimize_prompt(self): + prompt_path = self.root_path / "prompts" + load.set_file_name(self.template) + data = self.data_utils.load_results(prompt_path) + + if self.round == 1: + await self._handle_first_round(prompt_path, data) + return + + directory = self.prompt_utils.create_round_directory(prompt_path, self.round) + new_prompt = await self._generate_optimized_prompt() + self.prompt = new_prompt + + logger.info(f"\nRound {self.round} Prompt: {self.prompt}\n") + self.prompt_utils.write_prompt(directory, prompt=self.prompt) + + success, answers = await self._evaluate_new_prompt(prompt_path, data, directory) + self._log_optimization_result(success) + + return self.prompt + + async def _handle_first_round(self, prompt_path: Path, data: List[dict]) -> None: + logger.info("\n⚡ RUNNING Round 1 PROMPT ⚡\n") + directory = self.prompt_utils.create_round_directory(prompt_path, self.round) + + prompt, _, _, _ = load.load_meta_data() + self.prompt = prompt + self.prompt_utils.write_prompt(directory, prompt=self.prompt) + + new_samples = await self.evaluation_utils.execute_prompt(self, directory) + _, answers = await self.evaluation_utils.evaluate_prompt( + self, None, new_samples, path=prompt_path, data=data, initial=True + ) + self.prompt_utils.write_answers(directory, answers=answers) + + async def _generate_optimized_prompt(self): + _, requirements, qa, count = load.load_meta_data() + samples = self.data_utils.get_best_round() + + logger.info(f"\n🚀Round {self.round} OPTIMIZATION STARTING 🚀\n") + logger.info(f"\nSelecting prompt for round {samples['round']} and advancing to the iteration phase\n") + + golden_answer = self.data_utils.list_to_markdown(qa) + best_answer = self.data_utils.list_to_markdown(samples["answers"]) + + optimize_prompt = PROMPT_OPTIMIZE_PROMPT.format( + prompt=samples["prompt"], + answers=best_answer, + requirements=requirements, + golden_answers=golden_answer, + count=count, + ) + + response = await self.llm.responser( + request_type=RequestType.OPTIMIZE, messages=[{"role": "user", "content": optimize_prompt}] + ) + + modification = extract_content(response, "modification") + logger.info(f"Modification of {self.round} round: {modification}") + + prompt = extract_content(response, "prompt") + return prompt if prompt else "" + + async def _evaluate_new_prompt(self, prompt_path, data, directory): + logger.info("\n⚡ RUNNING OPTIMIZED PROMPT ⚡\n") + new_samples = await self.evaluation_utils.execute_prompt(self, directory) + + logger.info("\n📊 EVALUATING OPTIMIZED PROMPT 📊\n") + samples = self.data_utils.get_best_round() + success, answers = await self.evaluation_utils.evaluate_prompt( + self, samples, new_samples, path=prompt_path, data=data, initial=False + ) + + self.prompt_utils.write_answers(directory, answers=answers) + return success, answers + + def _log_optimization_result(self, success): + logger.info("\n🎯 OPTIMIZATION RESULT 🎯\n") + logger.info(f"\nRound {self.round} Optimization: {'✅ SUCCESS' if success else '❌ FAILED'}\n") diff --git a/metagpt/ext/spo/prompts/evaluate_prompt.py b/metagpt/ext/spo/prompts/evaluate_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..80a9b093bfba7ee5f7559cb14b681da341803c5e --- /dev/null +++ b/metagpt/ext/spo/prompts/evaluate_prompt.py @@ -0,0 +1,20 @@ +EVALUATE_PROMPT = """ +Based on the original requirements, evaluate the two responses, A and B, and determine which one better meets the requirements. If a reference answer is provided, strictly follow the format/content of the reference answer. + +# Requirement +{requirement} + +# A +{sample} + +# B +{new_sample} + +# Golden answer +{answers} + +Provide your analysis and the choice you believe is better, using XML tags to encapsulate your response. + +Some analysis +A/B (the better answer in your opinion) +""" diff --git a/metagpt/ext/spo/prompts/optimize_prompt.py b/metagpt/ext/spo/prompts/optimize_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ca81e3342b7b244ec6e66f6720d76087bac1df --- /dev/null +++ b/metagpt/ext/spo/prompts/optimize_prompt.py @@ -0,0 +1,32 @@ +PROMPT_OPTIMIZE_PROMPT = """ +You are building a prompt to address user requirement. Based on the given prompt, +please reconstruct and optimize it. You can add, modify, or delete prompts. Please include a single modification in +XML tags in your reply. During the optimization, you can incorporate any thinking models. +This is a prompt that performed excellently in a previous iteration. You must make further optimizations and improvements based on this prompt. The modified prompt must differ from the provided example. + +requirements: +``` +{requirements} +``` + +reference prompt: +``` +{prompt} +``` + +The execution result of this reference prompt is(some cases): +``` +{answers} +``` + +The best answer we expect(some cases): +``` +{golden_answers} +``` + +Provide your analysis, optimization points, and the complete optimized prompt using the following XML format: + +Analyze what drawbacks exist in the results produced by the reference prompt and how to improve them. +Summarize the key points for improvement in one sentence +Provide the complete optimized prompt {count} +""" diff --git a/metagpt/ext/spo/settings/Navigate.yaml b/metagpt/ext/spo/settings/Navigate.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b20a6de973767bbc8eaed9a77cfd2c06ca6271e --- /dev/null +++ b/metagpt/ext/spo/settings/Navigate.yaml @@ -0,0 +1,47 @@ +prompt: | + Please think step by step. + Ensure the response concludes with the answer in the XML format: + [Yes or No]. + +requirements: | + Must put the final answer at the end with XML. ((Yes or No),such as Yes) + The provided prompt needs to adapt to all current types of questions. + +count: None + +qa: + - question: | + If you follow these instructions, do you return to the starting point? Always face forward. Take 7 steps left. Take 2 steps backward. Take 7 steps backward. Take 7 steps backward. Take 3 steps forward. + Options: + - Yes + - No + + answer: | + A lot of thinking and analysis processes. + ... + Final Answer: + (Yes or No) + + - question: | + If you follow these instructions, do you return to the starting point? Always face forward. Take 6 steps backward. Take 8 steps left. Take 3 steps right. Take 7 steps forward. Take 3 steps right. Take 9 steps right. Take 1 step backward. Take 7 steps left. + Options: + - Yes + - No + + answer: | + A lot of thinking and analysis processes. + ... + Final Answer: + (Yes or No) + + - question: | + If you follow these instructions, do you return to the starting point? Turn left. Turn left. Take 6 steps. Take 3 steps. Turn around. Take 1 step. Take 3 steps. Take 5 steps. + Options: + - Yes + - No + + answer: | + A lot of thinking and analysis processes. + ... + Final Answer: + (Yes or No) diff --git a/metagpt/ext/spo/settings/Poem.yaml b/metagpt/ext/spo/settings/Poem.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dba690c45b998452a64bea30a92f11ae5f52eb01 --- /dev/null +++ b/metagpt/ext/spo/settings/Poem.yaml @@ -0,0 +1,23 @@ +prompt: | + Create poetry in the requested style and format. + +requirements: | + None + +count: None + +qa: + - question: | + Write a modern sonnet about climate change + answer: | + None + + - question: | + Create a haiku series about New York City + answer: | + None + + - question: | + Write a free verse poem about social media + answer: | + None diff --git a/metagpt/ext/spo/utils/__init__.py b/metagpt/ext/spo/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/metagpt/ext/spo/utils/data_utils.py b/metagpt/ext/spo/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..17771c0213d353ad0c07a71c9e4895645310dba5 --- /dev/null +++ b/metagpt/ext/spo/utils/data_utils.py @@ -0,0 +1,106 @@ +import datetime +import json +from pathlib import Path +from typing import Dict, List, Union + +import pandas as pd + +from metagpt.logs import logger + + +class DataUtils: + def __init__(self, root_path: Path): + self.root_path = root_path + self.top_scores = [] + + def load_results(self, path: Path) -> list: + result_path = self.get_results_file_path(path) + if result_path.exists(): + try: + return json.loads(result_path.read_text()) + except json.JSONDecodeError: + return [] + return [] + + def get_best_round(self): + self._load_scores() + + for entry in self.top_scores: + if entry["succeed"]: + return entry + + return None + + def get_results_file_path(self, prompt_path: Path) -> Path: + return prompt_path / "results.json" + + def create_result_data(self, round: int, answers: list[dict], prompt: str, succeed: bool, tokens: int) -> dict: + now = datetime.datetime.now() + return {"round": round, "answers": answers, "prompt": prompt, "succeed": succeed, "tokens": tokens, "time": now} + + def save_results(self, json_file_path: Path, data: Union[List, Dict]): + json_path = json_file_path + json_path.write_text(json.dumps(data, default=str, indent=4)) + + def _load_scores(self): + rounds_dir = self.root_path / "prompts" + result_file = rounds_dir / "results.json" + self.top_scores = [] + + try: + if not result_file.exists(): + logger.warning(f"Results file not found at {result_file}") + return self.top_scores + + data = json.loads(result_file.read_text(encoding="utf-8")) + df = pd.DataFrame(data) + + for index, row in df.iterrows(): + self.top_scores.append( + { + "round": row["round"], + "succeed": row["succeed"], + "prompt": row["prompt"], + "answers": row["answers"], + } + ) + + self.top_scores.sort(key=lambda x: x["round"], reverse=True) + + except FileNotFoundError: + logger.error(f"Could not find results file: {result_file}") + except json.JSONDecodeError: + logger.error(f"Invalid JSON format in file: {result_file}") + except Exception as e: + logger.error(f"Unexpected error loading scores: {str(e)}") + + return self.top_scores + + def list_to_markdown(self, questions_list: list): + """ + Convert a list of question-answer dictionaries to a formatted Markdown string. + + Args: + questions_list (list): List of dictionaries containing 'question' and 'answer' keys + + Returns: + str: Formatted Markdown string + """ + markdown_text = "```\n" + + for i, qa_pair in enumerate(questions_list, 1): + # Add question section + markdown_text += f"Question {i}\n\n" + markdown_text += f"{qa_pair['question']}\n\n" + + # Add answer section + markdown_text += f"Answer {i}\n\n" + markdown_text += f"{qa_pair['answer']}\n\n" + + # Add separator between QA pairs except for the last one + if i < len(questions_list): + markdown_text += "---\n\n" + + markdown_text += "\n```" + + return markdown_text diff --git a/metagpt/ext/spo/utils/evaluation_utils.py b/metagpt/ext/spo/utils/evaluation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb026a21e323990276cf34ad592451d65e17254 --- /dev/null +++ b/metagpt/ext/spo/utils/evaluation_utils.py @@ -0,0 +1,81 @@ +import asyncio +from pathlib import Path +from typing import Any, List, Optional, Tuple + +import tiktoken + +from metagpt.ext.spo.components.evaluator import QuickEvaluate, QuickExecute +from metagpt.logs import logger + +EVALUATION_REPETITION = 4 + + +def count_tokens(sample: dict): + if not sample: + return 0 + else: + encoding = tiktoken.get_encoding("cl100k_base") + return len(encoding.encode(str(sample["answers"]))) + + +class EvaluationUtils: + def __init__(self, root_path: Path) -> None: + self.root_path = root_path + + async def execute_prompt(self, optimizer: Any, prompt_path: Path) -> dict: + optimizer.prompt = optimizer.prompt_utils.load_prompt(optimizer.round, prompt_path) + executor = QuickExecute(prompt=optimizer.prompt) + + answers = await executor.prompt_execute() + + cur_round = optimizer.round + + new_data = {"round": cur_round, "answers": answers, "prompt": optimizer.prompt} + + return new_data + + async def evaluate_prompt( + self, + optimizer: Any, + samples: Optional[dict], + new_samples: dict, + path: Path, + data: List[dict], + initial: bool = False, + ) -> Tuple[bool, dict]: + evaluator = QuickEvaluate() + new_token = count_tokens(new_samples) + + if initial is True: + succeed = True + else: + evaluation_results = [] + + evaluation_results.extend( + await asyncio.gather( + *( + evaluator.prompt_evaluate(samples=samples, new_samples=new_samples) + for _ in range(EVALUATION_REPETITION) + ) + ) + ) + + logger.info(f"Evaluation Results {evaluation_results}") + + true_count = evaluation_results.count(True) + false_count = evaluation_results.count(False) + succeed = true_count > false_count + + new_data = optimizer.data_utils.create_result_data( + new_samples["round"], new_samples["answers"], new_samples["prompt"], succeed, new_token + ) + + data.append(new_data) + + result_path = optimizer.data_utils.get_results_file_path(path) + + optimizer.data_utils.save_results(result_path, data) + + answers = new_samples["answers"] + + return succeed, answers diff --git a/metagpt/ext/spo/utils/llm_client.py b/metagpt/ext/spo/utils/llm_client.py new file mode 100644 index 0000000000000000000000000000000000000000..81524d3c137f491cb4f54dfbd7f616209bcf12cd --- /dev/null +++ b/metagpt/ext/spo/utils/llm_client.py @@ -0,0 +1,107 @@ +import asyncio +import re +from enum import Enum +from typing import Any, List, Optional + +from metagpt.configs.models_config import ModelsConfig +from metagpt.llm import LLM +from metagpt.logs import logger + + +class RequestType(Enum): + OPTIMIZE = "optimize" + EVALUATE = "evaluate" + EXECUTE = "execute" + + +class SPO_LLM: + _instance: Optional["SPO_LLM"] = None + + def __init__( + self, + optimize_kwargs: Optional[dict] = None, + evaluate_kwargs: Optional[dict] = None, + execute_kwargs: Optional[dict] = None, + ) -> None: + self.evaluate_llm = LLM(llm_config=self._load_llm_config(evaluate_kwargs)) + self.optimize_llm = LLM(llm_config=self._load_llm_config(optimize_kwargs)) + self.execute_llm = LLM(llm_config=self._load_llm_config(execute_kwargs)) + + def _load_llm_config(self, kwargs: dict) -> Any: + model = kwargs.get("model") + if not model: + raise ValueError("'model' parameter is required") + + try: + model_config = ModelsConfig.default().get(model) + if model_config is None: + raise ValueError(f"Model '{model}' not found in configuration") + + config = model_config.model_copy() + + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + + return config + + except AttributeError: + raise ValueError(f"Model '{model}' not found in configuration") + except Exception as e: + raise ValueError(f"Error loading configuration for model '{model}': {str(e)}") + + async def responser(self, request_type: RequestType, messages: List[dict]) -> str: + llm_mapping = { + RequestType.OPTIMIZE: self.optimize_llm, + RequestType.EVALUATE: self.evaluate_llm, + RequestType.EXECUTE: self.execute_llm, + } + + llm = llm_mapping.get(request_type) + if not llm: + raise ValueError(f"Invalid request type. Valid types: {', '.join([t.value for t in RequestType])}") + + response = await llm.acompletion(messages) + return response.choices[0].message.content + + @classmethod + def initialize(cls, optimize_kwargs: dict, evaluate_kwargs: dict, execute_kwargs: dict) -> None: + """Initialize the global instance""" + cls._instance = cls(optimize_kwargs, evaluate_kwargs, execute_kwargs) + + @classmethod + def get_instance(cls) -> "SPO_LLM": + """Get the global instance""" + if cls._instance is None: + raise RuntimeError("SPO_LLM not initialized. Call initialize() first.") + return cls._instance + + +def extract_content(xml_string: str, tag: str) -> Optional[str]: + pattern = rf"<{tag}>(.*?)" + match = re.search(pattern, xml_string, re.DOTALL) + return match.group(1).strip() if match else None + + +async def main(): + # test LLM + SPO_LLM.initialize( + optimize_kwargs={"model": "gpt-4o", "temperature": 0.7}, + evaluate_kwargs={"model": "gpt-4o-mini", "temperature": 0.3}, + execute_kwargs={"model": "gpt-4o-mini", "temperature": 0.3}, + ) + + llm = SPO_LLM.get_instance() + + # test messages + hello_msg = [{"role": "user", "content": "hello"}] + response = await llm.responser(request_type=RequestType.EXECUTE, messages=hello_msg) + logger(f"AI: {response}") + response = await llm.responser(request_type=RequestType.OPTIMIZE, messages=hello_msg) + logger(f"AI: {response}") + response = await llm.responser(request_type=RequestType.EVALUATE, messages=hello_msg) + logger(f"AI: {response}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/metagpt/ext/spo/utils/load.py b/metagpt/ext/spo/utils/load.py new file mode 100644 index 0000000000000000000000000000000000000000..6333b2775e41cb4e9eb63d4901ee98c477adfda9 --- /dev/null +++ b/metagpt/ext/spo/utils/load.py @@ -0,0 +1,48 @@ +import random +from pathlib import Path + +import yaml + +FILE_NAME = "" +SAMPLE_K = 3 + + +def set_file_name(name: str): + global FILE_NAME + FILE_NAME = name + + +def load_meta_data(k: int = SAMPLE_K): + # load yaml file + config_path = Path(__file__).parent.parent / "settings" / FILE_NAME + + if not config_path.exists(): + raise FileNotFoundError(f"Configuration file '{FILE_NAME}' not found in settings directory") + + try: + with config_path.open("r", encoding="utf-8") as file: + data = yaml.safe_load(file) + except yaml.YAMLError as e: + raise ValueError(f"Error parsing YAML file '{FILE_NAME}': {str(e)}") + except Exception as e: + raise Exception(f"Error reading file '{FILE_NAME}': {str(e)}") + + qa = [] + + for item in data["qa"]: + question = item["question"] + answer = item["answer"] + qa.append({"question": question, "answer": answer}) + + prompt = data["prompt"] + requirements = data["requirements"] + count = data["count"] + + if isinstance(count, int): + count = f", within {count} words" + else: + count = "" + + random_qa = random.sample(qa, min(k, len(qa))) + + return prompt, requirements, random_qa, count diff --git a/metagpt/ext/spo/utils/prompt_utils.py b/metagpt/ext/spo/utils/prompt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c960bb70ba288f13a34aeb960a25719221f0bf --- /dev/null +++ b/metagpt/ext/spo/utils/prompt_utils.py @@ -0,0 +1,34 @@ +from pathlib import Path + +from metagpt.logs import logger + + +class PromptUtils: + def __init__(self, root_path: Path): + self.root_path = root_path + + def create_round_directory(self, prompt_path: Path, round_number: int) -> Path: + directory = prompt_path / f"round_{round_number}" + directory.mkdir(parents=True, exist_ok=True) + return directory + + def load_prompt(self, round_number: int, prompts_path: Path): + prompt_file = prompts_path / "prompt.txt" + + try: + return prompt_file.read_text(encoding="utf-8") + except FileNotFoundError as e: + logger.info(f"Error loading prompt for round {round_number}: {e}") + raise + + def write_answers(self, directory: Path, answers: dict, name: str = "answers.txt"): + answers_file = directory / name + with answers_file.open("w", encoding="utf-8") as file: + for item in answers: + file.write(f"Question:\n{item['question']}\n") + file.write(f"Answer:\n{item['answer']}\n") + file.write("\n") + + def write_prompt(self, directory: Path, prompt: str): + prompt_file = directory / "prompt.txt" + prompt_file.write_text(prompt, encoding="utf-8") diff --git a/metagpt/ext/stanford_town/.DS_Store b/metagpt/ext/stanford_town/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..1c9ea2ba03ceb6df169ad69ca5ac5311d49bc9a7 Binary files /dev/null and b/metagpt/ext/stanford_town/.DS_Store differ diff --git a/metagpt/ext/stanford_town/README.md b/metagpt/ext/stanford_town/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1bdcac145f047b51614645ab6f7fd7ce6292d5f7 --- /dev/null +++ b/metagpt/ext/stanford_town/README.md @@ -0,0 +1,51 @@ +## Stanford Town Game + +### Pre-Description +In order to facilitate GA( [generative_agents](https://github.com/joonspk-research/generative_agents) )'s frontend docking data (to avoid changing its code), you can set the value `temp_storage_path` to `temp_storage` of `generative_agents` when start `run_st_game.py`. like + +`python3 run_st_game.py --temp_storage_path path/to/ga/temp_storage xxx` + +Or change the path under `const.py` like beflow + +``` +STORAGE_PATH = EXAMPLE_PATH.joinpath("storage") +TEMP_STORAGE_PATH = EXAMPLE_PATH.joinpath("temp_storage") +# updated +STORAGE_PATH = Path("{path/to/ga/storage}") +TEMP_STORAGE_PATH = Path("{path/to/ga/temp_storage}") +``` + +This can be used to achieve docking of simulation data without changing the GA code. Otherwise, the GA code must be modified to adapt to the MG output path. + +If you don't want to start from 0, copy other simulation directories under `generative_agents/environment/frontend_server/storage/` to `examples/stanford_town/storage`, and select a directory named `fork_sim_code`. + +### Backend service startup +The execution entry is `python3 run_st_game.py "Host a open lunch party at 13:00 pm" "base_the_ville_isabella_maria_klaus" "test_sim" 10` +or +`python3 run_st_game.py "Host a open lunch party at 13:00 pm" "base_the_ville_isabella_maria_klaus" "test_sim" 10 --temp_storage_path path/to/ga/temp_storage` + +`idea` is the user's voice to the first Agent, and it is disseminated through this voice to see whether the final multi-agents achieve the goal of hosting or participating in the event. + +### Frontend service startup +Enter project folder `generative_agents` + +Enter `environment/frontend_server` and use `python3 manage.py runserver` to start the front-end service. +Visit `http://localhost:8000/simulator_home` to enter the current simulation interface. + +## Acknowledgements +The reproduction work has referred the [generative_agents](https://github.com/joonspk-research/generative_agents), let's make a general statement here. + +### Citation +```bib +@inproceedings{Park2023GenerativeAgents, +author = {Park, Joon Sung and O'Brien, Joseph C. and Cai, Carrie J. and Morris, Meredith Ringel and Liang, Percy and Bernstein, Michael S.}, +title = {Generative Agents: Interactive Simulacra of Human Behavior}, +year = {2023}, +publisher = {Association for Computing Machinery}, +address = {New York, NY, USA}, +booktitle = {In the 36th Annual ACM Symposium on User Interface Software and Technology (UIST '23)}, +keywords = {Human-AI interaction, agents, generative AI, large language models}, +location = {San Francisco, CA, USA}, +series = {UIST '23} +} +``` \ No newline at end of file diff --git a/metagpt/ext/stanford_town/README_CN.md b/metagpt/ext/stanford_town/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..3daf68d08f4494a1137cf3ff4a981c85ed41f4cd --- /dev/null +++ b/metagpt/ext/stanford_town/README_CN.md @@ -0,0 +1,50 @@ +## Stanford Town Game + +### 前置 +为了方便GA( [generative_agents](https://github.com/joonspk-research/generative_agents) )的前端对接数据(避免改动它那块的代码),可在启动`run_st_game.py`加上`temp_storage_path`指向`generative_agents`对应的`temp_storage`路径。比如 + +`python3 run_st_game.py --temp_storage_path path/to/ga/temp_storage xxx` + +或将`const.py`下的 + +``` +STORAGE_PATH = EXAMPLE_PATH.joinpath("storage") +TEMP_STORAGE_PATH = EXAMPLE_PATH.joinpath("temp_storage") +# 更新为 +STORAGE_PATH = Path("{path/to/ga/storage}") +TEMP_STORAGE_PATH = Path("{path/to/ga/temp_storage}") +``` +这样可用实现不改变GA代码情况下,实现仿真数据的对接。不然得修改GA的代码来适配MG的输出路径。 + +如果你不想从0开始启动,拷贝`generative_agents/environment/frontend_server/storage/`下的其他仿真目录到`examples/stanford_town/storage`,并选择一个目录名作为`fork_sim_code`。 + +### 后端服务启动 +执行入口为:`python3 run_st_game.py "Host a open lunch party at 13:00 pm" "base_the_ville_isabella_maria_klaus" "test_sim" 10` +或者 +`python3 run_st_game.py "Host a open lunch party at 13:00 pm" "base_the_ville_isabella_maria_klaus" "test_sim" 10 --temp_storage_path path/to/ga/temp_storage` + +`idea`为用户给第一个Agent的用户心声,并通过这个心声进行传播,看最后多智能体是否达到举办、参加活动的目标。 + +### 前端服务启动 +进入`generative_agents`项目目录 + +进入`environment/frontend_server`,使用`python3 manage.py runserver`启动前端服务。 +访问`http://localhost:8000/simulator_home` 进入当前的仿真界面。 + +## 致谢 +复现工作参考了 [generative_agents](https://github.com/joonspk-research/generative_agents), 感谢相关作者们。 + +### 引用 +```bib +@inproceedings{Park2023GenerativeAgents, +author = {Park, Joon Sung and O'Brien, Joseph C. and Cai, Carrie J. and Morris, Meredith Ringel and Liang, Percy and Bernstein, Michael S.}, +title = {Generative Agents: Interactive Simulacra of Human Behavior}, +year = {2023}, +publisher = {Association for Computing Machinery}, +address = {New York, NY, USA}, +booktitle = {In the 36th Annual ACM Symposium on User Interface Software and Technology (UIST '23)}, +keywords = {Human-AI interaction, agents, generative AI, large language models}, +location = {San Francisco, CA, USA}, +series = {UIST '23} +} +``` diff --git a/metagpt/ext/stanford_town/__init__.py b/metagpt/ext/stanford_town/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56ea35c9f719f30ad6e8b0accf7f4480cefc98bb --- /dev/null +++ b/metagpt/ext/stanford_town/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : stanford town implement diff --git a/metagpt/ext/stanford_town/actions/__init__.py b/metagpt/ext/stanford_town/actions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/ext/stanford_town/actions/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/ext/stanford_town/actions/agent_chat_sum_rel.py b/metagpt/ext/stanford_town/actions/agent_chat_sum_rel.py new file mode 100644 index 0000000000000000000000000000000000000000..98d370bb075b6933a7b53c964b253a2d311c2f97 --- /dev/null +++ b/metagpt/ext/stanford_town/actions/agent_chat_sum_rel.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : summarize relationship in a agent chat + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class AgentChatSumRel(STAction): + name: str = "AgentChatSumRel" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + resp = False + try: + _ = llm_resp.split('"')[0].strip() + resp = True + except Exception: + pass + return resp + + def _func_cleanup(self, llm_resp: str, prompt: str) -> str: + return llm_resp.split('"')[0].strip() + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, init_role: "STRole", target_role: "STRole", statements: str) -> str: + def create_prompt_input(init_role: "STRole", target_role: "STRole", statements: str) -> str: + prompt_input = [statements, init_role.name, target_role.name] + return prompt_input + + prompt_input = create_prompt_input(init_role, target_role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "summarize_chat_relationship_v2.txt") + + example_output = "Jane Doe is working on a project" + special_instruction = "The output should be a string that responds to the question." + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {init_role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/metagpt/ext/stanford_town/actions/decide_to_talk.py b/metagpt/ext/stanford_town/actions/decide_to_talk.py new file mode 100644 index 0000000000000000000000000000000000000000..a393f31af71495bc3a9ee07e0a6e9d6810ab2fe9 --- /dev/null +++ b/metagpt/ext/stanford_town/actions/decide_to_talk.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : device to talk to another role, return yes or no + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class DecideToTalk(STAction): + name: str = "DecideToTalk" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + resp = False + try: + if llm_resp.split("Answer in yes or no:")[-1].strip().lower() in ["yes", "no"]: + resp = True + except ValueError: + pass + return resp + + def _func_cleanup(self, llm_resp: str, prompt: str) -> str: + return llm_resp.split("Answer in yes or no:")[-1].strip().lower() + + def _func_fail_default_resp(self) -> str: + return "yes" + + async def run(self, init_role: "STRole", target_role: "STRole", retrieved: dict, *args, **kwargs) -> bool: + """Run action""" + + def create_prompt_input(init_role: "STRole", target_role: "STRole", retrieved: dict) -> str: + scratch = init_role.rc.scratch + target_scratch = target_role.rc.scratch + last_chat = init_role.rc.memory.get_last_chat(target_role.name) + last_chatted_time = "" + last_chat_about = "" + if last_chat: + last_chatted_time = last_chat.created.strftime("%B %d, %Y, %H:%M:%S") + last_chat_about = last_chat.description + + context = "" + for c_node in retrieved["events"]: + curr_desc = c_node.description.split(" ") + curr_desc[2:3] = ["was"] + curr_desc = " ".join(curr_desc) + context += f"{curr_desc}. " + context += "\n" + for c_node in retrieved["thoughts"]: + context += f"{c_node.description}. " + + curr_time = scratch.curr_time.strftime("%B %d, %Y, %H:%M:%S %p") + init_act_desc = scratch.act_description + if "(" in init_act_desc: + init_act_desc = init_act_desc.split("(")[-1][:-1] + + if len(scratch.planned_path) == 0 and "waiting" not in init_act_desc: + init_p_desc = f"{init_role.name} is already {init_act_desc}" + elif "waiting" in init_act_desc: + init_p_desc = f"{init_role.name} is {init_act_desc}" + else: + init_p_desc = f"{init_role.name} is on the way to {init_act_desc}" + + target_act_desc = scratch.act_description + if "(" in target_act_desc: + target_act_desc = target_act_desc.split("(")[-1][:-1] + + if len(target_scratch.planned_path) == 0 and "waiting" not in init_act_desc: + target_p_desc = f"{target_role.name} is already {target_act_desc}" + elif "waiting" in init_act_desc: + target_p_desc = f"{init_role.name} is {init_act_desc}" + else: + target_p_desc = f"{target_role.name} is on the way to {target_act_desc}" + + prompt_input = [] + prompt_input += [context] + + prompt_input += [curr_time] + + prompt_input += [init_role.name] + prompt_input += [target_role.name] + prompt_input += [last_chatted_time] + prompt_input += [last_chat_about] + + prompt_input += [init_p_desc] + prompt_input += [target_p_desc] + prompt_input += [init_role.name] + prompt_input += [target_role.name] + return prompt_input + + prompt_input = create_prompt_input(init_role, target_role, retrieved) + prompt = self.generate_prompt_with_tmpl_filename( + prompt_input=prompt_input, tmpl_filename="decide_to_talk_v2.txt" + ) + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=20) # yes or no + result = True if output == "yes" else False + logger.info(f"Role: {init_role.name} Action: {self.cls_name} output: {result}") + return result diff --git a/metagpt/ext/stanford_town/actions/dummy_action.py b/metagpt/ext/stanford_town/actions/dummy_action.py new file mode 100644 index 0000000000000000000000000000000000000000..a5004d5ef36028e5761c270ee3c916ca9440ce3c --- /dev/null +++ b/metagpt/ext/stanford_town/actions/dummy_action.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : dummy action to make every STRole can deal DummyMessage which is caused by DummyAction + +from metagpt.actions import Action +from metagpt.schema import Message + + +class DummyAction(Action): + async def run(self, *args, **kwargs): + raise NotImplementedError + + +class DummyMessage(Message): + """ + dummy message to pass to role and make them to have a execution every round + """ + + content: str = "dummy" + cause_by: str = "DummyAction" diff --git a/metagpt/ext/stanford_town/actions/gen_action_details.py b/metagpt/ext/stanford_town/actions/gen_action_details.py new file mode 100644 index 0000000000000000000000000000000000000000..8e268a723a361217ed5e899d993526ded784d3af --- /dev/null +++ b/metagpt/ext/stanford_town/actions/gen_action_details.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : gen_action_details + +import random + +from metagpt.environment.stanford_town.env_space import EnvObsParams, EnvObsType +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class GenActionSector(STAction): + name: str = "GenActionSector" + + def _func_cleanup(self, llm_resp: str, prompt: str): + cleaned_response = llm_resp.split("}")[0] + return cleaned_response + + def _func_validate(self, llm_resp: str, prompt: str): + if len(llm_resp.strip()) < 1: + return False + if "}" not in llm_resp: + return False + if "," in llm_resp: + return False + return True + + def _func_fail_default_resp(self): + fs = "kitchen" + return fs + + async def run(self, role: "STRole", access_tile: dict[str, str], act_desp: str): + def create_prompt_input(role, access_tile: dict[str, str], act_desp): + act_world = f"{access_tile['world']}" + + prompt_input = [] + + prompt_input += [role.scratch.get_str_name()] + prompt_input += [role.scratch.living_area.split(":")[1]] + x = f"{act_world}:{role.scratch.living_area.split(':')[1]}" + prompt_input += [role.s_mem.get_str_accessible_sector_arenas(x)] + + prompt_input += [role.scratch.get_str_name()] + prompt_input += [f"{access_tile['sector']}"] + x = f"{act_world}:{access_tile['sector']}" + prompt_input += [role.s_mem.get_str_accessible_sector_arenas(x)] + + if role.scratch.get_str_daily_plan_req() != "": + prompt_input += [f"\n{role.scratch.get_str_daily_plan_req()}"] + else: + prompt_input += [""] + + # MAR 11 TEMP + prompt_input = [] + act_world = access_tile["world"] + accessible_sector_str = role.s_mem.get_str_accessible_sectors(act_world) + curr = accessible_sector_str.split(", ") + fin_accessible_sectors = [] + for i in curr: + if "'s house" in i: + if role.scratch.last_name in i: + fin_accessible_sectors += [i] + else: + fin_accessible_sectors += [i] + accessible_sector_str = ", ".join(fin_accessible_sectors) + # END MAR 11 TEMP + + prompt_input += [accessible_sector_str] + + act_desp_1 = act_desp + act_desp_2 = act_desp + if "(" in act_desp: + act_desp_1 = act_desp.split("(")[0].strip() + act_desp_2 = act_desp.split("(")[-1][:-1] + prompt_input += [role.scratch.get_str_name()] + prompt_input += [act_desp_1] + + prompt_input += [act_desp_2] + prompt_input += [role.scratch.get_str_name()] + return prompt_input + + prompt_template = "action_location_sector_v1.txt" + prompt_input = create_prompt_input(role, access_tile, act_desp) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=15) + y = f"{access_tile['world']}" + x = [i.strip() for i in role.s_mem.get_str_accessible_sectors(y).split(",")] + if output not in x: + # output = random.choice(x) + output = role.scratch.living_area.split(":")[1] + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenActionArena(STAction): + name: str = "GenActionArena" + + def _func_cleanup(self, llm_resp: str, prompt: str): + cleaned_response = llm_resp.split("}")[0] + return cleaned_response + + def _func_validate(self, llm_resp: str, prompt: str): + if len(llm_resp.strip()) < 1: + return False + if "}" not in llm_resp: + return False + if "," in llm_resp: + return False + return True + + def _func_fail_default_resp(self): + fs = "kitchen" + return fs + + async def run(self, role: "STRole", act_desp: str, act_world: str, act_sector: str): + def create_prompt_input(role, act_desp, act_world, act_sector): + prompt_input = [] + prompt_input += [role.scratch.get_str_name()] + x = f"{act_world}:{act_sector}" + prompt_input += [act_sector] + + # MAR 11 TEMP + accessible_arena_str = role.s_mem.get_str_accessible_sector_arenas(x) + curr = accessible_arena_str.split(", ") + fin_accessible_arenas = [] + for i in curr: + if "'s room" in i: + if role.scratch.last_name in i: + fin_accessible_arenas += [i] + else: + fin_accessible_arenas += [i] + accessible_arena_str = ", ".join(fin_accessible_arenas) + # END MAR 11 TEMP + prompt_input += [accessible_arena_str] + act_desp_1 = act_desp + act_desp_2 = act_desp + if "(" in act_desp: + act_desp_1 = act_desp.split("(")[0].strip() + act_desp_2 = act_desp.split("(")[-1][:-1] + prompt_input += [role.scratch.get_str_name()] + prompt_input += [act_desp_1] + + prompt_input += [act_desp_2] + prompt_input += [role.scratch.get_str_name()] + + prompt_input += [act_sector] + prompt_input += [accessible_arena_str] + return prompt_input + + prompt_template = "action_location_object_vMar11.txt" + prompt_input = create_prompt_input(role, act_desp, act_world, act_sector) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=15) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenActionObject(STAction): + name: str = "GenActionObject" + + def _func_validate(self, llm_resp: str, prompt: str): + if len(llm_resp.strip()) < 1: + return False + return True + + def _func_cleanup(self, llm_resp: str, prompt: str): + cleaned_response = llm_resp.strip() + return cleaned_response + + def _func_fail_default_resp(self): + fs = "bed" + return fs + + async def run(self, role: "STRole", act_desp: str, temp_address: str): + def create_prompt_input(role, act_desp, temp_address): + prompt_input = [] + if "(" in act_desp: + act_desp = act_desp.split("(")[-1][:-1] + + prompt_input += [act_desp] + prompt_input += [role.s_mem.get_str_accessible_arena_game_objects(temp_address)] + return prompt_input + + prompt_template = "action_object_v2.txt" + prompt_input = create_prompt_input(role, act_desp, temp_address) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=15) + x = [i.strip() for i in role.s_mem.get_str_accessible_arena_game_objects(temp_address).split(",")] + if output not in x: + output = random.choice(x) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenPronunciatio(STAction): + name: str = "GenPronunciatio" + + def _func_cleanup(self, llm_resp: str, prompt: str): + cr = llm_resp.strip() + if len(cr) > 3: + cr = cr[:3] + return cr + + def _func_validate(self, llm_resp: str, prompt: str): + try: + self._func_cleanup(llm_resp, prompt="") + if len(llm_resp) == 0: + return False + except Exception: + return False + return True + + def _func_fail_default_resp(self): + fs = "😋" + return fs + + async def run(self, role: "STRole", act_desp: str): + def create_prompt_input(act_desp): + if "(" in act_desp: + act_desp = act_desp.split("(")[-1].split(")")[0] + prompt_input = [act_desp] + return prompt_input + + prompt_template = "generate_pronunciatio_v1.txt" + prompt_input = create_prompt_input(act_desp) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + example_output = "🛁🧖‍♀️" + special_instruction = "The value for the output must ONLY contain the emojis." + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenEventTriple(STAction): + name: str = "GenEventTriple" + + def _func_cleanup(self, llm_resp: str, prompt: str): + cr = llm_resp.strip() + cr = [i.strip() for i in cr.split(")")[0].split(",")] + return cr + + def _func_validate(self, llm_resp: str, prompt: str): + try: + llm_resp = self._func_cleanup(llm_resp, prompt="") + if len(llm_resp) != 2: + return False + except Exception: + return False + return True + + def _func_fail_default_resp(self, role): + fs = (role.name, "is", "idle") + return fs + + async def run(self, role: "STRole", act_desp: str): + def create_prompt_input(role, act_desp): + if "(" in act_desp: + act_desp = act_desp.split("(")[-1].split(")")[0] + prompt_input = [role.name, act_desp, role.name] + return prompt_input + + prompt_template = "generate_event_triple_v1.txt" + prompt_input = create_prompt_input(role, act_desp) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + self.fail_default_resp = self._func_fail_default_resp(role) + output = await self._run_gpt35_max_tokens(prompt, max_tokens=30) + output = (role.name, output[0], output[1]) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenActObjDescription(STAction): + name: str = "GenActObjDescription" + + def _func_cleanup(self, llm_resp: str, prompt: str): + cr = llm_resp.strip() + if cr[-1] == ".": + cr = cr[:-1] + return cr + + def _func_validate(self, llm_resp: str, prompt: str): + try: + llm_resp = self._func_cleanup(llm_resp, prompt="") + except Exception: + return False + return True + + def _func_fail_default_resp(self, act_game_object): + fs = f"{act_game_object} is idle" + return fs + + async def run(self, role: "STRole", act_game_object: str, act_desp: str): + def create_prompt_input(act_game_object, act_desp, role): + prompt_input = [act_game_object, role.name, act_desp, act_game_object, act_game_object] + return prompt_input + + prompt_template = "generate_obj_event_v1.txt" + prompt_input = create_prompt_input(act_game_object, act_desp, role) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + example_output = "being fixed" + special_instruction = "The output should ONLY contain the phrase that should go in ." + self.fail_default_resp = self._func_fail_default_resp(act_game_object) + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenObjEventTriple(STAction): + name: str = "GenObjEventTriple" + + def _func_cleanup(self, llm_resp: str, prompt: str): + cr = llm_resp.strip() + cr = [i.strip() for i in cr.split(")")[0].split(",")] + return cr + + def _func_validate(self, llm_resp: str, prompt: str): + try: + llm_resp = self._func_cleanup(llm_resp, prompt="") + if len(llm_resp) != 2: + return False + except Exception: + return False + return True + + def _func_fail_default_resp(self, act_game_object: str): + fs = (act_game_object, "is", "idle") + return fs + + async def run(self, role: "STRole", act_game_object, act_obj_desp): + def create_prompt_input(act_game_object, act_obj_desp): + prompt_input = [act_game_object, act_obj_desp, act_game_object] + return prompt_input + + prompt_template = "generate_event_triple_v1.txt" + prompt_input = create_prompt_input(act_game_object, act_obj_desp) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + self.fail_default_resp = self._func_fail_default_resp(act_game_object) + output = await self._run_gpt35_max_tokens(prompt, max_tokens=30) + output = (act_game_object, output[0], output[1]) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +class GenActionDetails(STAction): + name: str = "GenActionDetails" + + def _func_cleanup(self, llm_resp: str, prompt: str) -> list: + pass + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + # TODO -- this sometimes generates error + try: + self._func_cleanup(llm_resp) + except Exception: + return False + return True + + def _func_fail_default_resp(self): + fs = {} + return fs + + async def run(self, role: "STRole", act_desp: str, act_dura): + access_tile = role.rc.env.observe( + obs_params=EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=role.scratch.curr_tile) + ) + act_world = access_tile["world"] + act_sector = await GenActionSector().run(role, access_tile, act_desp) + act_arena = await GenActionArena().run(role, act_desp, act_world, act_sector) + act_address = f"{act_world}:{act_sector}:{act_arena}" + if not role.s_mem.get_str_accessible_arena_game_objects(act_address): + act_game_object = "" + else: + act_game_object = await GenActionObject().run(role, act_desp, act_address) + new_address = f"{act_world}:{act_sector}:{act_arena}:{act_game_object}" + act_pron = await GenPronunciatio().run(role, act_desp) + act_event = await GenEventTriple().run(role, act_desp) + # Persona's actions also influence the object states. We set those up here. + act_obj_desp = await GenActObjDescription().run(role, act_game_object, act_desp) + act_obj_pron = await GenPronunciatio().run(role, act_obj_desp) + act_obj_event = await GenObjEventTriple().run(role, act_game_object, act_obj_desp) + result_dict = { + "action_address": new_address, + "action_duration": int(act_dura), + "action_description": act_desp, + "action_pronunciatio": act_pron, + "action_event": act_event, + "chatting_with": None, + "chat": None, + "chatting_with_buffer": None, + "chatting_end_time": None, + "act_obj_description": act_obj_desp, + "act_obj_pronunciatio": act_obj_pron, + "act_obj_event": act_obj_event, + } + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {result_dict}") + return result_dict diff --git a/metagpt/ext/stanford_town/actions/gen_daily_schedule.py b/metagpt/ext/stanford_town/actions/gen_daily_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..5dffa8995260467c76f9e9810eefd748f960d334 --- /dev/null +++ b/metagpt/ext/stanford_town/actions/gen_daily_schedule.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : gen_daily_schedule + + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class GenDailySchedule(STAction): + name: str = "GenDailySchedule" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt="") + except Exception: + return False + return True + + def _func_cleanup(self, llm_resp: str, prompt: str) -> list: + cr = [] + _cr = llm_resp.split(")") + for i in _cr: + if i[-1].isdigit(): + i = i[:-1].strip() + if i[-1] == "." or i[-1] == ",": + cr += [i[:-1].strip()] + return cr + + def _func_fail_default_resp(self) -> int: + fs = [ + "wake up and complete the morning routine at 6:00 am", + "eat breakfast at 7:00 am", + "read a book from 8:00 am to 12:00 pm", + "have lunch at 12:00 pm", + "take a nap from 1:00 pm to 4:00 pm", + "relax and watch TV from 7:00 pm to 8:00 pm", + "go to bed at 11:00 pm", + ] + return fs + + async def run(self, role: "STRole", wake_up_hour: str): + def create_prompt_input(role, wake_up_hour): + prompt_input = [] + prompt_input += [role.scratch.get_str_iss()] + prompt_input += [role.scratch.get_str_lifestyle()] + prompt_input += [role.scratch.get_str_curr_date_str()] + prompt_input += [role.scratch.get_str_firstname()] + prompt_input += [f"{str(wake_up_hour)}:00 am"] + return prompt_input + + wake_up_hour = int(wake_up_hour) + prompt_template = "daily_planning_v6.txt" + prompt_input = create_prompt_input(role, wake_up_hour) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=500) + output = [f"wake up and complete the morning routine at {wake_up_hour}:00 am"] + output + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/metagpt/ext/stanford_town/actions/gen_hourly_schedule.py b/metagpt/ext/stanford_town/actions/gen_hourly_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..5d59f96ddaa81f918d324933df70ebabcf6fb634 --- /dev/null +++ b/metagpt/ext/stanford_town/actions/gen_hourly_schedule.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : gen_hourly_schedule + +import random +import string + +from metagpt.logs import logger + +from .st_action import STAction + + +def get_random_alphanumeric(i=6, j=6): + """ + Returns a random alpha numeric strength that has the length of somewhere + between i and j. + + INPUT: + i: min_range for the length + j: max_range for the length + OUTPUT: + an alpha numeric str with the length of somewhere between i and j. + """ + k = random.randint(i, j) + x = "".join(random.choices(string.ascii_letters + string.digits, k=k)) + return x + + +class GenHourlySchedule(STAction): + name: str = "GenHourlySchedule" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt="") + except Exception: + return False + return True + + def _func_cleanup(self, llm_resp: str, prompt: str) -> list: + cr = llm_resp.strip() + if cr[-1] == ".": + cr = cr[:-1] + # to only use the first line of output + cr = cr.split("\n")[0] + return cr + + def _func_fail_default_resp(self) -> int: + fs = "asleep" + return fs + + async def _generate_schedule_for_given_hour( + self, role: "STRole", curr_hour_str, p_f_ds_hourly_org, hour_str, intermission2=None + ): + def create_prompt_input(persona, curr_hour_str, p_f_ds_hourly_org, hour_str, intermission2=None): + schedule_format = "" + for i in hour_str: + schedule_format += f"[{persona.scratch.get_str_curr_date_str()} -- {i}]" + schedule_format += " Activity: [Fill in]\n" + schedule_format = schedule_format[:-1] + + intermission_str = "Here the originally intended hourly breakdown of" + intermission_str += f" {persona.scratch.get_str_firstname()}'s schedule today: " + for count, i in enumerate(persona.scratch.daily_req): + intermission_str += f"{str(count + 1)}) {i}, " + intermission_str = intermission_str[:-2] + + prior_schedule = "" + if p_f_ds_hourly_org: + prior_schedule = "\n" + for count, i in enumerate(p_f_ds_hourly_org): + prior_schedule += f"[(ID:{get_random_alphanumeric()})" + prior_schedule += f" {persona.scratch.get_str_curr_date_str()} --" + prior_schedule += f" {hour_str[count]}] Activity:" + prior_schedule += f" {persona.scratch.get_str_firstname()}" + prior_schedule += f" is {i}\n" + + prompt_ending = f"[(ID:{get_random_alphanumeric()})" + prompt_ending += f" {persona.scratch.get_str_curr_date_str()}" + prompt_ending += f" -- {curr_hour_str}] Activity:" + prompt_ending += f" {persona.scratch.get_str_firstname()} is" + + if intermission2: + intermission2 = f"\n{intermission2}" + + prompt_input = [] + prompt_input += [schedule_format] + prompt_input += [persona.scratch.get_str_iss()] + + prompt_input += [prior_schedule + "\n"] + prompt_input += [intermission_str] + if intermission2: + prompt_input += [intermission2] + else: + prompt_input += [""] + prompt_input += [prompt_ending] + + return prompt_input + + prompt_template = "generate_hourly_schedule_v2.txt" + prompt_input = create_prompt_input(role, curr_hour_str, p_f_ds_hourly_org, hour_str, intermission2) + prompt_input_str = "\n".join(prompt_input) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, prompt_template) + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=50) + logger.info( + f"Role: {role.name} _generate_schedule_for_given_hour prompt_input: {prompt_input_str}, " + f"output: {output}" + ) + return output + + async def run(self, role: "STRole", wake_up_hour: int): + hour_str = [ + "00:00 AM", + "01:00 AM", + "02:00 AM", + "03:00 AM", + "04:00 AM", + "05:00 AM", + "06:00 AM", + "07:00 AM", + "08:00 AM", + "09:00 AM", + "10:00 AM", + "11:00 AM", + "12:00 PM", + "01:00 PM", + "02:00 PM", + "03:00 PM", + "04:00 PM", + "05:00 PM", + "06:00 PM", + "07:00 PM", + "08:00 PM", + "09:00 PM", + "10:00 PM", + "11:00 PM", + ] + n_m1_activity = [] + diversity_repeat_count = 1 # TODO mg 1->3 + for i in range(diversity_repeat_count): + logger.info(f"diversity_repeat_count idx: {i}") + n_m1_activity_set = set(n_m1_activity) + if len(n_m1_activity_set) < 5: + n_m1_activity = [] + for count, curr_hour_str in enumerate(hour_str): + if wake_up_hour > 0: + n_m1_activity += ["sleeping"] + wake_up_hour -= 1 + else: + logger.info(f"_generate_schedule_for_given_hour idx: {count}, n_m1_activity: {n_m1_activity}") + n_m1_activity += [ + await self._generate_schedule_for_given_hour(role, curr_hour_str, n_m1_activity, hour_str) + ] + + # Step 1. Compressing the hourly schedule to the following format: + # The integer indicates the number of hours. They should add up to 24. + # [['sleeping', 6], ['waking up and starting her morning routine', 1], + # ['eating breakfast', 1], ['getting ready for the day', 1], + # ['working on her painting', 2], ['taking a break', 1], + # ['having lunch', 1], ['working on her painting', 3], + # ['taking a break', 2], ['working on her painting', 2], + # ['relaxing and watching TV', 1], ['going to bed', 1], ['sleeping', 2]] + _n_m1_hourly_compressed = [] + prev = None + prev_count = 0 + for i in n_m1_activity: + if i != prev: + prev_count = 1 + _n_m1_hourly_compressed += [[i, prev_count]] + prev = i + elif _n_m1_hourly_compressed: + _n_m1_hourly_compressed[-1][1] += 1 + + # Step 2. Expand to min scale (from hour scale) + # [['sleeping', 360], ['waking up and starting her morning routine', 60], + # ['eating breakfast', 60],.. + n_m1_hourly_compressed = [] + for task, duration in _n_m1_hourly_compressed: + n_m1_hourly_compressed += [[task, duration * 60]] + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {n_m1_hourly_compressed}") + return n_m1_hourly_compressed diff --git a/metagpt/ext/stanford_town/actions/gen_iter_chat_utt.py b/metagpt/ext/stanford_town/actions/gen_iter_chat_utt.py new file mode 100644 index 0000000000000000000000000000000000000000..40f6d3af0ed87d5a030a8e7594b297014814bc54 --- /dev/null +++ b/metagpt/ext/stanford_town/actions/gen_iter_chat_utt.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : generate_iterative_chat_utt + +from metagpt.environment.stanford_town.env_space import EnvObsParams, EnvObsType +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.ext.stanford_town.utils.utils import extract_first_json_dict +from metagpt.logs import logger + + +class GenIterChatUTT(STAction): + name: str = "GenIterChatUTT" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + resp = False + try: + _ = extract_first_json_dict(llm_resp) + resp = True + except Exception: + pass + return resp + + def _func_cleanup(self, llm_resp: str, prompt: str) -> dict: + gpt_response = extract_first_json_dict(llm_resp) + + cleaned_dict = dict() + cleaned = [] + for key, val in gpt_response.items(): + cleaned += [val] + cleaned_dict["utterance"] = cleaned[0] + cleaned_dict["end"] = True + if "f" in str(cleaned[1]) or "F" in str(cleaned[1]): + cleaned_dict["end"] = False + + return cleaned_dict + + def _func_fail_default_resp(self) -> dict: + cleaned_dict = dict() + cleaned_dict["utterance"] = "..." + cleaned_dict["end"] = False + return cleaned_dict + + async def run( + self, + init_role: "STRole", + target_role: "STRole", + retrieved: dict, + curr_context: str, + curr_chat: list[str], + *args, + **kwargs, + ) -> dict: + def create_prompt_input( + access_tile: dict[str, str], + init_role: "STRole", + target_role: "STRole", + retrieved: dict, + curr_context: str, + curr_chat: list[str], + ): + role = init_role + scratch = role.rc.scratch + target_scratch = target_role.rc.scratch + prev_convo_insert = "\n" + if role.rc.memory.chat_list: + for i in role.rc.memory.chat_list: + if i.object == target_role.name: + v1 = int((scratch.curr_time - i.created).total_seconds() / 60) + prev_convo_insert += ( + f"{str(v1)} minutes ago, {scratch.name} and " + f"{target_scratch.name} were already {i.description} " + f"This context takes place after that conversation." + ) + break + if prev_convo_insert == "\n": + prev_convo_insert = "" + if role.rc.memory.chat_list: + if int((scratch.curr_time - role.rc.memory.chat_list[-1].created).total_seconds() / 60) > 480: + prev_convo_insert = "" + logger.info(f"prev_convo_insert: {prev_convo_insert}") + + curr_sector = f"{access_tile['sector']}" + curr_arena = f"{access_tile['arena']}" + curr_location = f"{curr_arena} in {curr_sector}" + + retrieved_str = "" + for key, vals in retrieved.items(): + for v in vals: + retrieved_str += f"- {v.description}\n" + + convo_str = "" + for i in curr_chat: + convo_str += ": ".join(i) + "\n" + if convo_str == "": + convo_str = "[The conversation has not started yet -- start it!]" + + init_iss = f"Here is Here is a brief description of {scratch.name}.\n{scratch.get_str_iss()}" + prompt_input = [ + init_iss, + scratch.name, + retrieved_str, + prev_convo_insert, + curr_location, + curr_context, + scratch.name, + target_scratch.name, + convo_str, + scratch.name, + target_scratch.name, + scratch.name, + scratch.name, + scratch.name, + ] + return prompt_input + + access_tile = init_role.rc.env.observe( + obs_params=EnvObsParams(obs_type=EnvObsType.GET_TITLE, coord=init_role.scratch.curr_tile) + ) + prompt_input = create_prompt_input(access_tile, init_role, target_role, retrieved, curr_context, curr_chat) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "iterative_convo_v1.txt") + # original using `ChatGPT_safe_generate_response_OLD` + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_wo_extra_prompt(prompt) + logger.info(f"Role: {init_role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/metagpt/ext/stanford_town/actions/inner_voice_action.py b/metagpt/ext/stanford_town/actions/inner_voice_action.py new file mode 100644 index 0000000000000000000000000000000000000000..83cfa037ba8de69309a1f9438509b2a5ec8de8b6 --- /dev/null +++ b/metagpt/ext/stanford_town/actions/inner_voice_action.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class AgentWhisperThoughtAction(STAction): + name: str = "AgentWhisperThoughtAction" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> list: + return llm_resp.split('"')[0].strip() + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str: + def create_prompt_input(role: "STRole", statements, test_input=None): + prompt_input = [role.scratch.name, statements] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "whisper_inner_thought_v1.txt") + + output = await self._run_gpt35_max_tokens(prompt, max_tokens=50) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/metagpt/ext/stanford_town/actions/new_decomp_schedule.py b/metagpt/ext/stanford_town/actions/new_decomp_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..759ec170f464622a304aac85269efe043568d16d --- /dev/null +++ b/metagpt/ext/stanford_town/actions/new_decomp_schedule.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : new_decomp_schedule + +import datetime + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class NewDecompSchedule(STAction): + name: str = "NewDecompSchedule" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + resp = False + try: + llm_resp = self._func_cleanup(llm_resp, prompt) + dur_sum = 0 + for act, dur in llm_resp: + dur_sum += dur + if isinstance(act, str): + return False + if isinstance(dur, int): + return False + x = prompt.split("\n")[0].split("originally planned schedule from")[-1].strip()[:-1] + x = [datetime.datetime.strptime(i.strip(), "%H:%M %p") for i in x.split(" to ")] + delta_min = int((x[1] - x[0]).total_seconds() / 60) + + if int(dur_sum) != int(delta_min): + return False + except Exception: + pass + return resp + + def _func_cleanup(self, llm_resp: str, prompt: str) -> list: + new_schedule = prompt + " " + llm_resp.strip() + new_schedule = new_schedule.split("The revised schedule:")[-1].strip() + new_schedule = new_schedule.split("\n") + + ret_temp = [] + for i in new_schedule: + ret_temp += [i.split(" -- ")] + + ret = [] + for time_str, action in ret_temp: + start_time = time_str.split(" ~ ")[0].strip() + end_time = time_str.split(" ~ ")[1].strip() + delta = datetime.datetime.strptime(end_time, "%H:%M") - datetime.datetime.strptime(start_time, "%H:%M") + delta_min = int(delta.total_seconds() / 60) + if delta_min < 0: + delta_min = 0 + ret += [[action, delta_min]] + + return ret + + def _func_fail_default_resp(self, main_act_dur: int, truncated_act_dur: int) -> int: + dur_sum = 0 + for act, dur in main_act_dur: + dur_sum += dur + + ret = truncated_act_dur[:] + ret += main_act_dur[len(ret) - 1 :] + + # If there are access, we need to trim... + ret_dur_sum = 0 + count = 0 + over = None + for act, dur in ret: + ret_dur_sum += dur + if ret_dur_sum == dur_sum: + break + if ret_dur_sum > dur_sum: + over = ret_dur_sum - dur_sum + break + count += 1 + + if over: + ret = ret[: count + 1] + ret[-1][1] -= over + + return ret + + async def run( + self, + role: "STRole", + main_act_dur: int, + truncated_act_dur: int, + start_time_hour: datetime, + end_time_hour: datetime, + inserted_act: str, + inserted_act_dur: int, + *args, + **kwargs, + ): + def create_prompt_input( + role: "STRole", + main_act_dur: int, + truncated_act_dur: int, + start_time_hour: datetime, + end_time_hour: datetime, + inserted_act: str, + inserted_act_dur: int, + ): + persona_name = role.name + start_hour_str = start_time_hour.strftime("%H:%M %p") + end_hour_str = end_time_hour.strftime("%H:%M %p") + + original_plan = "" + for_time = start_time_hour + for i in main_act_dur: + original_plan += ( + f'{for_time.strftime("%H:%M")} ~ ' + f'{(for_time + datetime.timedelta(minutes=int(i[1]))).strftime("%H:%M")} -- ' + i[0] + ) + original_plan += "\n" + for_time += datetime.timedelta(minutes=int(i[1])) + + new_plan_init = "" + for_time = start_time_hour + for count, i in enumerate(truncated_act_dur): + new_plan_init += ( + f'{for_time.strftime("%H:%M")} ~ ' + f'{(for_time + datetime.timedelta(minutes=int(i[1]))).strftime("%H:%M")} -- ' + i[0] + ) + new_plan_init += "\n" + if count < len(truncated_act_dur) - 1: + for_time += datetime.timedelta(minutes=int(i[1])) + + new_plan_init += (for_time + datetime.timedelta(minutes=int(i[1]))).strftime("%H:%M") + " ~" + + prompt_input = [ + persona_name, + start_hour_str, + end_hour_str, + original_plan, + persona_name, + inserted_act, + inserted_act_dur, + persona_name, + start_hour_str, + end_hour_str, + end_hour_str, + new_plan_init, + ] + return prompt_input + + prompt_input = create_prompt_input( + role, main_act_dur, truncated_act_dur, start_time_hour, end_time_hour, inserted_act, inserted_act_dur + ) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "new_decomp_schedule_v1.txt") + self.fail_default_resp = self._func_fail_default_resp(main_act_dur, truncated_act_dur) + output = await self._run_gpt35_max_tokens(prompt, max_tokens=1000) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/metagpt/ext/stanford_town/actions/run_reflect_action.py b/metagpt/ext/stanford_town/actions/run_reflect_action.py new file mode 100644 index 0000000000000000000000000000000000000000..895f6828f03d629b08f8b8f2207250ce1940c736 --- /dev/null +++ b/metagpt/ext/stanford_town/actions/run_reflect_action.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Integration Reflect Action + +import re + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +# Run GPT Prompt Focal Point method +class AgentFocusPt(STAction): + name: str = "AgentFocusPt" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + try: + """ + Cleanup handling has been completed for run_v2 + """ + return llm_resp + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: "STRole", statements: str, n: int, test_input=None) -> str: + def create_prompt_input(role: "STRole", statements, n, test_input=None): + prompt_input = [statements, str(n)] + return prompt_input + + prompt_input = create_prompt_input(role, statements, n) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "generate_focal_pt_v1.txt") + + example_output = '["What should Jane do for lunch", "Does Jane like strawberry", "Who is Jane"]' + special_instruction = "Output must be a list of str." + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +# Run GPT Prompt Insight and Guidance +class AgentInsightAndGuidance(STAction): + name: str = "AgentInsightAndGuidance" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> dict: + try: + llm_resp = "1. " + llm_resp.strip() + ret = dict() + for i in llm_resp.split("\n"): + row = " ".join(i.split(". ")[1:]) + if "(because of " not in row: + continue + thought = row.split("(because of ")[0].strip() + if ")" not in row.split("(because of ")[1]: + continue + evi_raw = row.split("(because of ")[1].split(")")[0].strip() + evi_raw = re.findall(r"\d+", evi_raw) + evi_raw = [int(i.strip()) for i in evi_raw] + ret[thought] = evi_raw + return ret + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self, n: int) -> str: + return ["I am hungry"] * n + + async def run(self, role: "STRole", statements: str, n: int, test_input=None) -> dict: + def create_prompt_input(role, statements, n, test_input=None): + prompt_input = [statements, str(n)] + return prompt_input + + prompt_input = create_prompt_input(role, statements, n) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "insight_and_evidence_v1.txt") + + self.fail_default_resp = self._func_fail_default_resp(n) + output = await self._run_gpt35_max_tokens(prompt, max_tokens=150) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +# Run GPT Prompt Event Triple +class AgentEventTriple(STAction): + name: str = "AgentEventTriple" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + llm_resp = self._func_cleanup(llm_resp, prompt="") + if len(llm_resp) != 2: + return False + except Exception: + return False + return True + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> list: + try: + cr = llm_resp.strip() + cr = [i.strip() for i in cr.split(")")[0].split(",")] + if len(cr) != 2: + return cr[-2:] + return cr + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, statements: str, role: "STRole", verbose=False) -> tuple: + def create_prompt_input(statements, role): + if "(" in statements: + statements = statements.split("(")[-1].split(")")[0] + prompt_input = [role.scratch.name, statements, role.scratch.name] + return prompt_input + + prompt_input = create_prompt_input(statements, role) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "generate_event_triple_v1.txt") + + output = await self._run_gpt35_max_tokens(prompt, max_tokens=30) + output = (role.scratch.name, output[0], output[1]) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +# Run GPT Prompt Event Poignancy +class AgentEventPoignancy(STAction): + name: str = "AgentEventPoignancy" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> int: + try: + llm_resp = int(llm_resp.strip()) + return llm_resp + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str: + def create_prompt_input(role: "STRole", statements: str, test_input=None): + prompt_input = [role.scratch.name, role.scratch.get_str_iss(), role.scratch.name, statements] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "poignancy_event_v1.txt") + + example_output = "5" # ######## + special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10." + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +# Run GPT Prompt Chat Poignancy +class AgentChatPoignancy(STAction): + name: str = "AgentChatPoignancy" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> int: + try: + llm_resp = int(llm_resp.strip()) + return llm_resp + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str: + def create_prompt_input(role: "STRole", statements, test_input=None): + prompt_input = [role.scratch.name, role.scratch.get_str_iss(), role.scratch.name, statements] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "poignancy_chat_v1.txt") + + example_output = "5" # ######## + special_instruction = "The output should ONLY contain ONE integer value on the scale of 1 to 10." + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +# Run GPT Prompt Planning Thought on Convo +class AgentPlanThoughtOnConvo(STAction): + name: str = "AgentPlanThoughtOnConvo" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + try: + return llm_resp.split('"')[0].strip() + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str: + def create_prompt_input(role, statements, test_input=None): + prompt_input = [statements, role.scratch.name, role.scratch.name, role.scratch.name] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "planning_thought_on_convo_v1.txt") + + output = await self._run_gpt35_max_tokens(prompt, max_tokens=50) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output + + +# Run GPT Prompt Memory on Convo +class AgentMemoryOnConvo(STAction): + name: str = "AgentMemoryOnConvo" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + try: + self._func_cleanup(llm_resp, prompt) + return True + except Exception: + return False + + def _func_cleanup(self, llm_resp: str, prompt: str = "") -> str: + try: + return llm_resp.split('"')[0].strip() + except Exception as exp: + logger.error(f"{self.cls_name} with error {exp}") + + def _func_fail_default_resp(self) -> str: + pass + + async def run(self, role: "STRole", statements: str, test_input=None, verbose=False) -> str: + def create_prompt_input(role, statements, test_input=None): + prompt_input = [statements, role.scratch.name, role.scratch.name, role.scratch.name] + return prompt_input + + prompt_input = create_prompt_input(role, statements) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "memo_on_convo_v1.txt") + example_output = "Jane Doe was interesting to talk to." + special_instruction = ( + "The output should ONLY contain a string that summarizes anything interesting " + "that the agent may have noticed" + ) + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/metagpt/ext/stanford_town/actions/st_action.py b/metagpt/ext/stanford_town/actions/st_action.py new file mode 100644 index 0000000000000000000000000000000000000000..48cda353cc804e13555af9b783a244cecce2268a --- /dev/null +++ b/metagpt/ext/stanford_town/actions/st_action.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : StanfordTown Action +import json +import time +from abc import abstractmethod +from pathlib import Path +from typing import Any, Optional, Union + +from metagpt.actions.action import Action +from metagpt.ext.stanford_town.utils.const import PROMPTS_DIR +from metagpt.logs import logger + + +class STAction(Action): + name: str = "STAction" + prompt_dir: Path = PROMPTS_DIR + fail_default_resp: Optional[str] = None + + @property + def cls_name(self): + return self.__class__.__name__ + + @abstractmethod + def _func_validate(self, llm_resp: str, prompt: str): + raise NotImplementedError + + @abstractmethod + def _func_cleanup(self, llm_resp: str, prompt: str): + raise NotImplementedError + + @abstractmethod + def _func_fail_default_resp(self): + raise NotImplementedError + + def generate_prompt_with_tmpl_filename(self, prompt_input: Union[str, list], tmpl_filename) -> str: + """ + same with `generate_prompt` + Args: + prompt_input: the input we want to feed in (IF THERE ARE MORE THAN ONE INPUT, THIS CAN BE A LIST.) + tmpl_filename: prompt template filename + Returns: + a str prompt that will be sent to LLM server. + """ + if isinstance(prompt_input, str): + prompt_input = [prompt_input] + prompt_input = [str(i) for i in prompt_input] + + f = open(str(self.prompt_dir.joinpath(tmpl_filename)), "r") + prompt = f.read() + f.close() + for count, i in enumerate(prompt_input): + prompt = prompt.replace(f"!!", i) + if "###" in prompt: + prompt = prompt.split("###")[1] + return prompt.strip() + + async def _aask(self, prompt: str) -> str: + return await self.llm.aask(prompt) + + async def _run_gpt35_max_tokens(self, prompt: str, max_tokens: int = 50, retry: int = 3): + for idx in range(retry): + try: + tmp_max_tokens_rsp = getattr(self.config.llm, "max_token", 1500) + setattr(self.config.llm, "max_token", max_tokens) + self.llm.use_system_prompt = False # to make it behave like a non-chat completions + + llm_resp = await self._aask(prompt) + + setattr(self.config.llm, "max_token", tmp_max_tokens_rsp) + logger.info(f"Action: {self.cls_name} llm _run_gpt35_max_tokens raw resp: {llm_resp}") + if self._func_validate(llm_resp, prompt): + return self._func_cleanup(llm_resp, prompt) + except Exception as exp: + logger.warning(f"Action: {self.cls_name} _run_gpt35_max_tokens exp: {exp}") + time.sleep(5) + return self.fail_default_resp + + async def _run_gpt35( + self, prompt: str, example_output: str, special_instruction: str, retry: int = 3 + ) -> Union[bool, Any]: + """same with `gpt_structure.ChatGPT_safe_generate_response`""" + prompt = '"""\n' + prompt + '\n"""\n' + prompt += f"Output the response to the prompt above in json. {special_instruction}\n" + prompt += "Example output json:\n" + prompt += '{"output": "' + str(example_output) + '"}' + + for idx in range(retry): + try: + llm_resp = await self._aask(prompt) + logger.info(f"Action: {self.cls_name} llm _run_gpt35 raw resp: {llm_resp}") + end_idx = llm_resp.strip().rfind("}") + 1 + llm_resp = llm_resp[:end_idx] + llm_resp = json.loads(llm_resp)["output"] + + if self._func_validate(llm_resp, prompt): + return self._func_cleanup(llm_resp, prompt) + except Exception as exp: + logger.warning(f"Action: {self.cls_name} _run_gpt35 exp: {exp}") + time.sleep(5) # usually avoid `Rate limit` + return False + + async def _run_gpt35_wo_extra_prompt(self, prompt: str, retry: int = 3) -> str: + for idx in range(retry): + try: + llm_resp = await self._aask(prompt) + llm_resp = llm_resp.strip() + logger.info(f"Action: {self.cls_name} llm _run_gpt35_wo_extra_prompt raw resp: {llm_resp}") + if self._func_validate(llm_resp, prompt): + return self._func_cleanup(llm_resp, prompt) + except Exception as exp: + logger.warning(f"Action: {self.cls_name} _run_gpt35_wo_extra_prompt exp: {exp}") + time.sleep(5) # usually avoid `Rate limit` + return self.fail_default_resp + + async def run(self, *args, **kwargs): + """Run action""" + raise NotImplementedError("The run method should be implemented in a subclass.") diff --git a/metagpt/ext/stanford_town/actions/summarize_conv.py b/metagpt/ext/stanford_town/actions/summarize_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..5be5fcaa4381b55946f1208ec54f356ca74922a4 --- /dev/null +++ b/metagpt/ext/stanford_town/actions/summarize_conv.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : summarize the content of agents' conversation + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class SummarizeConv(STAction): + name: str = "SummarizeConv" + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + resp = False + try: + _ = self._func_cleanup(llm_resp, prompt) + resp = True + except Exception: + pass + return resp + + def _func_cleanup(self, llm_resp: str, prompt: str) -> str: + ret = "conversing about " + llm_resp.strip() + return ret + + def _func_fail_default_resp(self) -> str: + return "conversing with a housemate about morning greetings" + + async def run(self, conv: list): + def create_prompt_input(conversation: list): + convo_str = "" + for row in conversation: + convo_str += f'{row[0]}: "{row[1]}"\n' + prompt_input = [convo_str] + return prompt_input + + prompt_input = create_prompt_input(conv) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "summarize_conversation_v1.txt") + + example_output = "conversing about what to eat for lunch" + special_instruction = ( + "The output must continue the sentence above by filling in the tag. " + "Don't start with 'this is a conversation about...' Just finish the sentence " + "but do not miss any important details (including who are chatting)." + ) + output = await self._run_gpt35(prompt, example_output, special_instruction) + logger.info(f"Action: {self.cls_name} output: {output}") + return output diff --git a/metagpt/ext/stanford_town/actions/task_decomp.py b/metagpt/ext/stanford_town/actions/task_decomp.py new file mode 100644 index 0000000000000000000000000000000000000000..3a23a73456e190811609f095c0552cc80050932b --- /dev/null +++ b/metagpt/ext/stanford_town/actions/task_decomp.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : task_decomp + +import datetime + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class TaskDecomp(STAction): + name: str = "TaskDecomp" + + def _func_cleanup(self, llm_resp: str, prompt: str) -> list: + # TODO SOMETHING HERE sometimes fails... See screenshot + temp = [i.strip() for i in llm_resp.split("\n")] + _cr = [] + cr = [] + for count, i in enumerate(temp): + if count != 0: + _cr += [" ".join([j.strip() for j in i.split(" ")][3:])] + else: + _cr += [i] + for count, i in enumerate(_cr): + k = [j.strip() for j in i.split("(duration in minutes:")] + task = k[0] + if task[-1] == ".": + task = task[:-1] + duration = int(k[1].split(",")[0].strip()) + cr += [[task, duration]] + + total_expected_min = int(prompt.split("(total duration in minutes")[-1].split("):")[0].strip()) + + # TODO -- now, you need to make sure that this is the same as the sum of + # the current action sequence. + curr_min_slot = [ + ["dummy", -1], + ] # (task_name, task_index) + for count, i in enumerate(cr): + i_task = i[0] + i_duration = i[1] + + i_duration -= i_duration % 5 + if i_duration > 0: + for j in range(i_duration): + curr_min_slot += [(i_task, count)] + curr_min_slot = curr_min_slot[1:] + + if len(curr_min_slot) > total_expected_min: + last_task = curr_min_slot[60] + for i in range(1, 6): + curr_min_slot[-1 * i] = last_task + elif len(curr_min_slot) < total_expected_min: + last_task = curr_min_slot[-1] + for i in range(total_expected_min - len(curr_min_slot)): + curr_min_slot += [last_task] + + cr_ret = [ + ["dummy", -1], + ] + for task, task_index in curr_min_slot: + if task != cr_ret[-1][0]: + cr_ret += [[task, 1]] + else: + cr_ret[-1][1] += 1 + cr = cr_ret[1:] + + return cr + + def _func_validate(self, llm_resp: str, prompt: str) -> bool: + # TODO -- this sometimes generates error + try: + self._func_cleanup(llm_resp, prompt) + except Exception: + return False + return True + + def _func_fail_default_resp(self) -> int: + fs = [["asleep", 0]] + return fs + + async def run(self, role: "STRole", task_desc: int, truncated_act_dur: int, *args, **kwargs): + def create_prompt_input(role, task, duration): + """ + Today is Saturday June 25. From 00:00 ~ 06:00am, Maeve is + planning on sleeping, 06:00 ~ 07:00am, Maeve is + planning on waking up and doing her morning routine, + and from 07:00am ~08:00am, Maeve is planning on having breakfast. + """ + + curr_f_org_index = role.scratch.get_f_daily_schedule_hourly_org_index() + all_indices = [] + # if curr_f_org_index > 0: + # all_indices += [curr_f_org_index-1] + all_indices += [curr_f_org_index] + if curr_f_org_index + 1 <= len(role.scratch.f_daily_schedule_hourly_org): + all_indices += [curr_f_org_index + 1] + if curr_f_org_index + 2 <= len(role.scratch.f_daily_schedule_hourly_org): + all_indices += [curr_f_org_index + 2] + + curr_time_range = "" + + logger.debug("DEBUG") + logger.debug(role.scratch.f_daily_schedule_hourly_org) + logger.debug(all_indices) + + summ_str = f'Today is {role.scratch.curr_time.strftime("%B %d, %Y")}. ' + summ_str += "From " + for index in all_indices: + logger.debug(f"index {index}") + if index < len(role.scratch.f_daily_schedule_hourly_org): + start_min = 0 + for i in range(index): + start_min += role.scratch.f_daily_schedule_hourly_org[i][1] + end_min = start_min + role.scratch.f_daily_schedule_hourly_org[index][1] + start_time = datetime.datetime.strptime("00:00:00", "%H:%M:%S") + datetime.timedelta( + minutes=start_min + ) + end_time = datetime.datetime.strptime("00:00:00", "%H:%M:%S") + datetime.timedelta( + minutes=end_min + ) + start_time_str = start_time.strftime("%H:%M%p") + end_time_str = end_time.strftime("%H:%M%p") + summ_str += ( + f"{start_time_str} ~ {end_time_str}, {role.name} is planning " + f"on {role.scratch.f_daily_schedule_hourly_org[index][0]}, " + ) + if curr_f_org_index + 1 == index: + curr_time_range = f"{start_time_str} ~ {end_time_str}" + summ_str = summ_str[:-2] + "." + + prompt_input = [] + prompt_input += [role.scratch.get_str_iss()] + prompt_input += [summ_str] + # prompt_input += [role.scratch.get_str_curr_date_str()] + prompt_input += [role.scratch.get_str_firstname()] + prompt_input += [role.scratch.get_str_firstname()] + prompt_input += [task] + prompt_input += [curr_time_range] + prompt_input += [duration] + prompt_input += [role.scratch.get_str_firstname()] + return prompt_input + + prompt_input = create_prompt_input(role, task_desc, truncated_act_dur) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "task_decomp_v3.txt") + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=1000) + logger.info(f"Role: {role.name} {self.cls_name} output: {output}") + + fin_output = [] + time_sum = 0 + for i_task, i_duration in output: + time_sum += i_duration + # HM????????? + # if time_sum < duration: + if time_sum <= truncated_act_dur: + fin_output += [[i_task, i_duration]] + else: + break + ftime_sum = 0 + for fi_task, fi_duration in fin_output: + ftime_sum += fi_duration + + fin_output[-1][1] += truncated_act_dur - ftime_sum + output = fin_output + + task_decomp = output + ret = [] + for decomp_task, duration in task_decomp: + ret += [[f"{task_desc} ({decomp_task})", duration]] + output = ret + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/metagpt/ext/stanford_town/actions/wake_up.py b/metagpt/ext/stanford_town/actions/wake_up.py new file mode 100644 index 0000000000000000000000000000000000000000..ea44cd3a427d3526cd673e7619ca462db870d873 --- /dev/null +++ b/metagpt/ext/stanford_town/actions/wake_up.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : wake_up + + +from metagpt.ext.stanford_town.actions.st_action import STAction +from metagpt.logs import logger + + +class WakeUp(STAction): + name: str = "WakeUp" + + def _func_validate(self, llm_resp: str, prompt: str = None) -> bool: + try: + self._func_cleanup(llm_resp, prompt="") + except Exception: + return False + return True + + def _func_cleanup(self, llm_resp: str, prompt: str) -> int: + cr = int(llm_resp.strip().lower().split("am")[0]) + return cr + + def _func_fail_default_resp(self) -> int: + fs = 8 + return fs + + async def run(self, role: "STRole"): + def create_prompt_input(role): + prompt_input = [ + role.scratch.get_str_iss(), + role.scratch.get_str_lifestyle(), + role.scratch.get_str_firstname(), + ] + return prompt_input + + prompt_input = create_prompt_input(role) + prompt = self.generate_prompt_with_tmpl_filename(prompt_input, "wake_up_hour_v1.txt") + self.fail_default_resp = self._func_fail_default_resp() + output = await self._run_gpt35_max_tokens(prompt, max_tokens=5) + logger.info(f"Role: {role.name} Action: {self.cls_name} output: {output}") + return output diff --git a/metagpt/ext/stanford_town/memory/__init__.py b/metagpt/ext/stanford_town/memory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/metagpt/ext/stanford_town/memory/agent_memory.py b/metagpt/ext/stanford_town/memory/agent_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..d212232f42c560c90d264eb2e2ebc63d6674b2a6 --- /dev/null +++ b/metagpt/ext/stanford_town/memory/agent_memory.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : BasicMemory,AgentMemory实现 + +from datetime import datetime +from pathlib import Path +from typing import Optional + +from pydantic import Field, field_serializer, model_validator + +from metagpt.logs import logger +from metagpt.memory.memory import Memory +from metagpt.schema import Message +from metagpt.utils.common import read_json_file, write_json_file + + +class BasicMemory(Message): + """ + BasicMemory继承于MG的Message类,其中content属性替代description属性 + Message类中对于Chat类型支持的非常好,对于Agent个体的Perceive,Reflection,Plan支持的并不多 + 在Type设计上,我们延续GA的三个种类,但是对于Chat种类的对话进行特别设计(具体怎么设计还没想好) + """ + + memory_id: Optional[str] = Field(default=None) # 记忆ID + memory_count: int = -1 # 第几个记忆,实际数值与Memory相等 + type_count: int = -1 # 第几种记忆,类型为整数 + memory_type: Optional[str] = Field(default=None) # 记忆类型,包含 event,thought,chat三种类型 + depth: int = -1 # 记忆深度,类型为整数 + created: Optional[datetime] = Field(default=None) # 创建时间 + expiration: Optional[datetime] = Field(default=None) # 记忆失效时间,默认为空() + last_accessed: Optional[datetime] = Field(default=None) # 上一次调用的时间,初始化时候与self.created一致 + subject: Optional[str] = Field(default=None) # 主语 + predicate: Optional[str] = Field(default=None) # 谓语 + object: Optional[str] = Field(default=None) # 宾语 + + description: Optional[str] = Field(default=None) + embedding_key: Optional[str] = Field(default=None) # 内容与self.content一致 + poignancy: int = -1 # importance值 + keywords: list[str] = Field(default=[]) # keywords + filling: list = Field(default=[]) # 装的与之相关联的memory_id的列表 + + __hash__ = object.__hash__ # support hash in AgentMemory + + @model_validator(mode="before") + @classmethod + def check_values(cls, values): + if "created" in values: + values["last_accessed"] = values["created"] + if "content" in values: + values["description"] = values["content"] + if "filling" in values: + values["filling"] = values["filling"] or [] + return values + + @field_serializer("created", "expiration") + def transform_time_field(self, time_field: Optional[datetime]) -> str: + if time_field: + time_field = time_field.strftime("%Y-%m-%d %H:%M:%S") + return time_field + + def summary(self): + return self.subject, self.predicate, self.object + + def save_to_dict(self) -> dict: + """ + 将MemoryBasic类转化为字典,用于存储json文件 + 这里需要注意,cause_by跟GA不兼容,所以需要做一个格式转换 + """ + memory_dict = dict() + node_id = self.memory_id + basic_mem_obj = self.model_dump( + include=[ + "node_count", + "type_count", + "type", + "depth", + "created", + "expiration", + "subject", + "predicate", + "object", + "description", + "embedding_key", + "poignancy", + "keywords", + "filling", + "cause_by", + ] + ) + + memory_dict[node_id] = basic_mem_obj + return memory_dict + + +class AgentMemory(Memory): + """ + GA中主要存储三种JSON + 1. embedding.json (Dict embedding_key:embedding) + 2. Node.json (Dict Node_id:Node) + 3. kw_strength.json + """ + + storage: list[BasicMemory] = [] # 重写Storage,存储BasicMemory所有节点 + event_list: list[BasicMemory] = [] # 存储event记忆 + thought_list: list[BasicMemory] = [] # 存储thought记忆 + chat_list: list[BasicMemory] = [] # chat-related memory + + event_keywords: dict[str, list[BasicMemory]] = dict() # 存储keywords + thought_keywords: dict[str, list[BasicMemory]] = dict() + chat_keywords: dict[str, list[BasicMemory]] = dict() + + kw_strength_event: dict[str, int] = dict() + kw_strength_thought: dict[str, int] = dict() + + memory_saved: Optional[Path] = Field(default=None) + embeddings: dict[str, list[float]] = dict() + + def set_mem_path(self, memory_saved: Path): + self.memory_saved = memory_saved + self.load(memory_saved) + + def save(self, memory_saved: Path): + """ + 将MemoryBasic类存储为Nodes.json形式。复现GA中的Kw Strength.json形式 + 这里添加一个路径即可 + TODO 这里在存储时候进行倒序存储,之后需要验证(test_memory通过) + """ + memory_json = dict() + for i in range(len(self.storage)): + memory_node = self.storage[len(self.storage) - i - 1] + memory_node = memory_node.save_to_dict() + memory_json.update(memory_node) + write_json_file(memory_saved.joinpath("nodes.json"), memory_json) + write_json_file(memory_saved.joinpath("embeddings.json"), self.embeddings) + + strength_json = dict() + strength_json["kw_strength_event"] = self.kw_strength_event + strength_json["kw_strength_thought"] = self.kw_strength_thought + write_json_file(memory_saved.joinpath("kw_strength.json"), strength_json) + + def load(self, memory_saved: Path): + """ + 将GA的JSON解析,填充到AgentMemory类之中 + """ + self.embeddings = read_json_file(memory_saved.joinpath("embeddings.json")) + memory_load = read_json_file(memory_saved.joinpath("nodes.json")) + for count in range(len(memory_load.keys())): + node_id = f"node_{str(count + 1)}" + node_details = memory_load[node_id] + node_type = node_details["type"] + created = datetime.strptime(node_details["created"], "%Y-%m-%d %H:%M:%S") + expiration = None + if node_details["expiration"]: + expiration = datetime.strptime(node_details["expiration"], "%Y-%m-%d %H:%M:%S") + + s = node_details["subject"] + p = node_details["predicate"] + o = node_details["object"] + + description = node_details["description"] + embedding_pair = (node_details["embedding_key"], self.embeddings[node_details["embedding_key"]]) + poignancy = node_details["poignancy"] + keywords = set(node_details["keywords"]) + filling = node_details["filling"] + if node_type == "thought": + self.add_thought( + created, expiration, s, p, o, description, keywords, poignancy, embedding_pair, filling + ) + if node_type == "event": + self.add_event(created, expiration, s, p, o, description, keywords, poignancy, embedding_pair, filling) + if node_type == "chat": + self.add_chat(created, expiration, s, p, o, description, keywords, poignancy, embedding_pair, filling) + + strength_keywords_load = read_json_file(memory_saved.joinpath("kw_strength.json")) + if strength_keywords_load["kw_strength_event"]: + self.kw_strength_event = strength_keywords_load["kw_strength_event"] + if strength_keywords_load["kw_strength_thought"]: + self.kw_strength_thought = strength_keywords_load["kw_strength_thought"] + + def add(self, memory_basic: BasicMemory): + """ + Add a new message to storage, while updating the index + 重写add方法,修改原有的Message类为BasicMemory类,并添加不同的记忆类型添加方式 + """ + if memory_basic.memory_id in self.storage: + return + self.storage.append(memory_basic) + if memory_basic.memory_type == "chat": + self.chat_list[0:0] = [memory_basic] + return + if memory_basic.memory_type == "thought": + self.thought_list[0:0] = [memory_basic] + return + if memory_basic.memory_type == "event": + self.event_list[0:0] = [memory_basic] + return + + def add_chat( + self, created, expiration, s, p, o, content, keywords, poignancy, embedding_pair, filling, cause_by="" + ): + """ + 调用add方法,初始化chat,在创建的时候就需要调用embedding函数 + """ + memory_count = len(self.storage) + 1 + type_count = len(self.thought_list) + 1 + memory_type = "chat" + memory_id = f"node_{str(memory_count)}" + depth = 1 + + memory_node = BasicMemory( + memory_id=memory_id, + memory_count=memory_count, + type_count=type_count, + memory_type=memory_type, + depth=depth, + created=created, + expiration=expiration, + subject=s, + predicate=p, + object=o, + description=content, + embedding_key=embedding_pair[0], + poignancy=poignancy, + keywords=keywords, + filling=filling, + cause_by=cause_by, + ) + + keywords = [i.lower() for i in keywords] + for kw in keywords: + if kw in self.chat_keywords: + self.chat_keywords[kw][0:0] = [memory_node] + else: + self.chat_keywords[kw] = [memory_node] + + self.add(memory_node) + + self.embeddings[embedding_pair[0]] = embedding_pair[1] + return memory_node + + def add_thought(self, created, expiration, s, p, o, content, keywords, poignancy, embedding_pair, filling): + """ + 调用add方法,初始化thought + """ + memory_count = len(self.storage) + 1 + type_count = len(self.thought_list) + 1 + memory_type = "thought" + memory_id = f"node_{str(memory_count)}" + depth = 1 + + try: + if filling: + depth_list = [memory_node.depth for memory_node in self.storage if memory_node.memory_id in filling] + depth += max(depth_list) + except Exception as exp: + logger.warning(f"filling init occur {exp}") + pass + + memory_node = BasicMemory( + memory_id=memory_id, + memory_count=memory_count, + type_count=type_count, + memory_type=memory_type, + depth=depth, + created=created, + expiration=expiration, + subject=s, + predicate=p, + object=o, + description=content, + embedding_key=embedding_pair[0], + poignancy=poignancy, + keywords=keywords, + filling=filling, + ) + + keywords = [i.lower() for i in keywords] + for kw in keywords: + if kw in self.thought_keywords: + self.thought_keywords[kw][0:0] = [memory_node] + else: + self.thought_keywords[kw] = [memory_node] + + self.add(memory_node) + + if f"{p} {o}" != "is idle": + for kw in keywords: + if kw in self.kw_strength_thought: + self.kw_strength_thought[kw] += 1 + else: + self.kw_strength_thought[kw] = 1 + + self.embeddings[embedding_pair[0]] = embedding_pair[1] + return memory_node + + def add_event(self, created, expiration, s, p, o, content, keywords, poignancy, embedding_pair, filling): + """ + 调用add方法,初始化event + """ + memory_count = len(self.storage) + 1 + type_count = len(self.event_list) + 1 + memory_type = "event" + memory_id = f"node_{str(memory_count)}" + depth = 0 + + if "(" in content: + content = " ".join(content.split()[:3]) + " " + content.split("(")[-1][:-1] + + memory_node = BasicMemory( + memory_id=memory_id, + memory_count=memory_count, + type_count=type_count, + memory_type=memory_type, + depth=depth, + created=created, + expiration=expiration, + subject=s, + predicate=p, + object=o, + description=content, + embedding_key=embedding_pair[0], + poignancy=poignancy, + keywords=keywords, + filling=filling, + ) + + keywords = [i.lower() for i in keywords] + for kw in keywords: + if kw in self.event_keywords: + self.event_keywords[kw][0:0] = [memory_node] + else: + self.event_keywords[kw] = [memory_node] + + self.add(memory_node) + + if f"{p} {o}" != "is idle": + for kw in keywords: + if kw in self.kw_strength_event: + self.kw_strength_event[kw] += 1 + else: + self.kw_strength_event[kw] = 1 + + self.embeddings[embedding_pair[0]] = embedding_pair[1] + return memory_node + + def get_summarized_latest_events(self, retention): + ret_set = set() + for e_node in self.event_list[:retention]: + ret_set.add(e_node.summary()) + return ret_set + + def get_last_chat(self, target_role_name: str): + if target_role_name.lower() in self.chat_keywords: + return self.chat_keywords[target_role_name.lower()][0] + else: + return False + + def retrieve_relevant_thoughts(self, s_content: str, p_content: str, o_content: str) -> set: + contents = [s_content, p_content, o_content] + + ret = [] + for i in contents: + if i in self.thought_keywords: + ret += self.thought_keywords[i.lower()] + + ret = set(ret) + return ret + + def retrieve_relevant_events(self, s_content: str, p_content: str, o_content: str) -> set: + contents = [s_content, p_content, o_content] + + ret = [] + for i in contents: + if i in self.event_keywords: + ret += self.event_keywords[i] + + ret = set(ret) + return ret diff --git a/metagpt/ext/stanford_town/memory/retrieve.py b/metagpt/ext/stanford_town/memory/retrieve.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b32f965037f2ba1b931dd8aca77ae56d0a9b36 --- /dev/null +++ b/metagpt/ext/stanford_town/memory/retrieve.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Retrieve函数实现 + +import datetime + +from numpy import dot +from numpy.linalg import norm + +from metagpt.ext.stanford_town.memory.agent_memory import BasicMemory +from metagpt.ext.stanford_town.utils.utils import get_embedding + + +def agent_retrieve( + agent_memory, + curr_time: datetime.datetime, + memory_forget: float, + query: str, + nodes: list[BasicMemory], + topk: int = 4, +) -> list[BasicMemory]: + """ + Retrieve需要集合Role使用,原因在于Role才具有AgentMemory,scratch + 逻辑:Role调用该函数,self.rc.AgentMemory,self.rc.scratch.curr_time,self.rc.scratch.memory_forget + 输入希望查询的内容与希望回顾的条数,返回TopK条高分记忆,即List[BasicMemory] + + Score_lists示例 + { + "memory": memories[i], BasicMemory类 + "importance": memories[i].poignancy + "recency": 衰减因子计算结果 + "relevance": 搜索结果 + } + """ + memories = nodes + agent_memory_embedding = agent_memory.embeddings + memories = sorted(memories, key=lambda memory_node: memory_node.last_accessed, reverse=True) + + score_list = [] + score_list = extract_importance(memories, score_list) + score_list = extract_recency(curr_time, memory_forget, score_list) + score_list = extract_relevance(agent_memory_embedding, query, score_list) + score_list = normalize_score_floats(score_list, 0, 1) + + total_dict = {} + gw = [1, 1, 1] # 三个因素的权重,重要性,近因性,相关性, + for i in range(len(score_list)): + total_score = ( + score_list[i]["importance"] * gw[0] + score_list[i]["recency"] * gw[1] + score_list[i]["relevance"] * gw[2] + ) + total_dict[score_list[i]["memory"].memory_id] = total_score + + result = top_highest_x_values(total_dict, topk) + + return result # 返回的是一个BasicMemory列表 + + +def new_agent_retrieve(role, focus_points: list, n_count=30) -> dict: + """ + 输入为role,关注点列表,返回记忆数量 + 输出为字典,键为focus_point,值为对应的记忆列表 + """ + retrieved = dict() + for focal_pt in focus_points: + nodes = [ + [i.last_accessed, i] + for i in role.memory.event_list + role.memory.thought_list + if "idle" not in i.embedding_key + ] + nodes = sorted(nodes, key=lambda x: x[0]) + nodes = [i for created, i in nodes] + results = agent_retrieve( + role.memory, role.scratch.curr_time, role.scratch.recency_decay, focal_pt, nodes, n_count + ) + final_result = [] + for n in results: + for i in role.memory.storage: + if i.memory_id == n: + i.last_accessed = role.scratch.curr_time + final_result.append(i) + + retrieved[focal_pt] = final_result + + return retrieved + + +def top_highest_x_values(d, x): + """ + 输入字典,Topx + 返回以字典值排序,字典键组成的List[BasicMemory] + """ + top_v = [item[0] for item in sorted(d.items(), key=lambda item: item[1], reverse=True)[:x]] + return top_v + + +def extract_importance(memories, score_list): + """ + 抽取重要性 + """ + for i in range(len(memories)): + score = {"memory": memories[i], "importance": memories[i].poignancy} + score_list.append(score) + return score_list + + +def extract_relevance(agent_memory_embedding, query, score_list): + """ + 抽取相关性 + """ + query_embedding = get_embedding(query) + # 进行 + for i in range(len(score_list)): + node_embedding = agent_memory_embedding[score_list[i]["memory"].embedding_key] + result = cos_sim(node_embedding, query_embedding) + score_list[i]["relevance"] = result + + return score_list + + +def extract_recency(curr_time, memory_forget, score_list): + """ + 抽取近因性,目前使用的现实世界过一天走一个衰减因子 + """ + for i in range(len(score_list)): + day_count = (curr_time - score_list[i]["memory"].created).days + score_list[i]["recency"] = memory_forget**day_count + return score_list + + +def cos_sim(a, b): + """ + 计算余弦相似度 + """ + return dot(a, b) / (norm(a) * norm(b)) + + +def normalize_list_floats(single_list, target_min, target_max): + """ + 单个列表归一化 + """ + if len(single_list) == 0: + return [] + + min_val = min(single_list) + max_val = max(single_list) + range_val = max_val - min_val + + if range_val == 0: + for i in range(len(single_list)): + single_list[i] = (target_max - target_min) / 2 + else: + for i in range(len(single_list)): + single_list[i] = (single_list[i] - min_val) * (target_max - target_min) / range_val + target_min + return single_list + + +def normalize_score_floats(score_list, target_min, target_max): + """ + 整体归一化 + """ + importance_list = [] + relevance_list = [] + recency_list = [] + + for i in range(len(score_list)): + importance_list.append(score_list[i]["importance"]) + relevance_list.append(score_list[i]["relevance"]) + recency_list.append(score_list[i]["recency"]) + + # 进行归一化操作 + importance_list = normalize_list_floats(importance_list, target_min, target_max) + relevance_list = normalize_list_floats(relevance_list, target_min, target_max) + recency_list = normalize_list_floats(recency_list, target_min, target_max) + + for i in range(len(score_list)): + score_list[i]["importance"] = importance_list[i] + score_list[i]["relevance"] = relevance_list[i] + score_list[i]["recency"] = recency_list[i] + + return score_list diff --git a/metagpt/ext/stanford_town/memory/scratch.py b/metagpt/ext/stanford_town/memory/scratch.py new file mode 100644 index 0000000000000000000000000000000000000000..b4036f839fb555ef2302345a8065bf38ce7c4494 --- /dev/null +++ b/metagpt/ext/stanford_town/memory/scratch.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : Scratch类实现(角色信息类) + +from datetime import datetime, timedelta +from pathlib import Path +from typing import Optional, Union + +from pydantic import BaseModel, Field, field_serializer, field_validator + +from metagpt.utils.common import read_json_file, write_json_file + + +class Scratch(BaseModel): + # 类别1:人物超参 + vision_r: int = 4 + att_bandwidth: int = 3 + retention: int = 5 + + # 类别2:世界信息 + curr_time: Optional[datetime] = Field(default=None) + curr_tile: Optional[list[int]] = Field(default=None) + daily_plan_req: Optional[str] = Field(default=None) + + # 类别3:人物角色的核心身份 + name: Optional[str] = Field(default=None) + first_name: Optional[str] = Field(default=None) + last_name: Optional[str] = Field(default=None) + age: Optional[int] = Field(default=None) + innate: Optional[str] = Field(default=None) # L0 permanent core traits. + learned: Optional[str] = Field(default=None) # L1 stable traits. + currently: Optional[str] = Field(default=None) # L2 external implementation. + lifestyle: Optional[str] = Field(default=None) + living_area: Optional[str] = Field(default=None) + + # 类别4:旧反思变量 + concept_forget: int = 100 + daily_reflection_time: int = 60 * 3 + daily_reflection_size: int = 5 + overlap_reflect_th: int = 2 + kw_strg_event_reflect_th: int = 4 + kw_strg_thought_reflect_th: int = 4 + + # 类别5:新反思变量 + recency_w: int = 1 + relevance_w: int = 1 + importance_w: int = 1 + recency_decay: float = 0.99 + importance_trigger_max: int = 150 + importance_trigger_curr: int = 150 + importance_ele_n: int = 0 + thought_count: int = 5 + + # 类别6:个人计划 + daily_req: list[str] = Field(default=[]) + f_daily_schedule: list[list[Union[int, str]]] = Field(default=[]) + f_daily_schedule_hourly_org: list[list[Union[int, str]]] = Field(default=[]) + + # 类别7:当前动作 + act_address: Optional[str] = Field(default=None) + act_start_time: Optional[datetime] = Field(default=None) + act_duration: Optional[int] = Field(default=None) + act_description: Optional[str] = Field(default=None) + act_pronunciatio: Optional[str] = Field(default=None) + act_event: list[Optional[str]] = [None, None, None] + + act_obj_description: Optional[str] = Field(default=None) + act_obj_pronunciatio: Optional[str] = Field(default=None) + act_obj_event: list[Optional[str]] = [None, None, None] + + chatting_with: Optional[str] = Field(default=None) + chat: Optional[str] = Field(default=None) + chatting_with_buffer: dict = dict() + chatting_end_time: Optional[datetime] = Field(default=None) + + act_path_set: bool = False + planned_path: list[list[int]] = Field(default=[]) + + @field_validator("curr_time", "act_start_time", "chatting_end_time", mode="before") + @classmethod + def check_time_filed(cls, time_filed): + val = datetime.strptime(time_filed, "%B %d, %Y, %H:%M:%S") if time_filed else None + return val + + @field_serializer("curr_time", "act_start_time", "chatting_end_time") + def transform_time_field(self, time_filed: Optional[datetime]) -> str: + if time_filed: + time_filed = time_filed.strftime("%B %d, %Y, %H:%M:%S") + return time_filed + + @classmethod + def init_scratch_from_path(cls, f_saved: Path): + scratch_load = read_json_file(f_saved) + scratch = Scratch(**scratch_load) + return scratch + + def save(self, out_json: Path): + """ + Save persona's scratch. + + INPUT: + out_json: The file where we wil be saving our persona's state. + OUTPUT: + None + """ + scratch = self.model_dump() + write_json_file(out_json, scratch, encoding="utf-8") + + def get_f_daily_schedule_index(self, advance=0): + """ + We get the current index of self.f_daily_schedule. + + Recall that self.f_daily_schedule stores the decomposed action sequences + up until now, and the hourly sequences of the future action for the rest + of today. Given that self.f_daily_schedule is a list of list where the + inner list is composed of [task, duration], we continue to add up the + duration until we reach "if elapsed > today_min_elapsed" condition. The + index where we stop is the index we will return. + + INPUT + advance: Integer value of the number minutes we want to look into the + future. This allows us to get the index of a future timeframe. + OUTPUT + an integer value for the current index of f_daily_schedule. + """ + # We first calculate teh number of minutes elapsed today. + today_min_elapsed = 0 + today_min_elapsed += self.curr_time.hour * 60 + today_min_elapsed += self.curr_time.minute + today_min_elapsed += advance + + x = 0 + for task, duration in self.f_daily_schedule: + x += duration + x = 0 + for task, duration in self.f_daily_schedule_hourly_org: + x += duration + + # We then calculate the current index based on that. + curr_index = 0 + elapsed = 0 + for task, duration in self.f_daily_schedule: + elapsed += duration + if elapsed > today_min_elapsed: + return curr_index + curr_index += 1 + + return curr_index + + def get_f_daily_schedule_hourly_org_index(self, advance=0): + """ + We get the current index of self.f_daily_schedule_hourly_org. + It is otherwise the same as get_f_daily_schedule_index. + + INPUT + advance: Integer value of the number minutes we want to look into the + future. This allows us to get the index of a future timeframe. + OUTPUT + an integer value for the current index of f_daily_schedule. + """ + # We first calculate teh number of minutes elapsed today. + today_min_elapsed = 0 + today_min_elapsed += self.curr_time.hour * 60 + today_min_elapsed += self.curr_time.minute + today_min_elapsed += advance + # We then calculate the current index based on that. + curr_index = 0 + elapsed = 0 + for task, duration in self.f_daily_schedule_hourly_org: + elapsed += duration + if elapsed > today_min_elapsed: + return curr_index + curr_index += 1 + return curr_index + + def get_str_iss(self): + """ + ISS stands for "identity stable set." This describes the commonset summary + of this persona -- basically, the bare minimum description of the persona + that gets used in almost all prompts that need to call on the persona. + + INPUT + None + OUTPUT + the identity stable set summary of the persona in a string form. + EXAMPLE STR OUTPUT + "Name: Dolores Heitmiller + Age: 28 + Innate traits: hard-edged, independent, loyal + Learned traits: Dolores is a painter who wants live quietly and paint + while enjoying her everyday life. + Currently: Dolores is preparing for her first solo show. She mostly + works from home. + Lifestyle: Dolores goes to bed around 11pm, sleeps for 7 hours, eats + dinner around 6pm. + Daily plan requirement: Dolores is planning to stay at home all day and + never go out." + """ + commonset = "" + commonset += f"Name: {self.name}\n" + commonset += f"Age: {self.age}\n" + commonset += f"Innate traits: {self.innate}\n" + commonset += f"Learned traits: {self.learned}\n" + commonset += f"Currently: {self.currently}\n" + commonset += f"Lifestyle: {self.lifestyle}\n" + commonset += f"Daily plan requirement: {self.daily_plan_req}\n" + commonset += f"Current Date: {self.curr_time.strftime('%A %B %d') if self.curr_time else ''}\n" + return commonset + + def get_str_name(self): + return self.name + + def get_str_firstname(self): + return self.first_name + + def get_str_lastname(self): + return self.last_name + + def get_str_age(self): + return str(self.age) + + def get_str_innate(self): + return self.innate + + def get_str_learned(self): + return self.learned + + def get_str_currently(self): + return self.currently + + def get_str_lifestyle(self): + return self.lifestyle + + def get_str_daily_plan_req(self): + return self.daily_plan_req + + def get_str_curr_date_str(self): + return self.curr_time.strftime("%A %B %d") + + def get_curr_event(self): + if not self.act_address: + return self.name, None, None + else: + return self.act_event + + def get_curr_event_and_desc(self): + if not self.act_address: + return self.name, None, None, None + else: + return self.act_event[0], self.act_event[1], self.act_event[2], self.act_description + + def get_curr_obj_event_and_desc(self): + if not self.act_address: + return "", None, None, None + else: + return self.act_address, self.act_obj_event[1], self.act_obj_event[2], self.act_obj_description + + def add_new_action( + self, + action_address, + action_duration, + action_description, + action_pronunciatio, + action_event, + chatting_with, + chat, + chatting_with_buffer, + chatting_end_time, + act_obj_description, + act_obj_pronunciatio, + act_obj_event, + act_start_time=None, + ): + self.act_address = action_address + self.act_duration = action_duration + self.act_description = action_description + self.act_pronunciatio = action_pronunciatio + self.act_event = action_event + + self.chatting_with = chatting_with + self.chat = chat + if chatting_with_buffer: + self.chatting_with_buffer.update(chatting_with_buffer) + self.chatting_end_time = chatting_end_time + + self.act_obj_description = act_obj_description + self.act_obj_pronunciatio = act_obj_pronunciatio + self.act_obj_event = act_obj_event + + self.act_start_time = self.curr_time + + self.act_path_set = False + + def act_time_str(self): + """ + Returns a string output of the current time. + + INPUT + None + OUTPUT + A string output of the current time. + EXAMPLE STR OUTPUT + "14:05 P.M." + """ + return self.act_start_time.strftime("%H:%M %p") + + def act_check_finished(self): + """ + Checks whether the self.Action instance has finished. + + INPUT + curr_datetime: Current time. If current time is later than the action's + start time + its duration, then the action has finished. + OUTPUT + Boolean [True]: Action has finished. + Boolean [False]: Action has not finished and is still ongoing. + """ + if not self.act_address: + return True + + if self.chatting_with: + end_time = self.chatting_end_time + else: + x = self.act_start_time + if x.second != 0: + x = x.replace(second=0) + x = x + timedelta(minutes=1) + end_time = x + timedelta(minutes=self.act_duration) + + if end_time.strftime("%H:%M:%S") == self.curr_time.strftime("%H:%M:%S"): + return True + return False + + def act_summarize(self): + """ + Summarize the current action as a dictionary. + + INPUT + None + OUTPUT + ret: A human readable summary of the action. + """ + exp = dict() + exp["persona"] = self.name + exp["address"] = self.act_address + exp["start_datetime"] = self.act_start_time + exp["duration"] = self.act_duration + exp["description"] = self.act_description + exp["pronunciatio"] = self.act_pronunciatio + return exp + + def act_summary_str(self): + """ + Returns a string summary of the current action. Meant to be + human-readable. + + INPUT + None + OUTPUT + ret: A human readable summary of the action. + """ + start_datetime_str = self.act_start_time.strftime("%A %B %d -- %H:%M %p") + ret = f"[{start_datetime_str}]\n" + ret += f"Activity: {self.name} is {self.act_description}\n" + ret += f"Address: {self.act_address}\n" + ret += f"Duration in minutes (e.g., x min): {str(self.act_duration)} min\n" + return ret + + def get_daily_schedule(self, daily_schedule: list[list[str]]): + ret = "" + curr_min_sum = 0 + for row in daily_schedule: + curr_min_sum += row[1] + hour = int(curr_min_sum / 60) + minute = curr_min_sum % 60 + ret += f"{hour:02}:{minute:02} || {row[0]}\n" + return ret + + def get_str_daily_schedule_summary(self): + return self.get_daily_schedule(self.f_daily_schedule) + + def get_str_daily_schedule_hourly_org_summary(self): + return self.get_daily_schedule(self.f_daily_schedule_hourly_org) diff --git a/metagpt/ext/stanford_town/memory/spatial_memory.py b/metagpt/ext/stanford_town/memory/spatial_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..71b8569079c9663ff8f3bb4944b766ebf47367f2 --- /dev/null +++ b/metagpt/ext/stanford_town/memory/spatial_memory.py @@ -0,0 +1,116 @@ +""" +Author: Joon Sung Park (joonspk@stanford.edu) + +File: spatial_memory.py +Description: Defines the MemoryTree class that serves as the agents' spatial +memory that aids in grounding their behavior in the game world. +""" +from pathlib import Path + +from pydantic import BaseModel, Field + +from metagpt.logs import logger +from metagpt.utils.common import read_json_file, write_json_file + + +class MemoryTree(BaseModel): + tree: dict = Field(default=dict) + + def set_mem_path(self, f_saved: Path): + self.tree = read_json_file(f_saved) + + def print_tree(self) -> None: + def _print_tree(tree, depth): + dash = " >" * depth + if isinstance(tree, list): + if tree: + logger.info(f"{dash} {tree}") + return + + for key, val in tree.items(): + if key: + logger.info(f"{dash} {tree}") + _print_tree(val, depth + 1) + + _print_tree(self.tree, 0) + + def save(self, out_json: Path) -> None: + write_json_file(out_json, self.tree) + + def get_str_accessible_sectors(self, curr_world: str) -> str: + """ + Returns a summary string of all the arenas that the persona can access + within the current sector. + + Note that there are places a given persona cannot enter. This information + is provided in the persona sheet. We account for this in this function. + + INPUT + None + OUTPUT + A summary string of all the arenas that the persona can access. + EXAMPLE STR OUTPUT + "bedroom, kitchen, dining room, office, bathroom" + """ + x = ", ".join(list(self.tree[curr_world].keys())) + return x + + def get_str_accessible_sector_arenas(self, sector: str) -> str: + """ + Returns a summary string of all the arenas that the persona can access + within the current sector. + + Note that there are places a given persona cannot enter. This information + is provided in the persona sheet. We account for this in this function. + + INPUT + None + OUTPUT + A summary string of all the arenas that the persona can access. + EXAMPLE STR OUTPUT + "bedroom, kitchen, dining room, office, bathroom" + """ + curr_world, curr_sector = sector.split(":") + if not curr_sector: + return "" + x = ", ".join(list(self.tree[curr_world][curr_sector].keys())) + return x + + def get_str_accessible_arena_game_objects(self, arena: str) -> str: + """ + Get a str list of all accessible game objects that are in the arena. If + temp_address is specified, we return the objects that are available in + that arena, and if not, we return the objects that are in the arena our + persona is currently in. + + INPUT + temp_address: optional arena address + OUTPUT + str list of all accessible game objects in the gmae arena. + EXAMPLE STR OUTPUT + "phone, charger, bed, nightstand" + """ + curr_world, curr_sector, curr_arena = arena.split(":") + + if not curr_arena: + return "" + + try: + x = ", ".join(list(self.tree[curr_world][curr_sector][curr_arena])) + except Exception: + x = ", ".join(list(self.tree[curr_world][curr_sector][curr_arena.lower()])) + return x + + def add_tile_info(self, tile_info: dict) -> None: + if tile_info["world"]: + if tile_info["world"] not in self.tree: + self.tree[tile_info["world"]] = {} + if tile_info["sector"]: + if tile_info["sector"] not in self.tree[tile_info["world"]]: + self.tree[tile_info["world"]][tile_info["sector"]] = {} + if tile_info["arena"]: + if tile_info["arena"] not in self.tree[tile_info["world"]][tile_info["sector"]]: + self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] = [] + if tile_info["game_object"]: + if tile_info["game_object"] not in self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]]: + self.tree[tile_info["world"]][tile_info["sector"]][tile_info["arena"]] += [tile_info["game_object"]] diff --git a/metagpt/ext/stanford_town/plan/__init__.py b/metagpt/ext/stanford_town/plan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bcf8efd09712339308e72659e84450d3fa829fd --- /dev/null +++ b/metagpt/ext/stanford_town/plan/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/metagpt/ext/stanford_town/plan/converse.py b/metagpt/ext/stanford_town/plan/converse.py new file mode 100644 index 0000000000000000000000000000000000000000..8eefbc9b42b4e0bd5f359f61e924bd4a455e0127 --- /dev/null +++ b/metagpt/ext/stanford_town/plan/converse.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : conversation between two agents + +from typing import Tuple + +from metagpt.ext.stanford_town.actions.agent_chat_sum_rel import AgentChatSumRel +from metagpt.ext.stanford_town.actions.gen_iter_chat_utt import GenIterChatUTT +from metagpt.ext.stanford_town.memory.retrieve import new_agent_retrieve +from metagpt.logs import logger + + +async def agent_conversation(init_role: "STRole", target_role: "STRole", conv_rounds: int = 8) -> list[list[str]]: + curr_chat = [] + logger.info(f"Role: {init_role.name} starts a conversation with Role: {target_role.name}") + + for idx in range(conv_rounds): + logger.info(f"Conv round: {idx} between {init_role.name} and {target_role.name}") + scratch = init_role.rc.scratch + target_scratch = target_role.rc.scratch + + focal_points = [f"{target_scratch.name}"] + retrieved = new_agent_retrieve(init_role, focal_points, 50) + relationship = await generate_summarize_agent_relationship(init_role, target_role, retrieved) + logger.info(f"The relationship between {init_role.name} and {target_role.name}: {relationship}") + last_chat = "" + for i in curr_chat[-4:]: + last_chat += ": ".join(i) + "\n" + if last_chat: + focal_points = [f"{relationship}", f"{target_scratch.name} is {target_scratch.act_description}", last_chat] + else: + focal_points = [f"{relationship}", f"{target_scratch.name} is {target_scratch.act_description}"] + retrieved = new_agent_retrieve(init_role, focal_points, 15) + utt, end = await generate_one_utterance(init_role, target_role, retrieved, curr_chat) + + curr_chat += [[scratch.name, utt]] + if end: + break + + focal_points = [f"{scratch.name}"] + retrieved = new_agent_retrieve(target_role, focal_points, 50) + relationship = await generate_summarize_agent_relationship(target_role, init_role, retrieved) + logger.info(f"The relationship between {target_role.name} and {init_role.name}: {relationship}") + last_chat = "" + for i in curr_chat[-4:]: + last_chat += ": ".join(i) + "\n" + if last_chat: + focal_points = [f"{relationship}", f"{scratch.name} is {scratch.act_description}", last_chat] + else: + focal_points = [f"{relationship}", f"{scratch.name} is {scratch.act_description}"] + retrieved = new_agent_retrieve(target_role, focal_points, 15) + utt, end = await generate_one_utterance(target_role, init_role, retrieved, curr_chat) + + curr_chat += [[target_scratch.name, utt]] + if end: + break + + logger.warning(f"Conversations between {target_role.name} and {init_role.name}:") + for row in curr_chat: + logger.info(row) + + return curr_chat + + +async def generate_summarize_agent_relationship(init_role: "STRole", target_role: "STRole", retrieved: dict) -> str: + all_embedding_keys = list() + for key, val in retrieved.items(): + for i in val: + all_embedding_keys += [i.embedding_key] + all_embedding_key_str = "" + for i in all_embedding_keys: + all_embedding_key_str += f"{i}\n" + + summarized_relationship = await AgentChatSumRel().run(init_role, target_role, all_embedding_key_str) + return summarized_relationship + + +async def generate_one_utterance(init_role, target_role, retrieved: dict, curr_chat: list) -> Tuple[str, str]: + # Chat version optimized for speed via batch generation + scratch = init_role.rc.scratch + target_scratch = target_role.rc.scratch + curr_context = ( + f"{scratch.name} " + + f"was {scratch.act_description} " + + f"when {scratch.name} " + + f"saw {target_scratch.name} " + + f"in the middle of {target_scratch.act_description}.\n" + ) + curr_context += f"{scratch.name} " + "is initiating a conversation with " + f"{target_scratch.name}." + + x = await GenIterChatUTT().run(init_role, target_role, retrieved, curr_context, curr_chat) + + return x["utterance"], x["end"] diff --git a/metagpt/ext/stanford_town/plan/st_plan.py b/metagpt/ext/stanford_town/plan/st_plan.py new file mode 100644 index 0000000000000000000000000000000000000000..f63052fc5324f06b67d8426f687046852d76952d --- /dev/null +++ b/metagpt/ext/stanford_town/plan/st_plan.py @@ -0,0 +1,706 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : st' planning execution + +import datetime +import math +import random +from typing import Tuple, Union + +from metagpt.ext.stanford_town.actions.decide_to_talk import DecideToTalk +from metagpt.ext.stanford_town.actions.gen_action_details import GenActionDetails +from metagpt.ext.stanford_town.actions.gen_daily_schedule import GenDailySchedule +from metagpt.ext.stanford_town.actions.gen_hourly_schedule import GenHourlySchedule +from metagpt.ext.stanford_town.actions.new_decomp_schedule import NewDecompSchedule +from metagpt.ext.stanford_town.actions.summarize_conv import SummarizeConv +from metagpt.ext.stanford_town.actions.task_decomp import TaskDecomp +from metagpt.ext.stanford_town.actions.wake_up import WakeUp +from metagpt.ext.stanford_town.memory.retrieve import new_agent_retrieve +from metagpt.ext.stanford_town.plan.converse import agent_conversation +from metagpt.ext.stanford_town.utils.utils import get_embedding +from metagpt.llm import LLM +from metagpt.logs import logger + + +async def plan(role: "STRole", roles: dict["STRole"], new_day: bool, retrieved: dict) -> str: + # PART 1: Generate the hourly schedule. + if new_day: + await _long_term_planning(role, new_day) + + # PART 2: If the current action has expired, we want to create a new plan. + act_check_finished = role.scratch.act_check_finished() + logger.info(f"Role: {role.name} act_check_finished is {act_check_finished}") + if act_check_finished: + await _determine_action(role) + + # PART 3: If you perceived an event that needs to be responded to (saw + # another role), and retrieved relevant information. + # Step 1: Retrieved may have multiple events represented in it. The first + # job here is to determine which of the events we want to focus + # on for the role. + # takes the form of a dictionary like this: + # dictionary {["curr_event"] = , + # ["events"] = [, ...], + # ["thoughts"] = [, ...]} + focused_event = False + if retrieved.keys(): + focused_event = _choose_retrieved(role.name, retrieved) + + # Step 2: Once we choose an event, we need to determine whether the + # role will take any actions for the perceived event. There are + # three possible modes of reaction returned by _should_react. + # a) "chat with {target_role.name}" + # b) "react" + # c) False + logger.info(f"Role: {role.name} focused_event: {focused_event}") + if focused_event: + reaction_mode = await _should_react(role, focused_event, roles) + logger.info(f"Role: {role.name} reaction_mode: {reaction_mode}") + if reaction_mode: + # If we do want to chat, then we generate conversation + if reaction_mode[:9] == "chat with": + await _chat_react(role, reaction_mode, roles) + elif reaction_mode[:4] == "wait": + await _wait_react(role, reaction_mode) + + # Step 3: Chat-related state clean up. + # If the persona is not chatting with anyone, we clean up any of the + # chat-related states here. + if role.rc.scratch.act_event[1] != "chat with": + role.rc.scratch.chatting_with = None + role.rc.scratch.chat = None + role.rc.scratch.chatting_end_time = None + # We want to make sure that the persona does not keep conversing with each + # other in an infinite loop. So, chatting_with_buffer maintains a form of + # buffer that makes the persona wait from talking to the same target + # immediately after chatting once. We keep track of the buffer value here. + curr_persona_chat_buffer = role.rc.scratch.chatting_with_buffer + for persona_name, buffer_count in curr_persona_chat_buffer.items(): + if persona_name != role.rc.scratch.chatting_with: + role.rc.scratch.chatting_with_buffer[persona_name] -= 1 + + return role.rc.scratch.act_address + + +def _choose_retrieved(role_name: str, retrieved: dict) -> Union[None, dict]: + """ + Retrieved elements have multiple core "curr_events". We need to choose one + event to which we are going to react to. We pick that event here. + Args: + role_name: Current role instance's name whose action we are determining. + retrieved: A dictionary of that were retrieved from the + the role's associative memory. This dictionary takes the + following form: + dictionary[event.description] = + {["curr_event"] = , + ["events"] = [, ...], + ["thoughts"] = [, ...] } + """ + # Once we are done with the reflection, we might want to build a more + # complex structure here. + + # We do not want to take self events... for now + copy_retrieved = retrieved.copy() + for event_desc, rel_ctx in copy_retrieved.items(): + curr_event = rel_ctx["curr_event"] + if curr_event.subject == role_name: + del retrieved[event_desc] + + # Always choose role first. + priority = [] + for event_desc, rel_ctx in retrieved.items(): + curr_event = rel_ctx["curr_event"] + if ":" not in curr_event.subject and curr_event.subject != role_name: + priority += [rel_ctx] + if priority: + return random.choice(priority) + + # Skip idle. + for event_desc, rel_ctx in retrieved.items(): + if "is idle" not in event_desc: + priority += [rel_ctx] + if priority: + return random.choice(priority) + return None + + +async def _should_react(role: "STRole", retrieved: dict, roles: dict): + """ + Determines what form of reaction the role should exihibit given the + retrieved values. + INPUT + role: Current <"STRole"> instance whose action we are determining. + retrieved: A dictionary of that were retrieved from the + the role's associative memory. This dictionary takes the + following form: + dictionary[event.description] = + {["curr_event"] = , + ["events"] = [, ...], + ["thoughts"] = [, ...] } + roles: A dictionary that contains all role names as keys, and the + <"STRole"> instance as values. + """ + + async def lets_talk(init_role: "STRole", target_role: "STRole", retrieved: dict): + if init_role.name == target_role.name: + logger.info(f"Role: {role.name} _should_react lets_talk meet same role, return False") + return False + + scratch = init_role.rc.scratch + target_scratch = target_role.rc.scratch + if ( + not target_scratch.act_address + or not target_scratch.act_description + or not scratch.act_address + or not scratch.act_description + ): + return False + + if "sleeping" in target_scratch.act_description or "sleeping" in scratch.act_description: + return False + + if scratch.curr_time.hour == 23: + return False + + if "" in target_scratch.act_address: + return False + + if target_scratch.chatting_with or scratch.chatting_with: + return False + + if target_role.name in scratch.chatting_with_buffer: + if scratch.chatting_with_buffer[target_role.name] > 0: + return False + + if await DecideToTalk().run(init_role, target_role, retrieved): + return True + + return False + + async def lets_react(init_role: "STRole", target_role: "STRole", retrieved: dict): + if init_role.name == target_role.name: + logger.info(f"Role: {role.name} _should_react lets_react meet same role, return False") + return False + + scratch = init_role.rc.scratch + target_scratch = target_role.rc.scratch + if ( + not target_scratch.act_address + or not target_scratch.act_description + or not scratch.act_address + or not scratch.act_description + ): + return False + + if "sleeping" in target_scratch.act_description or "sleeping" in scratch.act_description: + return False + + # return False + if scratch.curr_time.hour == 23: + return False + + if "waiting" in target_scratch.act_description: + return False + if scratch.planned_path == []: + return False + + if scratch.act_address != target_scratch.act_address: + return False + + react_mode = await DecideToTalk().run(init_role, target_role, retrieved) + + if react_mode == "1": + wait_until = ( + target_scratch.act_start_time + datetime.timedelta(minutes=target_scratch.act_duration - 1) + ).strftime("%B %d, %Y, %H:%M:%S") + return f"wait: {wait_until}" + elif react_mode == "2": + return False + return "do other things" + else: + return False # "keep" + + # If the role is chatting right now, default to no reaction + scratch = role.rc.scratch + if scratch.chatting_with: + return False + if "" in scratch.act_address: + return False + + # Recall that retrieved takes the following form: + # dictionary {["curr_event"] = } + curr_event = retrieved["curr_event"] + logger.info(f"Role: {role.name} _should_react curr_event.subject: {curr_event.subject}") + + if ":" not in curr_event.subject: + # this is a role event. + if await lets_talk(role, roles[curr_event.subject], retrieved): + return f"chat with {curr_event.subject}" + react_mode = await lets_react(role, roles[curr_event.subject], retrieved) + return react_mode + return False + + +async def _chat_react(role: "STRole", reaction_mode: str, roles: dict["STRole"]): + # There are two roles -- the role who is initiating the conversation + # and the role who is the target. We get the role instances here. + init_role = role + target_role = roles[reaction_mode[9:].strip()] + + # Actually creating the conversation here. + convo, duration_min = await generate_convo(init_role, target_role) # 2222 + convo_summary = await generate_convo_summary(convo) + inserted_act = convo_summary + inserted_act_dur = duration_min + + act_start_time = target_role.rc.scratch.act_start_time + + curr_time = target_role.rc.scratch.curr_time + if curr_time.second != 0: + temp_curr_time = curr_time + datetime.timedelta(seconds=60 - curr_time.second) + chatting_end_time = temp_curr_time + datetime.timedelta(minutes=inserted_act_dur) + else: + chatting_end_time = curr_time + datetime.timedelta(minutes=inserted_act_dur) + + for role, p in [("init", init_role), ("target", target_role)]: + if role == "init": + act_address = f" {target_role.name}" + act_event = (p.name, "chat with", target_role.name) + chatting_with = target_role.name + chatting_with_buffer = {} + chatting_with_buffer[target_role.name] = 800 + elif role == "target": + act_address = f" {init_role.name}" + act_event = (p.name, "chat with", init_role.name) + chatting_with = init_role.name + chatting_with_buffer = {} + chatting_with_buffer[init_role.name] = 800 + + act_pronunciatio = "💬" + act_obj_description = None + act_obj_pronunciatio = None + act_obj_event = (None, None, None) + + await _create_react( + p, + inserted_act, + inserted_act_dur, + act_address, + act_event, + chatting_with, + convo, + chatting_with_buffer, + chatting_end_time, + act_pronunciatio, + act_obj_description, + act_obj_pronunciatio, + act_obj_event, + act_start_time, + ) + + +async def _create_react( + role: "STRole", + inserted_act: str, + inserted_act_dur: int, + act_address: str, + act_event: Tuple, + chatting_with: str, + chat: list, + chatting_with_buffer: dict, + chatting_end_time: datetime, + act_pronunciatio: str, + act_obj_description: str, + act_obj_pronunciatio: str, + act_obj_event: Tuple, + act_start_time=None, +): + p = role + scratch = role.rc.scratch + + min_sum = 0 + for i in range(scratch.get_f_daily_schedule_hourly_org_index()): + min_sum += scratch.f_daily_schedule_hourly_org[i][1] + start_hour = int(min_sum / 60) + + if scratch.f_daily_schedule_hourly_org[scratch.get_f_daily_schedule_hourly_org_index()][1] >= 120: + end_hour = ( + start_hour + scratch.f_daily_schedule_hourly_org[scratch.get_f_daily_schedule_hourly_org_index()][1] / 60 + ) + + elif ( + scratch.f_daily_schedule_hourly_org[scratch.get_f_daily_schedule_hourly_org_index()][1] + + scratch.f_daily_schedule_hourly_org[scratch.get_f_daily_schedule_hourly_org_index() + 1][1] + ): + end_hour = start_hour + ( + ( + scratch.f_daily_schedule_hourly_org[scratch.get_f_daily_schedule_hourly_org_index()][1] + + scratch.f_daily_schedule_hourly_org[scratch.get_f_daily_schedule_hourly_org_index() + 1][1] + ) + / 60 + ) + + else: + end_hour = start_hour + 2 + end_hour = int(end_hour) + + dur_sum = 0 + count = 0 + start_index = None + end_index = None + for act, dur in scratch.f_daily_schedule: + if dur_sum >= start_hour * 60 and start_index is None: + start_index = count + if dur_sum >= end_hour * 60 and end_index is None: + end_index = count + dur_sum += dur + count += 1 + + ret = await generate_new_decomp_schedule(p, inserted_act, inserted_act_dur, start_hour, end_hour) + scratch.f_daily_schedule[start_index:end_index] = ret + scratch.add_new_action( + act_address, + inserted_act_dur, + inserted_act, + act_pronunciatio, + act_event, + chatting_with, + chat, + chatting_with_buffer, + chatting_end_time, + act_obj_description, + act_obj_pronunciatio, + act_obj_event, + act_start_time, + ) + + +async def _wait_react(role: "STRole", reaction_mode: str): + scratch = role.rc.scratch + + inserted_act = f'waiting to start {scratch.act_description.split("(")[-1][:-1]}' + end_time = datetime.datetime.strptime(reaction_mode[6:].strip(), "%B %d, %Y, %H:%M:%S") + inserted_act_dur = ( + (end_time.minute + end_time.hour * 60) - (scratch.curr_time.minute + scratch.curr_time.hour * 60) + 1 + ) + + act_address = f" {scratch.curr_tile[0]} {scratch.curr_tile[1]}" + act_event = (role.name, "waiting to start", scratch.act_description.split("(")[-1][:-1]) + chatting_with = None + chat = None + chatting_with_buffer = None + chatting_end_time = None + + act_pronunciatio = "⌛" + act_obj_description = None + act_obj_pronunciatio = None + act_obj_event = (None, None, None) + + await _create_react( + role, + inserted_act, + inserted_act_dur, + act_address, + act_event, + chatting_with, + chat, + chatting_with_buffer, + chatting_end_time, + act_pronunciatio, + act_obj_description, + act_obj_pronunciatio, + act_obj_event, + ) + + +async def generate_convo(init_role: "STRole", target_role: "STRole") -> Union[list, int]: + convo = await agent_conversation(init_role, target_role) + all_utt = "" + + for row in convo: + speaker = row[0] + utt = row[1] + all_utt += f"{speaker}: {utt}\n" + + convo_length = math.ceil(int(len(all_utt) / 8) / 30) + + return convo, convo_length + + +async def generate_convo_summary(conv: list[list[str]]) -> str: + conv_summary = await SummarizeConv().run(conv) + return conv_summary + + +async def generate_new_decomp_schedule( + role: "STRole", inserted_act: str, inserted_act_dur: int, start_hour: int, end_hour: int +): + # Step 1: Setting up the core variables for the function. + #

is the role whose schedule we are editing right now. + scratch = role.rc.scratch + # indicates the number of minutes that have passed today. + today_min_pass = int(scratch.curr_time.hour) * 60 + int(scratch.curr_time.minute) + 1 + + # Step 2: We need to create and . + main_act_dur = [] + truncated_act_dur = [] + dur_sum = 0 # duration sum + count = 0 # enumerate count + truncated_fin = False + + logger.debug(f"DEBUG::: {scratch.name}") + for act, dur in scratch.f_daily_schedule: + if (dur_sum >= start_hour * 60) and (dur_sum < end_hour * 60): + main_act_dur += [[act, dur]] + if dur_sum <= today_min_pass: + truncated_act_dur += [[act, dur]] + elif dur_sum > today_min_pass and not truncated_fin: + # We need to insert that last act, duration list like this one: + # e.g., ['wakes up and completes her morning routine (wakes up...)', 2] + truncated_act_dur += [[scratch.f_daily_schedule[count][0], dur_sum - today_min_pass]] + truncated_act_dur[-1][-1] -= ( + dur_sum - today_min_pass + ) # DEC 7 DEBUG;.. is the +1 the right thing to do??? + # DEC 7 DEBUG;.. is the +1 the right thing to do??? + # truncated_act_dur[-1][-1] -= (dur_sum - today_min_pass + 1) + logger.debug(f"DEBUG::: {truncated_act_dur}") + + # DEC 7 DEBUG;.. is the +1 the right thing to do??? + # truncated_act_dur[-1][-1] -= (dur_sum - today_min_pass) + truncated_fin = True + dur_sum += dur + count += 1 + + main_act_dur = main_act_dur + + x = ( + truncated_act_dur[-1][0].split("(")[0].strip() + + " (on the way to " + + truncated_act_dur[-1][0].split("(")[-1][:-1] + + ")" + ) + truncated_act_dur[-1][0] = x + + if "(" in truncated_act_dur[-1][0]: + inserted_act = truncated_act_dur[-1][0].split("(")[0].strip() + " (" + inserted_act + ")" + + # To do inserted_act_dur+1 below is an important decision but I'm not sure + # if I understand the full extent of its implications. Might want to + # revisit. + truncated_act_dur += [[inserted_act, inserted_act_dur]] + start_time_hour = datetime.datetime(2022, 10, 31, 0, 0) + datetime.timedelta(hours=start_hour) + end_time_hour = datetime.datetime(2022, 10, 31, 0, 0) + datetime.timedelta(hours=end_hour) + + return await NewDecompSchedule().run( + role, main_act_dur, truncated_act_dur, start_time_hour, end_time_hour, inserted_act, inserted_act_dur + ) + + +async def _long_term_planning(role: "STRole", new_day: bool): + """ + Formulates the role's daily long-term plan if it is the start of a new + day. This basically has two components: first, we create the wake-up hour, + and second, we create the hourly schedule based on it. + INPUT + new_day: Indicates whether the current time signals a "First day", + "New day", or False (for neither). This is important because we + create the roles' long term planning on the new day. + """ + # We start by creating the wake up hour for the role. + wake_up_hour = await WakeUp().run(role) + wake_up_hour = int(wake_up_hour) + logger.info(f"Role: {role.name} long_term_planning, wake_up_hour: {wake_up_hour}") + + # When it is a new day, we start by creating the daily_req of the role. + # Note that the daily_req is a list of strings that describe the role's + # day in broad strokes. + if new_day == "First day": + # Bootstrapping the daily plan for the start of then generation: + # if this is the start of generation (so there is no previous day's + # daily requirement, or if we are on a new day, we want to create a new + # set of daily requirements. + role.scratch.daily_req = await GenDailySchedule().run(role, wake_up_hour) + logger.info(f"Role: {role.name} daily requirements: {role.scratch.daily_req}") + elif new_day == "New day": + revise_identity(role) + + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - TODO + # We need to create a new daily_req here... + role.scratch.daily_req = role.scratch.daily_req + + # Based on the daily_req, we create an hourly schedule for the role, + # which is a list of todo items with a time duration (in minutes) that + # add up to 24 hours. + role.scratch.f_daily_schedule = await GenHourlySchedule().run(role, wake_up_hour) + logger.info(f"Role: {role.name} f_daily_schedule: {role.scratch.f_daily_schedule}") + role.scratch.f_daily_schedule_hourly_org = role.scratch.f_daily_schedule[:] + + # Added March 4 -- adding plan to the memory. + thought = f"This is {role.scratch.name}'s plan for {role.scratch.curr_time.strftime('%A %B %d')}:" + for i in role.scratch.daily_req: + thought += f" {i}," + thought = thought[:-1] + "." + created = role.scratch.curr_time + expiration = role.scratch.curr_time + datetime.timedelta(days=30) + s, p, o = (role.scratch.name, "plan", role.scratch.curr_time.strftime("%A %B %d")) + keywords = set(["plan"]) + thought_poignancy = 5 + thought_embedding_pair = (thought, get_embedding(thought)) + role.a_mem.add_thought( + created, expiration, s, p, o, thought, keywords, thought_poignancy, thought_embedding_pair, None + ) + + +async def _determine_action(role: "STRole"): + """ + Creates the next action sequence for the role. + The main goal of this function is to run "add_new_action" on the role's + scratch space, which sets up all the action related variables for the next + action. + As a part of this, the role may need to decompose its hourly schedule as + needed. + INPUT + role: Current instance whose action we are determining. + """ + + def determine_decomp(act_desp, act_dura): + """ + Given an action description and its duration, we determine whether we need + to decompose it. If the action is about the agent sleeping, we generally + do not want to decompose it, so that's what we catch here. + + INPUT: + act_desp: the description of the action (e.g., "sleeping") + act_dura: the duration of the action in minutes. + OUTPUT: + a boolean. True if we need to decompose, False otherwise. + """ + if "sleep" not in act_desp and "bed" not in act_desp: + return True + elif "sleeping" in act_desp or "asleep" in act_desp or "in bed" in act_desp: + return False + elif "sleep" in act_desp or "bed" in act_desp: + if act_dura > 60: + return False + return True + + # The goal of this function is to get us the action associated with + # . As a part of this, we may need to decompose some large + # chunk actions. + # Importantly, we try to decompose at least two hours worth of schedule at + # any given point. + curr_index = role.scratch.get_f_daily_schedule_index() + curr_index_60 = role.scratch.get_f_daily_schedule_index(advance=60) + + logger.info(f"f_daily_schedule: {role.scratch.f_daily_schedule}") + # * Decompose * + # During the first hour of the day, we need to decompose two hours + # sequence. We do that here. + if curr_index == 0: + # This portion is invoked if it is the first hour of the day. + act_desp, act_dura = role.scratch.f_daily_schedule[curr_index] + if act_dura >= 60: + # We decompose if the next action is longer than an hour, and fits the + # criteria described in determine_decomp. + if determine_decomp(act_desp, act_dura): + role.scratch.f_daily_schedule[curr_index : curr_index + 1] = await TaskDecomp().run( + role, act_desp, act_dura + ) + if curr_index_60 + 1 < len(role.scratch.f_daily_schedule): + act_desp, act_dura = role.scratch.f_daily_schedule[curr_index_60 + 1] + if act_dura >= 60: + if determine_decomp(act_desp, act_dura): + role.scratch.f_daily_schedule[curr_index_60 + 1 : curr_index_60 + 2] = await TaskDecomp().run( + role, act_desp, act_dura + ) + + if curr_index_60 < len(role.scratch.f_daily_schedule): + # If it is not the first hour of the day, this is always invoked (it is + # also invoked during the first hour of the day -- to double up so we can + # decompose two hours in one go). Of course, we need to have something to + # decompose as well, so we check for that too. + if role.scratch.curr_time.hour < 23: + # And we don't want to decompose after 11 pm. + act_desp, act_dura = role.scratch.f_daily_schedule[curr_index_60] + if act_dura >= 60: + if determine_decomp(act_desp, act_dura): + role.scratch.f_daily_schedule[curr_index_60 : curr_index_60 + 1] = await TaskDecomp().run( + role, act_desp, act_dura + ) + # * End of Decompose * + + # Generate an instance from the action description and duration. By + # this point, we assume that all the relevant actions are decomposed and + # ready in f_daily_schedule. + logger.debug("DEBUG LJSDLFSKJF") + for i in role.scratch.f_daily_schedule: + logger.debug(i) + logger.debug(curr_index) + logger.debug(len(role.scratch.f_daily_schedule)) + logger.debug(role.scratch.name) + + # 1440 + x_emergency = 0 + for i in role.scratch.f_daily_schedule: + x_emergency += i[1] + + if 1440 - x_emergency > 0: + logger.info(f"x_emergency__AAA: {x_emergency}") + role.scratch.f_daily_schedule += [["sleeping", 1440 - x_emergency]] + + act_desp, act_dura = role.scratch.f_daily_schedule[curr_index] + + new_action_details = await GenActionDetails().run(role, act_desp, act_dura) + # Adding the action to role's queue. + role.scratch.add_new_action(**new_action_details) + + +def revise_identity(role: "STRole"): + p_name = role.scratch.name + + focal_points = [ + f"{p_name}'s plan for {role.scratch.get_str_curr_date_str()}.", + f"Important recent events for {p_name}'s life.", + ] + retrieved = new_agent_retrieve(role, focal_points) + + statements = "[Statements]\n" + for key, val in retrieved.items(): + for i in val: + statements += f"{i.created.strftime('%A %B %d -- %H:%M %p')}: {i.embedding_key}\n" + + plan_prompt = statements + "\n" + plan_prompt += f"Given the statements above, is there anything that {p_name} should remember as they plan for" + plan_prompt += f" *{role.scratch.curr_time.strftime('%A %B %d')}*? " + plan_prompt += "If there is any scheduling information, be as specific as possible (include date, time, and location if stated in the statement)\n\n" + plan_prompt += f"Write the response from {p_name}'s perspective." + plan_note = LLM().ask(plan_prompt) + + thought_prompt = statements + "\n" + thought_prompt += ( + f"Given the statements above, how might we summarize {p_name}'s feelings about their days up to now?\n\n" + ) + thought_prompt += f"Write the response from {p_name}'s perspective." + thought_note = LLM().ask(thought_prompt) + + currently_prompt = ( + f"{p_name}'s status from {(role.scratch.curr_time - datetime.timedelta(days=1)).strftime('%A %B %d')}:\n" + ) + currently_prompt += f"{role.scratch.currently}\n\n" + currently_prompt += f"{p_name}'s thoughts at the end of {(role.scratch.curr_time - datetime.timedelta(days=1)).strftime('%A %B %d')}:\n" + currently_prompt += (plan_note + thought_note).replace("\n", "") + "\n\n" + currently_prompt += f"It is now {role.scratch.curr_time.strftime('%A %B %d')}. Given the above, write {p_name}'s status for {role.scratch.curr_time.strftime('%A %B %d')} that reflects {p_name}'s thoughts at the end of {(role.scratch.curr_time - datetime.timedelta(days=1)).strftime('%A %B %d')}. Write this in third-person talking about {p_name}." + currently_prompt += "If there is any scheduling information, be as specific as possible (include date, time, and location if stated in the statement).\n\n" + currently_prompt += "Follow this format below:\nStatus: " + new_currently = LLM().ask(currently_prompt) + + role.scratch.currently = new_currently + + daily_req_prompt = role.scratch.get_str_iss() + "\n" + daily_req_prompt += f"Today is {role.scratch.curr_time.strftime('%A %B %d')}. Here is {role.scratch.name}'s plan today in broad-strokes (with the time of the day. e.g., have a lunch at 12:00 pm, watch TV from 7 to 8 pm).\n\n" + daily_req_prompt += "Follow this format (the list should have 4~6 items but no more):\n" + daily_req_prompt += "1. wake up and complete the morning routine at