mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
<WIP> memory changes
- inlined AgenticSystemInstanceConfig so API feels more ergonomic - renamed it to AgentConfig, AgentInstance -> Agent - added a MemoryConfig and `memory` parameter - added `attachments` to input and `output_attachments` to the response - some naming changes
This commit is contained in:
parent
5655266d58
commit
48c6a32edd
12 changed files with 149 additions and 163 deletions
|
@ -14,47 +14,10 @@ from llama_models.llama3.api.datatypes import ToolPromptFormat
|
|||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.agentic_system.api.datatypes import (
|
||||
AgenticSystemInstanceConfig,
|
||||
AgenticSystemTurnResponseEvent,
|
||||
AgenticSystemTurnResponseEventType,
|
||||
AgenticSystemTurnResponseStepCompletePayload,
|
||||
AgenticSystemTurnResponseStepProgressPayload,
|
||||
AgenticSystemTurnResponseStepStartPayload,
|
||||
AgenticSystemTurnResponseTurnCompletePayload,
|
||||
AgenticSystemTurnResponseTurnStartPayload,
|
||||
InferenceStep,
|
||||
Session,
|
||||
ShieldCallStep,
|
||||
StepType,
|
||||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
|
||||
from llama_toolchain.inference.api import ChatCompletionRequest, Inference
|
||||
from llama_toolchain.inference.api.datatypes import (
|
||||
Attachment,
|
||||
BuiltinTool,
|
||||
ChatCompletionResponseEventType,
|
||||
CompletionMessage,
|
||||
Message,
|
||||
Role,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
URL,
|
||||
)
|
||||
from llama_toolchain.safety.api import Safety
|
||||
from llama_toolchain.safety.api.datatypes import (
|
||||
BuiltinShield,
|
||||
ShieldDefinition,
|
||||
ShieldResponse,
|
||||
)
|
||||
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||
from llama_toolchain.inference.api import * # noqa: F403
|
||||
from llama_toolchain.memory.api import * # noqa: F403
|
||||
from llama_toolchain.safety.api import * # noqa: F403
|
||||
|
||||
from llama_toolchain.tools.base import BaseTool
|
||||
from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
|
||||
|
@ -62,27 +25,20 @@ from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
|
|||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
|
||||
class AgentInstance(ShieldRunnerMixin):
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
system_id: int,
|
||||
instance_config: AgenticSystemInstanceConfig,
|
||||
model: str,
|
||||
agent_config: AgentConfig,
|
||||
inference_api: Inference,
|
||||
memory_api: Memory,
|
||||
safety_api: Safety,
|
||||
builtin_tools: List[SingleMessageBuiltinTool],
|
||||
custom_tool_definitions: List[ToolDefinition],
|
||||
input_shields: List[ShieldDefinition],
|
||||
output_shields: List[ShieldDefinition],
|
||||
max_infer_iters: int = 10,
|
||||
prefix_messages: Optional[List[Message]] = None,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
):
|
||||
self.system_id = system_id
|
||||
self.instance_config = instance_config
|
||||
|
||||
self.model = model
|
||||
self.agent_config = agent_config
|
||||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.safety_api = safety_api
|
||||
|
||||
self.max_infer_iters = max_infer_iters
|
||||
|
@ -93,8 +49,8 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
ShieldRunnerMixin.__init__(
|
||||
self,
|
||||
safety_api,
|
||||
input_shields=input_shields,
|
||||
output_shields=output_shields,
|
||||
input_shields=agent_config.input_shields,
|
||||
output_shields=agent_config.output_shields,
|
||||
)
|
||||
|
||||
def create_session(self, name: str) -> Session:
|
||||
|
@ -152,7 +108,7 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
# print_dialog(messages)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
params = self.instance_config.sampling_params
|
||||
params = self.agent_config.sampling_params
|
||||
start_time = datetime.now()
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
|
@ -215,6 +171,53 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
)
|
||||
yield chunk
|
||||
|
||||
async def run(
|
||||
self,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
stream: bool = False,
|
||||
max_gen_len: Optional[int] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||
|
||||
async for res in self.run_shields_wrapper(
|
||||
turn_id, input_messages, self.input_shields, "user-input"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
else:
|
||||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
turn_id, input_messages, temperature, top_p, stream, max_gen_len
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
elif isinstance(res, CompletionMessage):
|
||||
final_response = res
|
||||
break
|
||||
else:
|
||||
yield res
|
||||
|
||||
assert final_response is not None
|
||||
# for output shields run on the full input and output combination
|
||||
messages = input_messages + [final_response]
|
||||
|
||||
async for res in self.run_shields_wrapper(
|
||||
turn_id, messages, self.output_shields, "assistant-output"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
else:
|
||||
yield res
|
||||
|
||||
yield final_response
|
||||
|
||||
async def run_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
|
@ -276,52 +279,7 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
stream: bool = False,
|
||||
max_gen_len: Optional[int] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# stremaing. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||
|
||||
async for res in self.run_shields_wrapper(
|
||||
turn_id, input_messages, self.input_shields, "user-input"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
else:
|
||||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
turn_id, input_messages, temperature, top_p, stream, max_gen_len
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
elif isinstance(res, CompletionMessage):
|
||||
final_response = res
|
||||
break
|
||||
else:
|
||||
yield res
|
||||
|
||||
assert final_response is not None
|
||||
# for output shields run on the full input and output combination
|
||||
messages = input_messages + [final_response]
|
||||
|
||||
async for res in self.run_shields_wrapper(
|
||||
turn_id, messages, self.output_shields, "assistant-output"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
else:
|
||||
yield res
|
||||
|
||||
yield final_response
|
||||
async def _should_retrieve_context(self, messages: List[Message]) -> bool: ...
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
|
@ -332,6 +290,11 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
stream: bool = False,
|
||||
max_gen_len: Optional[int] = None,
|
||||
) -> AsyncGenerator:
|
||||
need_context = await self._should_retrieve_context(input_messages)
|
||||
if need_context:
|
||||
context = await self._retrieve_context(input_messages)
|
||||
# input_messages = preprocess_dialog(input_messages, self.prefix_messages)
|
||||
# input_messages = input_messages + context
|
||||
input_messages = preprocess_dialog(input_messages)
|
||||
|
||||
attachments = []
|
||||
|
@ -359,10 +322,10 @@ class AgentInstance(ShieldRunnerMixin):
|
|||
|
||||
# where are the available tools?
|
||||
req = ChatCompletionRequest(
|
||||
model=self.model,
|
||||
model=self.agent_config.model,
|
||||
messages=input_messages,
|
||||
tools=self.instance_config.available_tools,
|
||||
tool_prompt_format=self.instance_config.tool_prompt_format,
|
||||
tools=self.agent_config.available_tools,
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
sampling_params=SamplingParams(
|
||||
temperature=temperature,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue