<WIP> memory changes

- inlined AgenticSystemInstanceConfig so API feels more ergonomic
- renamed it to AgentConfig, AgentInstance -> Agent
- added a MemoryConfig and `memory` parameter
- added `attachments` to input and `output_attachments` to the response

- some naming changes
This commit is contained in:
Ashwin Bharambe 2024-08-14 13:46:44 -07:00
parent 5655266d58
commit 48c6a32edd
12 changed files with 149 additions and 163 deletions

View file

@ -14,47 +14,10 @@ from llama_models.llama3.api.datatypes import ToolPromptFormat
from termcolor import cprint
from llama_toolchain.agentic_system.api.datatypes import (
AgenticSystemInstanceConfig,
AgenticSystemTurnResponseEvent,
AgenticSystemTurnResponseEventType,
AgenticSystemTurnResponseStepCompletePayload,
AgenticSystemTurnResponseStepProgressPayload,
AgenticSystemTurnResponseStepStartPayload,
AgenticSystemTurnResponseTurnCompletePayload,
AgenticSystemTurnResponseTurnStartPayload,
InferenceStep,
Session,
ShieldCallStep,
StepType,
ToolExecutionStep,
Turn,
)
from llama_toolchain.inference.api import ChatCompletionRequest, Inference
from llama_toolchain.inference.api.datatypes import (
Attachment,
BuiltinTool,
ChatCompletionResponseEventType,
CompletionMessage,
Message,
Role,
SamplingParams,
StopReason,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
URL,
)
from llama_toolchain.safety.api import Safety
from llama_toolchain.safety.api.datatypes import (
BuiltinShield,
ShieldDefinition,
ShieldResponse,
)
from llama_toolchain.agentic_system.api.endpoints import * # noqa
from llama_toolchain.agentic_system.api import * # noqa: F403
from llama_toolchain.inference.api import * # noqa: F403
from llama_toolchain.memory.api import * # noqa: F403
from llama_toolchain.safety.api import * # noqa: F403
from llama_toolchain.tools.base import BaseTool
from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
@ -62,27 +25,20 @@ from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
from .safety import SafetyException, ShieldRunnerMixin
class AgentInstance(ShieldRunnerMixin):
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
system_id: int,
instance_config: AgenticSystemInstanceConfig,
model: str,
agent_config: AgentConfig,
inference_api: Inference,
memory_api: Memory,
safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool],
custom_tool_definitions: List[ToolDefinition],
input_shields: List[ShieldDefinition],
output_shields: List[ShieldDefinition],
max_infer_iters: int = 10,
prefix_messages: Optional[List[Message]] = None,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
):
self.system_id = system_id
self.instance_config = instance_config
self.model = model
self.agent_config = agent_config
self.inference_api = inference_api
self.memory_api = memory_api
self.safety_api = safety_api
self.max_infer_iters = max_infer_iters
@ -93,8 +49,8 @@ class AgentInstance(ShieldRunnerMixin):
ShieldRunnerMixin.__init__(
self,
safety_api,
input_shields=input_shields,
output_shields=output_shields,
input_shields=agent_config.input_shields,
output_shields=agent_config.output_shields,
)
def create_session(self, name: str) -> Session:
@ -152,7 +108,7 @@ class AgentInstance(ShieldRunnerMixin):
# print_dialog(messages)
turn_id = str(uuid.uuid4())
params = self.instance_config.sampling_params
params = self.agent_config.sampling_params
start_time = datetime.now()
yield AgenticSystemTurnResponseStreamChunk(
event=AgenticSystemTurnResponseEvent(
@ -215,6 +171,53 @@ class AgentInstance(ShieldRunnerMixin):
)
yield chunk
async def run(
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
return
else:
yield res
async for res in self._run(
turn_id, input_messages, temperature, top_p, stream, max_gen_len
):
if isinstance(res, bool):
return
elif isinstance(res, CompletionMessage):
final_response = res
break
else:
yield res
assert final_response is not None
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
else:
yield res
yield final_response
async def run_shields_wrapper(
self,
turn_id: str,
@ -276,52 +279,7 @@ class AgentInstance(ShieldRunnerMixin):
)
)
async def run(
self,
turn_id: str,
input_messages: List[Message],
temperature: float,
top_p: float,
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# stremaing. However, it also makes things complicated here because AsyncGenerators cannot
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
async for res in self.run_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
):
if isinstance(res, bool):
return
else:
yield res
async for res in self._run(
turn_id, input_messages, temperature, top_p, stream, max_gen_len
):
if isinstance(res, bool):
return
elif isinstance(res, CompletionMessage):
final_response = res
break
else:
yield res
assert final_response is not None
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
async for res in self.run_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
else:
yield res
yield final_response
async def _should_retrieve_context(self, messages: List[Message]) -> bool: ...
async def _run(
self,
@ -332,6 +290,11 @@ class AgentInstance(ShieldRunnerMixin):
stream: bool = False,
max_gen_len: Optional[int] = None,
) -> AsyncGenerator:
need_context = await self._should_retrieve_context(input_messages)
if need_context:
context = await self._retrieve_context(input_messages)
# input_messages = preprocess_dialog(input_messages, self.prefix_messages)
# input_messages = input_messages + context
input_messages = preprocess_dialog(input_messages)
attachments = []
@ -359,10 +322,10 @@ class AgentInstance(ShieldRunnerMixin):
# where are the available tools?
req = ChatCompletionRequest(
model=self.model,
model=self.agent_config.model,
messages=input_messages,
tools=self.instance_config.available_tools,
tool_prompt_format=self.instance_config.tool_prompt_format,
tools=self.agent_config.available_tools,
tool_prompt_format=self.agent_config.tool_prompt_format,
stream=True,
sampling_params=SamplingParams(
temperature=temperature,