<WIP> memory changes

- inlined AgenticSystemInstanceConfig so API feels more ergonomic
- renamed it to AgentConfig, AgentInstance -> Agent
- added a MemoryConfig and `memory` parameter
- added `attachments` to input and `output_attachments` to the response

- some naming changes
This commit is contained in:
Ashwin Bharambe 2024-08-14 13:46:44 -07:00
parent 5655266d58
commit 48c6a32edd
12 changed files with 149 additions and 163 deletions

View file

@ -96,6 +96,8 @@ class Turn(BaseModel):
] ]
steps: List[Step] steps: List[Step]
output_message: CompletionMessage output_message: CompletionMessage
output_attachments: List[Attachment] = Field(default_factory=list)
started_at: datetime started_at: datetime
completed_at: Optional[datetime] = None completed_at: Optional[datetime] = None
@ -111,13 +113,22 @@ class Session(BaseModel):
@json_schema_type @json_schema_type
class AgenticSystemInstanceConfig(BaseModel): class MemoryConfig(BaseModel):
instructions: str memory_bank_id: str
# this configuration can hold other information we may want to configure
# how will the agent use the memory bank API?
#
#
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams() sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot or built-in tool configurations as input to the model # zero-shot or built-in tool configurations as input to the model
available_tools: Optional[List[AgenticSystemToolDefinition]] = Field( available_tools: Optional[List[AgenticSystemToolDefinition]] = Field(
default_factory=list default_factory=list
) )
memory: Optional[List[MemoryConfig]] = Field(default_factory=list)
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list) output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
@ -130,6 +141,16 @@ class AgenticSystemInstanceConfig(BaseModel):
) )
@json_schema_type
class AgentConfig(AgentConfigCommon):
model: str
instructions: str
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None
class AgenticSystemTurnResponseEventType(Enum): class AgenticSystemTurnResponseEventType(Enum):
step_start = "step_start" step_start = "step_start"
step_complete = "step_complete" step_complete = "step_complete"

View file

