mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 00:12:24 +00:00
address feedback
This commit is contained in:
parent
ee542a7373
commit
16d1f66f55
9 changed files with 286 additions and 149 deletions
|
|
@ -13,7 +13,7 @@ import secrets
|
|||
import string
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
|
@ -21,6 +21,8 @@ from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDe
|
|||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
AgentTool,
|
||||
AgentToolWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResponseEvent,
|
||||
AgentTurnResponseEventType,
|
||||
|
|
@ -188,6 +190,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
tools_for_turn=request.tools,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
|
|
@ -237,6 +240,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages: List[Message],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
|
@ -253,7 +257,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
session_id, turn_id, input_messages, sampling_params, stream
|
||||
session_id, turn_id, input_messages, sampling_params, stream, tools_for_turn
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -348,82 +352,90 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
input_messages: List[Message],
|
||||
sampling_params: SamplingParams,
|
||||
stream: bool = False,
|
||||
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||
) -> AsyncGenerator:
|
||||
if self.agent_config.preprocessing_tools:
|
||||
with tracing.span("preprocessing_tools") as span:
|
||||
for tool_name in self.agent_config.preprocessing_tools:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
args = dict(
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=input_messages,
|
||||
)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call_delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
content=ToolCall(
|
||||
call_id="", tool_name=tool_name, arguments={}
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
)
|
||||
tool_args = {}
|
||||
if tools_for_turn:
|
||||
for tool in tools_for_turn:
|
||||
if isinstance(tool, AgentToolWithArgs):
|
||||
tool_args[tool.name] = tool.args
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="",
|
||||
tool_name=tool_name,
|
||||
arguments={},
|
||||
)
|
||||
],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id="",
|
||||
tool_name=tool_name,
|
||||
content=result.content,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
tool_defs = await self._get_tool_defs(tools_for_turn)
|
||||
if "memory" in tool_defs and len(input_messages) > 0:
|
||||
with tracing.span("memory_tool") as span:
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
extra_args = tool_args.get("memory", {})
|
||||
args = {
|
||||
# Query memory with the last message's content
|
||||
"query": input_messages[-1],
|
||||
**extra_args,
|
||||
}
|
||||
serialized_args = tracing.serialize_value(args)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
tool_call_delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.success,
|
||||
content=ToolCall(
|
||||
call_id="",
|
||||
tool_name="memory",
|
||||
arguments=serialized_args,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", result.content)
|
||||
span.set_attribute("error_code", result.error_code)
|
||||
span.set_attribute("error_message", result.error_message)
|
||||
if isinstance(tool_name, BuiltinTool):
|
||||
span.set_attribute("tool_name", tool_name.value)
|
||||
else:
|
||||
span.set_attribute("tool_name", tool_name)
|
||||
if result.error_code == 0:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = result.content
|
||||
)
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
tool_name="memory",
|
||||
args=args,
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_id=step_id,
|
||||
step_details=ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="",
|
||||
tool_name="memory",
|
||||
arguments={},
|
||||
)
|
||||
],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id="",
|
||||
tool_name="memory",
|
||||
content=result.content,
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute("output", result.content)
|
||||
span.set_attribute("error_code", result.error_code)
|
||||
span.set_attribute("error_message", result.error_message)
|
||||
span.set_attribute("tool_name", "memory")
|
||||
if result.error_code == 0:
|
||||
last_message = input_messages[-1]
|
||||
last_message.context = result.content
|
||||
|
||||
output_attachments = []
|
||||
|
||||
|
|
@ -451,7 +463,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
tools=await self._get_tools(),
|
||||
tools=[
|
||||
tool
|
||||
for tool in tool_defs.values()
|
||||
if tool.tool_name != "memory"
|
||||
],
|
||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||
stream=True,
|
||||
sampling_params=sampling_params,
|
||||
|
|
@ -654,44 +670,66 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
n_iter += 1
|
||||
|
||||
async def _get_tools(self) -> List[ToolDefinition]:
|
||||
ret = []
|
||||
for tool in self.agent_config.client_tools:
|
||||
params = {}
|
||||
for param in tool.parameters:
|
||||
params[param.name] = ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
ret.append(
|
||||
ToolDefinition(
|
||||
tool_name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=params,
|
||||
)
|
||||
async def _get_tool_defs(
|
||||
self, tools_for_turn: Optional[List[AgentTool]]
|
||||
) -> Dict[str, ToolDefinition]:
|
||||
# Determine which tools to include
|
||||
agent_config_tools = set(
|
||||
tool.name if isinstance(tool, AgentToolWithArgs) else tool
|
||||
for tool in self.agent_config.tools
|
||||
)
|
||||
tools_for_turn_set = (
|
||||
agent_config_tools
|
||||
if tools_for_turn is None
|
||||
else {
|
||||
tool.name if isinstance(tool, AgentToolWithArgs) else tool
|
||||
for tool in tools_for_turn
|
||||
}
|
||||
)
|
||||
|
||||
ret = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
ret[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
for tool_name in self.agent_config.tool_names:
|
||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
||||
if tool.built_in_type:
|
||||
ret.append(ToolDefinition(tool_name=tool.built_in_type))
|
||||
|
||||
for tool_name in agent_config_tools:
|
||||
if tool_name not in tools_for_turn_set:
|
||||
continue
|
||||
params = {}
|
||||
for param in tool.parameters:
|
||||
params[param.name] = ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
ret.append(
|
||||
ToolDefinition(
|
||||
tool_name=tool.identifier,
|
||||
description=tool.description,
|
||||
parameters=params,
|
||||
|
||||
tool_def = await self.tool_groups_api.get_tool(tool_name)
|
||||
|
||||
if tool_def.built_in_type:
|
||||
ret[tool_def.built_in_type] = ToolDefinition(
|
||||
tool_name=tool_def.built_in_type
|
||||
)
|
||||
continue
|
||||
|
||||
ret[tool_def.identifier] = ToolDefinition(
|
||||
tool_name=tool_def.identifier,
|
||||
description=tool_def.description,
|
||||
parameters={
|
||||
param.name: ToolParamDefinition(
|
||||
param_type=param.parameter_type,
|
||||
description=param.description,
|
||||
required=param.required,
|
||||
default=param.default,
|
||||
)
|
||||
for param in tool_def.parameters
|
||||
},
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from llama_stack.apis.agents import (
|
|||
Agents,
|
||||
AgentSessionCreateResponse,
|
||||
AgentStepResponse,
|
||||
AgentTool,
|
||||
AgentTurnCreateRequest,
|
||||
Session,
|
||||
Turn,
|
||||
|
|
@ -145,6 +146,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
ToolResponseMessage,
|
||||
]
|
||||
],
|
||||
tools: Optional[List[AgentTool]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
|
|
@ -152,6 +154,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
session_id=session_id,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
tools=tools,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
|
|
|
|||
|
|
@ -54,14 +54,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
return []
|
||||
|
||||
async def _retrieve_context(
|
||||
self, messages: List[Message], bank_ids: List[str]
|
||||
self, message: Message, bank_ids: List[str]
|
||||
) -> Optional[List[InterleavedContent]]:
|
||||
if not bank_ids:
|
||||
return None
|
||||
if len(messages) == 0:
|
||||
return None
|
||||
|
||||
message = messages[-1] # only use the last message as input to the query
|
||||
query = await generate_rag_query(
|
||||
self.config.query_generator_config,
|
||||
message,
|
||||
|
|
@ -113,10 +109,15 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
|||
config = MemoryToolConfig()
|
||||
if tool.metadata.get("config") is not None:
|
||||
config = MemoryToolConfig(**tool.metadata["config"])
|
||||
|
||||
if "memory_bank_id" in args:
|
||||
bank_ids = [args["memory_bank_id"]]
|
||||
else:
|
||||
bank_ids = [
|
||||
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||
]
|
||||
context = await self._retrieve_context(
|
||||
args["input_messages"],
|
||||
[bank_config.bank_id for bank_config in config.memory_bank_configs],
|
||||
args["query"],
|
||||
bank_ids,
|
||||
)
|
||||
if context is None:
|
||||
context = []
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue