From 31289e3f4786a28d89fdcd1641e601364aba2703 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 22 Aug 2024 17:44:56 -0700 Subject: [PATCH] InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool --- .../agentic_system/api/datatypes.py | 37 +++++++++++++--- .../agentic_system/api/endpoints.py | 2 +- .../meta_reference/agent_instance.py | 43 +++++++++---------- .../meta_reference/agentic_system.py | 10 +---- llama_toolchain/memory/api/endpoints.py | 4 +- 5 files changed, 56 insertions(+), 40 deletions(-) diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py index 176a1f467..e529e7e6a 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -10,7 +10,7 @@ from typing import Any, Dict, List, Literal, Optional, Union from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, validator from typing_extensions import Annotated from llama_toolchain.common.deployment_types import * # noqa: F403 @@ -20,7 +20,31 @@ from llama_toolchain.memory.api.datatypes import * # noqa: F403 @json_schema_type -class AgenticSystemToolDefinition(ToolDefinition): +class Attachment(BaseModel): + url: URL + mime_type: str + + +class AgenticSystemBuiltinTool(BuiltinTool): + memory = "memory" + + +@json_schema_type +class AgenticSystemToolDefinition(BaseModel): + tool_name: Union[AgenticSystemBuiltinTool, str] + description: Optional[str] = None + parameters: Optional[Dict[str, ToolParamDefinition]] = None + + @validator("tool_name", pre=True) + @classmethod + def validate_field(cls, v): + if isinstance(v, str): + try: + return AgenticSystemBuiltinTool(v) + except ValueError: + return v + return v + execution_config: Optional[RestAPIExecutionConfig] = None input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) @@ -124,11 +148,7 @@ class MemoryConfig(BaseModel): class AgentConfigCommon(BaseModel): sampling_params: Optional[SamplingParams] = SamplingParams() - # zero-shot or built-in tool configurations as input to the model - available_tools: Optional[List[AgenticSystemToolDefinition]] = Field( - default_factory=list - ) - memory: Optional[List[MemoryConfig]] = Field(default_factory=list) + memory_configs: Optional[List[MemoryConfig]] = Field(default_factory=list) input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) @@ -136,6 +156,9 @@ class AgentConfigCommon(BaseModel): # if you completely want to replace the messages prefixed by the system, # this is debug only debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list) + + tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list) + tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto) tool_prompt_format: Optional[ToolPromptFormat] = Field( default=ToolPromptFormat.json ) diff --git a/llama_toolchain/agentic_system/api/endpoints.py b/llama_toolchain/agentic_system/api/endpoints.py index 10acce07e..6ef35c30e 100644 --- a/llama_toolchain/agentic_system/api/endpoints.py +++ b/llama_toolchain/agentic_system/api/endpoints.py @@ -46,7 +46,7 @@ class AgenticSystemTurnCreateRequest(BaseModel, AgentConfigOverridablePerTurn): ToolResponseMessage, ] ] - attachments: List[Attachment] + attachments: Optional[List[Attachment]] = None stream: Optional[bool] = False diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 8e90cdd47..770bd8d1a 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -33,7 +33,6 @@ class ChatAgent(ShieldRunnerMixin): memory_api: Memory, safety_api: Safety, builtin_tools: List[SingleMessageBuiltinTool], - custom_tool_definitions: List[ToolDefinition], max_infer_iters: int = 10, ): self.agent_config = agent_config @@ -108,7 +107,6 @@ class ChatAgent(ShieldRunnerMixin): # print_dialog(messages) turn_id = str(uuid.uuid4()) - params = self.agent_config.sampling_params start_time = datetime.now() yield AgenticSystemTurnResponseStreamChunk( event=AgenticSystemTurnResponseEvent( @@ -123,10 +121,9 @@ class ChatAgent(ShieldRunnerMixin): async for chunk in self.run( turn_id=turn_id, input_messages=messages, - temperature=params.temperature, - top_p=params.top_p, + attachments=request.attachments or [], + sampling_params=self.agent_config.sampling_params, stream=request.stream, - max_gen_len=params.max_tokens, ): if isinstance(chunk, CompletionMessage): cprint( @@ -175,10 +172,9 @@ class ChatAgent(ShieldRunnerMixin): self, turn_id: str, input_messages: List[Message], - temperature: float, - top_p: float, + attachments: List[Attachment], + sampling_params: SamplingParams, stream: bool = False, - max_gen_len: Optional[int] = None, ) -> AsyncGenerator: # Doing async generators makes downstream code much simpler and everything amenable to # streaming. However, it also makes things complicated here because AsyncGenerators cannot @@ -194,7 +190,7 @@ class ChatAgent(ShieldRunnerMixin): yield res async for res in self._run( - turn_id, input_messages, temperature, top_p, stream, max_gen_len + turn_id, input_messages, attachments, sampling_params, stream ): if isinstance(res, bool): return @@ -279,20 +275,27 @@ class ChatAgent(ShieldRunnerMixin): ) ) - async def _should_retrieve_context(self, messages: List[Message]) -> bool: ... + async def _should_retrieve_context( + self, messages: List[Message], attachments: List[Attachment] + ) -> bool: + return self.agent_config.memory_configs or len(attachments) > 0 + + async def _retrieve_context( + self, messages: List[Message], attachments: List[Attachment] + ) -> List[Message]: + return [] async def _run( self, turn_id: str, input_messages: List[Message], - temperature: float, - top_p: float, + attachments: List[Attachment], + sampling_params: SamplingParams, stream: bool = False, - max_gen_len: Optional[int] = None, ) -> AsyncGenerator: - need_context = await self._should_retrieve_context(input_messages) + need_context = await self._should_retrieve_context(input_messages, attachments) if need_context: - context = await self._retrieve_context(input_messages) + context_messages = await self._retrieve_context(input_messages) # input_messages = preprocess_dialog(input_messages, self.prefix_messages) # input_messages = input_messages + context input_messages = preprocess_dialog(input_messages) @@ -320,18 +323,13 @@ class ChatAgent(ShieldRunnerMixin): ) ) - # where are the available tools? req = ChatCompletionRequest( model=self.agent_config.model, messages=input_messages, tools=self.agent_config.available_tools, tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, - sampling_params=SamplingParams( - temperature=temperature, - top_p=top_p, - max_tokens=max_gen_len, - ), + sampling_params=sampling_params, ) tool_calls = [] @@ -554,6 +552,7 @@ def attachment_message(url: URL) -> ToolResponseMessage: def preprocess_dialog(messages: List[Message]) -> List[Message]: + # remove system message since those are """ Preprocesses the dialog by removing the system message and adding the system message to the beginning of the dialog. @@ -565,7 +564,7 @@ def preprocess_dialog(messages: List[Message]) -> List[Message]: continue # NOTE: the ideal behavior is to use `file_path = ...` but that - # means we need to have stateful execution o f code which we currently + # means we need to have stateful execution of code which we currently # do not have. if isinstance(m.content, Attachment): ret.append(attachment_message(m.content.url)) diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py index c410a928c..9078b5222 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -16,9 +16,6 @@ from llama_toolchain.inference.api.datatypes import BuiltinTool from llama_toolchain.memory.api import Memory from llama_toolchain.safety.api import Safety from llama_toolchain.agentic_system.api import * # noqa: F403 -from .agent_instance import ChatAgent -from .config import MetaReferenceImplConfig - from llama_toolchain.tools.builtin import ( BraveSearchTool, CodeInterpreterTool, @@ -27,7 +24,8 @@ from llama_toolchain.tools.builtin import ( ) from llama_toolchain.tools.safety import with_safety -from .agent_instance import AgentInstance +from .agent_instance import AgentInstance, ChatAgent +from .config import MetaReferenceImplConfig logger = logging.getLogger() @@ -76,7 +74,6 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): system_id = str(uuid.uuid4()) builtin_tools = [] - custom_tool_definitions = [] cfg = request.agent_config for dfn in cfg.available_tools: if isinstance(dfn.tool_name, BuiltinTool): @@ -104,8 +101,6 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): tool, self.safety_api, dfn.input_shields, dfn.output_shields ) ) - else: - custom_tool_definitions.append(dfn) AGENT_INSTANCES_BY_ID[system_id] = ChatAgent( agent_config=cfg, @@ -113,7 +108,6 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): safety_api=self.safety_api, memory_api=self.memory_api, builtin_tools=builtin_tools, - custom_tool_definitions=custom_tool_definitions, ) return AgenticSystemCreateResponse( diff --git a/llama_toolchain/memory/api/endpoints.py b/llama_toolchain/memory/api/endpoints.py index 5b6989a8f..810d821f7 100644 --- a/llama_toolchain/memory/api/endpoints.py +++ b/llama_toolchain/memory/api/endpoints.py @@ -6,7 +6,7 @@ from typing import List, Protocol -from llama_models.llama3_1.api.datatypes import InterleavedTextAttachment +from llama_models.llama3.api.datatypes import InterleavedTextMedia from llama_models.schema_utils import webmethod from .datatypes import * # noqa: F403 @@ -14,7 +14,7 @@ from .datatypes import * # noqa: F403 @json_schema_type class RetrieveMemoryDocumentsRequest(BaseModel): - query: InterleavedTextAttachment + query: InterleavedTextMedia bank_ids: str