mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
<WIP> memory changes
- inlined AgenticSystemInstanceConfig so API feels more ergonomic - renamed it to AgentConfig, AgentInstance -> Agent - added a MemoryConfig and `memory` parameter - added `attachments` to input and `output_attachments` to the response - some naming changes
This commit is contained in:
parent
5655266d58
commit
48c6a32edd
12 changed files with 149 additions and 163 deletions
|
@ -96,6 +96,8 @@ class Turn(BaseModel):
|
||||||
]
|
]
|
||||||
steps: List[Step]
|
steps: List[Step]
|
||||||
output_message: CompletionMessage
|
output_message: CompletionMessage
|
||||||
|
output_attachments: List[Attachment] = Field(default_factory=list)
|
||||||
|
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
completed_at: Optional[datetime] = None
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
@ -111,13 +113,22 @@ class Session(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemInstanceConfig(BaseModel):
|
class MemoryConfig(BaseModel):
|
||||||
instructions: str
|
memory_bank_id: str
|
||||||
|
|
||||||
|
# this configuration can hold other information we may want to configure
|
||||||
|
# how will the agent use the memory bank API?
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfigCommon(BaseModel):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
# zero-shot or built-in tool configurations as input to the model
|
# zero-shot or built-in tool configurations as input to the model
|
||||||
available_tools: Optional[List[AgenticSystemToolDefinition]] = Field(
|
available_tools: Optional[List[AgenticSystemToolDefinition]] = Field(
|
||||||
default_factory=list
|
default_factory=list
|
||||||
)
|
)
|
||||||
|
memory: Optional[List[MemoryConfig]] = Field(default_factory=list)
|
||||||
|
|
||||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||||
|
@ -130,6 +141,16 @@ class AgenticSystemInstanceConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class AgentConfig(AgentConfigCommon):
|
||||||
|
model: str
|
||||||
|
instructions: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
|
instructions: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemTurnResponseEventType(Enum):
|
class AgenticSystemTurnResponseEventType(Enum):
|
||||||
step_start = "step_start"
|
step_start = "step_start"
|
||||||
step_complete = "step_complete"
|
step_complete = "step_complete"
|
||||||
|
|
|
@ -7,18 +7,17 @@
|
||||||
from .datatypes import * # noqa: F403
|
from .datatypes import * # noqa: F403
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
# this dependency is annoying and we need a forked up version anyway
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemCreateRequest(BaseModel):
|
class AgenticSystemCreateRequest(BaseModel):
|
||||||
model: str
|
agent_config: AgentConfig
|
||||||
instance_config: AgenticSystemInstanceConfig
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgenticSystemCreateResponse(BaseModel):
|
class AgenticSystemCreateResponse(BaseModel):
|
||||||
|
# TODO: rename this to agent_id
|
||||||
system_id: str
|
system_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,20 +33,22 @@ class AgenticSystemSessionCreateResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
# what's the URI?
|
class AgenticSystemTurnCreateRequest(BaseModel, AgentConfigOverridablePerTurn):
|
||||||
class AgenticSystemTurnCreateRequest(BaseModel):
|
|
||||||
system_id: str
|
system_id: str
|
||||||
session_id: str
|
session_id: str
|
||||||
|
|
||||||
|
# TODO: figure out how we can simplify this and make why
|
||||||
|
# ToolResponseMessage needs to be here (it is function call
|
||||||
|
# execution from outside the system)
|
||||||
messages: List[
|
messages: List[
|
||||||
Union[
|
Union[
|
||||||
UserMessage,
|
UserMessage,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
attachments: List[Attachment]
|
||||||
|
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
override_config: Optional[AgenticSystemInstanceConfig] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type(
|
@json_schema_type(
|
||||||
|
@ -93,22 +94,6 @@ class AgenticSystem(Protocol):
|
||||||
request: AgenticSystemSessionCreateRequest,
|
request: AgenticSystemSessionCreateRequest,
|
||||||
) -> AgenticSystemSessionCreateResponse: ...
|
) -> AgenticSystemSessionCreateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/memory_bank/attach")
|
|
||||||
async def attach_memory_bank_to_agentic_system(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
memory_bank_ids: List[str],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/memory_bank/detach")
|
|
||||||
async def detach_memory_bank_from_agentic_system(
|
|
||||||
self,
|
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
|
||||||
memory_bank_ids: List[str],
|
|
||||||
) -> None: ...
|
|
||||||
|
|
||||||
@webmethod(route="/agentic_system/session/get")
|
@webmethod(route="/agentic_system/session/get")
|
||||||
async def get_agentic_system_session(
|
async def get_agentic_system_session(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -24,10 +24,10 @@ from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.event_logger import EventLogger
|
from llama_toolchain.agentic_system.event_logger import EventLogger
|
||||||
from .api import (
|
from .api import (
|
||||||
|
AgentConfig,
|
||||||
AgenticSystem,
|
AgenticSystem,
|
||||||
AgenticSystemCreateRequest,
|
AgenticSystemCreateRequest,
|
||||||
AgenticSystemCreateResponse,
|
AgenticSystemCreateResponse,
|
||||||
AgenticSystemInstanceConfig,
|
|
||||||
AgenticSystemSessionCreateRequest,
|
AgenticSystemSessionCreateRequest,
|
||||||
AgenticSystemSessionCreateResponse,
|
AgenticSystemSessionCreateResponse,
|
||||||
AgenticSystemToolDefinition,
|
AgenticSystemToolDefinition,
|
||||||
|
@ -129,7 +129,7 @@ async def run_main(host: str, port: int):
|
||||||
|
|
||||||
create_request = AgenticSystemCreateRequest(
|
create_request = AgenticSystemCreateRequest(
|
||||||
model="Meta-Llama3.1-8B-Instruct",
|
model="Meta-Llama3.1-8B-Instruct",
|
||||||
instance_config=AgenticSystemInstanceConfig(
|
agent_config=AgentConfig(
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
available_tools=tool_definitions,
|
available_tools=tool_definitions,
|
||||||
|
|
|
@ -5,4 +5,4 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from .agentic_system import get_provider_impl # noqa
|
from .agentic_system import get_provider_impl # noqa
|
||||||
from .config import AgenticSystemConfig # noqa
|
from .config import MetaReferenceImplConfig # noqa
|
||||||
|
|
|
@ -14,47 +14,10 @@ from llama_models.llama3.api.datatypes import ToolPromptFormat
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api.datatypes import (
|
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||||
AgenticSystemInstanceConfig,
|
from llama_toolchain.inference.api import * # noqa: F403
|
||||||
AgenticSystemTurnResponseEvent,
|
from llama_toolchain.memory.api import * # noqa: F403
|
||||||
AgenticSystemTurnResponseEventType,
|
from llama_toolchain.safety.api import * # noqa: F403
|
||||||
AgenticSystemTurnResponseStepCompletePayload,
|
|
||||||
AgenticSystemTurnResponseStepProgressPayload,
|
|
||||||
AgenticSystemTurnResponseStepStartPayload,
|
|
||||||
AgenticSystemTurnResponseTurnCompletePayload,
|
|
||||||
AgenticSystemTurnResponseTurnStartPayload,
|
|
||||||
InferenceStep,
|
|
||||||
Session,
|
|
||||||
ShieldCallStep,
|
|
||||||
StepType,
|
|
||||||
ToolExecutionStep,
|
|
||||||
Turn,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_toolchain.inference.api import ChatCompletionRequest, Inference
|
|
||||||
from llama_toolchain.inference.api.datatypes import (
|
|
||||||
Attachment,
|
|
||||||
BuiltinTool,
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
CompletionMessage,
|
|
||||||
Message,
|
|
||||||
Role,
|
|
||||||
SamplingParams,
|
|
||||||
StopReason,
|
|
||||||
ToolCallDelta,
|
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolResponse,
|
|
||||||
ToolResponseMessage,
|
|
||||||
URL,
|
|
||||||
)
|
|
||||||
from llama_toolchain.safety.api import Safety
|
|
||||||
from llama_toolchain.safety.api.datatypes import (
|
|
||||||
BuiltinShield,
|
|
||||||
ShieldDefinition,
|
|
||||||
ShieldResponse,
|
|
||||||
)
|
|
||||||
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
|
||||||
|
|
||||||
from llama_toolchain.tools.base import BaseTool
|
from llama_toolchain.tools.base import BaseTool
|
||||||
from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
|
from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
|
||||||
|
@ -62,27 +25,20 @@ from llama_toolchain.tools.builtin import SingleMessageBuiltinTool
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
|
|
||||||
|
|
||||||
class AgentInstance(ShieldRunnerMixin):
|
class ChatAgent(ShieldRunnerMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
system_id: int,
|
agent_config: AgentConfig,
|
||||||
instance_config: AgenticSystemInstanceConfig,
|
|
||||||
model: str,
|
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
|
memory_api: Memory,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
builtin_tools: List[SingleMessageBuiltinTool],
|
builtin_tools: List[SingleMessageBuiltinTool],
|
||||||
custom_tool_definitions: List[ToolDefinition],
|
custom_tool_definitions: List[ToolDefinition],
|
||||||
input_shields: List[ShieldDefinition],
|
|
||||||
output_shields: List[ShieldDefinition],
|
|
||||||
max_infer_iters: int = 10,
|
max_infer_iters: int = 10,
|
||||||
prefix_messages: Optional[List[Message]] = None,
|
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
|
||||||
):
|
):
|
||||||
self.system_id = system_id
|
self.agent_config = agent_config
|
||||||
self.instance_config = instance_config
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
self.memory_api = memory_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
|
|
||||||
self.max_infer_iters = max_infer_iters
|
self.max_infer_iters = max_infer_iters
|
||||||
|
@ -93,8 +49,8 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
ShieldRunnerMixin.__init__(
|
ShieldRunnerMixin.__init__(
|
||||||
self,
|
self,
|
||||||
safety_api,
|
safety_api,
|
||||||
input_shields=input_shields,
|
input_shields=agent_config.input_shields,
|
||||||
output_shields=output_shields,
|
output_shields=agent_config.output_shields,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_session(self, name: str) -> Session:
|
def create_session(self, name: str) -> Session:
|
||||||
|
@ -152,7 +108,7 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
# print_dialog(messages)
|
# print_dialog(messages)
|
||||||
|
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
params = self.instance_config.sampling_params
|
params = self.agent_config.sampling_params
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
yield AgenticSystemTurnResponseStreamChunk(
|
yield AgenticSystemTurnResponseStreamChunk(
|
||||||
event=AgenticSystemTurnResponseEvent(
|
event=AgenticSystemTurnResponseEvent(
|
||||||
|
@ -215,6 +171,53 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
turn_id: str,
|
||||||
|
input_messages: List[Message],
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
stream: bool = False,
|
||||||
|
max_gen_len: Optional[int] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||||
|
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||||
|
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||||
|
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||||
|
|
||||||
|
async for res in self.run_shields_wrapper(
|
||||||
|
turn_id, input_messages, self.input_shields, "user-input"
|
||||||
|
):
|
||||||
|
if isinstance(res, bool):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
yield res
|
||||||
|
|
||||||
|
async for res in self._run(
|
||||||
|
turn_id, input_messages, temperature, top_p, stream, max_gen_len
|
||||||
|
):
|
||||||
|
if isinstance(res, bool):
|
||||||
|
return
|
||||||
|
elif isinstance(res, CompletionMessage):
|
||||||
|
final_response = res
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
yield res
|
||||||
|
|
||||||
|
assert final_response is not None
|
||||||
|
# for output shields run on the full input and output combination
|
||||||
|
messages = input_messages + [final_response]
|
||||||
|
|
||||||
|
async for res in self.run_shields_wrapper(
|
||||||
|
turn_id, messages, self.output_shields, "assistant-output"
|
||||||
|
):
|
||||||
|
if isinstance(res, bool):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
yield res
|
||||||
|
|
||||||
|
yield final_response
|
||||||
|
|
||||||
async def run_shields_wrapper(
|
async def run_shields_wrapper(
|
||||||
self,
|
self,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
|
@ -276,52 +279,7 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(
|
async def _should_retrieve_context(self, messages: List[Message]) -> bool: ...
|
||||||
self,
|
|
||||||
turn_id: str,
|
|
||||||
input_messages: List[Message],
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
stream: bool = False,
|
|
||||||
max_gen_len: Optional[int] = None,
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
|
||||||
# stremaing. However, it also makes things complicated here because AsyncGenerators cannot
|
|
||||||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
|
||||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
|
||||||
|
|
||||||
async for res in self.run_shields_wrapper(
|
|
||||||
turn_id, input_messages, self.input_shields, "user-input"
|
|
||||||
):
|
|
||||||
if isinstance(res, bool):
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
yield res
|
|
||||||
|
|
||||||
async for res in self._run(
|
|
||||||
turn_id, input_messages, temperature, top_p, stream, max_gen_len
|
|
||||||
):
|
|
||||||
if isinstance(res, bool):
|
|
||||||
return
|
|
||||||
elif isinstance(res, CompletionMessage):
|
|
||||||
final_response = res
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
yield res
|
|
||||||
|
|
||||||
assert final_response is not None
|
|
||||||
# for output shields run on the full input and output combination
|
|
||||||
messages = input_messages + [final_response]
|
|
||||||
|
|
||||||
async for res in self.run_shields_wrapper(
|
|
||||||
turn_id, messages, self.output_shields, "assistant-output"
|
|
||||||
):
|
|
||||||
if isinstance(res, bool):
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
yield res
|
|
||||||
|
|
||||||
yield final_response
|
|
||||||
|
|
||||||
async def _run(
|
async def _run(
|
||||||
self,
|
self,
|
||||||
|
@ -332,6 +290,11 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: Optional[int] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
need_context = await self._should_retrieve_context(input_messages)
|
||||||
|
if need_context:
|
||||||
|
context = await self._retrieve_context(input_messages)
|
||||||
|
# input_messages = preprocess_dialog(input_messages, self.prefix_messages)
|
||||||
|
# input_messages = input_messages + context
|
||||||
input_messages = preprocess_dialog(input_messages)
|
input_messages = preprocess_dialog(input_messages)
|
||||||
|
|
||||||
attachments = []
|
attachments = []
|
||||||
|
@ -359,10 +322,10 @@ class AgentInstance(ShieldRunnerMixin):
|
||||||
|
|
||||||
# where are the available tools?
|
# where are the available tools?
|
||||||
req = ChatCompletionRequest(
|
req = ChatCompletionRequest(
|
||||||
model=self.model,
|
model=self.agent_config.model,
|
||||||
messages=input_messages,
|
messages=input_messages,
|
||||||
tools=self.instance_config.available_tools,
|
tools=self.agent_config.available_tools,
|
||||||
tool_prompt_format=self.instance_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
sampling_params=SamplingParams(
|
sampling_params=SamplingParams(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
|
|
@ -13,16 +13,11 @@ from typing import AsyncGenerator, Dict
|
||||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||||
from llama_toolchain.inference.api import Inference
|
from llama_toolchain.inference.api import Inference
|
||||||
from llama_toolchain.inference.api.datatypes import BuiltinTool
|
from llama_toolchain.inference.api.datatypes import BuiltinTool
|
||||||
|
from llama_toolchain.memory.api import Memory
|
||||||
from llama_toolchain.safety.api import Safety
|
from llama_toolchain.safety.api import Safety
|
||||||
from llama_toolchain.agentic_system.api.endpoints import * # noqa
|
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||||
from llama_toolchain.agentic_system.api import (
|
from .agent_instance import ChatAgent
|
||||||
AgenticSystem,
|
from .config import MetaReferenceImplConfig
|
||||||
AgenticSystemCreateRequest,
|
|
||||||
AgenticSystemCreateResponse,
|
|
||||||
AgenticSystemSessionCreateRequest,
|
|
||||||
AgenticSystemSessionCreateResponse,
|
|
||||||
AgenticSystemTurnCreateRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_toolchain.tools.builtin import (
|
from llama_toolchain.tools.builtin import (
|
||||||
BraveSearchTool,
|
BraveSearchTool,
|
||||||
|
@ -34,16 +29,16 @@ from llama_toolchain.tools.safety import with_safety
|
||||||
|
|
||||||
from .agent_instance import AgentInstance
|
from .agent_instance import AgentInstance
|
||||||
|
|
||||||
from .config import AgenticSystemConfig
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: AgenticSystemConfig, deps: Dict[Api, ProviderSpec]):
|
async def get_provider_impl(
|
||||||
|
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
|
||||||
|
):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, AgenticSystemConfig
|
config, MetaReferenceImplConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = MetaReferenceAgenticSystemImpl(
|
impl = MetaReferenceAgenticSystemImpl(
|
||||||
|
@ -60,11 +55,16 @@ AGENT_INSTANCES_BY_ID = {}
|
||||||
|
|
||||||
class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: AgenticSystemConfig, inference_api: Inference, safety_api: Safety
|
self,
|
||||||
|
config: MetaReferenceImplConfig,
|
||||||
|
inference_api: Inference,
|
||||||
|
safety_api: Safety,
|
||||||
|
memory_api: Memory,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
|
self.memory_api = memory_api
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -77,7 +77,7 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
|
|
||||||
builtin_tools = []
|
builtin_tools = []
|
||||||
custom_tool_definitions = []
|
custom_tool_definitions = []
|
||||||
cfg = request.instance_config
|
cfg = request.agent_config
|
||||||
for dfn in cfg.available_tools:
|
for dfn in cfg.available_tools:
|
||||||
if isinstance(dfn.tool_name, BuiltinTool):
|
if isinstance(dfn.tool_name, BuiltinTool):
|
||||||
if dfn.tool_name == BuiltinTool.wolfram_alpha:
|
if dfn.tool_name == BuiltinTool.wolfram_alpha:
|
||||||
|
@ -107,18 +107,13 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
||||||
else:
|
else:
|
||||||
custom_tool_definitions.append(dfn)
|
custom_tool_definitions.append(dfn)
|
||||||
|
|
||||||
AGENT_INSTANCES_BY_ID[system_id] = AgentInstance(
|
AGENT_INSTANCES_BY_ID[system_id] = ChatAgent(
|
||||||
system_id=system_id,
|
agent_config=cfg,
|
||||||
instance_config=request.instance_config,
|
|
||||||
model=request.model,
|
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
|
safety_api=self.safety_api,
|
||||||
|
memory_api=self.memory_api,
|
||||||
builtin_tools=builtin_tools,
|
builtin_tools=builtin_tools,
|
||||||
custom_tool_definitions=custom_tool_definitions,
|
custom_tool_definitions=custom_tool_definitions,
|
||||||
safety_api=self.safety_api,
|
|
||||||
input_shields=cfg.input_shields,
|
|
||||||
output_shields=cfg.output_shields,
|
|
||||||
prefix_messages=cfg.debug_prefix_messages,
|
|
||||||
tool_prompt_format=cfg.tool_prompt_format,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgenticSystemCreateResponse(
|
return AgenticSystemCreateResponse(
|
||||||
|
|
|
@ -9,6 +9,6 @@ from typing import Optional
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class AgenticSystemConfig(BaseModel):
|
class MetaReferenceImplConfig(BaseModel):
|
||||||
brave_search_api_key: Optional[str] = None
|
brave_search_api_key: Optional[str] = None
|
||||||
wolfram_api_key: Optional[str] = None
|
wolfram_api_key: Optional[str] = None
|
||||||
|
|
|
@ -21,10 +21,11 @@ def available_agentic_system_providers() -> List[ProviderSpec]:
|
||||||
"transformers",
|
"transformers",
|
||||||
],
|
],
|
||||||
module="llama_toolchain.agentic_system.meta_reference",
|
module="llama_toolchain.agentic_system.meta_reference",
|
||||||
config_class="llama_toolchain.agentic_system.meta_reference.AgenticSystemConfig",
|
config_class="llama_toolchain.agentic_system.meta_reference.MetaReferenceImplConfig",
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.inference,
|
Api.inference,
|
||||||
Api.safety,
|
Api.safety,
|
||||||
|
Api.memory,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -15,8 +15,8 @@ from llama_models.llama3.api.datatypes import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_toolchain.agentic_system.api import (
|
from llama_toolchain.agentic_system.api import (
|
||||||
|
AgentConfig,
|
||||||
AgenticSystemCreateRequest,
|
AgenticSystemCreateRequest,
|
||||||
AgenticSystemInstanceConfig,
|
|
||||||
AgenticSystemSessionCreateRequest,
|
AgenticSystemSessionCreateRequest,
|
||||||
AgenticSystemToolDefinition,
|
AgenticSystemToolDefinition,
|
||||||
)
|
)
|
||||||
|
@ -100,7 +100,7 @@ async def get_agent_system_instance(
|
||||||
|
|
||||||
create_request = AgenticSystemCreateRequest(
|
create_request = AgenticSystemCreateRequest(
|
||||||
model=model,
|
model=model,
|
||||||
instance_config=AgenticSystemInstanceConfig(
|
agent_config=AgentConfig(
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
available_tools=tool_definitions,
|
available_tools=tool_definitions,
|
||||||
input_shields=(
|
input_shields=(
|
||||||
|
|
|
@ -17,6 +17,7 @@ class Api(Enum):
|
||||||
inference = "inference"
|
inference = "inference"
|
||||||
safety = "safety"
|
safety = "safety"
|
||||||
agentic_system = "agentic_system"
|
agentic_system = "agentic_system"
|
||||||
|
memory = "memory"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -6,17 +6,31 @@
|
||||||
|
|
||||||
from typing import List, Protocol
|
from typing import List, Protocol
|
||||||
|
|
||||||
from llama_models.schema_utils import webmethod
|
from llama_models.llama3_1.api.datatypes import InterleavedTextAttachment
|
||||||
|
|
||||||
|
from llama_models.schema_utils import webmethod
|
||||||
from .datatypes import * # noqa: F403
|
from .datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
class MemoryBanks(Protocol):
|
@json_schema_type
|
||||||
|
class RetrieveMemoryDocumentsRequest(BaseModel):
|
||||||
|
query: InterleavedTextAttachment
|
||||||
|
bank_ids: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RetrieveMemoryDocumentsResponse(BaseModel):
|
||||||
|
documents: List[MemoryBankDocument]
|
||||||
|
scores: List[float]
|
||||||
|
|
||||||
|
|
||||||
|
class Memory(Protocol):
|
||||||
@webmethod(route="/memory_banks/create")
|
@webmethod(route="/memory_banks/create")
|
||||||
def create_memory_bank(
|
def create_memory_bank(
|
||||||
self,
|
self,
|
||||||
bank_id: str,
|
bank_id: str,
|
||||||
bank_name: str,
|
bank_name: str,
|
||||||
|
embedding_model: str,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
|
@ -46,6 +60,12 @@ class MemoryBanks(Protocol):
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
|
@webmethod(route="/memory_bank/get")
|
||||||
|
def retrieve_memory_documents(
|
||||||
|
self,
|
||||||
|
request: RetrieveMemoryDocumentsRequest,
|
||||||
|
) -> List[MemoryBankDocument]: ...
|
||||||
|
|
||||||
@webmethod(route="/memory_bank/get")
|
@webmethod(route="/memory_bank/get")
|
||||||
def get_memory_documents(
|
def get_memory_documents(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -24,7 +24,7 @@ class LlamaStack(
|
||||||
Datasets,
|
Datasets,
|
||||||
Observability,
|
Observability,
|
||||||
PostTraining,
|
PostTraining,
|
||||||
MemoryBanks,
|
Memory,
|
||||||
Evaluations,
|
Evaluations,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue