<WIP> memory changes

- inlined AgenticSystemInstanceConfig so API feels more ergonomic
- renamed it to AgentConfig, AgentInstance -> Agent
- added a MemoryConfig and `memory` parameter
- added `attachments` to input and `output_attachments` to the response

- some naming changes
This commit is contained in:
Ashwin Bharambe 2024-08-14 13:46:44 -07:00
parent 5655266d58
commit 48c6a32edd
12 changed files with 149 additions and 163 deletions

View file

@ -5,4 +5,4 @@
# the root directory of this source tree.
from .agentic_system import get_provider_impl # noqa
from .config import AgenticSystemConfig # noqa
from .config import MetaReferenceImplConfig # noqa

View file

@ -14,47 +14,10 @@ from llama_models.llama3.api.datatypes import ToolPromptFormat
from termcolor import cprint
from llama_toolchain.agentic_system.api.datatypes import (
AgenticSystemInstanceConfig,
AgenticSystemTurnResponseEvent,
AgenticSystemTurnResponseEventType,
AgenticSystemTurnResponseStepCompletePayload,
AgenticSystemTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepStartPayload,
AgenticSystemTurnResponseTurnCompletePayload,
AgenticSystemTurnResponseTurnStartPayload,
InferenceStep,
Session,
ShieldCallStep,
StepType,
ToolExecutionStep,
Turn,
)
from llama_toolchain.inference.api import ChatCompletionRequest, Inference
from llama_toolchain.inference.api.datatypes import (
Attachment,
BuiltinTool,
ChatCompletionResponseEventType,
CompletionMessage,
Message,
Role,
SamplingParams,
StopReason,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
URL,
)
from llama_toolchain.safety.api import Safety
from llama_toolchain.safety.api.datatypes import (
BuiltinShield,
ShieldDefinition,
ShieldResponse,
)
from llama_toolchain.agentic_system.api.endpoints import * # noqa
from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
from llama_toolchain.tools.base import BaseTool
from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
@ -62,27 +25,20 @@ from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
from .safety import SafetyException, ShieldRunnerMixin
class AgentInstance(ShieldRunnerMixin):
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
system_id: int,
instance_config: AgenticSystemInstanceConfig,
model: str,
agent_config: AgentConfig,
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool],
custom_tool_definitions: List[ToolDefinition],
input_shields: List[ShieldDefinition],
output_shields: List[ShieldDefinition],
max_infer_iters: int = 10,
prefix_messages: Optional[List[Message]] = None,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
):
self.system_id = system_id
self.instance_config = instance_config
self.model = model
self.agent_config = agent_config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.max_infer_iters = max_infer_iters
@ -93,8 +49,8 @@ class AgentInstance(ShieldRunnerMixin):
ShieldRunnerMixin.__init__(
self,
safety_api,
input_shields=input_shields,
output_shields=output_shields,
input_shields=agent_config.input_shields,
output_shields=agent_config.output_shields,
)
def create_session(self, name: str) -> Session:
@ -152,7 +108,7 @@ class AgentInstance(ShieldRunnerMixin):
# print_dialog(messages)
turn_id = str(uuid.uuid4())
params = self.instance_config.sampling_params
params = self.agent_config.sampling_params
start_time = datetime.now()
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
@ -215,6 +171,53 @@ class AgentInstance(ShieldRunnerMixin):
)
yield chunk
async def run(
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
return
else:
yield res
async for res in self._run(
turn_id, input_messages, temperature, top_p, stream, max_gen_len
):
if isinstance(res, bool):
return
elif isinstance(res, CompletionMessage):
final_response = res
break
else:
yield res
assert final_response is not None
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
else:
yield res
yield final_response
async def run_shields_wrapper(
self,
turn_id: str,
@ -276,52 +279,7 @@ class AgentInstance(ShieldRunnerMixin):
)
)
async def run(
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# stremaing. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
return
else:
yield res
async for res in self._run(
turn_id, input_messages, temperature, top_p, stream, max_gen_len
):
if isinstance(res, bool):
return
elif isinstance(res, CompletionMessage):
final_response = res
break
else:
yield res
assert final_response is not None
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
else:
yield res
yield final_response
async def _should_retrieve_context(self, messages: List[Message]) -> bool: ...
async def _run(
self,
@ -332,6 +290,11 @@ class AgentInstance(ShieldRunnerMixin):
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
need_context = await self._should_retrieve_context(input_messages)
if need_context:
context = await self._retrieve_context(input_messages)
# input_messages = preprocess_dialog(input_messages, self.prefix_messages)
# input_messages = input_messages + context
input_messages = preprocess_dialog(input_messages)
attachments = []
@ -359,10 +322,10 @@ class AgentInstance(ShieldRunnerMixin):
# where are the available tools?
req = ChatCompletionRequest(
model=self.model,
model=self.agent_config.model,
messages=input_messages,
tools=self.instance_config.available_tools,
tool_prompt_format=self.instance_config.tool_prompt_format,
tools=self.agent_config.available_tools,
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=SamplingParams(
temperature=temperature,

View file

@ -13,16 +13,11 @@ from typing import AsyncGenerator, Dict
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.inference.api import Inference
from llama_toolchain.inference.api.datatypes import BuiltinTool
from llama_toolchain.memory.api import Memory
from llama_toolchain.safety.api import Safety
from llama_toolchain.agentic_system.api.endpoints import * # noqa
from llama_toolchain.agentic_system.api import (
AgenticSystem,
AgenticSystemCreateRequest,
AgenticSystemCreateResponse,
AgenticSystemSessionCreateRequest,
AgenticSystemSessionCreateResponse,
AgenticSystemTurnCreateRequest,
)
from llama_toolchain.agentic_system.api import * # noqa: F403
from .agent_instance import ChatAgent
from .config import MetaReferenceImplConfig
from llama_toolchain.tools.builtin import (
BraveSearchTool,
@ -34,16 +29,16 @@ from llama_toolchain.tools.safety import with_safety
from .agent_instance import AgentInstance
from .config import AgenticSystemConfig
logger = logging.getLogger()
logger.setLevel(logging.INFO)
async def get_provider_impl(config: AgenticSystemConfig, deps: Dict[Api, ProviderSpec]):
async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
):
assert isinstance(
config, AgenticSystemConfig
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgenticSystemImpl(
@ -60,11 +55,16 @@ AGENT_INSTANCES_BY_ID = {}
class MetaReferenceAgenticSystemImpl(AgenticSystem):
def __init__(
self, config: AgenticSystemConfig, inference_api: Inference, safety_api: Safety
self,
config: MetaReferenceImplConfig,
inference_api: Inference,
safety_api: Safety,
memory_api: Memory,
):
self.config = config
self.inference_api = inference_api
self.safety_api = safety_api
self.memory_api = memory_api
async def initialize(self) -> None:
pass
@ -77,7 +77,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
builtin_tools = []
custom_tool_definitions = []
cfg = request.instance_config
cfg = request.agent_config
for dfn in cfg.available_tools:
if isinstance(dfn.tool_name, BuiltinTool):
if dfn.tool_name == BuiltinTool.wolfram_alpha:
@ -107,18 +107,13 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
else:
custom_tool_definitions.append(dfn)
AGENT_INSTANCES_BY_ID[system_id] = AgentInstance(
system_id=system_id,
instance_config=request.instance_config,
model=request.model,
AGENT_INSTANCES_BY_ID[system_id] = ChatAgent(
agent_config=cfg,
inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,
builtin_tools=builtin_tools,
custom_tool_definitions=custom_tool_definitions,
safety_api=self.safety_api,
input_shields=cfg.input_shields,
output_shields=cfg.output_shields,
prefix_messages=cfg.debug_prefix_messages,
tool_prompt_format=cfg.tool_prompt_format,
)
return AgenticSystemCreateResponse(

View file

@ -9,6 +9,6 @@ from typing import Optional
from pydantic import BaseModel
class AgenticSystemConfig(BaseModel):
class MetaReferenceImplConfig(BaseModel):
brave_search_api_key: Optional[str] = None
wolfram_api_key: Optional[str] = None