Spaces:
Running
Running
import asyncio | |
from typing import AsyncGenerator, Awaitable, Callable | |
from pydantic import BaseModel, ConfigDict, Field | |
from metagpt.logs import logger | |
from metagpt.roles import Role | |
from metagpt.schema import Message | |
class SubscriptionRunner(BaseModel): | |
"""A simple wrapper to manage subscription tasks for different roles using asyncio. | |
Example: | |
>>> import asyncio | |
>>> from metagpt.address import SubscriptionRunner | |
>>> from metagpt.roles import Searcher | |
>>> from metagpt.schema import Message | |
>>> async def trigger(): | |
... while True: | |
... yield Message(content="the latest news about OpenAI") | |
... await asyncio.sleep(3600 * 24) | |
>>> async def callback(msg: Message): | |
... print(msg.content) | |
>>> async def main(): | |
... pb = SubscriptionRunner() | |
... await pb.subscribe(Searcher(), trigger(), callback) | |
... await pb.run() | |
>>> asyncio.run(main()) | |
""" | |
model_config = ConfigDict(arbitrary_types_allowed=True) | |
tasks: dict[Role, asyncio.Task] = Field(default_factory=dict) | |
async def subscribe( | |
self, | |
role: Role, | |
trigger: AsyncGenerator[Message, None], | |
callback: Callable[ | |
[ | |
Message, | |
], | |
Awaitable[None], | |
], | |
): | |
"""Subscribes a role to a trigger and sets up a callback to be called with the role's response. | |
Args: | |
role: The role to subscribe. | |
trigger: An asynchronous generator that yields Messages to be processed by the role. | |
callback: An asynchronous function to be called with the response from the role. | |
""" | |
loop = asyncio.get_running_loop() | |
async def _start_role(): | |
async for msg in trigger: | |
resp = await role.run(msg) | |
await callback(resp) | |
self.tasks[role] = loop.create_task(_start_role(), name=f"Subscription-{role}") | |
async def unsubscribe(self, role: Role): | |
"""Unsubscribes a role from its trigger and cancels the associated task. | |
Args: | |
role: The role to unsubscribe. | |
""" | |
task = self.tasks.pop(role) | |
task.cancel() | |
async def run(self, raise_exception: bool = True): | |
"""Runs all subscribed tasks and handles their completion or exception. | |
Args: | |
raise_exception: _description_. Defaults to True. | |
Raises: | |
task.exception: _description_ | |
""" | |
while True: | |
for role, task in self.tasks.items(): | |
if task.done(): | |
if task.exception(): | |
if raise_exception: | |
raise task.exception() | |
logger.opt(exception=task.exception()).error(f"Task {task.get_name()} run error") | |
else: | |
logger.warning( | |
f"Task {task.get_name()} has completed. " | |
"If this is unexpected behavior, please check the trigger function." | |
) | |
self.tasks.pop(role) | |
break | |
else: | |
await asyncio.sleep(1) | |