<WIP> memory changes

- inlined AgenticSystemInstanceConfig so API feels more ergonomic
- renamed it to AgentConfig, AgentInstance -> Agent
- added a MemoryConfig and `memory` parameter
- added `attachments` to input and `output_attachments` to the response

- some naming changes
This commit is contained in:
Ashwin Bharambe 2024-08-14 13:46:44 -07:00
parent 5655266d58
commit 48c6a32edd
12 changed files with 149 additions and 163 deletions

View file

@ -96,6 +96,8 @@ class Turn(BaseModel):
]
steps: List[Step]
output_message: CompletionMessage
output_attachments: List[Attachment] = Field(default_factory=list)
started_at: datetime
completed_at: Optional[datetime] = None
@ -111,13 +113,22 @@ class Session(BaseModel):
@json_schema_type
class AgenticSystemInstanceConfig(BaseModel):
instructions: str
class MemoryConfig(BaseModel):
memory_bank_id: str
# this configuration can hold other information we may want to configure
# how will the agent use the memory bank API?
#
#
class AgentConfigCommon(BaseModel):
sampling_params: Optional[SamplingParams] = SamplingParams()
# zero-shot or built-in tool configurations as input to the model
available_tools: Optional[List[AgenticSystemToolDefinition]] = Field(
default_factory=list
)
memory: Optional[List[MemoryConfig]] = Field(default_factory=list)
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
@ -130,6 +141,16 @@ class AgenticSystemInstanceConfig(BaseModel):
)
@json_schema_type
class AgentConfig(AgentConfigCommon):
model: str
instructions: str
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: Optional[str] = None
class AgenticSystemTurnResponseEventType(Enum):
step_start = "step_start"
step_complete = "step_complete"

View file

