mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
change context to content
This commit is contained in:
parent
d8a8a734b5
commit
e84d3966c4
1 changed files with 101 additions and 37 deletions
|
@ -38,10 +38,10 @@ from llama_stack.apis.agents import (
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
|
URL,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -57,11 +57,7 @@ from llama_stack.apis.inference import (
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||||
ToolGroups,
|
|
||||||
ToolInvocationResult,
|
|
||||||
ToolRuntime,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
|
@ -77,7 +73,9 @@ from .safety import SafetyException, ShieldRunnerMixin
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
def make_random_string(length: int = 8):
|
||||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
return "".join(
|
||||||
|
secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||||
|
@ -176,7 +174,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
messages.extend(self.turn_to_messages(turn))
|
messages.extend(self.turn_to_messages(turn))
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator:
|
async def create_and_execute_turn(
|
||||||
|
self, request: AgentTurnCreateRequest
|
||||||
|
) -> AsyncGenerator:
|
||||||
span = tracing.get_current_span()
|
span = tracing.get_current_span()
|
||||||
if span:
|
if span:
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
|
@ -221,13 +221,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
messages = await self.get_messages_from_turns(turns)
|
messages = await self.get_messages_from_turns(turns)
|
||||||
if is_resume:
|
if is_resume:
|
||||||
tool_response_messages = [
|
tool_response_messages = [
|
||||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
ToolResponseMessage(call_id=x.call_id, content=x.content)
|
||||||
|
for x in request.tool_responses
|
||||||
]
|
]
|
||||||
messages.extend(tool_response_messages)
|
messages.extend(tool_response_messages)
|
||||||
last_turn = turns[-1]
|
last_turn = turns[-1]
|
||||||
last_turn_messages = self.turn_to_messages(last_turn)
|
last_turn_messages = self.turn_to_messages(last_turn)
|
||||||
last_turn_messages = [
|
last_turn_messages = [
|
||||||
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
x
|
||||||
|
for x in last_turn_messages
|
||||||
|
if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||||
]
|
]
|
||||||
last_turn_messages.extend(tool_response_messages)
|
last_turn_messages.extend(tool_response_messages)
|
||||||
|
|
||||||
|
@ -237,17 +240,31 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# mark tool execution step as complete
|
# mark tool execution step as complete
|
||||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||||
# we'll create a new tool execution step with current time
|
# we'll create a new tool execution step with current time
|
||||||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
in_progress_tool_call_step = (
|
||||||
request.session_id, request.turn_id
|
await self.storage.get_in_progress_tool_call_step(
|
||||||
|
request.session_id, request.turn_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
now = datetime.now(timezone.utc).isoformat()
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
tool_execution_step = ToolExecutionStep(
|
tool_execution_step = ToolExecutionStep(
|
||||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
step_id=(
|
||||||
|
in_progress_tool_call_step.step_id
|
||||||
|
if in_progress_tool_call_step
|
||||||
|
else str(uuid.uuid4())
|
||||||
|
),
|
||||||
turn_id=request.turn_id,
|
turn_id=request.turn_id,
|
||||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
tool_calls=(
|
||||||
|
in_progress_tool_call_step.tool_calls
|
||||||
|
if in_progress_tool_call_step
|
||||||
|
else []
|
||||||
|
),
|
||||||
tool_responses=request.tool_responses,
|
tool_responses=request.tool_responses,
|
||||||
completed_at=now,
|
completed_at=now,
|
||||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
started_at=(
|
||||||
|
in_progress_tool_call_step.started_at
|
||||||
|
if in_progress_tool_call_step
|
||||||
|
else now
|
||||||
|
),
|
||||||
)
|
)
|
||||||
steps.append(tool_execution_step)
|
steps.append(tool_execution_step)
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
@ -281,9 +298,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
output_message = chunk
|
output_message = chunk
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
assert isinstance(
|
||||||
|
chunk, AgentTurnResponseStreamChunk
|
||||||
|
), f"Unexpected type {type(chunk)}"
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
if (
|
||||||
|
event.payload.event_type
|
||||||
|
== AgentTurnResponseEventType.step_complete.value
|
||||||
|
):
|
||||||
steps.append(event.payload.step_details)
|
steps.append(event.payload.step_details)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
@ -459,7 +481,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
contexts.append(raw_document_text)
|
contexts.append(raw_document_text)
|
||||||
|
|
||||||
attached_context = "\n".join(contexts)
|
attached_context = "\n".join(contexts)
|
||||||
input_messages[-1].context = attached_context
|
input_messages[-1].content += attached_context
|
||||||
|
|
||||||
session_info = await self.storage.get_session_info(session_id)
|
session_info = await self.storage.get_session_info(session_id)
|
||||||
# if the session has a memory bank id, let the memory tool use it
|
# if the session has a memory bank id, let the memory tool use it
|
||||||
|
@ -467,13 +489,19 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for tool_name in self.tool_name_to_args.keys():
|
for tool_name in self.tool_name_to_args.keys():
|
||||||
if tool_name == MEMORY_QUERY_TOOL:
|
if tool_name == MEMORY_QUERY_TOOL:
|
||||||
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
|
if "vector_db_ids" not in self.tool_name_to_args[tool_name]:
|
||||||
self.tool_name_to_args[tool_name]["vector_db_ids"] = [session_info.vector_db_id]
|
self.tool_name_to_args[tool_name]["vector_db_ids"] = [
|
||||||
|
session_info.vector_db_id
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
self.tool_name_to_args[tool_name]["vector_db_ids"].append(session_info.vector_db_id)
|
self.tool_name_to_args[tool_name]["vector_db_ids"].append(
|
||||||
|
session_info.vector_db_id
|
||||||
|
)
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
n_iter = (
|
||||||
|
await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||||
|
)
|
||||||
|
|
||||||
# Build a map of custom tools to their definitions for faster lookup
|
# Build a map of custom tools to their definitions for faster lookup
|
||||||
client_tools = {}
|
client_tools = {}
|
||||||
|
@ -551,12 +579,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
span.set_attribute("stop_reason", stop_reason)
|
span.set_attribute("stop_reason", stop_reason)
|
||||||
span.set_attribute(
|
span.set_attribute(
|
||||||
"input",
|
"input",
|
||||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
json.dumps(
|
||||||
|
[json.loads(m.model_dump_json()) for m in input_messages]
|
||||||
|
),
|
||||||
)
|
)
|
||||||
output_attr = json.dumps(
|
output_attr = json.dumps(
|
||||||
{
|
{
|
||||||
"content": content,
|
"content": content,
|
||||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
"tool_calls": [
|
||||||
|
json.loads(t.model_dump_json()) for t in tool_calls
|
||||||
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
span.set_attribute("output", output_attr)
|
span.set_attribute("output", output_attr)
|
||||||
|
@ -620,7 +652,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
message.content = [message.content] + output_attachments
|
message.content = [message.content] + output_attachments
|
||||||
yield message
|
yield message
|
||||||
else:
|
else:
|
||||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
logger.debug(
|
||||||
|
f"completion message with EOM (iter: {n_iter}): {str(message)}"
|
||||||
|
)
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
else:
|
else:
|
||||||
input_messages = input_messages + [message]
|
input_messages = input_messages + [message]
|
||||||
|
@ -669,7 +703,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
"input": message.model_dump_json(),
|
"input": message.model_dump_json(),
|
||||||
},
|
},
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
|
tool_execution_start_time = datetime.now(
|
||||||
|
timezone.utc
|
||||||
|
).isoformat()
|
||||||
tool_result = await self.execute_tool_call_maybe(
|
tool_result = await self.execute_tool_call_maybe(
|
||||||
session_id,
|
session_id,
|
||||||
tool_call,
|
tool_call,
|
||||||
|
@ -718,7 +754,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||||
# but that needs a lot more refactoring of Tool code potentially
|
# but that needs a lot more refactoring of Tool code potentially
|
||||||
if (type(result_message.content) is str) and (
|
if (type(result_message.content) is str) and (
|
||||||
out_attachment := _interpret_content_as_attachment(result_message.content)
|
out_attachment := _interpret_content_as_attachment(
|
||||||
|
result_message.content
|
||||||
|
)
|
||||||
):
|
):
|
||||||
# NOTE: when we push this message back to the model, the model may ignore the
|
# NOTE: when we push this message back to the model, the model may ignore the
|
||||||
# attached file path etc. since the model is trained to only provide a user message
|
# attached file path etc. since the model is trained to only provide a user message
|
||||||
|
@ -755,16 +793,24 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
toolgroup_to_args = {}
|
toolgroup_to_args = {}
|
||||||
for toolgroup in (self.agent_config.toolgroups or []) + (toolgroups_for_turn or []):
|
for toolgroup in (self.agent_config.toolgroups or []) + (
|
||||||
|
toolgroups_for_turn or []
|
||||||
|
):
|
||||||
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
if isinstance(toolgroup, AgentToolGroupWithArgs):
|
||||||
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
tool_group_name, _ = self._parse_toolgroup_name(toolgroup.name)
|
||||||
toolgroup_to_args[tool_group_name] = toolgroup.args
|
toolgroup_to_args[tool_group_name] = toolgroup.args
|
||||||
|
|
||||||
# Determine which tools to include
|
# Determine which tools to include
|
||||||
tool_groups_to_include = toolgroups_for_turn or self.agent_config.toolgroups or []
|
tool_groups_to_include = (
|
||||||
|
toolgroups_for_turn or self.agent_config.toolgroups or []
|
||||||
|
)
|
||||||
agent_config_toolgroups = []
|
agent_config_toolgroups = []
|
||||||
for toolgroup in tool_groups_to_include:
|
for toolgroup in tool_groups_to_include:
|
||||||
name = toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
|
name = (
|
||||||
|
toolgroup.name
|
||||||
|
if isinstance(toolgroup, AgentToolGroupWithArgs)
|
||||||
|
else toolgroup
|
||||||
|
)
|
||||||
if name not in agent_config_toolgroups:
|
if name not in agent_config_toolgroups:
|
||||||
agent_config_toolgroups.append(name)
|
agent_config_toolgroups.append(name)
|
||||||
|
|
||||||
|
@ -790,20 +836,32 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
toolgroup_name, input_tool_name = self._parse_toolgroup_name(
|
||||||
|
toolgroup_name_with_maybe_tool_name
|
||||||
|
)
|
||||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||||
if not tools.data:
|
if not tools.data:
|
||||||
available_tool_groups = ", ".join(
|
available_tool_groups = ", ".join(
|
||||||
[t.identifier for t in (await self.tool_groups_api.list_tool_groups()).data]
|
[
|
||||||
|
t.identifier
|
||||||
|
for t in (await self.tool_groups_api.list_tool_groups()).data
|
||||||
|
]
|
||||||
)
|
)
|
||||||
raise ValueError(f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}")
|
raise ValueError(
|
||||||
if input_tool_name is not None and not any(tool.identifier == input_tool_name for tool in tools.data):
|
f"Toolgroup {toolgroup_name} not found, available toolgroups: {available_tool_groups}"
|
||||||
|
)
|
||||||
|
if input_tool_name is not None and not any(
|
||||||
|
tool.identifier == input_tool_name for tool in tools.data
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
f"Tool {input_tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}"
|
||||||
)
|
)
|
||||||
|
|
||||||
for tool_def in tools.data:
|
for tool_def in tools.data:
|
||||||
if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP:
|
if (
|
||||||
|
toolgroup_name.startswith("builtin")
|
||||||
|
and toolgroup_name != RAG_TOOL_GROUP
|
||||||
|
):
|
||||||
identifier: str | BuiltinTool | None = tool_def.identifier
|
identifier: str | BuiltinTool | None = tool_def.identifier
|
||||||
if identifier == "web_search":
|
if identifier == "web_search":
|
||||||
identifier = BuiltinTool.brave_search
|
identifier = BuiltinTool.brave_search
|
||||||
|
@ -832,14 +890,18 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
for param in tool_def.parameters
|
for param in tool_def.parameters
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
tool_name_to_args[tool_def.identifier] = toolgroup_to_args.get(
|
||||||
|
toolgroup_name, {}
|
||||||
|
)
|
||||||
|
|
||||||
self.tool_defs, self.tool_name_to_args = (
|
self.tool_defs, self.tool_name_to_args = (
|
||||||
list(tool_name_to_def.values()),
|
list(tool_name_to_def.values()),
|
||||||
tool_name_to_args,
|
tool_name_to_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]:
|
def _parse_toolgroup_name(
|
||||||
|
self, toolgroup_name_with_maybe_tool_name: str
|
||||||
|
) -> tuple[str, Optional[str]]:
|
||||||
"""Parse a toolgroup name into its components.
|
"""Parse a toolgroup name into its components.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -875,7 +937,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
tool_name_str = tool_name
|
tool_name_str = tool_name
|
||||||
|
|
||||||
logger.info(f"executing tool call: {tool_name_str} with args: {tool_call.arguments}")
|
logger.info(
|
||||||
|
f"executing tool call: {tool_name_str} with args: {tool_call.arguments}"
|
||||||
|
)
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
tool_name=tool_name_str,
|
tool_name=tool_name_str,
|
||||||
kwargs={
|
kwargs={
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue