From 6049aada71bae0cc5515a20d2f1a35b395c91a14 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 11 Sep 2024 13:57:39 -0700 Subject: [PATCH] migrate agentic system --- llama_toolchain/agentic_system/api/api.py | 11 +++++++++- llama_toolchain/agentic_system/client.py | 6 ++---- .../meta_reference/agent_instance.py | 2 +- .../meta_reference/agentic_system.py | 20 ++++++++++++++++++- .../inference/meta_reference/inference.py | 7 +++++++ 5 files changed, 39 insertions(+), 7 deletions(-) diff --git a/llama_toolchain/agentic_system/api/api.py b/llama_toolchain/agentic_system/api/api.py index b8be54861..95af3727b 100644 --- a/llama_toolchain/agentic_system/api/api.py +++ b/llama_toolchain/agentic_system/api/api.py @@ -416,7 +416,16 @@ class AgenticSystem(Protocol): @webmethod(route="/agentic_system/turn/create") async def create_agentic_system_turn( self, - request: AgenticSystemTurnCreateRequest, + agent_id: str, + session_id: str, + messages: List[ + Union[ + UserMessage, + ToolResponseMessage, + ] + ], + attachments: Optional[List[Attachment]] = None, + stream: Optional[bool] = False, ) -> AgenticSystemTurnResponseStreamChunk: ... @webmethod(route="/agentic_system/turn/get") diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index b47e402f0..a1ba6cb48 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -73,9 +73,7 @@ class AgenticSystemClient(AgenticSystem): async with client.stream( "POST", f"{self.base_url}/agentic_system/turn/create", - json={ - "request": encodable_dict(request), - }, + json=encodable_dict(request), headers={"Content-Type": "application/json"}, timeout=20, ) as response: @@ -134,7 +132,7 @@ async def run_main(host: str, port: int): api = AgenticSystemClient(f"http://{host}:{port}") tool_definitions = [ - SearchToolDefinition(engine=SearchEngineType.bing), + SearchToolDefinition(engine=SearchEngineType.brave), WolframAlphaToolDefinition(), CodeInterpreterToolDefinition(), ] diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 36c3d19e8..f9a2d20dd 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -400,7 +400,7 @@ class ChatAgent(ShieldRunnerMixin): tool_calls = [] content = "" stop_reason = None - async for chunk in self.inference_api.chat_completion(req): + async for chunk in self.inference_api.chat_completion_impl(req): event = chunk.event if event.event_type == ChatCompletionResponseEventType.start: continue diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py index 9caa3a75b..3990ab58a 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -114,8 +114,26 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): async def create_agentic_system_turn( self, - request: AgenticSystemTurnCreateRequest, + agent_id: str, + session_id: str, + messages: List[ + Union[ + UserMessage, + ToolResponseMessage, + ] + ], + attachments: Optional[List[Attachment]] = None, + stream: Optional[bool] = False, ) -> AsyncGenerator: + # wrapper request to make it easier to pass around (internal only, not exposed to API) + request = AgenticSystemTurnCreateRequest( + agent_id=agent_id, + session_id=session_id, + messages=messages, + attachments=attachments, + stream=stream, + ) + agent_id = request.agent_id assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found" agent = AGENT_INSTANCES_BY_ID[agent_id] diff --git a/llama_toolchain/inference/meta_reference/inference.py b/llama_toolchain/inference/meta_reference/inference.py index c86c0db8b..b54e2f3f4 100644 --- a/llama_toolchain/inference/meta_reference/inference.py +++ b/llama_toolchain/inference/meta_reference/inference.py @@ -77,6 +77,13 @@ class MetaReferenceInferenceImpl(Inference): logprobs=logprobs, ) + return self._chat_completion(request) + + async def chat_completion_impl( + self, request: ChatCompletionRequest + ) -> AsyncIterator[ + Union[ChatCompletionResponseStreamChunk, ChatCompletionResponse] + ]: messages = prepare_messages(request) model = resolve_model(request.model) if model is None: