From 8efe61471972650c94a909371cd2a8827bec735a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 24 Aug 2024 22:07:06 -0700 Subject: [PATCH] re-work tool definitions, fix FastAPI issues, fix tool regressions --- .../agentic_system/api/datatypes.py | 1 + llama_toolchain/agentic_system/client.py | 69 ++++++++----------- .../meta_reference/agent_instance.py | 44 ++++++++---- .../meta_reference/agentic_system.py | 69 ++++++++++--------- llama_toolchain/distribution/registry.py | 29 +++++++- llama_toolchain/distribution/server.py | 4 +- llama_toolchain/inference/client.py | 9 ++- .../memory/meta_reference/faiss/faiss.py | 1 - llama_toolchain/safety/client.py | 9 ++- llama_toolchain/tools/builtin.py | 7 +- llama_toolchain/tools/custom/datatypes.py | 6 -- 11 files changed, 144 insertions(+), 104 deletions(-) diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py index cb99d80fc..c22d71635 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -77,6 +77,7 @@ class FunctionCallToolDefinition(ToolDefinitionCommon): type: Literal[AgenticSystemTool.function_call.value] = ( AgenticSystemTool.function_call.value ) + function_name: str description: str parameters: Dict[str, ToolParamDefinition] remote_execution: Optional[RestAPIExecutionConfig] = None diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index e3f7add44..e2adac495 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -6,49 +6,42 @@ import asyncio import json - from typing import AsyncGenerator import fire import httpx -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - SamplingParams, - ToolParamDefinition, - ToolPromptFormat, - UserMessage, -) +from pydantic import BaseModel from termcolor import cprint -from llama_toolchain.agentic_system.event_logger import EventLogger -from .api import ( - AgentConfig, - AgenticSystem, - AgenticSystemCreateResponse, - AgenticSystemSessionCreateResponse, - AgenticSystemToolDefinition, - AgenticSystemTurnCreateRequest, - AgenticSystemTurnResponseStreamChunk, -) +from llama_models.llama3.api.datatypes import * # noqa: F403 +from .api import * # noqa: F403 + +from .event_logger import EventLogger async def get_client_impl(base_url: str): return AgenticSystemClient(base_url) +def encodable_dict(d: BaseModel): + return json.loads(d.json()) + + class AgenticSystemClient(AgenticSystem): def __init__(self, base_url: str): self.base_url = base_url async def create_agentic_system( - self, request: AgenticSystemCreateRequest + self, agent_config: AgentConfig ) -> AgenticSystemCreateResponse: async with httpx.AsyncClient() as client: response = await client.post( f"{self.base_url}/agentic_system/create", - data=request.json(), + json={ + "agent_config": encodable_dict(agent_config), + }, headers={"Content-Type": "application/json"}, ) response.raise_for_status() @@ -56,12 +49,16 @@ class AgenticSystemClient(AgenticSystem): async def create_agentic_system_session( self, - request: AgenticSystemSessionCreateRequest, + agent_id: str, + session_name: str, ) -> AgenticSystemSessionCreateResponse: async with httpx.AsyncClient() as client: response = await client.post( f"{self.base_url}/agentic_system/session/create", - data=request.json(), + json={ + "agent_id": agent_id, + "session_name": session_name, + }, headers={"Content-Type": "application/json"}, ) response.raise_for_status() @@ -75,7 +72,9 @@ class AgenticSystemClient(AgenticSystem): async with client.stream( "POST", f"{self.base_url}/agentic_system/turn/create", - data=request.json(), + json={ + "request": encodable_dict(request), + }, headers={"Content-Type": "application/json"}, timeout=20, ) as response: @@ -96,19 +95,13 @@ async def run_main(host: str, port: int): api = AgenticSystemClient(f"http://{host}:{port}") tool_definitions = [ - AgenticSystemToolDefinition( - tool_name=BuiltinTool.brave_search, - ), - AgenticSystemToolDefinition( - tool_name=BuiltinTool.wolfram_alpha, - ), - AgenticSystemToolDefinition( - tool_name=BuiltinTool.code_interpreter, - ), + BraveSearchToolDefinition(), + WolframAlphaToolDefinition(), + CodeInterpreterToolDefinition(), ] tool_definitions += [ - AgenticSystemToolDefinition( - tool_name="get_boiling_point", + FunctionCallToolDefinition( + function_name="get_boiling_point", description="Get the boiling point of a imaginary liquids (eg. polyjuice)", parameters={ "liquid_name": ToolParamDefinition( @@ -128,12 +121,10 @@ async def run_main(host: str, port: int): agent_config = AgentConfig( model="Meta-Llama3.1-8B-Instruct", instructions="You are a helpful assistant", - sampling_params=SamplingParams(), + sampling_params=SamplingParams(temperature=1.0, top_p=0.9), tools=tool_definitions, - input_shields=[], - output_shields=[], - debug_prefix_messages=[], - tool_prompt_format=ToolPromptFormat.json, + tool_choice=ToolChoice.auto, + tool_prompt_format=ToolPromptFormat.function_tag, ) create_response = await api.create_agentic_system(agent_config) diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 37d05e8a2..427f96be8 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -10,8 +10,6 @@ import uuid from datetime import datetime from typing import AsyncGenerator, List -from llama_models.llama3.api.datatypes import ToolPromptFormat - from termcolor import cprint from llama_toolchain.agentic_system.api import * # noqa: F403 @@ -20,7 +18,10 @@ from llama_toolchain.memory.api import * # noqa: F403 from llama_toolchain.safety.api import * # noqa: F403 from llama_toolchain.tools.base import BaseTool -from llama_toolchain.tools.builtin import SingleMessageBuiltinTool +from llama_toolchain.tools.builtin import ( + interpret_content_as_attachment, + SingleMessageBuiltinTool, +) from .safety import SafetyException, ShieldRunnerMixin @@ -192,7 +193,7 @@ class ChatAgent(ShieldRunnerMixin): yield res async for res in self._run( - turn_id, session, input_messages, attachments, sampling_params, stream + session, turn_id, input_messages, attachments, sampling_params, stream ): if isinstance(res, bool): return @@ -358,7 +359,7 @@ class ChatAgent(ShieldRunnerMixin): req = ChatCompletionRequest( model=self.agent_config.model, messages=input_messages, - tools=self.agent_config.tools, + tools=self._get_tools(), tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, sampling_params=sampling_params, @@ -555,17 +556,13 @@ class ChatAgent(ShieldRunnerMixin): yield False return - if isinstance(result_message.content, Attachment): + if out_attachment := interpret_content_as_attachment( + result_message.content + ): # NOTE: when we push this message back to the model, the model may ignore the # attached file path etc. since the model is trained to only provide a user message # with the summary. We keep all generated attachments and then attach them to final message - output_attachments.append(result_message.content) - elif isinstance(result_message.content, list) or isinstance( - result_message.content, tuple - ): - for c in result_message.content: - if isinstance(c, Attachment): - output_attachments.append(c) + output_attachments.append(out_attachment) input_messages = input_messages + [message, result_message] @@ -667,6 +664,27 @@ class ChatAgent(ShieldRunnerMixin): "\n=== END-RETRIEVED-CONTEXT ===\n", ] + def _get_tools(self) -> List[ToolDefinition]: + ret = [] + for t in self.agent_config.tools: + if isinstance(t, BraveSearchToolDefinition): + ret.append(ToolDefinition(tool_name=BuiltinTool.brave_search)) + elif isinstance(t, WolframAlphaToolDefinition): + ret.append(ToolDefinition(tool_name=BuiltinTool.wolfram_alpha)) + elif isinstance(t, PhotogenToolDefinition): + ret.append(ToolDefinition(tool_name=BuiltinTool.photogen)) + elif isinstance(t, CodeInterpreterToolDefinition): + ret.append(ToolDefinition(tool_name=BuiltinTool.code_interpreter)) + elif isinstance(t, FunctionCallToolDefinition): + ret.append( + ToolDefinition( + tool_name=t.function_name, + description=t.description, + parameters=t.parameters, + ) + ) + return ret + def attachment_message(urls: List[URL]) -> ToolResponseMessage: content = [] diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py index d89058abe..52ebd1ec7 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -12,7 +12,6 @@ from typing import AsyncGenerator, Dict from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.inference.api import Inference -from llama_toolchain.inference.api.datatypes import BuiltinTool from llama_toolchain.memory.api import Memory from llama_toolchain.safety.api import Safety from llama_toolchain.agentic_system.api import * # noqa: F403 @@ -42,6 +41,7 @@ async def get_provider_impl( impl = MetaReferenceAgenticSystemImpl( config, deps[Api.inference], + deps[Api.memory], deps[Api.safety], ) await impl.initialize() @@ -56,54 +56,55 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): self, config: MetaReferenceImplConfig, inference_api: Inference, - safety_api: Safety, memory_api: Memory, + safety_api: Safety, ): self.config = config self.inference_api = inference_api - self.safety_api = safety_api self.memory_api = memory_api + self.safety_api = safety_api async def initialize(self) -> None: pass async def create_agentic_system( self, - request: AgenticSystemCreateRequest, + agent_config: AgentConfig, ) -> AgenticSystemCreateResponse: agent_id = str(uuid.uuid4()) builtin_tools = [] - cfg = request.agent_config - for dfn in cfg.tools: - if isinstance(dfn.tool_name, BuiltinTool): - if dfn.tool_name == BuiltinTool.wolfram_alpha: - key = self.config.wolfram_api_key - if not key: - raise ValueError("Wolfram API key not defined in config") - tool = WolframAlphaTool(key) - elif dfn.tool_name == BuiltinTool.brave_search: - key = self.config.brave_search_api_key - if not key: - raise ValueError("Brave API key not defined in config") - tool = BraveSearchTool(key) - elif dfn.tool_name == BuiltinTool.code_interpreter: - tool = CodeInterpreterTool() - elif dfn.tool_name == BuiltinTool.photogen: - tool = PhotogenTool( - dump_dir="/tmp/photogen_dump_" + os.environ["USER"], - ) - else: - raise ValueError(f"Unknown builtin tool: {dfn.tool_name}") - - builtin_tools.append( - with_safety( - tool, self.safety_api, dfn.input_shields, dfn.output_shields - ) + for tool_defn in agent_config.tools: + if isinstance(tool_defn, WolframAlphaToolDefinition): + key = self.config.wolfram_api_key + if not key: + raise ValueError("Wolfram API key not defined in config") + tool = WolframAlphaTool(key) + elif isinstance(tool_defn, BraveSearchToolDefinition): + key = self.config.brave_search_api_key + if not key: + raise ValueError("Brave API key not defined in config") + tool = BraveSearchTool(key) + elif isinstance(tool_defn, CodeInterpreterToolDefinition): + tool = CodeInterpreterTool() + elif isinstance(tool_defn, PhotogenToolDefinition): + tool = PhotogenTool( + dump_dir="/tmp/photogen_dump_" + os.environ["USER"], ) + else: + continue + + builtin_tools.append( + with_safety( + tool, + self.safety_api, + tool_defn.input_shields, + tool_defn.output_shields, + ) + ) AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent( - agent_config=cfg, + agent_config=agent_config, inference_api=self.inference_api, safety_api=self.safety_api, memory_api=self.memory_api, @@ -116,13 +117,13 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): async def create_agentic_system_session( self, - request: AgenticSystemSessionCreateRequest, + agent_id: str, + session_name: str, ) -> AgenticSystemSessionCreateResponse: - agent_id = request.agent_id assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found" agent = AGENT_INSTANCES_BY_ID[agent_id] - session = agent.create_session(request.session_name) + session = agent.create_session(session_name) return AgenticSystemSessionCreateResponse( session_id=session.session_id, ) diff --git a/llama_toolchain/distribution/registry.py b/llama_toolchain/distribution/registry.py index 33d6e8e2a..296ee3103 100644 --- a/llama_toolchain/distribution/registry.py +++ b/llama_toolchain/distribution/registry.py @@ -52,12 +52,37 @@ def available_distribution_specs() -> List[DistributionSpec]: }, ), DistributionSpec( - spec_id="test-memory", - description="Just a test distribution spec for testing memory bank APIs", + spec_id="test-agentic", + description="Test agentic with others as remote", provider_specs={ + Api.agentic_system: providers[Api.agentic_system]["meta-reference"], + Api.inference: remote_spec(Api.inference), + Api.memory: remote_spec(Api.memory), + Api.safety: remote_spec(Api.safety), + }, + ), + DistributionSpec( + spec_id="test-inference", + description="Test inference provider", + provider_specs={ + Api.inference: providers[Api.inference]["meta-reference"], + }, + ), + DistributionSpec( + spec_id="test-memory", + description="Test memory provider", + provider_specs={ + Api.inference: providers[Api.inference]["meta-reference"], Api.memory: providers[Api.memory]["meta-reference-faiss"], }, ), + DistributionSpec( + spec_id="test-safety", + description="Test safety provider", + provider_specs={ + Api.safety: providers[Api.safety]["meta-reference"], + }, + ), ] diff --git a/llama_toolchain/distribution/server.py b/llama_toolchain/distribution/server.py index 90c5a9a0f..dd92fd43e 100644 --- a/llama_toolchain/distribution/server.py +++ b/llama_toolchain/distribution/server.py @@ -214,7 +214,9 @@ def create_dynamic_typed_route(func: Any, method: str): # and some in the body endpoint.__signature__ = sig.replace( parameters=[ - param.replace(annotation=Annotated[param.annotation, Body()]) + param.replace( + annotation=Annotated[param.annotation, Body(..., embed=True)] + ) for param in sig.parameters.values() ] ) diff --git a/llama_toolchain/inference/client.py b/llama_toolchain/inference/client.py index aa84f906d..ec7ed859b 100644 --- a/llama_toolchain/inference/client.py +++ b/llama_toolchain/inference/client.py @@ -10,6 +10,7 @@ from typing import AsyncGenerator import fire import httpx +from pydantic import BaseModel from termcolor import cprint from .api import ( @@ -27,6 +28,10 @@ async def get_client_impl(base_url: str): return InferenceClient(base_url) +def encodable_dict(d: BaseModel): + return json.loads(d.json()) + + class InferenceClient(Inference): def __init__(self, base_url: str): print(f"Initializing client for {base_url}") @@ -46,7 +51,9 @@ class InferenceClient(Inference): async with client.stream( "POST", f"{self.base_url}/inference/chat_completion", - data=request.json(), + json={ + "request": encodable_dict(request), + }, headers={"Content-Type": "application/json"}, timeout=20, ) as response: diff --git a/llama_toolchain/memory/meta_reference/faiss/faiss.py b/llama_toolchain/memory/meta_reference/faiss/faiss.py index 0558a6eda..85a92f35f 100644 --- a/llama_toolchain/memory/meta_reference/faiss/faiss.py +++ b/llama_toolchain/memory/meta_reference/faiss/faiss.py @@ -160,7 +160,6 @@ class FaissMemoryImpl(Memory): config: MemoryBankConfig, url: Optional[URL] = None, ) -> MemoryBank: - print("Creating memory bank") assert url is None, "URL is not supported for this implementation" assert ( config.type == MemoryBankType.vector.value diff --git a/llama_toolchain/safety/client.py b/llama_toolchain/safety/client.py index 5d86f9291..0fbc4c7c0 100644 --- a/llama_toolchain/safety/client.py +++ b/llama_toolchain/safety/client.py @@ -10,6 +10,7 @@ import fire import httpx from llama_models.llama3.api.datatypes import UserMessage +from pydantic import BaseModel from termcolor import cprint from .api import ( @@ -25,6 +26,10 @@ async def get_client_impl(base_url: str): return SafetyClient(base_url) +def encodable_dict(d: BaseModel): + return json.loads(d.json()) + + class SafetyClient(Safety): def __init__(self, base_url: str): print(f"Initializing client for {base_url}") @@ -40,7 +45,9 @@ class SafetyClient(Safety): async with httpx.AsyncClient() as client: response = await client.post( f"{self.base_url}/safety/run_shields", - data=request.json(), + json={ + "request": encodable_dict(request), + }, headers={"Content-Type": "application/json"}, timeout=20, ) diff --git a/llama_toolchain/tools/builtin.py b/llama_toolchain/tools/builtin.py index c13af125f..e5e71187f 100644 --- a/llama_toolchain/tools/builtin.py +++ b/llama_toolchain/tools/builtin.py @@ -22,6 +22,7 @@ from .ipython_tool.code_execution import ( ) from llama_toolchain.inference.api import * # noqa: F403 +from llama_toolchain.agentic_system.api import * # noqa: F403 from .base import BaseTool @@ -55,9 +56,6 @@ class SingleMessageBuiltinTool(BaseTool): tool_name=tool_call.tool_name, content=response, ) - if attachment := interpret_content_as_attachment(response): - message.content = attachment - return [message] @abstractmethod @@ -316,7 +314,4 @@ class CodeInterpreterTool(BaseTool): tool_name=tool_call.tool_name, content="\n".join(pieces), ) - if attachment := interpret_content_as_attachment(res["stdout"]): - message.content = attachment - return [message] diff --git a/llama_toolchain/tools/custom/datatypes.py b/llama_toolchain/tools/custom/datatypes.py index d2a97376d..a7fe34e9b 100644 --- a/llama_toolchain/tools/custom/datatypes.py +++ b/llama_toolchain/tools/custom/datatypes.py @@ -12,9 +12,6 @@ from typing import Dict, List from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_toolchain.agentic_system.api import * # noqa: F403 -# TODO: this is symptomatic of us needing to pull more tooling related utilities -from llama_toolchain.tools.builtin import interpret_content_as_attachment - class CustomTool: """ @@ -94,9 +91,6 @@ class SingleMessageCustomTool(CustomTool): tool_name=tool_call.tool_name, content=response_str, ) - if attachment := interpret_content_as_attachment(response_str): - message.content = attachment - return [message] @abstractmethod