Upload IsoPro Package
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +201 -0
- README.md +276 -3
- isopro/.DS_Store +0 -0
- isopro/__init__.py +84 -0
- isopro/__pycache__/__init__.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/__init__.py +18 -0
- isopro/adversarial_simulation/__pycache__/__init__.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/__pycache__/adversarial_agent.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/__pycache__/adversarial_environment.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/__pycache__/adversarial_envrionment.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/__pycache__/adversarial_simulator.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/__pycache__/attack_utils.cpython-38.pyc +0 -0
- isopro/adversarial_simulation/adversarial_agent.py +51 -0
- isopro/adversarial_simulation/adversarial_environment.py +81 -0
- isopro/adversarial_simulation/adversarial_simulator.py +47 -0
- isopro/adversarial_simulation/attack_utils.py +65 -0
- isopro/adversarial_simulation/main.py +124 -0
- isopro/agents/__init__.py +7 -0
- isopro/agents/__pycache__/__init__.cpython-38.pyc +0 -0
- isopro/agents/__pycache__/ai_agent.cpython-38.pyc +0 -0
- isopro/agents/ai_agent.py +44 -0
- isopro/base/__init__.py +8 -0
- isopro/base/__pycache__/__init__.cpython-38.pyc +0 -0
- isopro/base/__pycache__/base_component.cpython-38.pyc +0 -0
- isopro/base/__pycache__/base_wrapper.cpython-38.pyc +0 -0
- isopro/base/base_component.py +34 -0
- isopro/base/base_wrapper.py +82 -0
- isopro/car_simulator/__init__.py +12 -0
- isopro/car_simulator/car_llm_agent.py +143 -0
- isopro/car_simulator/car_rl_environment.py +155 -0
- isopro/car_simulator/car_rl_model.zip +3 -0
- isopro/car_simulator/car_rl_training.py +38 -0
- isopro/car_simulator/carviz.py +227 -0
- isopro/car_simulator/llm_main.py +74 -0
- isopro/car_simulator/main.py +48 -0
- isopro/conversation_simulation/README.md +252 -0
- isopro/conversation_simulation/__init__.py +19 -0
- isopro/conversation_simulation/conversation_agent.py +41 -0
- isopro/conversation_simulation/conversation_environment.py +78 -0
- isopro/conversation_simulation/conversation_simulator.py +67 -0
- isopro/conversation_simulation/custom_persona.py +58 -0
- isopro/conversation_simulation/main.py +117 -0
- isopro/conversation_simulation/user_personas.py +112 -0
- isopro/environments/__init__.py +9 -0
- isopro/environments/__pycache__/__init__.cpython-38.pyc +0 -0
- isopro/environments/__pycache__/custom_environment.cpython-38.pyc +0 -0
- isopro/environments/__pycache__/llm_orchestrator.cpython-38.pyc +0 -0
- isopro/environments/__pycache__/simulation_environment.cpython-38.pyc +0 -0
- isopro/environments/custom_environment.py +108 -0
- isopro/environments/llm_orchestrator.py +194 -0
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,3 +1,276 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ISOPro: Pro Tools for Intelligent Simulation Orchestration for Large Language Models
|
| 2 |
+
|
| 3 |
+
ISOPRO is a powerful and flexible Python package designed for creating, managing, and analyzing simulations involving Large Language Models (LLMs). It provides a comprehensive suite of tools for reinforcement learning, conversation simulations, adversarial testing, custom environment creation, and advanced orchestration of multi-agent systems.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **Custom Environment Creation**: Easily create and manage custom simulation environments for LLMs
|
| 8 |
+
- **Conversation Simulation**: Simulate and analyze conversations with AI agents using various user personas
|
| 9 |
+
- **Adversarial Testing**: Conduct adversarial simulations to test the robustness of LLM-based systems
|
| 10 |
+
- **Reinforcement Learning**: Implement and experiment with RL algorithms in LLM contexts
|
| 11 |
+
- **Workflow Automation**: Learn and replicate UI workflows from video demonstrations
|
| 12 |
+
- **Car Environment Simulation**: Train and evaluate RL agents in driving scenarios
|
| 13 |
+
- **Utility Functions**: Analyze simulation results, calculate LLM metrics, and more
|
| 14 |
+
- **Flexible Integration**: Works with popular LLM platforms like OpenAI's GPT models, Claude (Anthropic), and Hugging Face models
|
| 15 |
+
- **Orchestration Simulation**: Manage and execute complex multi-agent simulations with different execution modes
|
| 16 |
+
|
| 17 |
+
## Installation
|
| 18 |
+
|
| 19 |
+
You can install isopro using pip:
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
pip install isopro
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
For workflow simulation features, ensure you have the required dependencies:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
pip install opencv-python numpy torch stable-baselines3 gymnasium tqdm
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
If you plan to use Claude capabilities:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
export ANTHROPIC_API_KEY=your_api_key_here
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
## Usage
|
| 38 |
+
|
| 39 |
+
### Adversarial Simulation
|
| 40 |
+
|
| 41 |
+
Test the robustness of AI models against adversarial attacks.
|
| 42 |
+
|
| 43 |
+
```python
|
| 44 |
+
from isopro.adversarial_simulation import AdversarialSimulator, AdversarialEnvironment
|
| 45 |
+
from isopro.agents.ai_agent import AI_Agent
|
| 46 |
+
import anthropic
|
| 47 |
+
|
| 48 |
+
class ClaudeAgent(AI_Agent):
|
| 49 |
+
def __init__(self, name):
|
| 50 |
+
super().__init__(name)
|
| 51 |
+
self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
| 52 |
+
|
| 53 |
+
def run(self, input_data):
|
| 54 |
+
response = self.client.messages.create(
|
| 55 |
+
model="claude-3-opus-20240229",
|
| 56 |
+
max_tokens=100,
|
| 57 |
+
messages=[{"role": "user", "content": input_data['text']}]
|
| 58 |
+
)
|
| 59 |
+
return response.content[0].text
|
| 60 |
+
|
| 61 |
+
# Create the AdversarialEnvironment
|
| 62 |
+
adv_env = AdversarialEnvironment(
|
| 63 |
+
agent_wrapper=ClaudeAgent("Claude Agent"),
|
| 64 |
+
num_adversarial_agents=2,
|
| 65 |
+
attack_types=["textbugger", "deepwordbug"],
|
| 66 |
+
attack_targets=["input", "output"]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Set up the adversarial simulator
|
| 70 |
+
simulator = AdversarialSimulator(adv_env)
|
| 71 |
+
|
| 72 |
+
# Run the simulation
|
| 73 |
+
input_data = ["What is the capital of France?", "How does photosynthesis work?"]
|
| 74 |
+
simulation_results = simulator.run_simulation(input_data, num_steps=1)
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Conversation Simulation
|
| 78 |
+
|
| 79 |
+
Simulate conversations between an AI assistant and various user personas.
|
| 80 |
+
|
| 81 |
+
```python
|
| 82 |
+
from isopro.conversation_simulation.conversation_simulator import ConversationSimulator
|
| 83 |
+
|
| 84 |
+
# Initialize the ConversationSimulator
|
| 85 |
+
simulator = ConversationSimulator(
|
| 86 |
+
ai_prompt="You are an AI assistant created to be helpful, harmless, and honest. You are a customer service agent for a tech company. Respond politely and professionally."
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Run a simulation with a predefined persona
|
| 90 |
+
conversation_history = simulator.run_simulation("upset", num_turns=3)
|
| 91 |
+
|
| 92 |
+
# Run a simulation with a custom persona
|
| 93 |
+
custom_persona = {
|
| 94 |
+
"name": "Techie Customer",
|
| 95 |
+
"characteristics": ["tech-savvy", "impatient", "detail-oriented"],
|
| 96 |
+
"message_templates": [
|
| 97 |
+
"I've tried rebooting my device, but the error persists. Can you help?",
|
| 98 |
+
"What's the latest update on the cloud service outage?",
|
| 99 |
+
"I need specifics on the API rate limits for the enterprise plan."
|
| 100 |
+
]
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
custom_conversation = simulator.run_custom_simulation(**custom_persona, num_turns=3)
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
### Workflow Simulation
|
| 107 |
+
|
| 108 |
+
Automate UI workflows by learning from video demonstrations.
|
| 109 |
+
|
| 110 |
+
```python
|
| 111 |
+
from isopro.workflow_simulation import WorkflowAutomation, AgentConfig
|
| 112 |
+
|
| 113 |
+
# Basic workflow automation
|
| 114 |
+
automation = WorkflowAutomation(
|
| 115 |
+
video="path/to/workflow.mp4",
|
| 116 |
+
config="config.json",
|
| 117 |
+
output="output_dir",
|
| 118 |
+
logs="logs_dir"
|
| 119 |
+
)
|
| 120 |
+
automation.run()
|
| 121 |
+
|
| 122 |
+
# Advanced configuration
|
| 123 |
+
agent_config = AgentConfig(
|
| 124 |
+
learning_rate=3e-4,
|
| 125 |
+
pretrain_epochs=10,
|
| 126 |
+
use_demonstration=True,
|
| 127 |
+
use_reasoning=True
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
simulator = WorkflowSimulator(
|
| 131 |
+
video_path="path/to/video.mp4",
|
| 132 |
+
agent_config=agent_config,
|
| 133 |
+
viz_config=visualization_config,
|
| 134 |
+
validation_config=validation_config,
|
| 135 |
+
output_dir="output"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
training_results = simulator.train_agents()
|
| 139 |
+
evaluation_results = simulator.evaluate_agents()
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### Car Reinforcement Learning
|
| 143 |
+
|
| 144 |
+
Train and evaluate RL agents in driving scenarios.
|
| 145 |
+
|
| 146 |
+
```python
|
| 147 |
+
from isopro.car_simulation import CarRLEnvironment, LLMCarRLWrapper, CarVisualization
|
| 148 |
+
|
| 149 |
+
# Create the car environment with LLM integration
|
| 150 |
+
env = CarRLEnvironment()
|
| 151 |
+
llm_env = LLMCarRLWrapper(env)
|
| 152 |
+
|
| 153 |
+
# Initialize visualization
|
| 154 |
+
viz = CarVisualization(env)
|
| 155 |
+
|
| 156 |
+
# Train and visualize
|
| 157 |
+
observation = llm_env.reset()
|
| 158 |
+
for step in range(1000):
|
| 159 |
+
action = llm_env.get_action(observation)
|
| 160 |
+
observation, reward, done, info = llm_env.step(action)
|
| 161 |
+
viz.render(observation)
|
| 162 |
+
|
| 163 |
+
if done:
|
| 164 |
+
observation = llm_env.reset()
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
### Reinforcement Learning with LLM
|
| 168 |
+
|
| 169 |
+
Integrate Large Language Models with reinforcement learning environments.
|
| 170 |
+
|
| 171 |
+
```python
|
| 172 |
+
import gymnasium as gym
|
| 173 |
+
from isopro.rl.rl_agent import RLAgent
|
| 174 |
+
from isopro.rl.rl_environment import LLMRLEnvironment
|
| 175 |
+
from stable_baselines3 import PPO
|
| 176 |
+
from isopro.rl.llm_cartpole_wrapper import LLMCartPoleWrapper
|
| 177 |
+
|
| 178 |
+
agent_prompt = """You are an AI trained to play the CartPole game.
|
| 179 |
+
Your goal is to balance a pole on a moving cart for as long as possible.
|
| 180 |
+
You will receive observations about the cart's position, velocity, pole angle, and angular velocity.
|
| 181 |
+
Based on these, you should decide whether to move the cart left or right."""
|
| 182 |
+
|
| 183 |
+
env = LLMCartPoleWrapper(agent_prompt, llm_call_limit=100, api_key=os.getenv("ANTHROPIC_API_KEY"))
|
| 184 |
+
rl_agent = RLAgent("LLM_CartPole_Agent", env, algorithm='PPO')
|
| 185 |
+
|
| 186 |
+
# Train the model
|
| 187 |
+
model.learn(total_timesteps=2)
|
| 188 |
+
|
| 189 |
+
# Test the model
|
| 190 |
+
obs, _ = env.reset()
|
| 191 |
+
for _ in range(1000):
|
| 192 |
+
action, _ = model.predict(obs, deterministic=True)
|
| 193 |
+
obs, reward, done, _, _ = env.step(action)
|
| 194 |
+
if done:
|
| 195 |
+
obs, _ = env.reset()
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
### AI Orchestration
|
| 199 |
+
|
| 200 |
+
Orchestrate multiple AI agents to work together on complex tasks.
|
| 201 |
+
|
| 202 |
+
```python
|
| 203 |
+
from isopro.orchestration_simulation import OrchestrationEnv
|
| 204 |
+
from isopro.orchestration_simulation.components import LLaMAAgent, AnalysisAgent, WritingAgent
|
| 205 |
+
from isopro.orchestration_simulation.evaluator import Evaluator
|
| 206 |
+
|
| 207 |
+
# Create the orchestration environment
|
| 208 |
+
env = OrchestrationEnv()
|
| 209 |
+
|
| 210 |
+
# Add agents to the environment
|
| 211 |
+
env.add_component(LLaMAAgent("Research", "conduct thorough research on the impact of artificial intelligence on job markets"))
|
| 212 |
+
env.add_component(AnalysisAgent("Analysis"))
|
| 213 |
+
env.add_component(WritingAgent("Writing"))
|
| 214 |
+
|
| 215 |
+
# Define the task
|
| 216 |
+
task = "Prepare a comprehensive report on the impact of artificial intelligence on job markets in the next decade."
|
| 217 |
+
|
| 218 |
+
# Run simulations in different modes
|
| 219 |
+
modes = ['parallel', 'sequence', 'node']
|
| 220 |
+
results = {}
|
| 221 |
+
|
| 222 |
+
for mode in modes:
|
| 223 |
+
result = env.run_simulation(mode=mode, input_data={'task': task, 'run_order': 'first'})
|
| 224 |
+
results[mode] = result
|
| 225 |
+
|
| 226 |
+
# Evaluate the results
|
| 227 |
+
evaluator = Evaluator()
|
| 228 |
+
best_mode = evaluator.evaluate(results)
|
| 229 |
+
print(f"The best execution mode for this task was: {best_mode}")
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
## Documentation
|
| 233 |
+
|
| 234 |
+
For more detailed information on each module and its usage, please refer to the [full documentation](https://isopro.readthedocs.io).
|
| 235 |
+
|
| 236 |
+
## Examples
|
| 237 |
+
|
| 238 |
+
The [isopro examples](https://github.com/iso-ai/isopro_examples) repository contains Jupyter notebooks with detailed examples:
|
| 239 |
+
|
| 240 |
+
- `adversarial_example.ipynb`: Demonstrates adversarial testing of language models
|
| 241 |
+
- `conversation_simulation_example.ipynb`: Shows how to simulate conversations with various user personas
|
| 242 |
+
- `workflow_automation_example.ipynb`: Illustrates automated UI workflow learning
|
| 243 |
+
- `car_rl_example.ipynb`: Demonstrates car environment training scenarios
|
| 244 |
+
- `run_cartpole_example.ipynb`: Illustrates the integration of LLMs with reinforcement learning
|
| 245 |
+
- `orchestrator_example.ipynb`: Provides a tutorial on using the AI orchestration capabilities
|
| 246 |
+
|
| 247 |
+
## Contributing
|
| 248 |
+
|
| 249 |
+
We welcome contributions! Please see our [Contributing Guide](CONTRIBUTING.md) for more details.
|
| 250 |
+
|
| 251 |
+
## License
|
| 252 |
+
|
| 253 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 254 |
+
|
| 255 |
+
## Support
|
| 256 |
+
|
| 257 |
+
If you encounter any problems or have any questions, please [open an issue](https://github.com/iso-ai/isopro/issues) on our GitHub repository.
|
| 258 |
+
|
| 259 |
+
## Citation
|
| 260 |
+
|
| 261 |
+
If you use ISOPRO in your research, please cite it as follows:
|
| 262 |
+
|
| 263 |
+
```
|
| 264 |
+
@software{isopro2024,
|
| 265 |
+
author = {Jazmia Henry},
|
| 266 |
+
title = {ISOPRO: Intelligent Simulation Orchestration for Large Language Models},
|
| 267 |
+
year = {2024},
|
| 268 |
+
publisher = {GitHub},
|
| 269 |
+
journal = {GitHub repository},
|
| 270 |
+
howpublished = {\url{https://github.com/iso-ai/isopro}}
|
| 271 |
+
}
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
## Contact
|
| 275 |
+
|
| 276 |
+
For questions or support, please open an issue on our [GitHub issue tracker](https://github.com/iso-ai/isopro/issues).
|
isopro/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
isopro/__init__.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# isopro/__init__.py
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
isopro: Intelligent Simulation Orchestration for LLMs
|
| 5 |
+
|
| 6 |
+
This package provides tools for creating, managing, and analyzing simulations
|
| 7 |
+
involving Large Language Models (LLMs), including reinforcement learning,
|
| 8 |
+
conversation simulations, and adversarial testing.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
__version__ = "0.1.5"
|
| 12 |
+
|
| 13 |
+
# Core components
|
| 14 |
+
from .environments.simulation_environment import SimulationEnvironment
|
| 15 |
+
from .environments.custom_environment import CustomEnvironment
|
| 16 |
+
from .environments.llm_orchestrator import LLMOrchestrator
|
| 17 |
+
from .agents.ai_agent import AI_Agent
|
| 18 |
+
from .base.base_component import BaseComponent
|
| 19 |
+
from .wrappers.simulation_wrapper import SimulationWrapper
|
| 20 |
+
from .rl.rl_environment import BaseRLEnvironment
|
| 21 |
+
from .rl.rl_agent import RLAgent
|
| 22 |
+
from .conversation_simulation import ConversationSimulator, ConversationEnvironment, ConversationAgent
|
| 23 |
+
from .adversarial_simulation import AdversarialSimulator, AdversarialEnvironment, AdversarialAgent
|
| 24 |
+
from .orchestration_simulation import LLaMAAgent, SubAgent, OrchestrationEnv, AI_AgentException, ComponentException, AI_Agent
|
| 25 |
+
|
| 26 |
+
# Workflow simulation components
|
| 27 |
+
from .workflow_simulation import (
|
| 28 |
+
WorkflowSimulator,
|
| 29 |
+
WorkflowEnvironment,
|
| 30 |
+
WorkflowState,
|
| 31 |
+
UIElement,
|
| 32 |
+
UIElementDetector,
|
| 33 |
+
MotionDetector,
|
| 34 |
+
EpisodeMetrics,
|
| 35 |
+
AgentConfig,
|
| 36 |
+
VisualizationConfig,
|
| 37 |
+
ValidationConfig,
|
| 38 |
+
WorkflowAutomation
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Car RL components
|
| 42 |
+
from .car_simulator import CarRLEnvironment, LLMCarRLWrapper, CarVisualization
|
| 43 |
+
|
| 44 |
+
__all__ = [
|
| 45 |
+
# Core components
|
| 46 |
+
"LLaMAAgent",
|
| 47 |
+
"SubAgent",
|
| 48 |
+
"OrchestrationEnv",
|
| 49 |
+
"AI_AgentException",
|
| 50 |
+
"ComponentException",
|
| 51 |
+
"AI_Agent",
|
| 52 |
+
"SimulationEnvironment",
|
| 53 |
+
"CustomEnvironment",
|
| 54 |
+
"LLMOrchestrator",
|
| 55 |
+
"AI_Agent",
|
| 56 |
+
"BaseComponent",
|
| 57 |
+
"SimulationWrapper",
|
| 58 |
+
"BaseRLEnvironment",
|
| 59 |
+
"RLAgent",
|
| 60 |
+
"ConversationSimulator",
|
| 61 |
+
"ConversationEnvironment",
|
| 62 |
+
"ConversationAgent",
|
| 63 |
+
"AdversarialSimulator",
|
| 64 |
+
"AdversarialEnvironment",
|
| 65 |
+
"AdversarialAgent",
|
| 66 |
+
|
| 67 |
+
# Workflow components
|
| 68 |
+
"WorkflowSimulator",
|
| 69 |
+
"WorkflowEnvironment",
|
| 70 |
+
"WorkflowState",
|
| 71 |
+
"UIElement",
|
| 72 |
+
"UIElementDetector",
|
| 73 |
+
"MotionDetector",
|
| 74 |
+
"EpisodeMetrics",
|
| 75 |
+
"AgentConfig",
|
| 76 |
+
"VisualizationConfig",
|
| 77 |
+
"ValidationConfig",
|
| 78 |
+
"WorkflowAutomation",
|
| 79 |
+
|
| 80 |
+
# Car RL components
|
| 81 |
+
"CarRLEnvironment",
|
| 82 |
+
"LLMCarRLWrapper",
|
| 83 |
+
"CarVisualization"
|
| 84 |
+
]
|
isopro/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (1.56 kB). View file
|
|
|
isopro/adversarial_simulation/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adversarial Simulation Module
|
| 3 |
+
|
| 4 |
+
This module provides tools for simulating adversarial attacks on AI models.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .adversarial_environment import AdversarialEnvironment
|
| 8 |
+
from .adversarial_agent import AdversarialAgent
|
| 9 |
+
from .adversarial_simulator import AdversarialSimulator
|
| 10 |
+
from .attack_utils import get_available_attacks, create_attack
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"AdversarialEnvironment",
|
| 14 |
+
"AdversarialAgent",
|
| 15 |
+
"AdversarialSimulator",
|
| 16 |
+
"get_available_attacks",
|
| 17 |
+
"create_attack",
|
| 18 |
+
]
|
isopro/adversarial_simulation/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (621 Bytes). View file
|
|
|
isopro/adversarial_simulation/__pycache__/adversarial_agent.cpython-38.pyc
ADDED
|
Binary file (1.87 kB). View file
|
|
|
isopro/adversarial_simulation/__pycache__/adversarial_environment.cpython-38.pyc
ADDED
|
Binary file (4.88 kB). View file
|
|
|
isopro/adversarial_simulation/__pycache__/adversarial_envrionment.cpython-38.pyc
ADDED
|
Binary file (4.88 kB). View file
|
|
|
isopro/adversarial_simulation/__pycache__/adversarial_simulator.cpython-38.pyc
ADDED
|
Binary file (2.48 kB). View file
|
|
|
isopro/adversarial_simulation/__pycache__/attack_utils.cpython-38.pyc
ADDED
|
Binary file (2.85 kB). View file
|
|
|
isopro/adversarial_simulation/adversarial_agent.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adversarial Agent
|
| 3 |
+
|
| 4 |
+
This module defines the AdversarialAgent class, which can apply various attacks to input or output text.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
from isopro.agents.ai_agent import AI_Agent
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class AdversarialAgent(AI_Agent):
|
| 14 |
+
def __init__(self, name: str, attack, target: str = "input"):
|
| 15 |
+
"""
|
| 16 |
+
Initialize the AdversarialAgent.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
name (str): The name of the agent.
|
| 20 |
+
attack (callable): The attack function to apply.
|
| 21 |
+
target (str): The target of the attack, either "input" or "output".
|
| 22 |
+
"""
|
| 23 |
+
super().__init__(name)
|
| 24 |
+
self.attack = attack
|
| 25 |
+
self.target = target
|
| 26 |
+
logger.info(f"Initialized AdversarialAgent '{name}' targeting {target}")
|
| 27 |
+
|
| 28 |
+
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 29 |
+
"""
|
| 30 |
+
Apply the adversarial attack to the input or output data.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
input_data (Dict[str, Any]): The input data containing 'text' and 'output' keys.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Dict[str, Any]: The perturbed data.
|
| 37 |
+
"""
|
| 38 |
+
logger.info(f"Running adversarial agent: {self.name}")
|
| 39 |
+
if self.target == "input":
|
| 40 |
+
if input_data.get('text'):
|
| 41 |
+
input_data['text'] = self.attack(input_data['text'])
|
| 42 |
+
else:
|
| 43 |
+
logger.warning("Input text is empty or missing. Skipping attack.")
|
| 44 |
+
elif self.target == "output":
|
| 45 |
+
if input_data.get('output'):
|
| 46 |
+
input_data['output'] = self.attack(input_data['output'])
|
| 47 |
+
else:
|
| 48 |
+
logger.warning("Output text is empty or missing. Skipping attack.")
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError(f"Invalid target: {self.target}")
|
| 51 |
+
return input_data
|
isopro/adversarial_simulation/adversarial_environment.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adversarial Environment
|
| 3 |
+
|
| 4 |
+
This module defines the AdversarialEnvironment class, which manages adversarial agents and applies attacks to the simulation state.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import random
|
| 8 |
+
from typing import List, Dict, Any
|
| 9 |
+
from isopro.environments.simulation_environment import SimulationEnvironment
|
| 10 |
+
from .adversarial_agent import AdversarialAgent
|
| 11 |
+
from .attack_utils import get_model_and_tokenizer, create_attack, get_available_attacks
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
class AdversarialEnvironment(SimulationEnvironment):
|
| 17 |
+
def __init__(self, agent_wrapper, num_adversarial_agents: int = 1, attack_types: List[str] = None, attack_targets: List[str] = None):
|
| 18 |
+
"""
|
| 19 |
+
Initialize the AdversarialEnvironment.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
agent_wrapper: The wrapped agent to pass the adversarially modified state to.
|
| 23 |
+
num_adversarial_agents (int): The number of adversarial agents to create.
|
| 24 |
+
attack_types (List[str], optional): The types of attacks to use. If None, all available attacks will be used.
|
| 25 |
+
attack_targets (List[str], optional): The targets for the attacks ("input", "output", or both). If None, both will be used.
|
| 26 |
+
"""
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.agent_wrapper = agent_wrapper
|
| 29 |
+
self.num_adversarial_agents = num_adversarial_agents
|
| 30 |
+
self.attack_types = attack_types or get_available_attacks()
|
| 31 |
+
self.attack_targets = attack_targets or ["input", "output"]
|
| 32 |
+
self.model, self.tokenizer = get_model_and_tokenizer()
|
| 33 |
+
self._create_adversarial_agents()
|
| 34 |
+
logger.info(f"Initialized AdversarialEnvironment with {num_adversarial_agents} agents")
|
| 35 |
+
|
| 36 |
+
def _create_adversarial_agents(self):
|
| 37 |
+
"""Create adversarial agents with random attack types and targets."""
|
| 38 |
+
for i in range(self.num_adversarial_agents):
|
| 39 |
+
attack_type = random.choice(self.attack_types)
|
| 40 |
+
attack_target = random.choice(self.attack_targets)
|
| 41 |
+
attack = create_attack(attack_type, self.model, self.tokenizer)
|
| 42 |
+
agent = AdversarialAgent(name=f"Adversarial Agent {i+1} ({attack_type}, {attack_target})", attack=attack, target=attack_target)
|
| 43 |
+
self.add_agent(agent)
|
| 44 |
+
logger.info(f"Created {self.num_adversarial_agents} adversarial agents")
|
| 45 |
+
|
| 46 |
+
def step(self, sim_state: Dict[str, Any]) -> Dict[str, Any]:
|
| 47 |
+
"""
|
| 48 |
+
Apply adversarial attacks and step the environment.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
sim_state (Dict[str, Any]): The current simulation state.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
Dict[str, Any]: The updated simulation state after applying attacks and stepping the wrapped agent.
|
| 55 |
+
"""
|
| 56 |
+
# Apply adversarial attacks
|
| 57 |
+
for agent in self.agents:
|
| 58 |
+
sim_state = agent.run(sim_state)
|
| 59 |
+
|
| 60 |
+
# Pass the adversarially modified state to the wrapped agent
|
| 61 |
+
return self.agent_wrapper.step(sim_state)
|
| 62 |
+
|
| 63 |
+
def reset(self):
|
| 64 |
+
"""Reset the environment and recreate adversarial agents."""
|
| 65 |
+
super().reset()
|
| 66 |
+
self._create_adversarial_agents()
|
| 67 |
+
logger.info("Reset AdversarialEnvironment and recreated agents")
|
| 68 |
+
|
| 69 |
+
def get_attack_distribution(self) -> Dict[str, int]:
|
| 70 |
+
"""
|
| 71 |
+
Get the distribution of attack types and targets among the adversarial agents.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Dict[str, int]: A dictionary containing the count of each attack type and target.
|
| 75 |
+
"""
|
| 76 |
+
attack_counts = {f"{attack_type}_{target}": 0 for attack_type in self.attack_types for target in self.attack_targets}
|
| 77 |
+
for agent in self.agents:
|
| 78 |
+
attack_type, target = agent.name.split('(')[-1].split(')')[0].split(', ')
|
| 79 |
+
attack_counts[f"{attack_type}_{target}"] += 1
|
| 80 |
+
logger.info(f"Current attack distribution: {attack_counts}")
|
| 81 |
+
return attack_counts
|
isopro/adversarial_simulation/adversarial_simulator.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adversarial Simulator
|
| 3 |
+
|
| 4 |
+
This module provides a high-level interface for running adversarial simulations.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import List, Dict, Any
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class AdversarialSimulator:
|
| 13 |
+
def __init__(self, environment):
|
| 14 |
+
"""
|
| 15 |
+
Initialize the AdversarialSimulator.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
environment: The AdversarialEnvironment to use in the simulation.
|
| 19 |
+
"""
|
| 20 |
+
self.environment = environment
|
| 21 |
+
logger.info("Initialized AdversarialSimulator")
|
| 22 |
+
|
| 23 |
+
def run_simulation(self, input_data: List[str], num_steps: int = 1) -> List[Dict[str, Any]]:
|
| 24 |
+
"""
|
| 25 |
+
Run an adversarial simulation.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
input_data (List[str]): The list of input texts to use in the simulation.
|
| 29 |
+
num_steps (int): The number of steps to run the simulation for each input.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
List[Dict[str, Any]]: A list of simulation results, including original and perturbed inputs and outputs.
|
| 33 |
+
"""
|
| 34 |
+
results = []
|
| 35 |
+
for text in input_data:
|
| 36 |
+
sim_state = {"text": text, "output": ""}
|
| 37 |
+
original_output = self.environment.agent_wrapper.run({"text": text})
|
| 38 |
+
for _ in range(num_steps):
|
| 39 |
+
sim_state = self.environment.step(sim_state)
|
| 40 |
+
results.append({
|
| 41 |
+
"original_input": text,
|
| 42 |
+
"perturbed_input": sim_state["text"],
|
| 43 |
+
"original_output": original_output,
|
| 44 |
+
"perturbed_output": sim_state["output"]
|
| 45 |
+
})
|
| 46 |
+
logger.info(f"Completed simulation with {len(input_data)} inputs and {num_steps} steps each")
|
| 47 |
+
return results
|
isopro/adversarial_simulation/attack_utils.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Attack Utilities
|
| 3 |
+
|
| 4 |
+
This module provides utility functions for creating and managing adversarial attacks.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from typing import Tuple, Callable
|
| 9 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 10 |
+
from isoadverse.attacks.text_fgsm import text_fgsm_attack
|
| 11 |
+
from isoadverse.attacks.text_pgd import text_pgd_attack
|
| 12 |
+
from isoadverse.attacks.textbugger import textbugger_attack
|
| 13 |
+
from isoadverse.attacks.deepwordbug import deepwordbug_attack
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
def get_model_and_tokenizer(model_name: str = 'bert-base-uncased') -> Tuple[torch.nn.Module, torch.nn.Module]:
|
| 19 |
+
"""
|
| 20 |
+
Load a pre-trained model and tokenizer.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
model_name (str): The name of the model to load.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Tuple[torch.nn.Module, torch.nn.Module]: The loaded model and tokenizer.
|
| 27 |
+
"""
|
| 28 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
| 29 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
model.to(device)
|
| 32 |
+
logger.info(f"Loaded model {model_name} on {device}")
|
| 33 |
+
return model, tokenizer
|
| 34 |
+
|
| 35 |
+
def create_attack(attack_type: str, model: torch.nn.Module, tokenizer: torch.nn.Module) -> Callable:
|
| 36 |
+
"""
|
| 37 |
+
Create an attack function based on the specified attack type.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
attack_type (str): The type of attack to create.
|
| 41 |
+
model (torch.nn.Module): The model to use for the attack.
|
| 42 |
+
tokenizer (torch.nn.Module): The tokenizer to use for the attack.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Callable: The attack function.
|
| 46 |
+
"""
|
| 47 |
+
if attack_type == "fgsm":
|
| 48 |
+
return lambda x: text_fgsm_attack(model, tokenizer, x, torch.tensor([1]), epsilon=0.3)
|
| 49 |
+
elif attack_type == "pgd":
|
| 50 |
+
return lambda x: text_pgd_attack(model, tokenizer, x, torch.tensor([1]), epsilon=0.3, alpha=0.1, num_steps=10)
|
| 51 |
+
elif attack_type == "textbugger":
|
| 52 |
+
return lambda x: textbugger_attack(x, num_bugs=5)
|
| 53 |
+
elif attack_type == "deepwordbug":
|
| 54 |
+
return lambda x: deepwordbug_attack(x, num_bugs=5)
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError(f"Unknown attack type: {attack_type}")
|
| 57 |
+
|
| 58 |
+
def get_available_attacks() -> list:
|
| 59 |
+
"""
|
| 60 |
+
Get a list of available attack types.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
list: A list of available attack types.
|
| 64 |
+
"""
|
| 65 |
+
return ["fgsm", "pgd", "textbugger", "deepwordbug"]
|
isopro/adversarial_simulation/main.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List
|
| 3 |
+
from .adversarial_simulator import AdversarialSimulator
|
| 4 |
+
from .adversarial_environment import AdversarialEnvironment
|
| 5 |
+
from isopro.utils.analyze_adversarial_sim import analyze_adversarial_results, summarize_adversarial_impact
|
| 6 |
+
from isopro.agents.ai_agent import AI_Agent
|
| 7 |
+
import anthropic
|
| 8 |
+
import os
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
import json
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
# Set up logging
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
class ClaudeAgent(AI_Agent):
|
| 22 |
+
def __init__(self, name):
|
| 23 |
+
super().__init__(name)
|
| 24 |
+
self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
| 25 |
+
|
| 26 |
+
def run(self, input_data):
|
| 27 |
+
response = self.client.messages.create(
|
| 28 |
+
model="claude-3-opus-20240229",
|
| 29 |
+
max_tokens=100,
|
| 30 |
+
messages=[{"role": "user", "content": input_data['text']}]
|
| 31 |
+
)
|
| 32 |
+
return response.content[0].text
|
| 33 |
+
|
| 34 |
+
def step(self, sim_state):
|
| 35 |
+
sim_state['output'] = self.run(sim_state)
|
| 36 |
+
return sim_state
|
| 37 |
+
|
| 38 |
+
class NumpyEncoder(json.JSONEncoder):
|
| 39 |
+
def default(self, obj):
|
| 40 |
+
if isinstance(obj, np.floating):
|
| 41 |
+
return float(obj)
|
| 42 |
+
if isinstance(obj, torch.Tensor):
|
| 43 |
+
return obj.tolist()
|
| 44 |
+
return super(NumpyEncoder, self).default(obj)
|
| 45 |
+
|
| 46 |
+
def setup_logging(log_dir: str, run_id: str) -> None:
|
| 47 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 48 |
+
log_file = os.path.join(log_dir, f"adv-{run_id}.log")
|
| 49 |
+
|
| 50 |
+
logging.basicConfig(
|
| 51 |
+
level=logging.INFO,
|
| 52 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 53 |
+
handlers=[
|
| 54 |
+
logging.FileHandler(log_file),
|
| 55 |
+
logging.StreamHandler()
|
| 56 |
+
]
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def save_scores(output_dir: str, run_id: str, analysis_results: dict) -> None:
|
| 60 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 61 |
+
output_file = os.path.join(output_dir, f"adv-{run_id}.json")
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
with open(output_file, 'w') as f:
|
| 65 |
+
json.dump(analysis_results, f, indent=2, cls=NumpyEncoder)
|
| 66 |
+
|
| 67 |
+
logging.info(f"Saved analysis results to {output_file}")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
logging.error(f"Error saving analysis results: {str(e)}")
|
| 70 |
+
|
| 71 |
+
def get_sample_inputs() -> List[str]:
|
| 72 |
+
return [
|
| 73 |
+
"What is the capital of France?",
|
| 74 |
+
"How does photosynthesis work?",
|
| 75 |
+
"Explain the theory of relativity."
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
def main():
|
| 79 |
+
try:
|
| 80 |
+
run_id = datetime.now().strftime("%Y%m%d-%H%M%S")
|
| 81 |
+
|
| 82 |
+
log_dir = "logs"
|
| 83 |
+
setup_logging(log_dir, run_id)
|
| 84 |
+
|
| 85 |
+
logger = logging.getLogger(__name__)
|
| 86 |
+
logger.info(f"Starting adversarial simulation run {run_id}")
|
| 87 |
+
|
| 88 |
+
claude_agent = ClaudeAgent("Claude Agent")
|
| 89 |
+
|
| 90 |
+
# Create the AdversarialEnvironment
|
| 91 |
+
adv_env = AdversarialEnvironment(
|
| 92 |
+
agent_wrapper=claude_agent,
|
| 93 |
+
num_adversarial_agents=2,
|
| 94 |
+
attack_types=["textbugger", "deepwordbug"],
|
| 95 |
+
attack_targets=["input", "output"]
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Set up the adversarial simulator with the environment
|
| 99 |
+
simulator = AdversarialSimulator(adv_env)
|
| 100 |
+
|
| 101 |
+
input_data = get_sample_inputs()
|
| 102 |
+
|
| 103 |
+
logger.info("Starting adversarial simulation...")
|
| 104 |
+
simulation_results = simulator.run_simulation(input_data, num_steps=1)
|
| 105 |
+
|
| 106 |
+
logger.info("Analyzing simulation results...")
|
| 107 |
+
analysis_results = analyze_adversarial_results(simulation_results)
|
| 108 |
+
|
| 109 |
+
summary = summarize_adversarial_impact(analysis_results)
|
| 110 |
+
|
| 111 |
+
print("\nAdversarial Simulation Summary:")
|
| 112 |
+
print(summary)
|
| 113 |
+
|
| 114 |
+
output_dir = "output"
|
| 115 |
+
save_scores(output_dir, run_id, analysis_results)
|
| 116 |
+
|
| 117 |
+
logger.info("Simulation complete.")
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.error(f"An error occurred during the simulation: {str(e)}", exc_info=True)
|
| 121 |
+
raise
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
main()
|
isopro/agents/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Agent classes for the isopro package.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .ai_agent import AI_Agent
|
| 6 |
+
|
| 7 |
+
__all__ = ["AI_Agent"]
|
isopro/agents/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (263 Bytes). View file
|
|
|
isopro/agents/__pycache__/ai_agent.cpython-38.pyc
ADDED
|
Binary file (1.62 kB). View file
|
|
|
isopro/agents/ai_agent.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""AI Agent for Simulation Environment."""
|
| 2 |
+
from ..base.base_component import BaseComponent, agent_component
|
| 3 |
+
|
| 4 |
+
@agent_component
|
| 5 |
+
class AI_Agent(BaseComponent):
|
| 6 |
+
"""AI Agent for Simulation Environment."""
|
| 7 |
+
|
| 8 |
+
def __init__(self, name):
|
| 9 |
+
"""
|
| 10 |
+
Initialize the AI_Agent.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
name (str): The name of the agent.
|
| 14 |
+
"""
|
| 15 |
+
super().__init__(name)
|
| 16 |
+
self.components = []
|
| 17 |
+
|
| 18 |
+
def add_component(self, component):
|
| 19 |
+
"""
|
| 20 |
+
Add a component to the agent.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
component (BaseComponent): The component to add.
|
| 24 |
+
"""
|
| 25 |
+
if getattr(component, '_is_agent_component', False):
|
| 26 |
+
self.components.append(component)
|
| 27 |
+
else:
|
| 28 |
+
raise ValueError(f"Component {component} is not decorated with @agent_component")
|
| 29 |
+
|
| 30 |
+
def run(self, input_data):
|
| 31 |
+
"""
|
| 32 |
+
Run the agent's components and process input data.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
input_data (dict): The input data for the agent.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
dict: The processed output data.
|
| 39 |
+
"""
|
| 40 |
+
self.logger.info(f"Running agent: {self.name}")
|
| 41 |
+
output = input_data
|
| 42 |
+
for component in self.components:
|
| 43 |
+
output = component.run(output)
|
| 44 |
+
return output
|
isopro/base/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base classes for the isopro package.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .base_wrapper import BaseWrapper
|
| 6 |
+
from .base_component import BaseComponent
|
| 7 |
+
|
| 8 |
+
__all__ = ["BaseWrapper", "BaseComponent"]
|
isopro/base/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (326 Bytes). View file
|
|
|
isopro/base/__pycache__/base_component.cpython-38.pyc
ADDED
|
Binary file (1.44 kB). View file
|
|
|
isopro/base/__pycache__/base_wrapper.cpython-38.pyc
ADDED
|
Binary file (2.86 kB). View file
|
|
|
isopro/base/base_component.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base Component for Simulation Environment."""
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
from ..utils.logging_utils import setup_logger
|
| 4 |
+
|
| 5 |
+
class BaseComponent(ABC):
|
| 6 |
+
"""Base Component for Simulation Environment."""
|
| 7 |
+
|
| 8 |
+
def __init__(self, name):
|
| 9 |
+
"""
|
| 10 |
+
Initialize the BaseComponent.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
name (str): The name of the component.
|
| 14 |
+
"""
|
| 15 |
+
self.name = name
|
| 16 |
+
self.logger = setup_logger(f"{self.__class__.__name__}_{self.name}")
|
| 17 |
+
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def run(self):
|
| 20 |
+
"""Execute the component's main functionality."""
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
def __str__(self):
|
| 24 |
+
return f"{self.__class__.__name__}({self.name})"
|
| 25 |
+
|
| 26 |
+
def agent_component(cls):
|
| 27 |
+
"""
|
| 28 |
+
Decorator to mark a class as an agent component.
|
| 29 |
+
|
| 30 |
+
This decorator can be used to add metadata or perform
|
| 31 |
+
additional setup for agent components.
|
| 32 |
+
"""
|
| 33 |
+
cls._is_agent_component = True
|
| 34 |
+
return cls
|
isopro/base/base_wrapper.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base Wrapper for Simulation Environment."""
|
| 2 |
+
from abc import ABC, abstractmethod
|
| 3 |
+
import logging
|
| 4 |
+
from ..utils.logging_utils import setup_logger
|
| 5 |
+
|
| 6 |
+
class BaseWrapper(ABC):
|
| 7 |
+
"""Base Wrapper for Simulation Environment."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, agent):
|
| 10 |
+
"""
|
| 11 |
+
Initialize the BaseWrapper.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
agent: The agent to be wrapped.
|
| 15 |
+
"""
|
| 16 |
+
self.agent = agent
|
| 17 |
+
self.logger = setup_logger(self.__class__.__name__)
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def step(self):
|
| 21 |
+
"""Execute one time step within the environment."""
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def reset(self):
|
| 26 |
+
"""Reset the state of the environment to an initial state."""
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def render(self):
|
| 31 |
+
"""Render the environment."""
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def close(self):
|
| 36 |
+
"""Close the environment, clean up any resources."""
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def convert_to_agent_input(self, sim_state):
|
| 41 |
+
"""
|
| 42 |
+
Convert simulation state to agent input format.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
sim_state (dict): The current state of the simulation.
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
dict: The converted input for the agent.
|
| 49 |
+
"""
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
@abstractmethod
|
| 53 |
+
def convert_from_agent_output(self, agent_output):
|
| 54 |
+
"""
|
| 55 |
+
Convert agent output to simulation input format.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
agent_output (dict): The output from the agent.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
dict: The converted input for the simulation.
|
| 62 |
+
"""
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
def __getattr__(self, name):
|
| 66 |
+
"""
|
| 67 |
+
Attempt to get an attribute from the agent if it's not found in the wrapper.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
name (str): The name of the attribute.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
The requested attribute.
|
| 74 |
+
|
| 75 |
+
Raises:
|
| 76 |
+
AttributeError: If the attribute is not found in the agent or wrapper.
|
| 77 |
+
"""
|
| 78 |
+
try:
|
| 79 |
+
return getattr(self.agent, name)
|
| 80 |
+
except AttributeError:
|
| 81 |
+
self.logger.warning(f"Attribute '{name}' not found in agent or wrapper")
|
| 82 |
+
raise
|
isopro/car_simulator/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Car Reinforcement Learning Package
|
| 3 |
+
|
| 4 |
+
This package contains modules for simulating and visualizing
|
| 5 |
+
reinforcement learning agents in a car driving environment.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from .car_rl_environment import CarRLEnvironment
|
| 9 |
+
from .car_llm_agent import LLMCarRLWrapper
|
| 10 |
+
from .carviz import CarVisualization
|
| 11 |
+
|
| 12 |
+
__all__ = ['CarRLEnvironment', 'LLMCarRLWrapper', 'CarVisualization']
|
isopro/car_simulator/car_llm_agent.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from stable_baselines3 import PPO
|
| 3 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
| 4 |
+
from stable_baselines3.common.evaluation import evaluate_policy
|
| 5 |
+
import numpy as np
|
| 6 |
+
import anthropic
|
| 7 |
+
import logging
|
| 8 |
+
from typing import List, Dict, Any
|
| 9 |
+
from .car_rl_environment import CarRLEnvironment
|
| 10 |
+
import os
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
|
| 13 |
+
# Load environment variables from .env file
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
# Set up logging
|
| 17 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
class LLMCarRLWrapper(CarRLEnvironment):
|
| 21 |
+
def __init__(self, num_cars=1, time_of_day="12:00", is_rainy=False, is_weekday=True,
|
| 22 |
+
agent_prompt="You are an expert driving instructor. Provide concise guidance to improve the RL agent's driving performance.",
|
| 23 |
+
llm_call_limit=100, llm_call_frequency=100):
|
| 24 |
+
super().__init__(num_cars, time_of_day, is_rainy, is_weekday)
|
| 25 |
+
self.agent_prompt = agent_prompt
|
| 26 |
+
api_key = os.getenv('ANTHROPIC_API_KEY')
|
| 27 |
+
if not api_key:
|
| 28 |
+
raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
|
| 29 |
+
self.client = anthropic.Anthropic(api_key=api_key)
|
| 30 |
+
self.llm_call_count = 0
|
| 31 |
+
self.llm_call_limit = llm_call_limit
|
| 32 |
+
self.llm_call_frequency = llm_call_frequency
|
| 33 |
+
self.conversation_history: List[Dict[str, str]] = []
|
| 34 |
+
self.step_count = 0
|
| 35 |
+
self.current_guidance = {"action": "unknown"}
|
| 36 |
+
|
| 37 |
+
def reset(self, seed=None, options=None):
|
| 38 |
+
self.step_count = 0
|
| 39 |
+
self.current_guidance = {"action": "unknown"}
|
| 40 |
+
return super().reset(seed=seed)
|
| 41 |
+
|
| 42 |
+
def step(self, action):
|
| 43 |
+
self.step_count += 1
|
| 44 |
+
|
| 45 |
+
if self.step_count % self.llm_call_frequency == 0 and self.llm_call_count < self.llm_call_limit:
|
| 46 |
+
observation, reward, terminated, truncated, info = super().step(action)
|
| 47 |
+
self.current_guidance = self._get_llm_guidance(observation, reward, terminated)
|
| 48 |
+
self.llm_call_count += 1
|
| 49 |
+
else:
|
| 50 |
+
observation, reward, terminated, truncated, info = super().step(action)
|
| 51 |
+
|
| 52 |
+
adjusted_action = self._adjust_action_based_on_guidance(action, self.current_guidance)
|
| 53 |
+
|
| 54 |
+
return observation, reward, terminated, truncated, info
|
| 55 |
+
|
| 56 |
+
def _get_llm_guidance(self, observation, reward, terminated):
|
| 57 |
+
user_message = f"Current state: {observation}, Reward: {reward}, Terminated: {terminated}. Provide brief driving advice."
|
| 58 |
+
|
| 59 |
+
messages = self.conversation_history + [
|
| 60 |
+
{"role": "user", "content": user_message},
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
response = self.client.messages.create(
|
| 65 |
+
model="claude-3-opus-20240229",
|
| 66 |
+
max_tokens=50,
|
| 67 |
+
system=self.agent_prompt,
|
| 68 |
+
messages=messages
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
ai_response = response.content[0].text
|
| 72 |
+
self.conversation_history.append({"role": "user", "content": user_message})
|
| 73 |
+
self.conversation_history.append({"role": "assistant", "content": ai_response})
|
| 74 |
+
logger.debug(f"LLM guidance: {ai_response}")
|
| 75 |
+
return self._parse_llm_guidance(ai_response)
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.error(f"Error getting LLM guidance: {e}")
|
| 78 |
+
return {"action": "unknown"}
|
| 79 |
+
|
| 80 |
+
def _parse_llm_guidance(self, guidance):
|
| 81 |
+
guidance_lower = guidance.lower()
|
| 82 |
+
actions = {
|
| 83 |
+
"increase speed": {"action": "increase_speed"},
|
| 84 |
+
"decrease speed": {"action": "decrease_speed"},
|
| 85 |
+
"slow down": {"action": "decrease_speed"},
|
| 86 |
+
"turn left": {"action": "turn_left"},
|
| 87 |
+
"turn right": {"action": "turn_right"},
|
| 88 |
+
"stop": {"action": "stop"},
|
| 89 |
+
"start raining": {"environment": "rain", "status": True},
|
| 90 |
+
"increase traffic": {"environment": "traffic", "density": "high"}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
for key, value in actions.items():
|
| 94 |
+
if key in guidance_lower:
|
| 95 |
+
return value
|
| 96 |
+
|
| 97 |
+
return {"action": "unknown"}
|
| 98 |
+
|
| 99 |
+
def _adjust_action_based_on_guidance(self, action, guidance):
|
| 100 |
+
adjustments = {
|
| 101 |
+
"increase_speed": (0, 0.1),
|
| 102 |
+
"decrease_speed": (0, -0.1),
|
| 103 |
+
"turn_left": (1, -0.1),
|
| 104 |
+
"turn_right": (1, 0.1),
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
if guidance["action"] in adjustments:
|
| 108 |
+
index, adjustment = adjustments[guidance["action"]]
|
| 109 |
+
action[index] = np.clip(action[index] + adjustment, -1.0, 1.0)
|
| 110 |
+
|
| 111 |
+
return action
|
| 112 |
+
|
| 113 |
+
def make_env(llm_call_limit):
|
| 114 |
+
def _init():
|
| 115 |
+
return LLMCarRLWrapper(num_cars=3, time_of_day="08:00", is_rainy=False, is_weekday=True,
|
| 116 |
+
llm_call_limit=llm_call_limit)
|
| 117 |
+
return _init
|
| 118 |
+
|
| 119 |
+
def train_and_evaluate(env, total_timesteps=100000, eval_episodes=10):
|
| 120 |
+
model = PPO("MlpPolicy", env, verbose=1, learning_rate=0.0003, n_steps=2048,
|
| 121 |
+
batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2)
|
| 122 |
+
|
| 123 |
+
model.learn(total_timesteps=total_timesteps, progress_bar=True)
|
| 124 |
+
|
| 125 |
+
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=eval_episodes)
|
| 126 |
+
logger.info(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
|
| 127 |
+
|
| 128 |
+
return model, mean_reward
|
| 129 |
+
|
| 130 |
+
def main():
|
| 131 |
+
llm_call_limit = int(os.getenv('LLM_CALL_LIMIT', '10')) # Default to 10 if not set
|
| 132 |
+
|
| 133 |
+
env = DummyVecEnv([make_env(llm_call_limit)])
|
| 134 |
+
|
| 135 |
+
model, mean_reward = train_and_evaluate(env)
|
| 136 |
+
|
| 137 |
+
model.save("car_rl_llm_ppo_model")
|
| 138 |
+
|
| 139 |
+
logger.info("Training and evaluation completed.")
|
| 140 |
+
logger.info(f"Final mean reward: {mean_reward:.2f}")
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
main()
|
isopro/car_simulator/car_rl_environment.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from gymnasium import spaces
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import random
|
| 6 |
+
from typing import List, Dict, Tuple, Union
|
| 7 |
+
|
| 8 |
+
class CarRLEnvironment(gym.Env):
|
| 9 |
+
def __init__(self, num_cars=1, time_of_day="12:00", is_rainy=False, is_weekday=True):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.num_cars = num_cars
|
| 12 |
+
self.time_of_day = self.convert_time(time_of_day)
|
| 13 |
+
self.is_rainy = is_rainy
|
| 14 |
+
self.is_weekday = is_weekday
|
| 15 |
+
self.friction = 0.4 if is_rainy else 0.8
|
| 16 |
+
|
| 17 |
+
# Define action and observation spaces
|
| 18 |
+
self.action_space = spaces.Box(low=-1, high=1, shape=(num_cars * 2,), dtype=np.float32)
|
| 19 |
+
|
| 20 |
+
# Observation space: [x, y, vx, vy, angle] for each car + [time_of_day, is_rainy, is_weekday]
|
| 21 |
+
self.observation_space = spaces.Box(
|
| 22 |
+
low=-np.inf,
|
| 23 |
+
high=np.inf,
|
| 24 |
+
shape=(num_cars * 5 + 3,),
|
| 25 |
+
dtype=np.float32
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
self.cars = self.initialize_cars()
|
| 29 |
+
|
| 30 |
+
def convert_time(self, time_of_day: Union[str, float]) -> float:
|
| 31 |
+
"""Convert time to a float between 0 and 24."""
|
| 32 |
+
if isinstance(time_of_day, str):
|
| 33 |
+
try:
|
| 34 |
+
hours, minutes = map(int, time_of_day.split(':'))
|
| 35 |
+
return float(hours + minutes / 60.0)
|
| 36 |
+
except ValueError:
|
| 37 |
+
print(f"Invalid time format: {time_of_day}. Using default value of 12:00.")
|
| 38 |
+
return 12.0
|
| 39 |
+
elif isinstance(time_of_day, (int, float)):
|
| 40 |
+
return float(time_of_day) % 24.0
|
| 41 |
+
else:
|
| 42 |
+
print(f"Invalid time format: {time_of_day}. Using default value of 12:00.")
|
| 43 |
+
return 12.0
|
| 44 |
+
|
| 45 |
+
def initialize_cars(self) -> List[Dict[str, torch.Tensor]]:
|
| 46 |
+
"""Initialize car parameters."""
|
| 47 |
+
return [
|
| 48 |
+
{
|
| 49 |
+
"position": torch.tensor([random.uniform(-1, 1), random.uniform(-1, 1)], dtype=torch.float32),
|
| 50 |
+
"velocity": torch.tensor([random.uniform(-0.5, 0.5), random.uniform(-0.5, 0.5)], dtype=torch.float32),
|
| 51 |
+
"angle": torch.tensor([random.uniform(-np.pi, np.pi)], dtype=torch.float32)
|
| 52 |
+
} for _ in range(self.num_cars)
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
def reset(self, seed=None) -> Tuple[np.ndarray, Dict]:
|
| 56 |
+
super().reset(seed=seed)
|
| 57 |
+
self.cars = self.initialize_cars()
|
| 58 |
+
return self.get_observation(), {}
|
| 59 |
+
|
| 60 |
+
def get_observation(self) -> np.ndarray:
|
| 61 |
+
"""Get the current observation of the environment."""
|
| 62 |
+
car_obs = np.concatenate([
|
| 63 |
+
np.concatenate([
|
| 64 |
+
car["position"].numpy(),
|
| 65 |
+
car["velocity"].numpy(),
|
| 66 |
+
car["angle"].numpy()
|
| 67 |
+
]) for car in self.cars
|
| 68 |
+
])
|
| 69 |
+
env_obs = np.array([
|
| 70 |
+
self.time_of_day,
|
| 71 |
+
float(self.is_rainy),
|
| 72 |
+
float(self.is_weekday)
|
| 73 |
+
], dtype=np.float32)
|
| 74 |
+
return np.concatenate([car_obs, env_obs]).astype(np.float32)
|
| 75 |
+
|
| 76 |
+
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict]:
|
| 77 |
+
"""
|
| 78 |
+
Take a step in the environment.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
action (np.ndarray): Array of actions for all cars [acceleration1, steering1, acceleration2, steering2, ...]
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
observation, reward, terminated, truncated, info
|
| 85 |
+
"""
|
| 86 |
+
# Ensure action is the correct shape
|
| 87 |
+
action = np.array(action).flatten()
|
| 88 |
+
if action.shape[0] != self.num_cars * 2:
|
| 89 |
+
raise ValueError(f"Action shape {action.shape} does not match expected shape ({self.num_cars * 2},)")
|
| 90 |
+
|
| 91 |
+
for i in range(self.num_cars):
|
| 92 |
+
car_action = action[i*2:(i+1)*2]
|
| 93 |
+
self.apply_action(self.cars[i], car_action)
|
| 94 |
+
self.update_physics(self.cars[i])
|
| 95 |
+
|
| 96 |
+
observation = self.get_observation()
|
| 97 |
+
reward = self.calculate_reward()
|
| 98 |
+
terminated = self.is_terminated()
|
| 99 |
+
truncated = False
|
| 100 |
+
info = {}
|
| 101 |
+
|
| 102 |
+
return observation, reward, terminated, truncated, info
|
| 103 |
+
|
| 104 |
+
def apply_action(self, car: Dict[str, torch.Tensor], action: np.ndarray):
|
| 105 |
+
"""Apply the RL agent's action to the car."""
|
| 106 |
+
if len(action) != 2:
|
| 107 |
+
raise ValueError(f"Expected action to have 2 values, got {len(action)}")
|
| 108 |
+
|
| 109 |
+
acceleration, steering = action
|
| 110 |
+
car["velocity"] += torch.tensor([acceleration, 0.0], dtype=torch.float32) * 0.1 # Scale down the acceleration
|
| 111 |
+
car["angle"] += torch.tensor([steering], dtype=torch.float32) * 0.1 # Scale down the steering
|
| 112 |
+
|
| 113 |
+
def update_physics(self, car: Dict[str, torch.Tensor], dt: float = 0.1):
|
| 114 |
+
"""Update car position and velocity using physics simulation."""
|
| 115 |
+
# Update velocity (apply friction)
|
| 116 |
+
car["velocity"] *= (1 - self.friction * dt)
|
| 117 |
+
|
| 118 |
+
# Update position
|
| 119 |
+
car["position"] += car["velocity"] * dt
|
| 120 |
+
|
| 121 |
+
# Apply steering
|
| 122 |
+
angle = car["angle"].item()
|
| 123 |
+
rotation_matrix = torch.tensor([
|
| 124 |
+
[np.cos(angle), -np.sin(angle)],
|
| 125 |
+
[np.sin(angle), np.cos(angle)]
|
| 126 |
+
], dtype=torch.float32)
|
| 127 |
+
car["velocity"] = torch.matmul(rotation_matrix, car["velocity"])
|
| 128 |
+
|
| 129 |
+
# Bound the position to keep cars on the screen
|
| 130 |
+
car["position"] = torch.clamp(car["position"], -1, 1)
|
| 131 |
+
|
| 132 |
+
def calculate_reward(self) -> float:
|
| 133 |
+
"""Calculate the reward based on the current state."""
|
| 134 |
+
reward = 0.0
|
| 135 |
+
for car in self.cars:
|
| 136 |
+
# Reward for moving
|
| 137 |
+
speed = torch.norm(car["velocity"]).item()
|
| 138 |
+
reward += speed * 0.1
|
| 139 |
+
|
| 140 |
+
# Penalty for being close to the edge
|
| 141 |
+
distance_from_center = torch.norm(car["position"]).item()
|
| 142 |
+
reward -= distance_from_center * 0.1
|
| 143 |
+
|
| 144 |
+
return reward
|
| 145 |
+
|
| 146 |
+
def is_terminated(self) -> bool:
|
| 147 |
+
"""Check if the episode should be terminated."""
|
| 148 |
+
for car in self.cars:
|
| 149 |
+
if torch.any(torch.abs(car["position"]) > 1):
|
| 150 |
+
return True
|
| 151 |
+
return False
|
| 152 |
+
|
| 153 |
+
def render(self):
|
| 154 |
+
"""Render the environment (placeholder for potential future implementation)."""
|
| 155 |
+
pass
|
isopro/car_simulator/car_rl_model.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:85303b6b7e544f04d04cb949709ee37ac956a78f098c0390e2b210448bc446bb
|
| 3 |
+
size 164031
|
isopro/car_simulator/car_rl_training.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gymnasium as gym
|
| 2 |
+
from stable_baselines3 import PPO
|
| 3 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
| 4 |
+
from stable_baselines3.common.evaluation import evaluate_policy
|
| 5 |
+
import numpy as np
|
| 6 |
+
from .car_rl_environment import CarRLEnvironment
|
| 7 |
+
|
| 8 |
+
def make_env():
|
| 9 |
+
"""Create and return an instance of the CarRLEnvironment."""
|
| 10 |
+
return CarRLEnvironment(num_cars=3, time_of_day="08:00", is_rainy=False, is_weekday=True)
|
| 11 |
+
|
| 12 |
+
# Create a vectorized environment
|
| 13 |
+
env = DummyVecEnv([make_env])
|
| 14 |
+
|
| 15 |
+
# Initialize the PPO agent
|
| 16 |
+
model = PPO("MlpPolicy", env, verbose=1, learning_rate=0.0003, n_steps=2048, batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, ent_coef=0.0)
|
| 17 |
+
|
| 18 |
+
# Train the agent
|
| 19 |
+
total_timesteps = 1_000_000
|
| 20 |
+
model.learn(total_timesteps=total_timesteps, progress_bar=True)
|
| 21 |
+
|
| 22 |
+
# Evaluate the trained agent
|
| 23 |
+
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
|
| 24 |
+
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
|
| 25 |
+
|
| 26 |
+
# Save the trained model
|
| 27 |
+
model.save("car_rl_ppo_model")
|
| 28 |
+
|
| 29 |
+
# Test the trained agent
|
| 30 |
+
obs = env.reset()
|
| 31 |
+
for _ in range(1000):
|
| 32 |
+
action, _states = model.predict(obs, deterministic=True)
|
| 33 |
+
obs, rewards, dones, info = env.step(action)
|
| 34 |
+
env.render()
|
| 35 |
+
if dones.any():
|
| 36 |
+
obs = env.reset()
|
| 37 |
+
|
| 38 |
+
env.close()
|
isopro/car_simulator/carviz.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pygame
|
| 2 |
+
import numpy as np
|
| 3 |
+
from .car_rl_environment import CarRLEnvironment
|
| 4 |
+
from stable_baselines3 import PPO
|
| 5 |
+
import math
|
| 6 |
+
import random
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
|
| 9 |
+
# Initialize Pygame
|
| 10 |
+
pygame.init()
|
| 11 |
+
|
| 12 |
+
# Constants
|
| 13 |
+
SCREEN_WIDTH = 1000
|
| 14 |
+
SCREEN_HEIGHT = 800
|
| 15 |
+
ROAD_WIDTH = 800
|
| 16 |
+
ROAD_HEIGHT = 600
|
| 17 |
+
CAR_WIDTH = 40
|
| 18 |
+
CAR_HEIGHT = 20
|
| 19 |
+
INFO_BOX_WIDTH = 200
|
| 20 |
+
INFO_BOX_HEIGHT = 120
|
| 21 |
+
UI_PANEL_WIDTH = 200
|
| 22 |
+
|
| 23 |
+
# Colors
|
| 24 |
+
WHITE = (255, 255, 255)
|
| 25 |
+
BLACK = (0, 0, 0)
|
| 26 |
+
GRAY = (200, 200, 200)
|
| 27 |
+
RED = (255, 0, 0)
|
| 28 |
+
GREEN = (0, 255, 0)
|
| 29 |
+
BLUE = (0, 0, 255)
|
| 30 |
+
YELLOW = (255, 255, 0)
|
| 31 |
+
|
| 32 |
+
class CarVisualization:
|
| 33 |
+
def __init__(self, env, model):
|
| 34 |
+
self.env = env
|
| 35 |
+
self.unwrapped_env = env.envs[0]
|
| 36 |
+
self.model = model
|
| 37 |
+
self.screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
|
| 38 |
+
pygame.display.set_caption("Enhanced Car RL Visualization")
|
| 39 |
+
self.clock = pygame.time.Clock()
|
| 40 |
+
self.font = pygame.font.Font(None, 24)
|
| 41 |
+
self.rain = [self.RainDrop() for _ in range(100)]
|
| 42 |
+
self.obstacles = [self.Obstacle() for _ in range(5)]
|
| 43 |
+
self.time_of_day = self.float_to_datetime(self.unwrapped_env.time_of_day)
|
| 44 |
+
|
| 45 |
+
def float_to_datetime(self, time_float):
|
| 46 |
+
"""Convert a float time (0-24) to a datetime object."""
|
| 47 |
+
hours = int(time_float)
|
| 48 |
+
minutes = int((time_float - hours) * 60)
|
| 49 |
+
return datetime.min + timedelta(hours=hours, minutes=minutes)
|
| 50 |
+
|
| 51 |
+
def datetime_to_string(self, dt):
|
| 52 |
+
"""Convert a datetime object to a string in HH:MM format."""
|
| 53 |
+
return dt.strftime("%H:%M")
|
| 54 |
+
|
| 55 |
+
def draw_road(self):
|
| 56 |
+
road_rect = pygame.Rect((SCREEN_WIDTH - ROAD_WIDTH) // 2, (SCREEN_HEIGHT - ROAD_HEIGHT) // 2, ROAD_WIDTH, ROAD_HEIGHT)
|
| 57 |
+
road_color = self.get_road_color()
|
| 58 |
+
pygame.draw.rect(self.screen, road_color, road_rect)
|
| 59 |
+
|
| 60 |
+
# Draw lane markings
|
| 61 |
+
for i in range(1, 3):
|
| 62 |
+
y = (SCREEN_HEIGHT - ROAD_HEIGHT) // 2 + i * (ROAD_HEIGHT // 3)
|
| 63 |
+
pygame.draw.line(self.screen, WHITE, (road_rect.left, y), (road_rect.right, y), 2)
|
| 64 |
+
|
| 65 |
+
def get_road_color(self):
|
| 66 |
+
hour = self.time_of_day.hour
|
| 67 |
+
if 6 <= hour < 18: # Daytime
|
| 68 |
+
return GRAY
|
| 69 |
+
elif 18 <= hour < 20 or 4 <= hour < 6: # Dawn/Dusk
|
| 70 |
+
return (150, 150, 170)
|
| 71 |
+
else: # Night
|
| 72 |
+
return (100, 100, 120)
|
| 73 |
+
|
| 74 |
+
def draw_car(self, position, angle, color):
|
| 75 |
+
x, y = position
|
| 76 |
+
x = (x + 1) * ROAD_WIDTH / 2 + (SCREEN_WIDTH - ROAD_WIDTH) // 2
|
| 77 |
+
y = (y + 1) * ROAD_HEIGHT / 2 + (SCREEN_HEIGHT - ROAD_HEIGHT) // 2
|
| 78 |
+
|
| 79 |
+
car_surface = pygame.Surface((CAR_WIDTH, CAR_HEIGHT), pygame.SRCALPHA)
|
| 80 |
+
pygame.draw.rect(car_surface, color, (0, 0, CAR_WIDTH, CAR_HEIGHT))
|
| 81 |
+
pygame.draw.polygon(car_surface, BLACK, [(0, 0), (CAR_WIDTH // 2, 0), (0, CAR_HEIGHT)])
|
| 82 |
+
rotated_car = pygame.transform.rotate(car_surface, -math.degrees(angle))
|
| 83 |
+
self.screen.blit(rotated_car, rotated_car.get_rect(center=(x, y)))
|
| 84 |
+
|
| 85 |
+
def draw_info_box(self, car_index, position, action, reward):
|
| 86 |
+
x, y = position
|
| 87 |
+
x = (x + 1) * ROAD_WIDTH / 2 + (SCREEN_WIDTH - ROAD_WIDTH) // 2
|
| 88 |
+
y = (y + 1) * ROAD_HEIGHT / 2 + (SCREEN_HEIGHT - ROAD_HEIGHT) // 2
|
| 89 |
+
|
| 90 |
+
info_box = pygame.Surface((INFO_BOX_WIDTH, INFO_BOX_HEIGHT))
|
| 91 |
+
info_box.fill(WHITE)
|
| 92 |
+
pygame.draw.rect(info_box, BLACK, info_box.get_rect(), 2)
|
| 93 |
+
|
| 94 |
+
texts = [
|
| 95 |
+
f"Car {car_index + 1}",
|
| 96 |
+
f"Acceleration: {action[0]:.2f}",
|
| 97 |
+
f"Steering: {action[1]:.2f}",
|
| 98 |
+
f"Reward: {reward:.2f}",
|
| 99 |
+
f"Speed: {np.linalg.norm(self.unwrapped_env.cars[car_index]['velocity']):.2f}"
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
for i, text in enumerate(texts):
|
| 103 |
+
text_surface = self.font.render(text, True, BLACK)
|
| 104 |
+
info_box.blit(text_surface, (10, 10 + i * 25))
|
| 105 |
+
|
| 106 |
+
self.screen.blit(info_box, (x - INFO_BOX_WIDTH // 2, y - INFO_BOX_HEIGHT - 30))
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def draw_rain(self):
|
| 110 |
+
for drop in self.rain:
|
| 111 |
+
pygame.draw.line(self.screen, (200, 200, 255), (drop.x, drop.y), (drop.x, drop.y + drop.size), drop.size)
|
| 112 |
+
drop.fall()
|
| 113 |
+
|
| 114 |
+
def draw_obstacles(self):
|
| 115 |
+
for obstacle in self.obstacles:
|
| 116 |
+
pygame.draw.rect(self.screen, YELLOW, ((SCREEN_WIDTH - ROAD_WIDTH) // 2 + obstacle.x,
|
| 117 |
+
(SCREEN_HEIGHT - ROAD_HEIGHT) // 2 + obstacle.y,
|
| 118 |
+
obstacle.width, obstacle.height))
|
| 119 |
+
|
| 120 |
+
def draw_ui_panel(self):
|
| 121 |
+
panel = pygame.Surface((UI_PANEL_WIDTH, SCREEN_HEIGHT))
|
| 122 |
+
panel.fill(WHITE)
|
| 123 |
+
pygame.draw.rect(panel, BLACK, panel.get_rect(), 2)
|
| 124 |
+
|
| 125 |
+
texts = [
|
| 126 |
+
f"Time: {self.datetime_to_string(self.time_of_day)}",
|
| 127 |
+
f"Rainy: {'Yes' if self.unwrapped_env.is_rainy else 'No'}",
|
| 128 |
+
f"Weekday: {'Yes' if self.unwrapped_env.is_weekday else 'No'}",
|
| 129 |
+
"Press keys to change:",
|
| 130 |
+
"T: Time +1 hour",
|
| 131 |
+
"R: Toggle Rain",
|
| 132 |
+
"W: Toggle Weekday"
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
for i, text in enumerate(texts):
|
| 136 |
+
text_surface = self.font.render(text, True, BLACK)
|
| 137 |
+
panel.blit(text_surface, (10, 10 + i * 30))
|
| 138 |
+
|
| 139 |
+
self.screen.blit(panel, (SCREEN_WIDTH - UI_PANEL_WIDTH, 0))
|
| 140 |
+
|
| 141 |
+
def handle_events(self):
|
| 142 |
+
for event in pygame.event.get():
|
| 143 |
+
if event.type == pygame.QUIT:
|
| 144 |
+
return False
|
| 145 |
+
elif event.type == pygame.KEYDOWN:
|
| 146 |
+
if event.key == pygame.K_t:
|
| 147 |
+
self.time_of_day += timedelta(hours=1)
|
| 148 |
+
self.unwrapped_env.time_of_day = (self.time_of_day.hour + self.time_of_day.minute / 60) % 24
|
| 149 |
+
elif event.key == pygame.K_r:
|
| 150 |
+
self.unwrapped_env.is_rainy = not self.unwrapped_env.is_rainy
|
| 151 |
+
elif event.key == pygame.K_w:
|
| 152 |
+
self.unwrapped_env.is_weekday = not self.unwrapped_env.is_weekday
|
| 153 |
+
return True
|
| 154 |
+
|
| 155 |
+
class RainDrop:
|
| 156 |
+
def __init__(self):
|
| 157 |
+
self.x = random.randint(0, SCREEN_WIDTH)
|
| 158 |
+
self.y = random.randint(0, SCREEN_HEIGHT)
|
| 159 |
+
self.speed = random.randint(5, 15)
|
| 160 |
+
self.size = random.randint(1, 3)
|
| 161 |
+
|
| 162 |
+
def fall(self):
|
| 163 |
+
self.y += self.speed
|
| 164 |
+
if self.y > SCREEN_HEIGHT:
|
| 165 |
+
self.y = 0
|
| 166 |
+
self.x = random.randint(0, SCREEN_WIDTH)
|
| 167 |
+
|
| 168 |
+
class Obstacle:
|
| 169 |
+
def __init__(self):
|
| 170 |
+
self.width = random.randint(30, 60)
|
| 171 |
+
self.height = random.randint(30, 60)
|
| 172 |
+
self.x = random.randint(0, ROAD_WIDTH - self.width)
|
| 173 |
+
self.y = random.randint(0, ROAD_HEIGHT - self.height)
|
| 174 |
+
|
| 175 |
+
def run_visualization(self, num_episodes=5):
|
| 176 |
+
for episode in range(num_episodes):
|
| 177 |
+
obs = self.env.reset()
|
| 178 |
+
done = False
|
| 179 |
+
total_reward = 0
|
| 180 |
+
step = 0
|
| 181 |
+
|
| 182 |
+
while not done:
|
| 183 |
+
if not self.handle_events():
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
self.screen.fill(WHITE)
|
| 187 |
+
self.draw_road()
|
| 188 |
+
self.draw_obstacles()
|
| 189 |
+
if self.unwrapped_env.is_rainy:
|
| 190 |
+
self.draw_rain()
|
| 191 |
+
|
| 192 |
+
action, _ = self.model.predict(obs, deterministic=True)
|
| 193 |
+
obs, reward, done, info = self.env.step(action)
|
| 194 |
+
total_reward += reward[0]
|
| 195 |
+
|
| 196 |
+
for i, car in enumerate(self.unwrapped_env.cars):
|
| 197 |
+
position = car["position"].numpy()
|
| 198 |
+
angle = car["angle"].item()
|
| 199 |
+
color = (RED, GREEN, BLUE)[i % 3] # Cycle through colors for different cars
|
| 200 |
+
self.draw_car(position, angle, color)
|
| 201 |
+
self.draw_info_box(i, position, action[0][i*2:(i+1)*2], reward[0])
|
| 202 |
+
|
| 203 |
+
self.draw_ui_panel()
|
| 204 |
+
pygame.display.flip()
|
| 205 |
+
self.clock.tick(30)
|
| 206 |
+
step += 1
|
| 207 |
+
|
| 208 |
+
if done[0]:
|
| 209 |
+
break
|
| 210 |
+
|
| 211 |
+
print(f"Episode {episode + 1} finished. Total reward: {total_reward:.2f}")
|
| 212 |
+
|
| 213 |
+
pygame.quit()
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def main():
|
| 217 |
+
# Create and train the model (you might want to load a pre-trained model instead)
|
| 218 |
+
env = CarRLEnvironment(num_cars=3, time_of_day="08:00", is_rainy=False, is_weekday=True)
|
| 219 |
+
model = PPO("MlpPolicy", env, verbose=1)
|
| 220 |
+
model.learn(total_timesteps=10000) # Adjust as needed
|
| 221 |
+
|
| 222 |
+
# Create and run the visualization
|
| 223 |
+
viz = CarVisualization(env, model)
|
| 224 |
+
viz.run_visualization()
|
| 225 |
+
|
| 226 |
+
if __name__ == "__main__":
|
| 227 |
+
main()
|
isopro/car_simulator/llm_main.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from stable_baselines3 import PPO
|
| 4 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
| 5 |
+
from .car_llm_agent import LLMCarRLWrapper
|
| 6 |
+
from .car_rl_environment import CarRLEnvironment
|
| 7 |
+
from .carviz import CarVisualization
|
| 8 |
+
from stable_baselines3.common.evaluation import evaluate_policy
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
# Load environment variables from .env file
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
def parse_arguments():
|
| 15 |
+
parser = argparse.ArgumentParser(description="Car RL Simulation with LLM Integration and Visualization")
|
| 16 |
+
parser.add_argument("--num_cars", type=int, default=3, help="Number of cars in the simulation")
|
| 17 |
+
parser.add_argument("--time_of_day", type=str, default="08:00", help="Initial time of day (HH:MM format)")
|
| 18 |
+
parser.add_argument("--is_rainy", action="store_true", help="Set initial weather to rainy")
|
| 19 |
+
parser.add_argument("--is_weekday", action="store_true", help="Set initial day to weekday")
|
| 20 |
+
parser.add_argument("--train_steps", type=int, default=100000, help="Number of training steps")
|
| 21 |
+
parser.add_argument("--visualize_episodes", type=int, default=5, help="Number of episodes to visualize")
|
| 22 |
+
parser.add_argument("--load_model", type=str, help="Path to a pre-trained model to load")
|
| 23 |
+
parser.add_argument("--llm_call_limit", type=int, default=1000, help="Maximum number of LLM API calls")
|
| 24 |
+
parser.add_argument("--llm_call_frequency", type=int, default=100, help="Frequency of LLM calls (in steps)")
|
| 25 |
+
return parser.parse_args()
|
| 26 |
+
|
| 27 |
+
def make_env(num_cars, time_of_day, is_rainy, is_weekday, llm_call_limit, llm_call_frequency):
|
| 28 |
+
def _init():
|
| 29 |
+
return LLMCarRLWrapper(num_cars=num_cars, time_of_day=time_of_day, is_rainy=is_rainy,
|
| 30 |
+
is_weekday=is_weekday, llm_call_limit=llm_call_limit,
|
| 31 |
+
llm_call_frequency=llm_call_frequency)
|
| 32 |
+
return _init
|
| 33 |
+
|
| 34 |
+
def train_and_evaluate(env, total_timesteps, eval_episodes=10):
|
| 35 |
+
model = PPO("MlpPolicy", env, verbose=1, learning_rate=0.0003, n_steps=2048,
|
| 36 |
+
batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2)
|
| 37 |
+
|
| 38 |
+
model.learn(total_timesteps=total_timesteps, progress_bar=True)
|
| 39 |
+
|
| 40 |
+
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=eval_episodes)
|
| 41 |
+
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")
|
| 42 |
+
|
| 43 |
+
return model, mean_reward
|
| 44 |
+
|
| 45 |
+
def main():
|
| 46 |
+
args = parse_arguments()
|
| 47 |
+
|
| 48 |
+
# Ensure the ANTHROPIC_API_KEY is set
|
| 49 |
+
if not os.getenv('ANTHROPIC_API_KEY'):
|
| 50 |
+
raise ValueError("ANTHROPIC_API_KEY not found in environment variables")
|
| 51 |
+
|
| 52 |
+
# Create the vectorized environment with LLM integration
|
| 53 |
+
env = DummyVecEnv([make_env(args.num_cars, args.time_of_day, args.is_rainy, args.is_weekday,
|
| 54 |
+
args.llm_call_limit, args.llm_call_frequency)])
|
| 55 |
+
|
| 56 |
+
# Create or load the RL agent
|
| 57 |
+
if args.load_model and os.path.exists(args.load_model):
|
| 58 |
+
print(f"Loading pre-trained model from {args.load_model}")
|
| 59 |
+
model = PPO.load(args.load_model, env=env)
|
| 60 |
+
else:
|
| 61 |
+
print("Creating and training a new model")
|
| 62 |
+
model, mean_reward = train_and_evaluate(env, total_timesteps=args.train_steps)
|
| 63 |
+
|
| 64 |
+
# Save the trained model
|
| 65 |
+
model.save("car_rl_llm_model")
|
| 66 |
+
print("Model saved as car_rl_llm_model")
|
| 67 |
+
print(f"Final mean reward: {mean_reward:.2f}")
|
| 68 |
+
|
| 69 |
+
# Run the visualization
|
| 70 |
+
viz = CarVisualization(env, model)
|
| 71 |
+
viz.run_visualization(num_episodes=args.visualize_episodes)
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
main()
|
isopro/car_simulator/main.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from stable_baselines3 import PPO
|
| 4 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
| 5 |
+
from .car_rl_environment import CarRLEnvironment
|
| 6 |
+
from .carviz import CarVisualization
|
| 7 |
+
|
| 8 |
+
def parse_arguments():
|
| 9 |
+
parser = argparse.ArgumentParser(description="Car RL Simulation and Visualization")
|
| 10 |
+
parser.add_argument("--num_cars", type=int, default=3, help="Number of cars in the simulation")
|
| 11 |
+
parser.add_argument("--time_of_day", type=str, default="08:00", help="Initial time of day (HH:MM format)")
|
| 12 |
+
parser.add_argument("--is_rainy", action="store_true", help="Set initial weather to rainy")
|
| 13 |
+
parser.add_argument("--is_weekday", action="store_true", help="Set initial day to weekday")
|
| 14 |
+
parser.add_argument("--train_steps", type=int, default=10000, help="Number of training steps")
|
| 15 |
+
parser.add_argument("--visualize_episodes", type=int, default=5, help="Number of episodes to visualize")
|
| 16 |
+
parser.add_argument("--load_model", type=str, help="Path to a pre-trained model to load")
|
| 17 |
+
return parser.parse_args()
|
| 18 |
+
|
| 19 |
+
def make_env(num_cars, time_of_day, is_rainy, is_weekday):
|
| 20 |
+
def _init():
|
| 21 |
+
return CarRLEnvironment(num_cars=num_cars, time_of_day=time_of_day, is_rainy=is_rainy, is_weekday=is_weekday)
|
| 22 |
+
return _init
|
| 23 |
+
|
| 24 |
+
def main():
|
| 25 |
+
args = parse_arguments()
|
| 26 |
+
|
| 27 |
+
# Create the vectorized environment
|
| 28 |
+
env = DummyVecEnv([make_env(args.num_cars, args.time_of_day, args.is_rainy, args.is_weekday)])
|
| 29 |
+
|
| 30 |
+
# Create or load the RL agent
|
| 31 |
+
if args.load_model and os.path.exists(args.load_model):
|
| 32 |
+
print(f"Loading pre-trained model from {args.load_model}")
|
| 33 |
+
model = PPO.load(args.load_model, env=env)
|
| 34 |
+
else:
|
| 35 |
+
print("Creating and training a new model")
|
| 36 |
+
model = PPO("MlpPolicy", env, verbose=1)
|
| 37 |
+
model.learn(total_timesteps=args.train_steps)
|
| 38 |
+
|
| 39 |
+
# Save the trained model
|
| 40 |
+
model.save("car_rl_model")
|
| 41 |
+
print("Model saved as car_rl_model")
|
| 42 |
+
|
| 43 |
+
# Run the visualization
|
| 44 |
+
viz = CarVisualization(env, model)
|
| 45 |
+
viz.run_visualization(num_episodes=args.visualize_episodes)
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
main()
|
isopro/conversation_simulation/README.md
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Conversation Simulator
|
| 2 |
+
|
| 3 |
+
This module is part of the `isopro` package and simulates conversations between an AI assistant (either Claude or GPT-4) and various user personas. It's designed to test and demonstrate how the AI handles different types of customer service scenarios.
|
| 4 |
+
|
| 5 |
+
## Project Structure
|
| 6 |
+
|
| 7 |
+
The Conversation Simulator is located in the `conversation_simulator` folder within the `isopro` package:
|
| 8 |
+
|
| 9 |
+
```
|
| 10 |
+
isopro/
|
| 11 |
+
└── conversation_simulator/
|
| 12 |
+
├── main.py
|
| 13 |
+
├── conversation_simulator.ipynb
|
| 14 |
+
├── conversation_agent.py
|
| 15 |
+
├── conversation_environment.py
|
| 16 |
+
├── custom_persona.py
|
| 17 |
+
└── user_personas.py
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
## Prerequisites
|
| 21 |
+
|
| 22 |
+
Before you begin, ensure you have met the following requirements:
|
| 23 |
+
|
| 24 |
+
* You have installed Python 3.7 or later.
|
| 25 |
+
* You have an Anthropic API key (for Claude) and/or an OpenAI API key (for GPT-4).
|
| 26 |
+
* You have installed the `isopro` package.
|
| 27 |
+
* For the Jupyter notebook, you have Jupyter Notebook or JupyterLab installed.
|
| 28 |
+
|
| 29 |
+
## Setting up the Conversation Simulator
|
| 30 |
+
|
| 31 |
+
1. If you haven't already, install the `isopro` package:
|
| 32 |
+
```
|
| 33 |
+
pip install isopro
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
2. Create a `.env` file in your project root and add your API keys:
|
| 37 |
+
```
|
| 38 |
+
ANTHROPIC_API_KEY=your_anthropic_api_key_here
|
| 39 |
+
OPENAI_API_KEY=your_openai_api_key_here
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## Running the Conversation Simulator
|
| 43 |
+
|
| 44 |
+
You can run the Conversation Simulator either as a Python script or interactively using a Jupyter notebook.
|
| 45 |
+
|
| 46 |
+
### Using the Python Script
|
| 47 |
+
|
| 48 |
+
1. Basic usage:
|
| 49 |
+
```python
|
| 50 |
+
from isopro.conversation_simulator.main import main
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
main()
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
2. Running from the command line:
|
| 57 |
+
```
|
| 58 |
+
python -m isopro.conversation_simulator.main
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
### Using the Jupyter Notebook
|
| 62 |
+
|
| 63 |
+
Navigate to the `isopro/conversation_simulator/` directory and open the `conversation_simulator.ipynb` file using Jupyter Notebook or JupyterLab. Here's what you'll find in the notebook:
|
| 64 |
+
|
| 65 |
+
```python
|
| 66 |
+
# Conversation Simulator Jupyter Notebook
|
| 67 |
+
|
| 68 |
+
## Setup
|
| 69 |
+
|
| 70 |
+
import logging
|
| 71 |
+
from logging.handlers import RotatingFileHandler
|
| 72 |
+
import os
|
| 73 |
+
from datetime import datetime
|
| 74 |
+
from dotenv import load_dotenv
|
| 75 |
+
from isopro.conversation_simulation.conversation_simulator import ConversationSimulator
|
| 76 |
+
from isopro.conversation_simulation.custom_persona import create_custom_persona
|
| 77 |
+
|
| 78 |
+
# Load environment variables
|
| 79 |
+
load_dotenv()
|
| 80 |
+
|
| 81 |
+
# Set up logging
|
| 82 |
+
log_directory = "logs"
|
| 83 |
+
os.makedirs(log_directory, exist_ok=True)
|
| 84 |
+
log_file = os.path.join(log_directory, "conversation_simulator.log")
|
| 85 |
+
|
| 86 |
+
# Create a rotating file handler
|
| 87 |
+
file_handler = RotatingFileHandler(log_file, maxBytes=1024*1024, backupCount=5)
|
| 88 |
+
file_handler.setLevel(logging.DEBUG)
|
| 89 |
+
file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 90 |
+
file_handler.setFormatter(file_formatter)
|
| 91 |
+
|
| 92 |
+
# Create a console handler
|
| 93 |
+
console_handler = logging.StreamHandler()
|
| 94 |
+
console_handler.setLevel(logging.INFO)
|
| 95 |
+
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 96 |
+
console_handler.setFormatter(console_formatter)
|
| 97 |
+
|
| 98 |
+
# Set up the logger
|
| 99 |
+
logger = logging.getLogger()
|
| 100 |
+
logger.setLevel(logging.DEBUG)
|
| 101 |
+
logger.addHandler(file_handler)
|
| 102 |
+
logger.addHandler(console_handler)
|
| 103 |
+
|
| 104 |
+
print("Setup complete.")
|
| 105 |
+
|
| 106 |
+
## Helper Functions
|
| 107 |
+
|
| 108 |
+
def save_output(content, filename):
|
| 109 |
+
"""Save the output content to a file."""
|
| 110 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
| 111 |
+
f.write(content)
|
| 112 |
+
|
| 113 |
+
def get_user_choice():
|
| 114 |
+
"""Get user's choice of AI model."""
|
| 115 |
+
while True:
|
| 116 |
+
choice = input("Choose AI model (claude/openai): ").lower()
|
| 117 |
+
if choice in ['claude', 'openai']:
|
| 118 |
+
return choice
|
| 119 |
+
print("Invalid choice. Please enter 'claude' or 'openai'.")
|
| 120 |
+
|
| 121 |
+
print("Helper functions defined.")
|
| 122 |
+
|
| 123 |
+
## Main Simulation Function
|
| 124 |
+
|
| 125 |
+
def run_simulation():
|
| 126 |
+
# Get user's choice of AI model
|
| 127 |
+
ai_choice = get_user_choice()
|
| 128 |
+
|
| 129 |
+
# Set up the appropriate model and API key
|
| 130 |
+
if ai_choice == 'claude':
|
| 131 |
+
model = "claude-3-opus-20240229"
|
| 132 |
+
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
|
| 133 |
+
ai_name = "Claude"
|
| 134 |
+
else: # openai
|
| 135 |
+
model = "gpt-4-1106-preview"
|
| 136 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
| 137 |
+
ai_name = "GPT-4 Turbo"
|
| 138 |
+
|
| 139 |
+
# Initialize the ConversationSimulator
|
| 140 |
+
simulator = ConversationSimulator(
|
| 141 |
+
ai_prompt=f"You are {ai_name}, an AI assistant created to be helpful, harmless, and honest. You are a customer service agent for a tech company. Respond politely and professionally."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
output_content = f"Conversation Simulator using {ai_name} model: {model}\n\n"
|
| 145 |
+
|
| 146 |
+
# Run simulations with different personas
|
| 147 |
+
personas = ["upset", "human_request", "inappropriate", "incomplete_info"]
|
| 148 |
+
|
| 149 |
+
for persona in personas:
|
| 150 |
+
logger.info(f"Running simulation with {persona} persona using {ai_name}")
|
| 151 |
+
conversation_history = simulator.run_simulation(persona, num_turns=3)
|
| 152 |
+
|
| 153 |
+
output_content += f"\nConversation with {persona} persona:\n"
|
| 154 |
+
for message in conversation_history:
|
| 155 |
+
output_line = f"{message['role'].capitalize()}: {message['content']}\n"
|
| 156 |
+
output_content += output_line
|
| 157 |
+
logger.debug(output_line.strip())
|
| 158 |
+
output_content += "\n" + "-"*50 + "\n"
|
| 159 |
+
|
| 160 |
+
# Create and run a simulation with a custom persona
|
| 161 |
+
custom_persona_name = "Techie Customer"
|
| 162 |
+
custom_characteristics = ["tech-savvy", "impatient", "detail-oriented"]
|
| 163 |
+
custom_message_templates = [
|
| 164 |
+
"I've tried rebooting my device, but the error persists. Can you help?",
|
| 165 |
+
"What's the latest update on the cloud service outage?",
|
| 166 |
+
"I need specifics on the API rate limits for the enterprise plan.",
|
| 167 |
+
"The latency on your servers is unacceptable. What's being done about it?",
|
| 168 |
+
"Can you explain the technical details of your encryption method?"
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
logger.info(f"Running simulation with custom persona: {custom_persona_name} using {ai_name}")
|
| 172 |
+
custom_conversation = simulator.run_custom_simulation(
|
| 173 |
+
custom_persona_name,
|
| 174 |
+
custom_characteristics,
|
| 175 |
+
custom_message_templates,
|
| 176 |
+
num_turns=3
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
output_content += f"\nConversation with {custom_persona_name}:\n"
|
| 180 |
+
for message in custom_conversation:
|
| 181 |
+
output_line = f"{message['role'].capitalize()}: {message['content']}\n"
|
| 182 |
+
output_content += output_line
|
| 183 |
+
logger.debug(output_line.strip())
|
| 184 |
+
|
| 185 |
+
# Save the output to a file
|
| 186 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 187 |
+
output_directory = "output"
|
| 188 |
+
os.makedirs(output_directory, exist_ok=True)
|
| 189 |
+
output_file = os.path.join(output_directory, f"{ai_name.lower()}_conversation_output_{timestamp}.txt")
|
| 190 |
+
save_output(output_content, output_file)
|
| 191 |
+
logger.info(f"Output saved to {output_file}")
|
| 192 |
+
|
| 193 |
+
return output_content
|
| 194 |
+
|
| 195 |
+
print("Main simulation function defined.")
|
| 196 |
+
|
| 197 |
+
## Run the Simulation
|
| 198 |
+
|
| 199 |
+
simulation_output = run_simulation()
|
| 200 |
+
print(simulation_output)
|
| 201 |
+
|
| 202 |
+
## Analyze the Results
|
| 203 |
+
|
| 204 |
+
# Example analysis: Count the number of apologies
|
| 205 |
+
apology_count = simulation_output.lower().count("sorry") + simulation_output.lower().count("apologi")
|
| 206 |
+
print(f"Number of apologies: {apology_count}")
|
| 207 |
+
|
| 208 |
+
# Example analysis: Average length of AI responses
|
| 209 |
+
ai_responses = [line.split(": ", 1)[1] for line in simulation_output.split("\n") if line.startswith("Assistant: ")]
|
| 210 |
+
avg_response_length = sum(len(response.split()) for response in ai_responses) / len(ai_responses)
|
| 211 |
+
print(f"Average length of AI responses: {avg_response_length:.2f} words")
|
| 212 |
+
|
| 213 |
+
## Conclusion
|
| 214 |
+
|
| 215 |
+
# This notebook demonstrates how to use the Conversation Simulator from the isopro package.
|
| 216 |
+
# You can modify the personas, adjust the number of turns, or add your own analysis to
|
| 217 |
+
# further explore the capabilities of the AI models in customer service scenarios.
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
## Output and Logs
|
| 221 |
+
|
| 222 |
+
- Simulation outputs are saved in the `output` directory within your current working directory.
|
| 223 |
+
- Logs are saved in the `logs` directory within your current working directory.
|
| 224 |
+
|
| 225 |
+
## Customizing the Simulation
|
| 226 |
+
|
| 227 |
+
You can customize the simulation by modifying the `main.py` file or the Jupyter notebook:
|
| 228 |
+
|
| 229 |
+
- To change the predefined personas, modify the `personas` list.
|
| 230 |
+
- To adjust the custom persona, modify the `custom_persona_name`, `custom_characteristics`, and `custom_message_templates` variables.
|
| 231 |
+
- To change the number of turns in each conversation, modify the `num_turns` parameter in the `run_simulation` and `run_custom_simulation` method calls.
|
| 232 |
+
|
| 233 |
+
In the Jupyter notebook, you can also add new cells for additional analysis or visualization of the results.
|
| 234 |
+
|
| 235 |
+
## Troubleshooting
|
| 236 |
+
|
| 237 |
+
If you encounter any issues:
|
| 238 |
+
|
| 239 |
+
1. Make sure your API keys are correctly set in the `.env` file or environment variables.
|
| 240 |
+
2. Check the logs in the `logs` directory for detailed error messages.
|
| 241 |
+
3. Ensure you have the latest version of the `isopro` package installed.
|
| 242 |
+
4. For Jupyter notebook issues, make sure you have Jupyter installed and are running the notebook from the correct directory.
|
| 243 |
+
|
| 244 |
+
If problems persist, please open an issue in the project repository.
|
| 245 |
+
|
| 246 |
+
## Contributing
|
| 247 |
+
|
| 248 |
+
Contributions to the Conversation Simulator are welcome. Please feel free to submit a Pull Request to the `isopro` repository.
|
| 249 |
+
|
| 250 |
+
## License
|
| 251 |
+
|
| 252 |
+
This project is licensed under the MIT License - see the LICENSE file in the `isopro` package for details.
|
isopro/conversation_simulation/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Conversation Simulation Module
|
| 3 |
+
|
| 4 |
+
This module provides tools for simulating conversations with AI agents.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from .conversation_environment import ConversationEnvironment
|
| 8 |
+
from .conversation_agent import ConversationAgent
|
| 9 |
+
from .user_personas import UserPersona
|
| 10 |
+
from .custom_persona import create_custom_persona
|
| 11 |
+
from .conversation_simulator import ConversationSimulator
|
| 12 |
+
|
| 13 |
+
__all__ = [
|
| 14 |
+
"ConversationEnvironment",
|
| 15 |
+
"ConversationAgent",
|
| 16 |
+
"UserPersona",
|
| 17 |
+
"create_custom_persona",
|
| 18 |
+
"ConversationSimulator",
|
| 19 |
+
]
|
isopro/conversation_simulation/conversation_agent.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Conversation Agent
|
| 3 |
+
|
| 4 |
+
This module defines the AI agent used in the conversation simulation, using Anthropic's Claude API.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import anthropic
|
| 8 |
+
import os
|
| 9 |
+
import logging
|
| 10 |
+
from ..agents.ai_agent import AI_Agent
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
class ConversationAgent(AI_Agent):
|
| 18 |
+
def __init__(self, name, prompt, model="claude-3-opus-20240229"):
|
| 19 |
+
super().__init__(name)
|
| 20 |
+
self.prompt = prompt
|
| 21 |
+
self.model = model
|
| 22 |
+
self.client = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
| 23 |
+
logger.info(f"Initialized ConversationAgent '{name}' with Claude model {model}")
|
| 24 |
+
|
| 25 |
+
def generate_response(self, conversation_history):
|
| 26 |
+
try:
|
| 27 |
+
messages = [{"role": "user" if msg["role"] != "assistant" else "assistant", "content": msg["content"]}
|
| 28 |
+
for msg in conversation_history]
|
| 29 |
+
|
| 30 |
+
response = self.client.messages.create(
|
| 31 |
+
model=self.model,
|
| 32 |
+
max_tokens=1000,
|
| 33 |
+
system=self.prompt,
|
| 34 |
+
messages=messages
|
| 35 |
+
)
|
| 36 |
+
ai_message = response.content[0].text.strip()
|
| 37 |
+
logger.debug(f"Generated response: {ai_message}")
|
| 38 |
+
return ai_message
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.error(f"Error generating response: {e}")
|
| 41 |
+
return "I apologize, but I'm having trouble responding at the moment."
|
isopro/conversation_simulation/conversation_environment.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Conversation Environment
|
| 3 |
+
|
| 4 |
+
This module defines the environment for simulating conversations between a Claude-based AI agent and users with various personas.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from ..environments.simulation_environment import SimulationEnvironment
|
| 9 |
+
from .conversation_agent import ConversationAgent
|
| 10 |
+
from .user_personas import UserPersona
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class ConversationEnvironment(SimulationEnvironment):
|
| 15 |
+
"""
|
| 16 |
+
ConversationEnvironment
|
| 17 |
+
|
| 18 |
+
This class provides an environment for simulating conversations between Claude-based AI agents and users with various personas.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, ai_prompt="You are a helpful customer service agent. Respond politely and professionally."):
|
| 22 |
+
"""
|
| 23 |
+
Initialize the ConversationEnvironment.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
ai_prompt (str): The prompt to guide the AI agent's behavior.
|
| 27 |
+
"""
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.ai_prompt = ai_prompt
|
| 30 |
+
self.ai_agent = None
|
| 31 |
+
self.user_persona = None
|
| 32 |
+
logger.info("Initialized ConversationEnvironment")
|
| 33 |
+
|
| 34 |
+
def set_ai_agent(self, model="claude-3-opus-20240229"):
|
| 35 |
+
"""
|
| 36 |
+
Set up the Claude-based AI agent for the conversation.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
model (str): The name of the Claude model to use.
|
| 40 |
+
"""
|
| 41 |
+
self.ai_agent = ConversationAgent("Customer Service AI", self.ai_prompt, model)
|
| 42 |
+
logger.info(f"Set AI agent with Claude model: {model}")
|
| 43 |
+
def set_user_persona(self, persona_type, **kwargs):
|
| 44 |
+
"""
|
| 45 |
+
Set the user persona for the conversation.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
persona_type (str): The type of user persona to use.
|
| 49 |
+
**kwargs: Additional arguments for the user persona.
|
| 50 |
+
"""
|
| 51 |
+
self.user_persona = UserPersona.create(persona_type, **kwargs)
|
| 52 |
+
logger.info(f"Set user persona: {persona_type}")
|
| 53 |
+
|
| 54 |
+
def run_conversation(self, num_turns=5):
|
| 55 |
+
"""
|
| 56 |
+
Run a conversation between the AI agent and the user persona.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
num_turns (int): The number of conversation turns to simulate.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
list: A list of dictionaries containing the conversation history.
|
| 63 |
+
"""
|
| 64 |
+
if not self.ai_agent or not self.user_persona:
|
| 65 |
+
raise ValueError("Both AI agent and user persona must be set before running a conversation.")
|
| 66 |
+
|
| 67 |
+
conversation_history = []
|
| 68 |
+
for _ in range(num_turns):
|
| 69 |
+
user_message = self.user_persona.generate_message(conversation_history)
|
| 70 |
+
conversation_history.append({"role": "user", "content": user_message})
|
| 71 |
+
logger.debug(f"User: {user_message}")
|
| 72 |
+
|
| 73 |
+
ai_response = self.ai_agent.generate_response(conversation_history)
|
| 74 |
+
conversation_history.append({"role": "assistant", "content": ai_response})
|
| 75 |
+
logger.debug(f"AI: {ai_response}")
|
| 76 |
+
|
| 77 |
+
logger.info("Completed conversation simulation")
|
| 78 |
+
return conversation_history
|
isopro/conversation_simulation/conversation_simulator.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Conversation Simulator
|
| 3 |
+
|
| 4 |
+
This module provides a high-level interface for running conversation simulations
|
| 5 |
+
with different personas and analyzing the results using Anthropic's Claude API.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from .conversation_environment import ConversationEnvironment
|
| 10 |
+
from .custom_persona import create_custom_persona
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class ConversationSimulator:
|
| 15 |
+
"""
|
| 16 |
+
ConversationSimulator orchestrates conversation simulations with various personas using Claude.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, ai_prompt="You are a helpful customer service agent. Respond politely and professionally."):
|
| 20 |
+
"""
|
| 21 |
+
Initialize the ConversationSimulator.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
ai_prompt (str): The prompt to guide the Claude-based AI agent's behavior.
|
| 25 |
+
"""
|
| 26 |
+
self.environment = ConversationEnvironment(ai_prompt)
|
| 27 |
+
logger.info("Initialized ConversationSimulator with Claude")
|
| 28 |
+
|
| 29 |
+
def run_simulation(self, persona_type, num_turns=5, claude_model="claude-3-opus-20240229", **persona_kwargs):
|
| 30 |
+
"""
|
| 31 |
+
Run a conversation simulation with a specified persona using Claude.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
persona_type (str): The type of persona to use in the simulation.
|
| 35 |
+
num_turns (int): The number of conversation turns to simulate.
|
| 36 |
+
claude_model (str): The specific Claude model to use for the simulation.
|
| 37 |
+
**persona_kwargs: Additional arguments for creating the persona.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
list: A list of dictionaries containing the conversation history.
|
| 41 |
+
"""
|
| 42 |
+
self.environment.set_ai_agent(model=claude_model)
|
| 43 |
+
self.environment.set_user_persona(persona_type, **persona_kwargs)
|
| 44 |
+
conversation_history = self.environment.run_conversation(num_turns)
|
| 45 |
+
logger.info(f"Completed simulation with {persona_type} persona using Claude model {claude_model}")
|
| 46 |
+
return conversation_history
|
| 47 |
+
|
| 48 |
+
def run_custom_simulation(self, name, characteristics, message_templates, num_turns=5, claude_model="claude-3-opus-20240229"):
|
| 49 |
+
"""
|
| 50 |
+
Run a conversation simulation with a custom persona using Claude.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
name (str): The name of the custom persona.
|
| 54 |
+
characteristics (list): A list of characteristics that define the persona.
|
| 55 |
+
message_templates (list): A list of message templates the persona can use.
|
| 56 |
+
num_turns (int): The number of conversation turns to simulate.
|
| 57 |
+
claude_model (str): The specific Claude model to use for the simulation.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
list: A list of dictionaries containing the conversation history.
|
| 61 |
+
"""
|
| 62 |
+
custom_persona = create_custom_persona(name, characteristics, message_templates)
|
| 63 |
+
self.environment.set_ai_agent(model=claude_model)
|
| 64 |
+
self.environment.user_persona = custom_persona
|
| 65 |
+
conversation_history = self.environment.run_conversation(num_turns)
|
| 66 |
+
logger.info(f"Completed simulation with custom persona: {name} using Claude model {claude_model}")
|
| 67 |
+
return conversation_history
|
isopro/conversation_simulation/custom_persona.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom Persona
|
| 3 |
+
|
| 4 |
+
This module allows users to create custom personas for the conversation simulation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from .user_personas import UserPersona
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class CustomPersona(UserPersona):
|
| 13 |
+
"""
|
| 14 |
+
CustomPersona allows users to create their own persona with specific characteristics.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, name, characteristics, message_templates):
|
| 18 |
+
"""
|
| 19 |
+
Initialize the CustomPersona.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
name (str): The name of the custom persona.
|
| 23 |
+
characteristics (list): A list of characteristics that define the persona.
|
| 24 |
+
message_templates (list): A list of message templates the persona can use.
|
| 25 |
+
"""
|
| 26 |
+
super().__init__(name)
|
| 27 |
+
self.characteristics = characteristics
|
| 28 |
+
self.message_templates = message_templates
|
| 29 |
+
logger.info(f"Created CustomPersona: {name}")
|
| 30 |
+
|
| 31 |
+
def generate_message(self, conversation_history):
|
| 32 |
+
"""
|
| 33 |
+
Generate a message based on the custom persona's characteristics and templates.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
conversation_history (list): A list of dictionaries containing the conversation history.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
str: The generated message.
|
| 40 |
+
"""
|
| 41 |
+
import random
|
| 42 |
+
message = random.choice(self.message_templates)
|
| 43 |
+
logger.debug(f"CustomPersona '{self.name}' generated message: {message}")
|
| 44 |
+
return message
|
| 45 |
+
|
| 46 |
+
def create_custom_persona(name, characteristics, message_templates):
|
| 47 |
+
"""
|
| 48 |
+
Create a custom persona with the given characteristics and message templates.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
name (str): The name of the custom persona.
|
| 52 |
+
characteristics (list): A list of characteristics that define the persona.
|
| 53 |
+
message_templates (list): A list of message templates the persona can use.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
CustomPersona: An instance of the custom persona.
|
| 57 |
+
"""
|
| 58 |
+
return CustomPersona(name, characteristics, message_templates)
|
isopro/conversation_simulation/main.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from logging.handlers import RotatingFileHandler
|
| 3 |
+
import os
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from .conversation_simulator import ConversationSimulator
|
| 7 |
+
from .custom_persona import create_custom_persona
|
| 8 |
+
|
| 9 |
+
# Load environment variables
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
# Set up logging
|
| 13 |
+
log_directory = "logs"
|
| 14 |
+
os.makedirs(log_directory, exist_ok=True)
|
| 15 |
+
log_file = os.path.join(log_directory, "conversation_simulator.log")
|
| 16 |
+
|
| 17 |
+
# Create a rotating file handler
|
| 18 |
+
file_handler = RotatingFileHandler(log_file, maxBytes=1024*1024, backupCount=5)
|
| 19 |
+
file_handler.setLevel(logging.DEBUG)
|
| 20 |
+
file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 21 |
+
file_handler.setFormatter(file_formatter)
|
| 22 |
+
|
| 23 |
+
# Create a console handler
|
| 24 |
+
console_handler = logging.StreamHandler()
|
| 25 |
+
console_handler.setLevel(logging.INFO)
|
| 26 |
+
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| 27 |
+
console_handler.setFormatter(console_formatter)
|
| 28 |
+
|
| 29 |
+
# Set up the logger
|
| 30 |
+
logger = logging.getLogger()
|
| 31 |
+
logger.setLevel(logging.DEBUG)
|
| 32 |
+
logger.addHandler(file_handler)
|
| 33 |
+
logger.addHandler(console_handler)
|
| 34 |
+
|
| 35 |
+
def save_output(content, filename):
|
| 36 |
+
"""Save the output content to a file."""
|
| 37 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
| 38 |
+
f.write(content)
|
| 39 |
+
|
| 40 |
+
def get_user_choice():
|
| 41 |
+
"""Get user's choice of AI model."""
|
| 42 |
+
while True:
|
| 43 |
+
choice = input("Choose AI model (claude/openai): ").lower()
|
| 44 |
+
if choice in ['claude', 'openai']:
|
| 45 |
+
return choice
|
| 46 |
+
print("Invalid choice. Please enter 'claude' or 'openai'.")
|
| 47 |
+
|
| 48 |
+
def main():
|
| 49 |
+
# Get user's choice of AI model
|
| 50 |
+
ai_choice = get_user_choice()
|
| 51 |
+
|
| 52 |
+
# Set up the appropriate model and API key
|
| 53 |
+
if ai_choice == 'claude':
|
| 54 |
+
model = "claude-3-opus-20240229"
|
| 55 |
+
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
|
| 56 |
+
ai_name = "Claude"
|
| 57 |
+
else: # openai
|
| 58 |
+
model = "gpt-4-1106-preview"
|
| 59 |
+
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
| 60 |
+
ai_name = "GPT-4 Turbo"
|
| 61 |
+
|
| 62 |
+
# Initialize the ConversationSimulator
|
| 63 |
+
simulator = ConversationSimulator(
|
| 64 |
+
ai_prompt=f"You are {ai_name}, an AI assistant created to be helpful, harmless, and honest. You are a customer service agent for a tech company. Respond politely and professionally."
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
output_content = f"Conversation Simulator using {ai_name} model: {model}\n\n"
|
| 68 |
+
|
| 69 |
+
# Run simulations with different personas
|
| 70 |
+
personas = ["upset", "human_request", "inappropriate", "incomplete_info"]
|
| 71 |
+
|
| 72 |
+
for persona in personas:
|
| 73 |
+
logger.info(f"Running simulation with {persona} persona using {ai_name}")
|
| 74 |
+
conversation_history = simulator.run_simulation(persona, num_turns=3)
|
| 75 |
+
|
| 76 |
+
output_content += f"\nConversation with {persona} persona:\n"
|
| 77 |
+
for message in conversation_history:
|
| 78 |
+
output_line = f"{message['role'].capitalize()}: {message['content']}\n"
|
| 79 |
+
output_content += output_line
|
| 80 |
+
logger.debug(output_line.strip())
|
| 81 |
+
output_content += "\n" + "-"*50 + "\n"
|
| 82 |
+
|
| 83 |
+
# Create and run a simulation with a custom persona
|
| 84 |
+
custom_persona_name = "Techie Customer"
|
| 85 |
+
custom_characteristics = ["tech-savvy", "impatient", "detail-oriented"]
|
| 86 |
+
custom_message_templates = [
|
| 87 |
+
"I've tried rebooting my device, but the error persists. Can you help?",
|
| 88 |
+
"What's the latest update on the cloud service outage?",
|
| 89 |
+
"I need specifics on the API rate limits for the enterprise plan.",
|
| 90 |
+
"The latency on your servers is unacceptable. What's being done about it?",
|
| 91 |
+
"Can you explain the technical details of your encryption method?"
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
logger.info(f"Running simulation with custom persona: {custom_persona_name} using {ai_name}")
|
| 95 |
+
custom_conversation = simulator.run_custom_simulation(
|
| 96 |
+
custom_persona_name,
|
| 97 |
+
custom_characteristics,
|
| 98 |
+
custom_message_templates,
|
| 99 |
+
num_turns=3
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
output_content += f"\nConversation with {custom_persona_name}:\n"
|
| 103 |
+
for message in custom_conversation:
|
| 104 |
+
output_line = f"{message['role'].capitalize()}: {message['content']}\n"
|
| 105 |
+
output_content += output_line
|
| 106 |
+
logger.debug(output_line.strip())
|
| 107 |
+
|
| 108 |
+
# Save the output to a file
|
| 109 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 110 |
+
output_directory = "output"
|
| 111 |
+
os.makedirs(output_directory, exist_ok=True)
|
| 112 |
+
output_file = os.path.join(output_directory, f"{ai_name.lower()}_conversation_output_{timestamp}.txt")
|
| 113 |
+
save_output(output_content, output_file)
|
| 114 |
+
logger.info(f"Output saved to {output_file}")
|
| 115 |
+
|
| 116 |
+
if __name__ == "__main__":
|
| 117 |
+
main()
|
isopro/conversation_simulation/user_personas.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
User Personas
|
| 3 |
+
|
| 4 |
+
This module defines various user personas for the conversation simulation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import random
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class UserPersona:
|
| 13 |
+
"""
|
| 14 |
+
Base class for user personas in the conversation simulation.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, name):
|
| 18 |
+
self.name = name
|
| 19 |
+
|
| 20 |
+
def generate_message(self, conversation_history):
|
| 21 |
+
"""
|
| 22 |
+
Generate a message based on the persona and conversation history.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
conversation_history (list): A list of dictionaries containing the conversation history.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
str: The generated message.
|
| 29 |
+
"""
|
| 30 |
+
raise NotImplementedError("Subclasses must implement generate_message method")
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def create(persona_type, **kwargs):
|
| 34 |
+
"""
|
| 35 |
+
Factory method to create user personas.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
persona_type (str): The type of user persona to create.
|
| 39 |
+
**kwargs: Additional arguments for the user persona.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
UserPersona: An instance of the specified user persona.
|
| 43 |
+
"""
|
| 44 |
+
persona_classes = {
|
| 45 |
+
"upset": UpsetCustomer,
|
| 46 |
+
"human_request": HumanRequestCustomer,
|
| 47 |
+
"inappropriate": InappropriateCustomer,
|
| 48 |
+
"incomplete_info": IncompleteInfoCustomer,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
if persona_type not in persona_classes:
|
| 52 |
+
raise ValueError(f"Unknown persona type: {persona_type}")
|
| 53 |
+
|
| 54 |
+
return persona_classes[persona_type](**kwargs)
|
| 55 |
+
|
| 56 |
+
class UpsetCustomer(UserPersona):
|
| 57 |
+
def __init__(self):
|
| 58 |
+
super().__init__("Upset Customer")
|
| 59 |
+
self.complaints = [
|
| 60 |
+
"This is unacceptable!",
|
| 61 |
+
"I've been waiting for hours!",
|
| 62 |
+
"I want to speak to your manager!",
|
| 63 |
+
"This is the worst service I've ever experienced!",
|
| 64 |
+
"I'm extremely disappointed with your company!",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
def generate_message(self, conversation_history):
|
| 68 |
+
message = random.choice(self.complaints)
|
| 69 |
+
logger.debug(f"UpsetCustomer generated message: {message}")
|
| 70 |
+
return message
|
| 71 |
+
|
| 72 |
+
class HumanRequestCustomer(UserPersona):
|
| 73 |
+
def __init__(self):
|
| 74 |
+
super().__init__("Human Request Customer")
|
| 75 |
+
self.requests = [
|
| 76 |
+
"Can I speak to a human representative?",
|
| 77 |
+
"I don't want to talk to a bot. Get me a real person.",
|
| 78 |
+
"Is there a way to talk to an actual employee?",
|
| 79 |
+
"I need to speak with a human agent, not an AI.",
|
| 80 |
+
"Please transfer me to a live representative.",
|
| 81 |
+
]
|
| 82 |
+
|
| 83 |
+
def generate_message(self, conversation_history):
|
| 84 |
+
message = random.choice(self.requests)
|
| 85 |
+
logger.debug(f"HumanRequestCustomer generated message: {message}")
|
| 86 |
+
return message
|
| 87 |
+
|
| 88 |
+
class InappropriateCustomer(UserPersona):
|
| 89 |
+
def __init__(self):
|
| 90 |
+
super().__init__("Inappropriate Customer")
|
| 91 |
+
self.inappropriate_words = ["[INAPPROPRIATE1]", "[INAPPROPRIATE2]", "[INAPPROPRIATE3]"]
|
| 92 |
+
|
| 93 |
+
def generate_message(self, conversation_history):
|
| 94 |
+
message = f"You're a {random.choice(self.inappropriate_words)} and this service is {random.choice(self.inappropriate_words)}!"
|
| 95 |
+
logger.debug(f"InappropriateCustomer generated message: {message}")
|
| 96 |
+
return message
|
| 97 |
+
|
| 98 |
+
class IncompleteInfoCustomer(UserPersona):
|
| 99 |
+
def __init__(self):
|
| 100 |
+
super().__init__("Incomplete Info Customer")
|
| 101 |
+
self.vague_requests = [
|
| 102 |
+
"I need help with my account.",
|
| 103 |
+
"There's a problem with my order.",
|
| 104 |
+
"Something's not working right.",
|
| 105 |
+
"I have a question about your service.",
|
| 106 |
+
"Can you check on the status of my thing?",
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
def generate_message(self, conversation_history):
|
| 110 |
+
message = random.choice(self.vague_requests)
|
| 111 |
+
logger.debug(f"IncompleteInfoCustomer generated message: {message}")
|
| 112 |
+
return message
|
isopro/environments/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Environment classes for the isopro package.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .simulation_environment import SimulationEnvironment
|
| 6 |
+
from .custom_environment import CustomEnvironment
|
| 7 |
+
from .llm_orchestrator import LLMOrchestrator
|
| 8 |
+
|
| 9 |
+
__all__ = ["SimulationEnvironment", "CustomEnvironment", "LLMOrchestrator"]
|
isopro/environments/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (432 Bytes). View file
|
|
|
isopro/environments/__pycache__/custom_environment.cpython-38.pyc
ADDED
|
Binary file (4.18 kB). View file
|
|
|
isopro/environments/__pycache__/llm_orchestrator.cpython-38.pyc
ADDED
|
Binary file (7.06 kB). View file
|
|
|
isopro/environments/__pycache__/simulation_environment.cpython-38.pyc
ADDED
|
Binary file (2.04 kB). View file
|
|
|
isopro/environments/custom_environment.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Custom Environment for creating user-defined simulation environments."""
|
| 2 |
+
from ..environments.simulation_environment import SimulationEnvironment
|
| 3 |
+
from ..agents.ai_agent import AI_Agent
|
| 4 |
+
from ..base.base_component import BaseComponent, agent_component
|
| 5 |
+
|
| 6 |
+
class CustomAgent(AI_Agent):
|
| 7 |
+
"""
|
| 8 |
+
CustomAgent
|
| 9 |
+
|
| 10 |
+
This class defines a custom agent. Users can extend this class to implement their own agents.
|
| 11 |
+
"""
|
| 12 |
+
def __init__(self, name, custom_param):
|
| 13 |
+
"""
|
| 14 |
+
Initialize the CustomAgent.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
name (str): The name of the agent.
|
| 18 |
+
custom_param: A custom parameter for the agent.
|
| 19 |
+
"""
|
| 20 |
+
super().__init__(name)
|
| 21 |
+
self.custom_param = custom_param
|
| 22 |
+
|
| 23 |
+
def run(self, input_data):
|
| 24 |
+
"""
|
| 25 |
+
Run the custom agent.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
input_data (dict): The input data for the agent.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
dict: The processed output data.
|
| 32 |
+
"""
|
| 33 |
+
self.logger.info(f"Running custom agent: {self.name} with parameter: {self.custom_param}")
|
| 34 |
+
# Implement custom behavior here
|
| 35 |
+
return super().run(input_data)
|
| 36 |
+
|
| 37 |
+
@agent_component
|
| 38 |
+
class CustomComponent(BaseComponent):
|
| 39 |
+
"""
|
| 40 |
+
CustomComponent
|
| 41 |
+
|
| 42 |
+
This class defines a custom component. Users can extend this class to implement their own components.
|
| 43 |
+
"""
|
| 44 |
+
def __init__(self, name, custom_param):
|
| 45 |
+
"""
|
| 46 |
+
Initialize the CustomComponent.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
name (str): The name of the component.
|
| 50 |
+
custom_param: A custom parameter for the component.
|
| 51 |
+
"""
|
| 52 |
+
super().__init__(name)
|
| 53 |
+
self.custom_param = custom_param
|
| 54 |
+
|
| 55 |
+
def run(self, input_data):
|
| 56 |
+
"""
|
| 57 |
+
Run the custom component.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
input_data (dict): The input data for the component.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
dict: The processed output data.
|
| 64 |
+
"""
|
| 65 |
+
self.logger.info(f"Running custom component: {self.name} with parameter: {self.custom_param}")
|
| 66 |
+
# Implement custom behavior here
|
| 67 |
+
return input_data
|
| 68 |
+
|
| 69 |
+
class CustomEnvironment(SimulationEnvironment):
|
| 70 |
+
"""
|
| 71 |
+
CustomEnvironment
|
| 72 |
+
|
| 73 |
+
This class provides a template for creating a custom training environment.
|
| 74 |
+
Users can define their own agents and components, and integrate them into the simulation environment.
|
| 75 |
+
"""
|
| 76 |
+
def __init__(self, num_agents=1, custom_param=None):
|
| 77 |
+
"""
|
| 78 |
+
Initialize the CustomEnvironment.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
num_agents (int): The number of agents to create.
|
| 82 |
+
custom_param: A custom parameter for the environment.
|
| 83 |
+
"""
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.num_agents = num_agents
|
| 86 |
+
self.custom_param = custom_param
|
| 87 |
+
self._create_custom_agents()
|
| 88 |
+
|
| 89 |
+
def _create_custom_agents(self):
|
| 90 |
+
"""Create custom agents and add them to the environment."""
|
| 91 |
+
for i in range(self.num_agents):
|
| 92 |
+
agent = CustomAgent(name=f"Custom Agent {i+1}", custom_param=self.custom_param)
|
| 93 |
+
component = CustomComponent(name=f"Custom Component {i+1}", custom_param=self.custom_param)
|
| 94 |
+
agent.add_component(component)
|
| 95 |
+
self.add_agent(agent)
|
| 96 |
+
|
| 97 |
+
def add_custom_agent(self, agent_name, custom_param):
|
| 98 |
+
"""
|
| 99 |
+
Add a custom agent to the environment.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
agent_name (str): The name of the agent.
|
| 103 |
+
custom_param: A custom parameter for the agent.
|
| 104 |
+
"""
|
| 105 |
+
agent = CustomAgent(name=agent_name, custom_param=custom_param)
|
| 106 |
+
component = CustomComponent(name=f"Component for {agent_name}", custom_param=custom_param)
|
| 107 |
+
agent.add_component(component)
|
| 108 |
+
self.add_agent(agent)
|
isopro/environments/llm_orchestrator.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM Orchestrator for managing and executing LLM components in various modes.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import heapq
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
+
from typing import List, Any, Optional, Callable
|
| 9 |
+
from ..base.base_component import BaseComponent
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
class ComponentException(Exception):
|
| 14 |
+
"""Custom exception for component-related errors."""
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
class LLMOrchestrator:
|
| 18 |
+
"""
|
| 19 |
+
LLMOrchestrator manages and executes LLM components in various modes:
|
| 20 |
+
sequential, parallel, or priority-based node execution.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
"""Initialize the LLMOrchestrator with an empty list of components."""
|
| 25 |
+
self.components: List[BaseComponent] = []
|
| 26 |
+
self.priority_function: Optional[Callable[[BaseComponent, Any], int]] = None
|
| 27 |
+
|
| 28 |
+
def add_component(self, component: BaseComponent) -> None:
|
| 29 |
+
"""
|
| 30 |
+
Add a component to the orchestrator.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
component (BaseComponent): The component to be added.
|
| 34 |
+
|
| 35 |
+
Raises:
|
| 36 |
+
ValueError: If the component is None or not an instance of BaseComponent.
|
| 37 |
+
"""
|
| 38 |
+
if component is None:
|
| 39 |
+
raise ValueError("Cannot add None as a component")
|
| 40 |
+
if not isinstance(component, BaseComponent):
|
| 41 |
+
raise ValueError(f"Only BaseComponent instances can be added, got {type(component)}")
|
| 42 |
+
self.components.append(component)
|
| 43 |
+
|
| 44 |
+
def set_priority_function(self, priority_func: Callable[[BaseComponent, Any], int]) -> None:
|
| 45 |
+
"""
|
| 46 |
+
Set the priority function for node-based execution.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
priority_func (Callable[[BaseComponent, Any], int]): A function that takes a component
|
| 50 |
+
and input data, and returns an integer priority value.
|
| 51 |
+
"""
|
| 52 |
+
self.priority_function = priority_func
|
| 53 |
+
|
| 54 |
+
def run_orchestration(self, mode: str = 'sequence', input_data: Optional[Any] = None) -> List[Any]:
|
| 55 |
+
"""
|
| 56 |
+
Run the orchestration in the specified mode.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
mode (str): The execution mode ('sequence', 'parallel', or 'node').
|
| 60 |
+
input_data (Any, optional): The initial input data for the components.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
List[Any]: The results from all components.
|
| 64 |
+
|
| 65 |
+
Raises:
|
| 66 |
+
ValueError: If an invalid execution mode is specified.
|
| 67 |
+
"""
|
| 68 |
+
if not self.components:
|
| 69 |
+
logger.warning("No components to run")
|
| 70 |
+
return []
|
| 71 |
+
|
| 72 |
+
if mode == 'sequence':
|
| 73 |
+
return self._run_in_sequence(input_data)
|
| 74 |
+
elif mode == 'parallel':
|
| 75 |
+
return self._run_in_parallel(input_data)
|
| 76 |
+
elif mode == 'node':
|
| 77 |
+
return self._run_as_node(input_data)
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError("Invalid execution mode")
|
| 80 |
+
|
| 81 |
+
def _run_in_sequence(self, input_data: Any) -> List[Any]:
|
| 82 |
+
"""
|
| 83 |
+
Run components sequentially, passing the output of each as input to the next.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
input_data (Any): The initial input data for the first component.
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
List[Any]: The results from all components.
|
| 90 |
+
"""
|
| 91 |
+
logger.info("Running in sequence mode")
|
| 92 |
+
results = []
|
| 93 |
+
current_input = input_data
|
| 94 |
+
|
| 95 |
+
for component in self.components:
|
| 96 |
+
try:
|
| 97 |
+
result = self._run_component(component, current_input)
|
| 98 |
+
results.append(result)
|
| 99 |
+
current_input = result # Use the output as input for the next component
|
| 100 |
+
except ComponentException as e:
|
| 101 |
+
logger.error(f"Error: {e}")
|
| 102 |
+
results.append(str(e))
|
| 103 |
+
|
| 104 |
+
return results
|
| 105 |
+
|
| 106 |
+
def _run_in_parallel(self, input_data: Any) -> List[Any]:
|
| 107 |
+
"""
|
| 108 |
+
Run components in parallel, providing the same input to all components.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
input_data (Any): The input data for all components.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
List[Any]: The results from all components.
|
| 115 |
+
"""
|
| 116 |
+
logger.info("Running in parallel mode")
|
| 117 |
+
results = []
|
| 118 |
+
|
| 119 |
+
with ThreadPoolExecutor() as executor:
|
| 120 |
+
futures = [executor.submit(self._run_component, component, input_data)
|
| 121 |
+
for component in self.components]
|
| 122 |
+
|
| 123 |
+
for future in futures:
|
| 124 |
+
try:
|
| 125 |
+
result = future.result()
|
| 126 |
+
results.append(result)
|
| 127 |
+
except ComponentException as e:
|
| 128 |
+
logger.error(f"Error: {e}")
|
| 129 |
+
results.append(str(e))
|
| 130 |
+
|
| 131 |
+
return results
|
| 132 |
+
|
| 133 |
+
def _run_as_node(self, input_data: Any) -> List[Any]:
|
| 134 |
+
"""
|
| 135 |
+
Run components in priority-based node mode.
|
| 136 |
+
|
| 137 |
+
The priority is defined either by the LLM using reasoning on the best path
|
| 138 |
+
of solving the problem or designated by the user through the priority_function.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
input_data (Any): The input data for all components.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
List[Any]: The results from all components, ordered by priority.
|
| 145 |
+
"""
|
| 146 |
+
logger.info("Running in node mode (priority-based)")
|
| 147 |
+
results = []
|
| 148 |
+
|
| 149 |
+
if self.priority_function is None:
|
| 150 |
+
logger.warning("No priority function set. Using default priority (0) for all components.")
|
| 151 |
+
priority_queue = [(0, i, component) for i, component in enumerate(self.components)]
|
| 152 |
+
else:
|
| 153 |
+
priority_queue = [(self.priority_function(component, input_data), i, component)
|
| 154 |
+
for i, component in enumerate(self.components)]
|
| 155 |
+
|
| 156 |
+
heapq.heapify(priority_queue)
|
| 157 |
+
|
| 158 |
+
while priority_queue:
|
| 159 |
+
priority, _, component = heapq.heappop(priority_queue)
|
| 160 |
+
logger.info(f"Running component {component} with priority {priority}")
|
| 161 |
+
try:
|
| 162 |
+
result = self._run_component(component, input_data)
|
| 163 |
+
results.append(result)
|
| 164 |
+
|
| 165 |
+
# If the component changes the priority, we need to update the queue
|
| 166 |
+
if self.priority_function:
|
| 167 |
+
new_priority = self.priority_function(component, result)
|
| 168 |
+
if new_priority != priority:
|
| 169 |
+
heapq.heappush(priority_queue, (new_priority, len(results), component))
|
| 170 |
+
logger.info(f"Updated priority for component {component}: {priority} -> {new_priority}")
|
| 171 |
+
|
| 172 |
+
except ComponentException as e:
|
| 173 |
+
logger.error(f"Error: {e}")
|
| 174 |
+
results.append(str(e))
|
| 175 |
+
|
| 176 |
+
return results
|
| 177 |
+
|
| 178 |
+
def _run_component(self, component: BaseComponent, input_data: Any) -> Any:
|
| 179 |
+
"""
|
| 180 |
+
Run a single component with the given input data.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
component (BaseComponent): The component to run.
|
| 184 |
+
input_data (Any): The input data for the component.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
Any: The result of running the component.
|
| 188 |
+
|
| 189 |
+
Raises:
|
| 190 |
+
ComponentException: If the component doesn't have a callable 'run' method.
|
| 191 |
+
"""
|
| 192 |
+
if not hasattr(component, 'run') or not callable(component.run):
|
| 193 |
+
raise ComponentException(f"Component {component} does not have a callable 'run' method")
|
| 194 |
+
return component.run(input_data)
|