@ -7,18 +7,17 @@
from .datatypes import * # noqa: F403 from .datatypes import * # noqa: F403
from typing import Protocol from typing import Protocol
# this dependency is annoying and we need a forked up version anyway
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@json_schema_type @json_schema_type
class AgenticSystemCreateRequest(BaseModel): class AgenticSystemCreateRequest(BaseModel):
model: str agent_config: AgentConfig
instance_config: AgenticSystemInstanceConfig
@json_schema_type @json_schema_type
class AgenticSystemCreateResponse(BaseModel): class AgenticSystemCreateResponse(BaseModel):
# TODO: rename this to agent_id
system_id: str system_id: str
@ -34,20 +33,22 @@ class AgenticSystemSessionCreateResponse(BaseModel):
@json_schema_type @json_schema_type
# what's the URI? class AgenticSystemTurnCreateRequest(BaseModel, AgentConfigOverridablePerTurn):
class AgenticSystemTurnCreateRequest(BaseModel):
system_id: str system_id: str
session_id: str session_id: str
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: List[ messages: List[
Union[ Union[
UserMessage, UserMessage,
ToolResponseMessage, ToolResponseMessage,
] ]
] ]
attachments: List[Attachment]
stream: Optional[bool] = False stream: Optional[bool] = False
override_config: Optional[AgenticSystemInstanceConfig] = None
@json_schema_type( @json_schema_type(
@ -93,22 +94,6 @@ class AgenticSystem(Protocol):
request: AgenticSystemSessionCreateRequest, request: AgenticSystemSessionCreateRequest,
) -> AgenticSystemSessionCreateResponse: ... ) -> AgenticSystemSessionCreateResponse: ...
@webmethod(route="/agentic_system/memory_bank/attach")
async def attach_memory_bank_to_agentic_system(
self,
agent_id: str,
session_id: str,
memory_bank_ids: List[str],
) -> None: ...
@webmethod(route="/agentic_system/memory_bank/detach")
async def detach_memory_bank_from_agentic_system(
self,
agent_id: str,
session_id: str,
memory_bank_ids: List[str],
) -> None: ...
@webmethod(route="/agentic_system/session/get") @webmethod(route="/agentic_system/session/get")
async def get_agentic_system_session( async def get_agentic_system_session(
self, self,

View file

@ -24,10 +24,10 @@ from termcolor import cprint
from llama_toolchain.agentic_system.event_logger import EventLogger from llama_toolchain.agentic_system.event_logger import EventLogger
from .api import ( from .api import (
AgentConfig,
AgenticSystem, AgenticSystem,
AgenticSystemCreateRequest, AgenticSystemCreateRequest,
AgenticSystemCreateResponse, AgenticSystemCreateResponse,
AgenticSystemInstanceConfig,
AgenticSystemSessionCreateRequest, AgenticSystemSessionCreateRequest,
AgenticSystemSessionCreateResponse, AgenticSystemSessionCreateResponse,
AgenticSystemToolDefinition, AgenticSystemToolDefinition,
@ -129,7 +129,7 @@ async def run_main(host: str, port: int):
create_request = AgenticSystemCreateRequest( create_request = AgenticSystemCreateRequest(
model="Meta-Llama3.1-8B-Instruct", model="Meta-Llama3.1-8B-Instruct",
instance_config=AgenticSystemInstanceConfig( agent_config=AgentConfig(
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
available_tools=tool_definitions, available_tools=tool_definitions,

View file

@ -5,4 +5,4 @@
# the root directory of this source tree. # the root directory of this source tree.
from .agentic_system import get_provider_impl # noqa from .agentic_system import get_provider_impl # noqa
from .config import AgenticSystemConfig # noqa from .config import MetaReferenceImplConfig # noqa

View file

@ -14,47 +14,10 @@ from llama_models.llama3.api.datatypes import ToolPromptFormat
from termcolor import cprint from termcolor import cprint
from llama_toolchain.agentic_system.api.datatypes import ( from llama_toolchain.agentic_system.api import * # noqa: F403
AgenticSystemInstanceConfig, from llama_toolchain.inference.api import * # noqa: F403
AgenticSystemTurnResponseEvent, from llama_toolchain.memory.api import * # noqa: F403
AgenticSystemTurnResponseEventType, from llama_toolchain.safety.api import * # noqa: F403
AgenticSystemTurnResponseStepCompletePayload,
AgenticSystemTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepStartPayload,
AgenticSystemTurnResponseTurnCompletePayload,
AgenticSystemTurnResponseTurnStartPayload,
InferenceStep,
Session,
ShieldCallStep,
StepType,
ToolExecutionStep,
Turn,
)
from llama_toolchain.inference.api import ChatCompletionRequest, Inference
from llama_toolchain.inference.api.datatypes import (
Attachment,
BuiltinTool,
ChatCompletionResponseEventType,
CompletionMessage,
Message,
Role,
SamplingParams,
StopReason,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
URL,
)
from llama_toolchain.safety.api import Safety
from llama_toolchain.safety.api.datatypes import (
BuiltinShield,
ShieldDefinition,
ShieldResponse,
)
from llama_toolchain.agentic_system.api.endpoints import * # noqa
from llama_toolchain.tools.base import BaseTool from llama_toolchain.tools.base import BaseTool
from llama_toolchain.tools.builtin import SingleMessageBuiltinTool from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
@ -62,27 +25,20 @@ from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
from .safety import SafetyException, ShieldRunnerMixin from .safety import SafetyException, ShieldRunnerMixin
class AgentInstance(ShieldRunnerMixin): class ChatAgent(ShieldRunnerMixin):
def __init__( def __init__(
self, self,
system_id: int, agent_config: AgentConfig,
instance_config: AgenticSystemInstanceConfig,
model: str,
inference_api: Inference, inference_api: Inference,
memory_api: Memory,
safety_api: Safety, safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool], builtin_tools: List[SingleMessageBuiltinTool],
custom_tool_definitions: List[ToolDefinition], custom_tool_definitions: List[ToolDefinition],
input_shields: List[ShieldDefinition],
output_shields: List[ShieldDefinition],
max_infer_iters: int = 10, max_infer_iters: int = 10,
prefix_messages: Optional[List[Message]] = None,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
): ):
self.system_id = system_id self.agent_config = agent_config
self.instance_config = instance_config
self.model = model
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api self.safety_api = safety_api
self.max_infer_iters = max_infer_iters self.max_infer_iters = max_infer_iters
@ -93,8 +49,8 @@ class AgentInstance(ShieldRunnerMixin):
ShieldRunnerMixin.__init__( ShieldRunnerMixin.__init__(
self, self,
safety_api, safety_api,
input_shields=input_shields, input_shields=agent_config.input_shields,
output_shields=output_shields, output_shields=agent_config.output_shields,
) )
def create_session(self, name: str) -> Session: def create_session(self, name: str) -> Session:
@ -152,7 +108,7 @@ class AgentInstance(ShieldRunnerMixin):
# print_dialog(messages) # print_dialog(messages)
turn_id = str(uuid.uuid4()) turn_id = str(uuid.uuid4())
params = self.instance_config.sampling_params params = self.agent_config.sampling_params
start_time = datetime.now() start_time = datetime.now()
yield AgenticSystemTurnResponseStreamChunk( yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent( event=AgenticSystemTurnResponseEvent(
@ -215,6 +171,53 @@ class AgentInstance(ShieldRunnerMixin):
) )
yield chunk yield chunk
async def run(
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
return
else:
yield res
async for res in self._run(
turn_id, input_messages, temperature, top_p, stream, max_gen_len
):
if isinstance(res, bool):
return
elif isinstance(res, CompletionMessage):
final_response = res
break
else:
yield res
assert final_response is not None
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
else:
yield res
yield final_response
async def run_shields_wrapper( async def run_shields_wrapper(
self, self,
turn_id: str, turn_id: str,
@ -276,52 +279,7 @@ class AgentInstance(ShieldRunnerMixin):
) )
) )
async def run( async def _should_retrieve_context(self, messages: List[Message]) -> bool: ...
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# stremaing. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
return
else:
yield res
async for res in self._run(
turn_id, input_messages, temperature, top_p, stream, max_gen_len
):
if isinstance(res, bool):
return
elif isinstance(res, CompletionMessage):
final_response = res
break
else:
yield res
assert final_response is not None
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
else:
yield res
yield final_response
async def _run( async def _run(
self, self,
@ -332,6 +290,11 @@ class AgentInstance(ShieldRunnerMixin):
stream: bool = False, stream: bool = False,
max_gen_len: Optional[int] = None, max_gen_len: Optional[int] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
need_context = await self._should_retrieve_context(input_messages)
if need_context:
context = await self._retrieve_context(input_messages)
# input_messages = preprocess_dialog(input_messages, self.prefix_messages)
# input_messages = input_messages + context
input_messages = preprocess_dialog(input_messages) input_messages = preprocess_dialog(input_messages)
attachments = [] attachments = []
@ -359,10 +322,10 @@ class AgentInstance(ShieldRunnerMixin):
# where are the available tools? # where are the available tools?
req = ChatCompletionRequest( req = ChatCompletionRequest(
model=self.model, model=self.agent_config.model,
messages=input_messages, messages=input_messages,
tools=self.instance_config.available_tools, tools=self.agent_config.available_tools,
tool_prompt_format=self.instance_config.tool_prompt_format, tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True, stream=True,
sampling_params=SamplingParams( sampling_params=SamplingParams(
temperature=temperature, temperature=temperature,

View file

@ -13,16 +13,11 @@ from typing import AsyncGenerator, Dict
from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import Inference from llama_toolchain.inference.api import Inference
from llama_toolchain.inference.api.datatypes import BuiltinTool from llama_toolchain.inference.api.datatypes import BuiltinTool
from llama_toolchain.memory.api import Memory
from llama_toolchain.safety.api import Safety from llama_toolchain.safety.api import Safety
from llama_toolchain.agentic_system.api.endpoints import * # noqa from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.agentic_system.api import ( from .agent_instance import ChatAgent
AgenticSystem, from .config import MetaReferenceImplConfig
AgenticSystemCreateRequest,
AgenticSystemCreateResponse,
AgenticSystemSessionCreateRequest,
AgenticSystemSessionCreateResponse,
AgenticSystemTurnCreateRequest,
)
from llama_toolchain.tools.builtin import ( from llama_toolchain.tools.builtin import (
BraveSearchTool, BraveSearchTool,
@ -34,16 +29,16 @@ from llama_toolchain.tools.safety import with_safety
from .agent_instance import AgentInstance from .agent_instance import AgentInstance
from .config import AgenticSystemConfig
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
async def get_provider_impl(config: AgenticSystemConfig, deps: Dict[Api, ProviderSpec]): async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
):
assert isinstance( assert isinstance(
config, AgenticSystemConfig config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl( impl = MetaReferenceAgenticSystemImpl(
@ -60,11 +55,16 @@ AGENT_INSTANCES_BY_ID = {}
class MetaReferenceAgenticSystemImpl(AgenticSystem): class MetaReferenceAgenticSystemImpl(AgenticSystem):
def __init__( def __init__(
self, config: AgenticSystemConfig, inference_api: Inference, safety_api: Safety self,
config: MetaReferenceImplConfig,
inference_api: Inference,
safety_api: Safety,
memory_api: Memory,
): ):
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.safety_api = safety_api self.safety_api = safety_api
self.memory_api = memory_api
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -77,7 +77,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
builtin_tools = [] builtin_tools = []
custom_tool_definitions = [] custom_tool_definitions = []
cfg = request.instance_config cfg = request.agent_config
for dfn in cfg.available_tools: for dfn in cfg.available_tools:
if isinstance(dfn.tool_name, BuiltinTool): if isinstance(dfn.tool_name, BuiltinTool):
if dfn.tool_name == BuiltinTool.wolfram_alpha: if dfn.tool_name == BuiltinTool.wolfram_alpha:
@ -107,18 +107,13 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
else: else:
custom_tool_definitions.append(dfn) custom_tool_definitions.append(dfn)
AGENT_INSTANCES_BY_ID[system_id] = AgentInstance( AGENT_INSTANCES_BY_ID[system_id] = ChatAgent(
system_id=system_id, agent_config=cfg,
instance_config=request.instance_config,
model=request.model,
inference_api=self.inference_api, inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,
builtin_tools=builtin_tools, builtin_tools=builtin_tools,
custom_tool_definitions=custom_tool_definitions, custom_tool_definitions=custom_tool_definitions,
safety_api=self.safety_api,
input_shields=cfg.input_shields,
output_shields=cfg.output_shields,
prefix_messages=cfg.debug_prefix_messages,
tool_prompt_format=cfg.tool_prompt_format,
) )
return AgenticSystemCreateResponse( return AgenticSystemCreateResponse(

View file

@ -9,6 +9,6 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
class AgenticSystemConfig(BaseModel): class MetaReferenceImplConfig(BaseModel):
brave_search_api_key: Optional[str] = None brave_search_api_key: Optional[str] = None
wolfram_api_key: Optional[str] = None wolfram_api_key: Optional[str] = None

View file

@ -21,10 +21,11 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
"transformers", "transformers",
], ],
module="llama_toolchain.agentic_system.meta_reference", module="llama_toolchain.agentic_system.meta_reference",
config_class="llama_toolchain.agentic_system.meta_reference.AgenticSystemConfig", config_class="llama_toolchain.agentic_system.meta_reference.MetaReferenceImplConfig",
api_dependencies=[ api_dependencies=[
Api.inference, Api.inference,
Api.safety, Api.safety,
Api.memory,
], ],
), ),
] ]

View file

@ -15,8 +15,8 @@ from llama_models.llama3.api.datatypes import (
) )
from llama_toolchain.agentic_system.api import ( from llama_toolchain.agentic_system.api import (
AgentConfig,
AgenticSystemCreateRequest, AgenticSystemCreateRequest,
AgenticSystemInstanceConfig,
AgenticSystemSessionCreateRequest, AgenticSystemSessionCreateRequest,
AgenticSystemToolDefinition, AgenticSystemToolDefinition,
) )
@ -100,7 +100,7 @@ async def get_agent_system_instance(
create_request = AgenticSystemCreateRequest( create_request = AgenticSystemCreateRequest(
model=model, model=model,
instance_config=AgenticSystemInstanceConfig( agent_config=AgentConfig(
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
available_tools=tool_definitions, available_tools=tool_definitions,
input_shields=( input_shields=(

View file

@ -17,6 +17,7 @@ class Api(Enum):
inference = "inference" inference = "inference"
safety = "safety" safety = "safety"
agentic_system = "agentic_system" agentic_system = "agentic_system"
memory = "memory"
@json_schema_type @json_schema_type

View file

@ -6,17 +6,31 @@
from typing import List, Protocol from typing import List, Protocol
from llama_models.schema_utils import webmethod from llama_models.llama3_1.api.datatypes import InterleavedTextAttachment
from llama_models.schema_utils import webmethod
from .datatypes import * # noqa: F403 from .datatypes import * # noqa: F403
class MemoryBanks(Protocol): @json_schema_type
class RetrieveMemoryDocumentsRequest(BaseModel):
query: InterleavedTextAttachment
bank_ids: str
@json_schema_type
class RetrieveMemoryDocumentsResponse(BaseModel):
documents: List[MemoryBankDocument]
scores: List[float]
class Memory(Protocol):
@webmethod(route="/memory_banks/create") @webmethod(route="/memory_banks/create")
def create_memory_bank( def create_memory_bank(
self, self,
bank_id: str, bank_id: str,
bank_name: str, bank_name: str,
embedding_model: str,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ... ) -> None: ...
@ -46,6 +60,12 @@ class MemoryBanks(Protocol):
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ... ) -> None: ...
@webmethod(route="/memory_bank/get")
def retrieve_memory_documents(
self,
request: RetrieveMemoryDocumentsRequest,
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/get") @webmethod(route="/memory_bank/get")
def get_memory_documents( def get_memory_documents(
self, self,

View file

@ -24,7 +24,7 @@ class LlamaStack(
Datasets, Datasets,
Observability, Observability,
PostTraining, PostTraining,
MemoryBanks, Memory,
Evaluations, Evaluations,
): ):
pass pass