diff --git a/llama_toolchain/agentic_system/api/datatypes.py b/llama_toolchain/agentic_system/api/datatypes.py index 689abeceb..176a1f467 100644 --- a/llama_toolchain/agentic_system/api/datatypes.py +++ b/llama_toolchain/agentic_system/api/datatypes.py @@ -96,6 +96,8 @@ class Turn(BaseModel): ] steps: List[Step] output_message: CompletionMessage + output_attachments: List[Attachment] = Field(default_factory=list) + started_at: datetime completed_at: Optional[datetime] = None @@ -111,13 +113,22 @@ class Session(BaseModel): @json_schema_type -class AgenticSystemInstanceConfig(BaseModel): - instructions: str +class MemoryConfig(BaseModel): + memory_bank_id: str + + # this configuration can hold other information we may want to configure + # how will the agent use the memory bank API? + # + # + + +class AgentConfigCommon(BaseModel): sampling_params: Optional[SamplingParams] = SamplingParams() # zero-shot or built-in tool configurations as input to the model available_tools: Optional[List[AgenticSystemToolDefinition]] = Field( default_factory=list ) + memory: Optional[List[MemoryConfig]] = Field(default_factory=list) input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) @@ -130,6 +141,16 @@ class AgenticSystemInstanceConfig(BaseModel): ) +@json_schema_type +class AgentConfig(AgentConfigCommon): + model: str + instructions: str + + +class AgentConfigOverridablePerTurn(AgentConfigCommon): + instructions: Optional[str] = None + + class AgenticSystemTurnResponseEventType(Enum): step_start = "step_start" step_complete = "step_complete" diff --git a/llama_toolchain/agentic_system/api/endpoints.py b/llama_toolchain/agentic_system/api/endpoints.py index 1f6bdcc9d..10acce07e 100644 --- a/llama_toolchain/agentic_system/api/endpoints.py +++ b/llama_toolchain/agentic_system/api/endpoints.py @@ -7,18 +7,17 @@ from .datatypes import * # noqa: F403 from typing import Protocol -# this dependency is annoying and we need a forked up version anyway from llama_models.schema_utils import json_schema_type, webmethod @json_schema_type class AgenticSystemCreateRequest(BaseModel): - model: str - instance_config: AgenticSystemInstanceConfig + agent_config: AgentConfig @json_schema_type class AgenticSystemCreateResponse(BaseModel): + # TODO: rename this to agent_id system_id: str @@ -34,20 +33,22 @@ class AgenticSystemSessionCreateResponse(BaseModel): @json_schema_type -# what's the URI? -class AgenticSystemTurnCreateRequest(BaseModel): +class AgenticSystemTurnCreateRequest(BaseModel, AgentConfigOverridablePerTurn): system_id: str session_id: str + # TODO: figure out how we can simplify this and make why + # ToolResponseMessage needs to be here (it is function call + # execution from outside the system) messages: List[ Union[ UserMessage, ToolResponseMessage, ] ] + attachments: List[Attachment] stream: Optional[bool] = False - override_config: Optional[AgenticSystemInstanceConfig] = None @json_schema_type( @@ -93,22 +94,6 @@ class AgenticSystem(Protocol): request: AgenticSystemSessionCreateRequest, ) -> AgenticSystemSessionCreateResponse: ... - @webmethod(route="/agentic_system/memory_bank/attach") - async def attach_memory_bank_to_agentic_system( - self, - agent_id: str, - session_id: str, - memory_bank_ids: List[str], - ) -> None: ... - - @webmethod(route="/agentic_system/memory_bank/detach") - async def detach_memory_bank_from_agentic_system( - self, - agent_id: str, - session_id: str, - memory_bank_ids: List[str], - ) -> None: ... - @webmethod(route="/agentic_system/session/get") async def get_agentic_system_session( self, diff --git a/llama_toolchain/agentic_system/client.py b/llama_toolchain/agentic_system/client.py index 690a8499b..3f8e63245 100644 --- a/llama_toolchain/agentic_system/client.py +++ b/llama_toolchain/agentic_system/client.py @@ -24,10 +24,10 @@ from termcolor import cprint from llama_toolchain.agentic_system.event_logger import EventLogger from .api import ( + AgentConfig, AgenticSystem, AgenticSystemCreateRequest, AgenticSystemCreateResponse, - AgenticSystemInstanceConfig, AgenticSystemSessionCreateRequest, AgenticSystemSessionCreateResponse, AgenticSystemToolDefinition, @@ -129,7 +129,7 @@ async def run_main(host: str, port: int): create_request = AgenticSystemCreateRequest( model="Meta-Llama3.1-8B-Instruct", - instance_config=AgenticSystemInstanceConfig( + agent_config=AgentConfig( instructions="You are a helpful assistant", sampling_params=SamplingParams(), available_tools=tool_definitions, diff --git a/llama_toolchain/agentic_system/meta_reference/__init__.py b/llama_toolchain/agentic_system/meta_reference/__init__.py index 22b1f788a..11dc98333 100644 --- a/llama_toolchain/agentic_system/meta_reference/__init__.py +++ b/llama_toolchain/agentic_system/meta_reference/__init__.py @@ -5,4 +5,4 @@ # the root directory of this source tree. from .agentic_system import get_provider_impl # noqa -from .config import AgenticSystemConfig # noqa +from .config import MetaReferenceImplConfig # noqa diff --git a/llama_toolchain/agentic_system/meta_reference/agent_instance.py b/llama_toolchain/agentic_system/meta_reference/agent_instance.py index 2c769a5e1..8e90cdd47 100644 --- a/llama_toolchain/agentic_system/meta_reference/agent_instance.py +++ b/llama_toolchain/agentic_system/meta_reference/agent_instance.py @@ -14,47 +14,10 @@ from llama_models.llama3.api.datatypes import ToolPromptFormat from termcolor import cprint -from llama_toolchain.agentic_system.api.datatypes import ( - AgenticSystemInstanceConfig, - AgenticSystemTurnResponseEvent, - AgenticSystemTurnResponseEventType, - AgenticSystemTurnResponseStepCompletePayload, - AgenticSystemTurnResponseStepProgressPayload, - AgenticSystemTurnResponseStepStartPayload, - AgenticSystemTurnResponseTurnCompletePayload, - AgenticSystemTurnResponseTurnStartPayload, - InferenceStep, - Session, - ShieldCallStep, - StepType, - ToolExecutionStep, - Turn, -) - -from llama_toolchain.inference.api import ChatCompletionRequest, Inference -from llama_toolchain.inference.api.datatypes import ( - Attachment, - BuiltinTool, - ChatCompletionResponseEventType, - CompletionMessage, - Message, - Role, - SamplingParams, - StopReason, - ToolCallDelta, - ToolCallParseStatus, - ToolDefinition, - ToolResponse, - ToolResponseMessage, - URL, -) -from llama_toolchain.safety.api import Safety -from llama_toolchain.safety.api.datatypes import ( - BuiltinShield, - ShieldDefinition, - ShieldResponse, -) -from llama_toolchain.agentic_system.api.endpoints import * # noqa +from llama_toolchain.agentic_system.api import * # noqa: F403 +from llama_toolchain.inference.api import * # noqa: F403 +from llama_toolchain.memory.api import * # noqa: F403 +from llama_toolchain.safety.api import * # noqa: F403 from llama_toolchain.tools.base import BaseTool from llama_toolchain.tools.builtin import SingleMessageBuiltinTool @@ -62,27 +25,20 @@ from llama_toolchain.tools.builtin import SingleMessageBuiltinTool from .safety import SafetyException, ShieldRunnerMixin -class AgentInstance(ShieldRunnerMixin): +class ChatAgent(ShieldRunnerMixin): def __init__( self, - system_id: int, - instance_config: AgenticSystemInstanceConfig, - model: str, + agent_config: AgentConfig, inference_api: Inference, + memory_api: Memory, safety_api: Safety, builtin_tools: List[SingleMessageBuiltinTool], custom_tool_definitions: List[ToolDefinition], - input_shields: List[ShieldDefinition], - output_shields: List[ShieldDefinition], max_infer_iters: int = 10, - prefix_messages: Optional[List[Message]] = None, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, ): - self.system_id = system_id - self.instance_config = instance_config - - self.model = model + self.agent_config = agent_config self.inference_api = inference_api + self.memory_api = memory_api self.safety_api = safety_api self.max_infer_iters = max_infer_iters @@ -93,8 +49,8 @@ class AgentInstance(ShieldRunnerMixin): ShieldRunnerMixin.__init__( self, safety_api, - input_shields=input_shields, - output_shields=output_shields, + input_shields=agent_config.input_shields, + output_shields=agent_config.output_shields, ) def create_session(self, name: str) -> Session: @@ -152,7 +108,7 @@ class AgentInstance(ShieldRunnerMixin): # print_dialog(messages) turn_id = str(uuid.uuid4()) - params = self.instance_config.sampling_params + params = self.agent_config.sampling_params start_time = datetime.now() yield AgenticSystemTurnResponseStreamChunk( event=AgenticSystemTurnResponseEvent( @@ -215,6 +171,53 @@ class AgentInstance(ShieldRunnerMixin): ) yield chunk + async def run( + self, + turn_id: str, + input_messages: List[Message], + temperature: float, + top_p: float, + stream: bool = False, + max_gen_len: Optional[int] = None, + ) -> AsyncGenerator: + # Doing async generators makes downstream code much simpler and everything amenable to + # streaming. However, it also makes things complicated here because AsyncGenerators cannot + # return a "final value" for the `yield from` statement. we simulate that by yielding a + # final boolean (to see whether an exception happened) and then explicitly testing for it. + + async for res in self.run_shields_wrapper( + turn_id, input_messages, self.input_shields, "user-input" + ): + if isinstance(res, bool): + return + else: + yield res + + async for res in self._run( + turn_id, input_messages, temperature, top_p, stream, max_gen_len + ): + if isinstance(res, bool): + return + elif isinstance(res, CompletionMessage): + final_response = res + break + else: + yield res + + assert final_response is not None + # for output shields run on the full input and output combination + messages = input_messages + [final_response] + + async for res in self.run_shields_wrapper( + turn_id, messages, self.output_shields, "assistant-output" + ): + if isinstance(res, bool): + return + else: + yield res + + yield final_response + async def run_shields_wrapper( self, turn_id: str, @@ -276,52 +279,7 @@ class AgentInstance(ShieldRunnerMixin): ) ) - async def run( - self, - turn_id: str, - input_messages: List[Message], - temperature: float, - top_p: float, - stream: bool = False, - max_gen_len: Optional[int] = None, - ) -> AsyncGenerator: - # Doing async generators makes downstream code much simpler and everything amenable to - # stremaing. However, it also makes things complicated here because AsyncGenerators cannot - # return a "final value" for the `yield from` statement. we simulate that by yielding a - # final boolean (to see whether an exception happened) and then explicitly testing for it. - - async for res in self.run_shields_wrapper( - turn_id, input_messages, self.input_shields, "user-input" - ): - if isinstance(res, bool): - return - else: - yield res - - async for res in self._run( - turn_id, input_messages, temperature, top_p, stream, max_gen_len - ): - if isinstance(res, bool): - return - elif isinstance(res, CompletionMessage): - final_response = res - break - else: - yield res - - assert final_response is not None - # for output shields run on the full input and output combination - messages = input_messages + [final_response] - - async for res in self.run_shields_wrapper( - turn_id, messages, self.output_shields, "assistant-output" - ): - if isinstance(res, bool): - return - else: - yield res - - yield final_response + async def _should_retrieve_context(self, messages: List[Message]) -> bool: ... async def _run( self, @@ -332,6 +290,11 @@ class AgentInstance(ShieldRunnerMixin): stream: bool = False, max_gen_len: Optional[int] = None, ) -> AsyncGenerator: + need_context = await self._should_retrieve_context(input_messages) + if need_context: + context = await self._retrieve_context(input_messages) + # input_messages = preprocess_dialog(input_messages, self.prefix_messages) + # input_messages = input_messages + context input_messages = preprocess_dialog(input_messages) attachments = [] @@ -359,10 +322,10 @@ class AgentInstance(ShieldRunnerMixin): # where are the available tools? req = ChatCompletionRequest( - model=self.model, + model=self.agent_config.model, messages=input_messages, - tools=self.instance_config.available_tools, - tool_prompt_format=self.instance_config.tool_prompt_format, + tools=self.agent_config.available_tools, + tool_prompt_format=self.agent_config.tool_prompt_format, stream=True, sampling_params=SamplingParams( temperature=temperature, diff --git a/llama_toolchain/agentic_system/meta_reference/agentic_system.py b/llama_toolchain/agentic_system/meta_reference/agentic_system.py index 0d3f33507..c410a928c 100644 --- a/llama_toolchain/agentic_system/meta_reference/agentic_system.py +++ b/llama_toolchain/agentic_system/meta_reference/agentic_system.py @@ -13,16 +13,11 @@ from typing import AsyncGenerator, Dict from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.inference.api import Inference from llama_toolchain.inference.api.datatypes import BuiltinTool +from llama_toolchain.memory.api import Memory from llama_toolchain.safety.api import Safety -from llama_toolchain.agentic_system.api.endpoints import * # noqa -from llama_toolchain.agentic_system.api import ( - AgenticSystem, - AgenticSystemCreateRequest, - AgenticSystemCreateResponse, - AgenticSystemSessionCreateRequest, - AgenticSystemSessionCreateResponse, - AgenticSystemTurnCreateRequest, -) +from llama_toolchain.agentic_system.api import * # noqa: F403 +from .agent_instance import ChatAgent +from .config import MetaReferenceImplConfig from llama_toolchain.tools.builtin import ( BraveSearchTool, @@ -34,16 +29,16 @@ from llama_toolchain.tools.safety import with_safety from .agent_instance import AgentInstance -from .config import AgenticSystemConfig - logger = logging.getLogger() logger.setLevel(logging.INFO) -async def get_provider_impl(config: AgenticSystemConfig, deps: Dict[Api, ProviderSpec]): +async def get_provider_impl( + config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec] +): assert isinstance( - config, AgenticSystemConfig + config, MetaReferenceImplConfig ), f"Unexpected config type: {type(config)}" impl = MetaReferenceAgenticSystemImpl( @@ -60,11 +55,16 @@ AGENT_INSTANCES_BY_ID = {} class MetaReferenceAgenticSystemImpl(AgenticSystem): def __init__( - self, config: AgenticSystemConfig, inference_api: Inference, safety_api: Safety + self, + config: MetaReferenceImplConfig, + inference_api: Inference, + safety_api: Safety, + memory_api: Memory, ): self.config = config self.inference_api = inference_api self.safety_api = safety_api + self.memory_api = memory_api async def initialize(self) -> None: pass @@ -77,7 +77,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): builtin_tools = [] custom_tool_definitions = [] - cfg = request.instance_config + cfg = request.agent_config for dfn in cfg.available_tools: if isinstance(dfn.tool_name, BuiltinTool): if dfn.tool_name == BuiltinTool.wolfram_alpha: @@ -107,18 +107,13 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem): else: custom_tool_definitions.append(dfn) - AGENT_INSTANCES_BY_ID[system_id] = AgentInstance( - system_id=system_id, - instance_config=request.instance_config, - model=request.model, + AGENT_INSTANCES_BY_ID[system_id] = ChatAgent( + agent_config=cfg, inference_api=self.inference_api, + safety_api=self.safety_api, + memory_api=self.memory_api, builtin_tools=builtin_tools, custom_tool_definitions=custom_tool_definitions, - safety_api=self.safety_api, - input_shields=cfg.input_shields, - output_shields=cfg.output_shields, - prefix_messages=cfg.debug_prefix_messages, - tool_prompt_format=cfg.tool_prompt_format, ) return AgenticSystemCreateResponse( diff --git a/llama_toolchain/agentic_system/meta_reference/config.py b/llama_toolchain/agentic_system/meta_reference/config.py index cff22d03d..367ab17a5 100644 --- a/llama_toolchain/agentic_system/meta_reference/config.py +++ b/llama_toolchain/agentic_system/meta_reference/config.py @@ -9,6 +9,6 @@ from typing import Optional from pydantic import BaseModel -class AgenticSystemConfig(BaseModel): +class MetaReferenceImplConfig(BaseModel): brave_search_api_key: Optional[str] = None wolfram_api_key: Optional[str] = None diff --git a/llama_toolchain/agentic_system/providers.py b/llama_toolchain/agentic_system/providers.py index 463c2976e..7d49fd004 100644 --- a/llama_toolchain/agentic_system/providers.py +++ b/llama_toolchain/agentic_system/providers.py @@ -21,10 +21,11 @@ def available_agentic_system_providers() -> List[ProviderSpec]: "transformers", ], module="llama_toolchain.agentic_system.meta_reference", - config_class="llama_toolchain.agentic_system.meta_reference.AgenticSystemConfig", + config_class="llama_toolchain.agentic_system.meta_reference.MetaReferenceImplConfig", api_dependencies=[ Api.inference, Api.safety, + Api.memory, ], ), ] diff --git a/llama_toolchain/agentic_system/utils.py b/llama_toolchain/agentic_system/utils.py index 73fe9f918..f146402c7 100644 --- a/llama_toolchain/agentic_system/utils.py +++ b/llama_toolchain/agentic_system/utils.py @@ -15,8 +15,8 @@ from llama_models.llama3.api.datatypes import ( ) from llama_toolchain.agentic_system.api import ( + AgentConfig, AgenticSystemCreateRequest, - AgenticSystemInstanceConfig, AgenticSystemSessionCreateRequest, AgenticSystemToolDefinition, ) @@ -100,7 +100,7 @@ async def get_agent_system_instance( create_request = AgenticSystemCreateRequest( model=model, - instance_config=AgenticSystemInstanceConfig( + agent_config=AgentConfig( instructions="You are a helpful assistant", available_tools=tool_definitions, input_shields=( diff --git a/llama_toolchain/distribution/datatypes.py b/llama_toolchain/distribution/datatypes.py index 480024223..9ae148aed 100644 --- a/llama_toolchain/distribution/datatypes.py +++ b/llama_toolchain/distribution/datatypes.py @@ -17,6 +17,7 @@ class Api(Enum): inference = "inference" safety = "safety" agentic_system = "agentic_system" + memory = "memory" @json_schema_type diff --git a/llama_toolchain/memory/api/endpoints.py b/llama_toolchain/memory/api/endpoints.py index d8ac0e90c..5b6989a8f 100644 --- a/llama_toolchain/memory/api/endpoints.py +++ b/llama_toolchain/memory/api/endpoints.py @@ -6,17 +6,31 @@ from typing import List, Protocol -from llama_models.schema_utils import webmethod +from llama_models.llama3_1.api.datatypes import InterleavedTextAttachment +from llama_models.schema_utils import webmethod from .datatypes import * # noqa: F403 -class MemoryBanks(Protocol): +@json_schema_type +class RetrieveMemoryDocumentsRequest(BaseModel): + query: InterleavedTextAttachment + bank_ids: str + + +@json_schema_type +class RetrieveMemoryDocumentsResponse(BaseModel): + documents: List[MemoryBankDocument] + scores: List[float] + + +class Memory(Protocol): @webmethod(route="/memory_banks/create") def create_memory_bank( self, bank_id: str, bank_name: str, + embedding_model: str, documents: List[MemoryBankDocument], ) -> None: ... @@ -46,6 +60,12 @@ class MemoryBanks(Protocol): documents: List[MemoryBankDocument], ) -> None: ... + @webmethod(route="/memory_bank/get") + def retrieve_memory_documents( + self, + request: RetrieveMemoryDocumentsRequest, + ) -> List[MemoryBankDocument]: ... + @webmethod(route="/memory_bank/get") def get_memory_documents( self, diff --git a/llama_toolchain/stack.py b/llama_toolchain/stack.py index 88a54976c..afea66a0c 100644 --- a/llama_toolchain/stack.py +++ b/llama_toolchain/stack.py @@ -24,7 +24,7 @@ class LlamaStack( Datasets, Observability, PostTraining, - MemoryBanks, + Memory, Evaluations, ): pass