@ -7,18 +7,17 @@
from .datatypes import * # noqa: F403
from typing import Protocol
# this dependency is annoying and we need a forked up version anyway
from llama_models.schema_utils import json_schema_type, webmethod
@json_schema_type
class AgenticSystemCreateRequest(BaseModel):
model: str
instance_config: AgenticSystemInstanceConfig
agent_config: AgentConfig
@json_schema_type
class AgenticSystemCreateResponse(BaseModel):
# TODO: rename this to agent_id
system_id: str
@ -34,20 +33,22 @@ class AgenticSystemSessionCreateResponse(BaseModel):
@json_schema_type
# what's the URI?
class AgenticSystemTurnCreateRequest(BaseModel):
class AgenticSystemTurnCreateRequest(BaseModel, AgentConfigOverridablePerTurn):
system_id: str
session_id: str
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: List[
Union[
UserMessage,
ToolResponseMessage,
]
]
attachments: List[Attachment]
stream: Optional[bool] = False
override_config: Optional[AgenticSystemInstanceConfig] = None
@json_schema_type(
@ -93,22 +94,6 @@ class AgenticSystem(Protocol):
request: AgenticSystemSessionCreateRequest,
) -> AgenticSystemSessionCreateResponse: ...
@webmethod(route="/agentic_system/memory_bank/attach")
async def attach_memory_bank_to_agentic_system(
self,
agent_id: str,
session_id: str,
memory_bank_ids: List[str],
) -> None: ...
@webmethod(route="/agentic_system/memory_bank/detach")
async def detach_memory_bank_from_agentic_system(
self,
agent_id: str,
session_id: str,
memory_bank_ids: List[str],
) -> None: ...
@webmethod(route="/agentic_system/session/get")
async def get_agentic_system_session(
self,

View file

@ -24,10 +24,10 @@ from termcolor import cprint
from llama_toolchain.agentic_system.event_logger import EventLogger
from .api import (
AgentConfig,
AgenticSystem,
AgenticSystemCreateRequest,
AgenticSystemCreateResponse,
AgenticSystemInstanceConfig,
AgenticSystemSessionCreateRequest,
AgenticSystemSessionCreateResponse,
AgenticSystemToolDefinition,
@ -129,7 +129,7 @@ async def run_main(host: str, port: int):
create_request = AgenticSystemCreateRequest(
model="Meta-Llama3.1-8B-Instruct",
instance_config=AgenticSystemInstanceConfig(
agent_config=AgentConfig(
instructions="You are a helpful assistant",
sampling_params=SamplingParams(),
available_tools=tool_definitions,

View file

@ -5,4 +5,4 @@
# the root directory of this source tree.
from .agentic_system import get_provider_impl # noqa
from .config import AgenticSystemConfig # noqa
from .config import MetaReferenceImplConfig # noqa

View file

@ -14,47 +14,10 @@ from llama_models.llama3.api.datatypes import ToolPromptFormat
from termcolor import cprint
from llama_toolchain.agentic_system.api.datatypes import (
AgenticSystemInstanceConfig,
AgenticSystemTurnResponseEvent,
AgenticSystemTurnResponseEventType,
AgenticSystemTurnResponseStepCompletePayload,
AgenticSystemTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepStartPayload,
AgenticSystemTurnResponseTurnCompletePayload,
AgenticSystemTurnResponseTurnStartPayload,
InferenceStep,
Session,
ShieldCallStep,
StepType,
ToolExecutionStep,
Turn,
)
from llama_toolchain.inference.api import ChatCompletionRequest, Inference
from llama_toolchain.inference.api.datatypes import (
Attachment,
BuiltinTool,
ChatCompletionResponseEventType,
CompletionMessage,
Message,
Role,
SamplingParams,
StopReason,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
URL,
)
from llama_toolchain.safety.api import Safety
from llama_toolchain.safety.api.datatypes import (
BuiltinShield,
ShieldDefinition,
ShieldResponse,
)
from llama_toolchain.agentic_system.api.endpoints import * # noqa
from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
from llama_toolchain.tools.base import BaseTool
from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
@ -62,27 +25,20 @@ from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
from .safety import SafetyException, ShieldRunnerMixin
class AgentInstance(ShieldRunnerMixin):
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
system_id: int,
instance_config: AgenticSystemInstanceConfig,
model: str,
agent_config: AgentConfig,
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool],
custom_tool_definitions: List[ToolDefinition],
input_shields: List[ShieldDefinition],
output_shields: List[ShieldDefinition],
max_infer_iters: int = 10,
prefix_messages: Optional[List[Message]] = None,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
):
self.system_id = system_id
self.instance_config = instance_config
self.model = model
self.agent_config = agent_config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.max_infer_iters = max_infer_iters
@ -93,8 +49,8 @@ class AgentInstance(ShieldRunnerMixin):
ShieldRunnerMixin.__init__(
self,
safety_api,
input_shields=input_shields,
output_shields=output_shields,
input_shields=agent_config.input_shields,
output_shields=agent_config.output_shields,
)
def create_session(self, name: str) -> Session:
@ -152,7 +108,7 @@ class AgentInstance(ShieldRunnerMixin):
# print_dialog(messages)
turn_id = str(uuid.uuid4())
params = self.instance_config.sampling_params
params = self.agent_config.sampling_params
start_time = datetime.now()
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
@ -215,6 +171,53 @@ class AgentInstance(ShieldRunnerMixin):
)
yield chunk
async def run(
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
return
else:
yield res
async for res in self._run(
turn_id, input_messages, temperature, top_p, stream, max_gen_len
):
if isinstance(res, bool):
return
elif isinstance(res, CompletionMessage):
final_response = res
break
else:
yield res
assert final_response is not None
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
else:
yield res
yield final_response
async def run_shields_wrapper(
self,
turn_id: str,
@ -276,52 +279,7 @@ class AgentInstance(ShieldRunnerMixin):
)
)
async def run(
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# stremaing. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
return
else:
yield res
async for res in self._run(
turn_id, input_messages, temperature, top_p, stream, max_gen_len
):
if isinstance(res, bool):
return
elif isinstance(res, CompletionMessage):
final_response = res
break
else:
yield res
assert final_response is not None
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
else:
yield res
yield final_response
async def _should_retrieve_context(self, messages: List[Message]) -> bool: ...
async def _run(
self,
@ -332,6 +290,11 @@ class AgentInstance(ShieldRunnerMixin):
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
need_context = await self._should_retrieve_context(input_messages)
if need_context:
context = await self._retrieve_context(input_messages)
# input_messages = preprocess_dialog(input_messages, self.prefix_messages)
# input_messages = input_messages + context
input_messages = preprocess_dialog(input_messages)
attachments = []
@ -359,10 +322,10 @@ class AgentInstance(ShieldRunnerMixin):
# where are the available tools?
req = ChatCompletionRequest(
model=self.model,
model=self.agent_config.model,
messages=input_messages,
tools=self.instance_config.available_tools,
tool_prompt_format=self.instance_config.tool_prompt_format,
tools=self.agent_config.available_tools,
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=SamplingParams(
temperature=temperature,

View file

@ -13,16 +13,11 @@ from typing import AsyncGenerator, Dict
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import Inference
from llama_toolchain.inference.api.datatypes import BuiltinTool
from llama_toolchain.memory.api import Memory
from llama_toolchain.safety.api import Safety
from llama_toolchain.agentic_system.api.endpoints import * # noqa
from llama_toolchain.agentic_system.api import (
AgenticSystem,
AgenticSystemCreateRequest,
AgenticSystemCreateResponse,
AgenticSystemSessionCreateRequest,
AgenticSystemSessionCreateResponse,
AgenticSystemTurnCreateRequest,
)
from llama_toolchain.agentic_system.api import * # noqa: F403
from .agent_instance import ChatAgent
from .config import MetaReferenceImplConfig
from llama_toolchain.tools.builtin import (
BraveSearchTool,
@ -34,16 +29,16 @@ from llama_toolchain.tools.safety import with_safety
from .agent_instance import AgentInstance
from .config import AgenticSystemConfig
logger = logging.getLogger()
logger.setLevel(logging.INFO)
async def get_provider_impl(config: AgenticSystemConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
):
assert isinstance(
config, AgenticSystemConfig
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl(
@ -60,11 +55,16 @@ AGENT_INSTANCES_BY_ID = {}
class MetaReferenceAgenticSystemImpl(AgenticSystem):
def __init__(
self, config: AgenticSystemConfig, inference_api: Inference, safety_api: Safety
self,
config: MetaReferenceImplConfig,
inference_api: Inference,
safety_api: Safety,
memory_api: Memory,
):
self.config = config
self.inference_api = inference_api
self.safety_api = safety_api
self.memory_api = memory_api
async def initialize(self) -> None:
pass
@ -77,7 +77,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
builtin_tools = []
custom_tool_definitions = []
cfg = request.instance_config
cfg = request.agent_config
for dfn in cfg.available_tools:
if isinstance(dfn.tool_name, BuiltinTool):
if dfn.tool_name == BuiltinTool.wolfram_alpha:
@ -107,18 +107,13 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
else:
custom_tool_definitions.append(dfn)
AGENT_INSTANCES_BY_ID[system_id] = AgentInstance(
system_id=system_id,
instance_config=request.instance_config,
model=request.model,
AGENT_INSTANCES_BY_ID[system_id] = ChatAgent(
agent_config=cfg,
inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,
builtin_tools=builtin_tools,
custom_tool_definitions=custom_tool_definitions,
safety_api=self.safety_api,
input_shields=cfg.input_shields,
output_shields=cfg.output_shields,
prefix_messages=cfg.debug_prefix_messages,
tool_prompt_format=cfg.tool_prompt_format,
)
return AgenticSystemCreateResponse(

View file

@ -9,6 +9,6 @@ from typing import Optional
from pydantic import BaseModel
class AgenticSystemConfig(BaseModel):
class MetaReferenceImplConfig(BaseModel):
brave_search_api_key: Optional[str] = None
wolfram_api_key: Optional[str] = None

View file

@ -21,10 +21,11 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
"transformers",
],
module="llama_toolchain.agentic_system.meta_reference",
config_class="llama_toolchain.agentic_system.meta_reference.AgenticSystemConfig",
config_class="llama_toolchain.agentic_system.meta_reference.MetaReferenceImplConfig",
api_dependencies=[
Api.inference,
Api.safety,
Api.memory,
],
),
]

View file

@ -15,8 +15,8 @@ from llama_models.llama3.api.datatypes import (
)
from llama_toolchain.agentic_system.api import (
AgentConfig,
AgenticSystemCreateRequest,
AgenticSystemInstanceConfig,
AgenticSystemSessionCreateRequest,
AgenticSystemToolDefinition,
)
@ -100,7 +100,7 @@ async def get_agent_system_instance(
create_request = AgenticSystemCreateRequest(
model=model,
instance_config=AgenticSystemInstanceConfig(
agent_config=AgentConfig(
instructions="You are a helpful assistant",
available_tools=tool_definitions,
input_shields=(

View file

@ -17,6 +17,7 @@ class Api(Enum):
inference = "inference"
safety = "safety"
agentic_system = "agentic_system"
memory = "memory"
@json_schema_type

View file

@ -6,17 +6,31 @@
from typing import List, Protocol
from llama_models.schema_utils import webmethod
from llama_models.llama3_1.api.datatypes import InterleavedTextAttachment
from llama_models.schema_utils import webmethod
from .datatypes import * # noqa: F403
class MemoryBanks(Protocol):
@json_schema_type
class RetrieveMemoryDocumentsRequest(BaseModel):
query: InterleavedTextAttachment
bank_ids: str
@json_schema_type
class RetrieveMemoryDocumentsResponse(BaseModel):
documents: List[MemoryBankDocument]
scores: List[float]
class Memory(Protocol):
@webmethod(route="/memory_banks/create")
def create_memory_bank(
self,
bank_id: str,
bank_name: str,
embedding_model: str,
documents: List[MemoryBankDocument],
) -> None: ...
@ -46,6 +60,12 @@ class MemoryBanks(Protocol):
documents: List[MemoryBankDocument],
) -> None: ...
@webmethod(route="/memory_bank/get")
def retrieve_memory_documents(
self,
request: RetrieveMemoryDocumentsRequest,
) -> List[MemoryBankDocument]: ...
@webmethod(route="/memory_bank/get")
def get_memory_documents(
self,

View file

@ -24,7 +24,7 @@ class LlamaStack(
Datasets,
Observability,
PostTraining,
MemoryBanks,
Memory,
Evaluations,
):
pass