flesh out memory banks API

This commit is contained in:
Ashwin Bharambe 2024-08-23 06:38:15 -07:00
parent 31289e3f47
commit 77d6055d9f
11 changed files with 1792 additions and 974 deletions

View file

@ -14,12 +14,7 @@ from llama_models.llama3.api.datatypes import (
ToolPromptFormat,
)
from llama_toolchain.agentic_system.api import (
AgentConfig,
AgenticSystemCreateRequest,
AgenticSystemSessionCreateRequest,
AgenticSystemToolDefinition,
)
from llama_toolchain.agentic_system.api import AgentConfig, AgenticSystemToolDefinition
from llama_toolchain.agentic_system.client import AgenticSystemClient
from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import (
@ -32,9 +27,9 @@ from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
class AgenticSystemClientWrapper:
def __init__(self, api, system_id, custom_tools):
def __init__(self, api, agent_id, custom_tools):
self.api = api
self.system_id = system_id
self.agent_id = agent_id
self.custom_tools = custom_tools
self.session_id = None
@ -43,10 +38,8 @@ class AgenticSystemClientWrapper:
name = f"Session-{uuid.uuid4()}"
response = await self.api.create_agentic_system_session(
AgenticSystemSessionCreateRequest(
system_id=self.system_id,
session_name=name,
)
agent_id=self.agent_id,
session_name=name,
)
self.session_id = response.session_id
return self.session_id
@ -54,7 +47,7 @@ class AgenticSystemClientWrapper:
async def run(self, messages: List[Message], stream: bool = True):
async for chunk in execute_with_custom_tools(
self.api,
self.system_id,
self.agent_id,
self.session_id,
messages,
self.custom_tools,
@ -98,29 +91,27 @@ async def get_agent_system_instance(
ShieldDefinition(shield_type=BuiltinShield.injection_shield),
]
create_request = AgenticSystemCreateRequest(
agent_config = AgentConfig(
model=model,
agent_config=AgentConfig(
instructions="You are a helpful assistant",
available_tools=tool_definitions,
input_shields=(
[]
if disable_safety
else [
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
ShieldDefinition(shield_type=BuiltinShield.jailbreak_shield),
]
),
output_shields=(
[]
if disable_safety
else [
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
]
),
sampling_params=SamplingParams(),
tool_prompt_format=tool_prompt_format,
instructions="You are a helpful assistant",
tools=tool_definitions,
input_shields=(
[]
if disable_safety
else [
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
ShieldDefinition(shield_type=BuiltinShield.jailbreak_shield),
]
),
output_shields=(
[]
if disable_safety
else [
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
]
),
sampling_params=SamplingParams(),
tool_prompt_format=tool_prompt_format,
)
create_response = await api.create_agentic_system(create_request)
return AgenticSystemClientWrapper(api, create_response.system_id, custom_tools)
create_response = await api.create_agentic_system(agent_config)
return AgenticSystemClientWrapper(api, create_response.agent_id, custom_tools)