mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
InterleavedTextAttachment -> InterleavedTextMedia, introduce memory tool
This commit is contained in:
parent
48c6a32edd
commit
31289e3f47
5 changed files with 56 additions and 40 deletions
|
@ -10,7 +10,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
|||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_toolchain.common.deployment_types import * # noqa: F403
|
||||
|
@ -20,7 +20,31 @@ from llama_toolchain.memory.api.datatypes import * # noqa: F403
|
|||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemToolDefinition(ToolDefinition):
|
||||
class Attachment(BaseModel):
|
||||
url: URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class AgenticSystemBuiltinTool(BuiltinTool):
|
||||
memory = "memory"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgenticSystemToolDefinition(BaseModel):
|
||||
tool_name: Union[AgenticSystemBuiltinTool, str]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||
|
||||
@validator("tool_name", pre=True)
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return AgenticSystemBuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
execution_config: Optional[RestAPIExecutionConfig] = None
|
||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
|
@ -124,11 +148,7 @@ class MemoryConfig(BaseModel):
|
|||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||
# zero-shot or built-in tool configurations as input to the model
|
||||
available_tools: Optional[List[AgenticSystemToolDefinition]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
memory: Optional[List[MemoryConfig]] = Field(default_factory=list)
|
||||
memory_configs: Optional[List[MemoryConfig]] = Field(default_factory=list)
|
||||
|
||||
input_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
output_shields: Optional[List[ShieldDefinition]] = Field(default_factory=list)
|
||||
|
@ -136,6 +156,9 @@ class AgentConfigCommon(BaseModel):
|
|||
# if you completely want to replace the messages prefixed by the system,
|
||||
# this is debug only
|
||||
debug_prefix_messages: Optional[List[Message]] = Field(default_factory=list)
|
||||
|
||||
tools: Optional[List[AgenticSystemToolDefinition]] = Field(default_factory=list)
|
||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||
default=ToolPromptFormat.json
|
||||
)
|
||||
|
|
|
@ -46,7 +46,7 @@ class AgenticSystemTurnCreateRequest(BaseModel, AgentConfigOverridablePerTurn):
|
|||
ToolResponseMessage,
|
||||
]
|
||||
]
|
||||
attachments: List[Attachment]
|
||||
attachments: Optional[List[Attachment]] = None
|
||||
|
||||
stream: Optional[bool] = False
|
||||
|
||||
|
|
|
@ -33,7 +33,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
memory_api: Memory,
|
||||
safety_api: Safety,
|
||||
builtin_tools: List[SingleMessageBuiltinTool],
|
||||
custom_tool_definitions: List[ToolDefinition],
|
||||
max_infer_iters: int = 10,
|
||||
):
|
||||
self.agent_config = agent_config
|
||||
|
@ -108,7 +107,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# print_dialog(messages)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
params = self.agent_config.sampling_params
|
||||
start_time = datetime.now()
|
||||
yield AgenticSystemTurnResponseStreamChunk(
|
||||
event=AgenticSystemTurnResponseEvent(
|
||||
|
@ -123,10 +121,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async for chunk in self.run(
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
max_gen_len=params.max_tokens,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
cprint(
|
||||
|
@ -175,10 +172,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
attachments: List[Attachment],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
max_gen_len: Optional[int] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
@ -194,7 +190,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
turn_id, input_messages, temperature, top_p, stream, max_gen_len
|
||||
turn_id, input_messages, attachments, sampling_params, stream
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
@ -279,20 +275,27 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
async def _should_retrieve_context(self, messages: List[Message]) -> bool: ...
|
||||
async def _should_retrieve_context(
|
||||
self, messages: List[Message], attachments: List[Attachment]
|
||||
) -> bool:
|
||||
return self.agent_config.memory_configs or len(attachments) > 0
|
||||
|
||||
async def _retrieve_context(
|
||||
self, messages: List[Message], attachments: List[Attachment]
|
||||
) -> List[Message]:
|
||||
return []
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
attachments: List[Attachment],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
max_gen_len: Optional[int] = None,
|
||||
) -> AsyncGenerator:
|
||||
need_context = await self._should_retrieve_context(input_messages)
|
||||
need_context = await self._should_retrieve_context(input_messages, attachments)
|
||||
if need_context:
|
||||
context = await self._retrieve_context(input_messages)
|
||||
context_messages = await self._retrieve_context(input_messages)
|
||||
# input_messages = preprocess_dialog(input_messages, self.prefix_messages)
|
||||
# input_messages = input_messages + context
|
||||
input_messages = preprocess_dialog(input_messages)
|
||||
|
@ -320,18 +323,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
# where are the available tools?
|
||||
req = ChatCompletionRequest(
|
||||
model=self.agent_config.model,
|
||||
messages=input_messages,
|
||||
tools=self.agent_config.available_tools,
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
sampling_params=SamplingParams(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_gen_len,
|
||||
),
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
|
@ -554,6 +552,7 @@ def attachment_message(url: URL) -> ToolResponseMessage:
|
|||
|
||||
|
||||
def preprocess_dialog(messages: List[Message]) -> List[Message]:
|
||||
# remove system message since those are
|
||||
"""
|
||||
Preprocesses the dialog by removing the system message and
|
||||
adding the system message to the beginning of the dialog.
|
||||
|
@ -565,7 +564,7 @@ def preprocess_dialog(messages: List[Message]) -> List[Message]:
|
|||
continue
|
||||
|
||||
# NOTE: the ideal behavior is to use `file_path = ...` but that
|
||||
# means we need to have stateful execution o f code which we currently
|
||||
# means we need to have stateful execution of code which we currently
|
||||
# do not have.
|
||||
if isinstance(m.content, Attachment):
|
||||
ret.append(attachment_message(m.content.url))
|
||||
|
|
|
@ -16,9 +16,6 @@ from llama_toolchain.inference.api.datatypes import BuiltinTool
|
|||
from llama_toolchain.memory.api import Memory
|
||||
from llama_toolchain.safety.api import Safety
|
||||
from llama_toolchain.agentic_system.api import * # noqa: F403
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
from llama_toolchain.tools.builtin import (
|
||||
BraveSearchTool,
|
||||
CodeInterpreterTool,
|
||||
|
@ -27,7 +24,8 @@ from llama_toolchain.tools.builtin import (
|
|||
)
|
||||
from llama_toolchain.tools.safety import with_safety
|
||||
|
||||
from .agent_instance import AgentInstance
|
||||
from .agent_instance import AgentInstance, ChatAgent
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
@ -76,7 +74,6 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
|||
system_id = str(uuid.uuid4())
|
||||
|
||||
builtin_tools = []
|
||||
custom_tool_definitions = []
|
||||
cfg = request.agent_config
|
||||
for dfn in cfg.available_tools:
|
||||
if isinstance(dfn.tool_name, BuiltinTool):
|
||||
|
@ -104,8 +101,6 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
|||
tool, self.safety_api, dfn.input_shields, dfn.output_shields
|
||||
)
|
||||
)
|
||||
else:
|
||||
custom_tool_definitions.append(dfn)
|
||||
|
||||
AGENT_INSTANCES_BY_ID[system_id] = ChatAgent(
|
||||
agent_config=cfg,
|
||||
|
@ -113,7 +108,6 @@ class MetaReferenceAgenticSystemImpl(AgenticSystem):
|
|||
safety_api=self.safety_api,
|
||||
memory_api=self.memory_api,
|
||||
builtin_tools=builtin_tools,
|
||||
custom_tool_definitions=custom_tool_definitions,
|
||||
)
|
||||
|
||||
return AgenticSystemCreateResponse(
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import List, Protocol
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import InterleavedTextAttachment
|
||||
from llama_models.llama3.api.datatypes import InterleavedTextMedia
|
||||
|
||||
from llama_models.schema_utils import webmethod
|
||||
from .datatypes import * # noqa: F403
|
||||
|
@ -14,7 +14,7 @@ from .datatypes import * # noqa: F403
|
|||
|
||||
@json_schema_type
|
||||
class RetrieveMemoryDocumentsRequest(BaseModel):
|
||||
query: InterleavedTextAttachment
|
||||
query: InterleavedTextMedia
|
||||
bank_ids: str
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue