mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
flesh out memory banks API
This commit is contained in:
parent
31289e3f47
commit
77d6055d9f
11 changed files with 1792 additions and 974 deletions
|
@ -10,13 +10,13 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, validator
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_toolchain.common.deployment_types import * # noqa: F403
|
from llama_toolchain.common.deployment_types import * # noqa: F403
|
||||||
from llama_toolchain.inference.api import * # noqa: F403
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
from llama_toolchain.safety.api import * # noqa: F403
|
||||||
from llama_toolchain.memory.api.datatypes import * # noqa: F403
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -25,31 +25,81 @@ class Attachment(BaseModel):
|
||||||
mime_type: str
|
mime_type: str
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemBuiltinTool(BuiltinTool):
|
class AgenticSystemTool(Enum):
|
||||||
|
brave_search = "brave_search"
|
||||||
|
wolfram_alpha = "wolfram_alpha"
|
||||||
|
photogen = "photogen"
|
||||||
|
code_interpreter = "code_interpreter"
|
||||||
|
|
||||||
|
function_call = "function_call"
|
||||||
memory = "memory"
|
memory = "memory"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class ToolDefinitionCommon(BaseModel):
|
||||||
class AgenticSystemToolDefinition(BaseModel):
|
|
||||||
tool_name: Union[AgenticSystemBuiltinTool, str]
|
|
||||||
description: Optional[str] = None
|
|
||||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
|
||||||
|
|
||||||
@validator("tool_name", pre=True)
|
|
||||||
@classmethod
|
|
||||||
def validate_field(cls, v):
|
|
||||||
if isinstance(v, str):
|
|
||||||
try:
|
|
||||||
return AgenticSystemBuiltinTool(v)
|
|
||||||
except ValueError:
|
|
||||||
return v
|
|
||||||
return v
|
|
||||||
|
|
||||||
execution_config: Optional[RestAPIExecutionConfig] = None
|
|
||||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BraveSearchToolDefinition(ToolDefinitionCommon):
|
||||||
|
type: Literal[AgenticSystemTool.brave_search.value] = (
|
||||||
|
AgenticSystemTool.brave_search.value
|
||||||
|
)
|
||||||
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class WolframAlphaToolDefinition(ToolDefinitionCommon):
|
||||||
|
type: Literal[AgenticSystemTool.wolfram_alpha.value] = (
|
||||||
|
AgenticSystemTool.wolfram_alpha.value
|
||||||
|
)
|
||||||
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PhotogenToolDefinition(ToolDefinitionCommon):
|
||||||
|
type: Literal[AgenticSystemTool.photogen.value] = AgenticSystemTool.photogen.value
|
||||||
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CodeInterpreterToolDefinition(ToolDefinitionCommon):
|
||||||
|
type: Literal[AgenticSystemTool.code_interpreter.value] = (
|
||||||
|
AgenticSystemTool.code_interpreter.value
|
||||||
|
)
|
||||||
|
enable_inline_code_execution: bool = True
|
||||||
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class FunctionCallToolDefinition(ToolDefinitionCommon):
|
||||||
|
type: Literal[AgenticSystemTool.function_call.value] = (
|
||||||
|
AgenticSystemTool.function_call.value
|
||||||
|
)
|
||||||
|
description: str
|
||||||
|
parameters: Dict[str, ToolParamDefinition]
|
||||||
|
remote_execution: Optional[RestAPIExecutionConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MemoryToolDefinition(ToolDefinitionCommon):
|
||||||
|
type: Literal[AgenticSystemTool.memory.value] = AgenticSystemTool.memory.value
|
||||||
|
memory_banks: List[MemoryBank] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
AgenticSystemToolDefinition = Annotated[
|
||||||
|
Union[
|
||||||
|
BraveSearchToolDefinition,
|
||||||
|
WolframAlphaToolDefinition,
|
||||||
|
PhotogenToolDefinition,
|
||||||
|
CodeInterpreterToolDefinition,
|
||||||
|
FunctionCallToolDefinition,
|
||||||
|
MemoryToolDefinition,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class StepCommon(BaseModel):
|
class StepCommon(BaseModel):
|
||||||
turn_id: str
|
turn_id: str
|
||||||
step_id: str
|
step_id: str
|
||||||
|
@ -136,27 +186,45 @@ class Session(BaseModel):
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class MemoryBankConfigCommon(BaseModel):
|
||||||
class MemoryConfig(BaseModel):
|
bank_id: str
|
||||||
memory_bank_id: str
|
|
||||||
|
|
||||||
# this configuration can hold other information we may want to configure
|
|
||||||
# how will the agent use the memory bank API?
|
class VectorMemoryBankConfig(MemoryBankConfigCommon):
|
||||||
#
|
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
#
|
|
||||||
|
|
||||||
|
class KeyValueMemoryBankConfig(MemoryBankConfigCommon):
|
||||||
|
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||||
|
keys: List[str] # what keys to focus on
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordMemoryBankConfig(MemoryBankConfigCommon):
|
||||||
|
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||||
|
|
||||||
|
|
||||||
|
class GraphMemoryBankConfig(MemoryBankConfigCommon):
|
||||||
|
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||||
|
entities: List[str] # what entities to focus on
|
||||||
|
|
||||||
|
|
||||||
|
MemoryBankConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
VectorMemoryBankConfig,
|
||||||
|
KeyValueMemoryBankConfig,
|
||||||
|
KeywordMemoryBankConfig,
|
||||||
|
GraphMemoryBankConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
memory_configs: Optional[List[MemoryConfig]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
|
|
||||||
# if you completely want to replace the messages prefixed by the system,
|
|
||||||
# this is debug only
|
|
||||||
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
|
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
|
@ -168,6 +236,7 @@ class AgentConfigCommon(BaseModel):
|
||||||
class AgentConfig(AgentConfigCommon):
|
class AgentConfig(AgentConfigCommon):
|
||||||
model: str
|
model: str
|
||||||
instructions: str
|
instructions: str
|
||||||
|
memory_bank_configs: Optional[List[MemoryBankConfig]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
|
|
|
@ -10,21 +10,9 @@ from typing import Protocol
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemCreateRequest(BaseModel):
|
|
||||||
agent_config: AgentConfig
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemCreateResponse(BaseModel):
|
class AgenticSystemCreateResponse(BaseModel):
|
||||||
# TODO: rename this to agent_id
|
agent_id: str
|
||||||
system_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class AgenticSystemSessionCreateRequest(BaseModel):
|
|
||||||
system_id: str
|
|
||||||
session_name: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -33,8 +21,8 @@ class AgenticSystemSessionCreateResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemTurnCreateRequest(BaseModel, AgentConfigOverridablePerTurn):
|
class AgenticSystemTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||||
system_id: str
|
agent_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
|
|
||||||
# TODO: figure out how we can simplify this and make why
|
# TODO: figure out how we can simplify this and make why
|
||||||
|
@ -67,7 +55,7 @@ class AgenticSystem(Protocol):
|
||||||
@webmethod(route="/agentic_system/create")
|
@webmethod(route="/agentic_system/create")
|
||||||
async def create_agentic_system(
|
async def create_agentic_system(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemCreateRequest,
|
agent_config: AgentConfig,
|
||||||
) -> AgenticSystemCreateResponse: ...
|
) -> AgenticSystemCreateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/turn/create")
|
@webmethod(route="/agentic_system/turn/create")
|
||||||
|
@ -91,7 +79,8 @@ class AgenticSystem(Protocol):
|
||||||
@webmethod(route="/agentic_system/session/create")
|
@webmethod(route="/agentic_system/session/create")
|
||||||
async def create_agentic_system_session(
|
async def create_agentic_system_session(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemSessionCreateRequest,
|
agent_id: str,
|
||||||
|
session_name: str,
|
||||||
) -> AgenticSystemSessionCreateResponse: ...
|
) -> AgenticSystemSessionCreateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/get")
|
@webmethod(route="/agentic_system/session/get")
|
||||||
|
|
|
@ -26,9 +26,7 @@ from llama_toolchain.agentic_system.event_logger import EventLogger
|
||||||
from .api import (
|
from .api import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
AgenticSystem,
|
AgenticSystem,
|
||||||
AgenticSystemCreateRequest,
|
|
||||||
AgenticSystemCreateResponse,
|
AgenticSystemCreateResponse,
|
||||||
AgenticSystemSessionCreateRequest,
|
|
||||||
AgenticSystemSessionCreateResponse,
|
AgenticSystemSessionCreateResponse,
|
||||||
AgenticSystemToolDefinition,
|
AgenticSystemToolDefinition,
|
||||||
AgenticSystemTurnCreateRequest,
|
AgenticSystemTurnCreateRequest,
|
||||||
|
@ -127,27 +125,23 @@ async def run_main(host: str, port: int):
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
create_request = AgenticSystemCreateRequest(
|
agent_config = AgentConfig(
|
||||||
model="Meta-Llama3.1-8B-Instruct",
|
model="Meta-Llama3.1-8B-Instruct",
|
||||||
agent_config=AgentConfig(
|
instructions="You are a helpful assistant",
|
||||||
instructions="You are a helpful assistant",
|
sampling_params=SamplingParams(),
|
||||||
sampling_params=SamplingParams(),
|
tools=tool_definitions,
|
||||||
available_tools=tool_definitions,
|
input_shields=[],
|
||||||
input_shields=[],
|
output_shields=[],
|
||||||
output_shields=[],
|
debug_prefix_messages=[],
|
||||||
debug_prefix_messages=[],
|
tool_prompt_format=ToolPromptFormat.json,
|
||||||
tool_prompt_format=ToolPromptFormat.json,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
create_response = await api.create_agentic_system(create_request)
|
create_response = await api.create_agentic_system(agent_config)
|
||||||
print(create_response)
|
print(create_response)
|
||||||
|
|
||||||
session_response = await api.create_agentic_system_session(
|
session_response = await api.create_agentic_system_session(
|
||||||
AgenticSystemSessionCreateRequest(
|
agent_id=create_response.agent_id,
|
||||||
system_id=create_response.system_id,
|
session_name="test_session",
|
||||||
session_name="test_session",
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
print(session_response)
|
print(session_response)
|
||||||
|
|
||||||
|
@ -162,7 +156,7 @@ async def run_main(host: str, port: int):
|
||||||
cprint(f"User> {content}", color="blue")
|
cprint(f"User> {content}", color="blue")
|
||||||
iterator = api.create_agentic_system_turn(
|
iterator = api.create_agentic_system_turn(
|
||||||
AgenticSystemTurnCreateRequest(
|
AgenticSystemTurnCreateRequest(
|
||||||
system_id=create_response.system_id,
|
agent_id=create_response.agent_id,
|
||||||
session_id=session_response.session_id,
|
session_id=session_response.session_id,
|
||||||
messages=[
|
messages=[
|
||||||
UserMessage(content=content),
|
UserMessage(content=content),
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
import copy
|
import copy
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||||
|
|
||||||
|
@ -326,7 +326,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
model=self.agent_config.model,
|
model=self.agent_config.model,
|
||||||
messages=input_messages,
|
messages=input_messages,
|
||||||
tools=self.agent_config.available_tools,
|
tools=self.agent_config.tools,
|
||||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
|
|
@ -24,7 +24,7 @@ from llama_toolchain.tools.builtin import (
|
||||||
)
|
)
|
||||||
from llama_toolchain.tools.safety import with_safety
|
from llama_toolchain.tools.safety import with_safety
|
||||||
|
|
||||||
from .agent_instance import AgentInstance, ChatAgent
|
from .agent_instance import ChatAgent
|
||||||
from .config import MetaReferenceImplConfig
|
from .config import MetaReferenceImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,11 +71,11 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemCreateRequest,
|
request: AgenticSystemCreateRequest,
|
||||||
) -> AgenticSystemCreateResponse:
|
) -> AgenticSystemCreateResponse:
|
||||||
system_id = str(uuid.uuid4())
|
agent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
builtin_tools = []
|
builtin_tools = []
|
||||||
cfg = request.agent_config
|
cfg = request.agent_config
|
||||||
for dfn in cfg.available_tools:
|
for dfn in cfg.tools:
|
||||||
if isinstance(dfn.tool_name, BuiltinTool):
|
if isinstance(dfn.tool_name, BuiltinTool):
|
||||||
if dfn.tool_name == BuiltinTool.wolfram_alpha:
|
if dfn.tool_name == BuiltinTool.wolfram_alpha:
|
||||||
key = self.config.wolfram_api_key
|
key = self.config.wolfram_api_key
|
||||||
|
@ -102,7 +102,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
AGENT_INSTANCES_BY_ID[system_id] = ChatAgent(
|
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent(
|
||||||
agent_config=cfg,
|
agent_config=cfg,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
|
@ -111,16 +111,16 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgenticSystemCreateResponse(
|
return AgenticSystemCreateResponse(
|
||||||
system_id=system_id,
|
agent_id=agent_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agentic_system_session(
|
async def create_agentic_system_session(
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemSessionCreateRequest,
|
request: AgenticSystemSessionCreateRequest,
|
||||||
) -> AgenticSystemSessionCreateResponse:
|
) -> AgenticSystemSessionCreateResponse:
|
||||||
system_id = request.system_id
|
agent_id = request.agent_id
|
||||||
assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found"
|
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
||||||
agent = AGENT_INSTANCES_BY_ID[system_id]
|
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
||||||
|
|
||||||
session = agent.create_session(request.session_name)
|
session = agent.create_session(request.session_name)
|
||||||
return AgenticSystemSessionCreateResponse(
|
return AgenticSystemSessionCreateResponse(
|
||||||
|
@ -131,9 +131,9 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
self,
|
self,
|
||||||
request: AgenticSystemTurnCreateRequest,
|
request: AgenticSystemTurnCreateRequest,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
system_id = request.system_id
|
agent_id = request.agent_id
|
||||||
assert system_id in AGENT_INSTANCES_BY_ID, f"System {system_id} not found"
|
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
||||||
agent = AGENT_INSTANCES_BY_ID[system_id]
|
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
request.session_id in agent.sessions
|
request.session_id in agent.sessions
|
||||||
|
|
|
@ -19,7 +19,7 @@ from llama_toolchain.inference.api import Message
|
||||||
|
|
||||||
async def execute_with_custom_tools(
|
async def execute_with_custom_tools(
|
||||||
system: AgenticSystem,
|
system: AgenticSystem,
|
||||||
system_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
custom_tools: List[Any],
|
custom_tools: List[Any],
|
||||||
|
@ -35,7 +35,7 @@ async def execute_with_custom_tools(
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
|
||||||
request = AgenticSystemTurnCreateRequest(
|
request = AgenticSystemTurnCreateRequest(
|
||||||
system_id=system_id,
|
agent_id=agent_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=current_messages,
|
messages=current_messages,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
|
|
@ -14,12 +14,7 @@ from llama_models.llama3.api.datatypes import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import (
|
from llama_toolchain.agentic_system.api import AgentConfig, AgenticSystemToolDefinition
|
||||||
AgentConfig,
|
|
||||||
AgenticSystemCreateRequest,
|
|
||||||
AgenticSystemSessionCreateRequest,
|
|
||||||
AgenticSystemToolDefinition,
|
|
||||||
)
|
|
||||||
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
from llama_toolchain.agentic_system.client import AgenticSystemClient
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import (
|
from llama_toolchain.agentic_system.meta_reference.execute_with_custom_tools import (
|
||||||
|
@ -32,9 +27,9 @@ from llama_toolchain.safety.api.datatypes import BuiltinShield, ShieldDefinition
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemClientWrapper:
|
class AgenticSystemClientWrapper:
|
||||||
def __init__(self, api, system_id, custom_tools):
|
def __init__(self, api, agent_id, custom_tools):
|
||||||
self.api = api
|
self.api = api
|
||||||
self.system_id = system_id
|
self.agent_id = agent_id
|
||||||
self.custom_tools = custom_tools
|
self.custom_tools = custom_tools
|
||||||
self.session_id = None
|
self.session_id = None
|
||||||
|
|
||||||
|
@ -43,10 +38,8 @@ class AgenticSystemClientWrapper:
|
||||||
name = f"Session-{uuid.uuid4()}"
|
name = f"Session-{uuid.uuid4()}"
|
||||||
|
|
||||||
response = await self.api.create_agentic_system_session(
|
response = await self.api.create_agentic_system_session(
|
||||||
AgenticSystemSessionCreateRequest(
|
agent_id=self.agent_id,
|
||||||
system_id=self.system_id,
|
session_name=name,
|
||||||
session_name=name,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.session_id = response.session_id
|
self.session_id = response.session_id
|
||||||
return self.session_id
|
return self.session_id
|
||||||
|
@ -54,7 +47,7 @@ class AgenticSystemClientWrapper:
|
||||||
async def run(self, messages: List[Message], stream: bool = True):
|
async def run(self, messages: List[Message], stream: bool = True):
|
||||||
async for chunk in execute_with_custom_tools(
|
async for chunk in execute_with_custom_tools(
|
||||||
self.api,
|
self.api,
|
||||||
self.system_id,
|
self.agent_id,
|
||||||
self.session_id,
|
self.session_id,
|
||||||
messages,
|
messages,
|
||||||
self.custom_tools,
|
self.custom_tools,
|
||||||
|
@ -98,29 +91,27 @@ async def get_agent_system_instance(
|
||||||
ShieldDefinition(shield_type=BuiltinShield.injection_shield),
|
ShieldDefinition(shield_type=BuiltinShield.injection_shield),
|
||||||
]
|
]
|
||||||
|
|
||||||
create_request = AgenticSystemCreateRequest(
|
agent_config = AgentConfig(
|
||||||
model=model,
|
model=model,
|
||||||
agent_config=AgentConfig(
|
instructions="You are a helpful assistant",
|
||||||
instructions="You are a helpful assistant",
|
tools=tool_definitions,
|
||||||
available_tools=tool_definitions,
|
input_shields=(
|
||||||
input_shields=(
|
[]
|
||||||
[]
|
if disable_safety
|
||||||
if disable_safety
|
else [
|
||||||
else [
|
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
||||||
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
ShieldDefinition(shield_type=BuiltinShield.jailbreak_shield),
|
||||||
ShieldDefinition(shield_type=BuiltinShield.jailbreak_shield),
|
]
|
||||||
]
|
|
||||||
),
|
|
||||||
output_shields=(
|
|
||||||
[]
|
|
||||||
if disable_safety
|
|
||||||
else [
|
|
||||||
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
sampling_params=SamplingParams(),
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
),
|
),
|
||||||
|
output_shields=(
|
||||||
|
[]
|
||||||
|
if disable_safety
|
||||||
|
else [
|
||||||
|
ShieldDefinition(shield_type=BuiltinShield.llama_guard),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
sampling_params=SamplingParams(),
|
||||||
|
tool_prompt_format=tool_prompt_format,
|
||||||
)
|
)
|
||||||
create_response = await api.create_agentic_system(create_request)
|
create_response = await api.create_agentic_system(agent_config)
|
||||||
return AgenticSystemClientWrapper(api, create_response.system_id, custom_tools)
|
return AgenticSystemClientWrapper(api, create_response.agent_id, custom_tools)
|
||||||
|
|
|
@ -3,23 +3,3 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryBank(BaseModel):
|
|
||||||
memory_bank_id: str
|
|
||||||
memory_bank_name: str
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class MemoryBankDocument(BaseModel):
|
|
||||||
document_id: str
|
|
||||||
content: bytes
|
|
||||||
metadata: Dict[str, Any]
|
|
||||||
mime_type: str
|
|
||||||
|
|
|
@ -6,76 +6,132 @@
|
||||||
|
|
||||||
from typing import List, Protocol
|
from typing import List, Protocol
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import InterleavedTextMedia
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_models.schema_utils import webmethod
|
from llama_models.schema_utils import webmethod
|
||||||
from .datatypes import * # noqa: F403
|
from .datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RetrieveMemoryDocumentsRequest(BaseModel):
|
class MemoryBankDocument(BaseModel):
|
||||||
query: InterleavedTextMedia
|
document_id: str
|
||||||
bank_ids: str
|
content: InterleavedTextMedia | URL
|
||||||
|
mime_type: str
|
||||||
|
metadata: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(BaseModel):
|
||||||
|
content: InterleavedTextMedia
|
||||||
|
token_count: int
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RetrieveMemoryDocumentsResponse(BaseModel):
|
class QueryDocumentsResponse(BaseModel):
|
||||||
documents: List[MemoryBankDocument]
|
chunks: List[Chunk]
|
||||||
scores: List[float]
|
scores: List[float]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MemoryBankType(Enum):
|
||||||
|
vector = "vector"
|
||||||
|
keyvalue = "keyvalue"
|
||||||
|
keyword = "keyword"
|
||||||
|
graph = "graph"
|
||||||
|
|
||||||
|
|
||||||
|
class VectorMemoryBankConfig(BaseModel):
|
||||||
|
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
|
embedding_model: str
|
||||||
|
|
||||||
|
|
||||||
|
class KeyValueMemoryBankConfig(BaseModel):
|
||||||
|
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordMemoryBankConfig(BaseModel):
|
||||||
|
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value
|
||||||
|
|
||||||
|
|
||||||
|
class GraphMemoryBankConfig(BaseModel):
|
||||||
|
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
|
||||||
|
|
||||||
|
|
||||||
|
MemoryBankConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
VectorMemoryBankConfig,
|
||||||
|
KeyValueMemoryBankConfig,
|
||||||
|
KeywordMemoryBankConfig,
|
||||||
|
GraphMemoryBankConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class MemoryBank(BaseModel):
|
||||||
|
bank_id: str
|
||||||
|
name: str
|
||||||
|
config: MemoryBankConfig
|
||||||
|
# if there's a pre-existing store which obeys the MemoryBank REST interface
|
||||||
|
url: Optional[URL] = None
|
||||||
|
|
||||||
|
|
||||||
class Memory(Protocol):
|
class Memory(Protocol):
|
||||||
@webmethod(route="/memory_banks/create")
|
@webmethod(route="/memory_banks/create")
|
||||||
def create_memory_bank(
|
def create_memory_bank(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
name: str,
|
||||||
bank_name: str,
|
config: MemoryBankConfig,
|
||||||
embedding_model: str,
|
url: Optional[URL] = None,
|
||||||
documents: List[MemoryBankDocument],
|
) -> MemoryBank: ...
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/list")
|
@webmethod(route="/memory_banks/list", method="GET")
|
||||||
def get_memory_banks(self) -> List[MemoryBank]: ...
|
def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/get")
|
@webmethod(route="/memory_banks/get")
|
||||||
def get_memory_bank(self, bank_id: str) -> List[MemoryBank]: ...
|
def get_memory_bank(self, bank_id: str) -> MemoryBank: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_banks/drop")
|
@webmethod(route="/memory_banks/drop", method="DELETE")
|
||||||
def delete_memory_bank(
|
def drop_memory_bank(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
) -> str: ...
|
) -> str: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/insert")
|
@webmethod(route="/memory_bank/insert")
|
||||||
def insert_memory_documents(
|
def insert_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/update")
|
@webmethod(route="/memory_bank/update")
|
||||||
def update_memory_documents(
|
def update_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/get")
|
@webmethod(route="/memory_bank/query")
|
||||||
def retrieve_memory_documents(
|
def query_documents(
|
||||||
self,
|
|
||||||
request: RetrieveMemoryDocumentsRequest,
|
|
||||||
) -> List[MemoryBankDocument]: ...
|
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/get")
|
|
||||||
def get_memory_documents(
|
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
document_uuids: List[str],
|
query: InterleavedTextMedia,
|
||||||
) -> List[MemoryBankDocument]: ...
|
params: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> QueryDocumentsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/delete")
|
@webmethod(route="/memory_bank/documents/get")
|
||||||
def delete_memory_documents(
|
def get_documents(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
document_uuids: List[str],
|
document_ids: List[str],
|
||||||
) -> List[str]: ...
|
) -> List[MemoryBankDocument]: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_bank/documents/delete")
|
||||||
|
def delete_documents(
|
||||||
|
self,
|
||||||
|
bank_id: str,
|
||||||
|
document_ids: List[str],
|
||||||
|
) -> None: ...
|
||||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue