mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
address feedback
This commit is contained in:
parent
ee542a7373
commit
16d1f66f55
9 changed files with 286 additions and 149 deletions
|
@ -3705,10 +3705,10 @@
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tool_names": {
|
"tools": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string"
|
"$ref": "#/components/schemas/AgentTool"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"client_tools": {
|
"client_tools": {
|
||||||
|
@ -3717,12 +3717,6 @@
|
||||||
"$ref": "#/components/schemas/UserDefinedToolDef"
|
"$ref": "#/components/schemas/UserDefinedToolDef"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"preprocessing_tools": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "string"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"tool_choice": {
|
"tool_choice": {
|
||||||
"$ref": "#/components/schemas/ToolChoice",
|
"$ref": "#/components/schemas/ToolChoice",
|
||||||
"default": "auto"
|
"default": "auto"
|
||||||
|
@ -3753,6 +3747,51 @@
|
||||||
"enable_session_persistence"
|
"enable_session_persistence"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"AgentTool": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"args": {
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"name",
|
||||||
|
"args"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"ToolParameter": {
|
"ToolParameter": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
@ -3934,6 +3973,12 @@
|
||||||
},
|
},
|
||||||
"stream": {
|
"stream": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
|
},
|
||||||
|
"tools": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/AgentTool"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -7944,6 +7989,10 @@
|
||||||
"name": "AgentStepResponse",
|
"name": "AgentStepResponse",
|
||||||
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/AgentStepResponse\" />"
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/AgentStepResponse\" />"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "AgentTool",
|
||||||
|
"description": "<SchemaDefinition schemaRef=\"#/components/schemas/AgentTool\" />"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "AgentTurnResponseEvent",
|
"name": "AgentTurnResponseEvent",
|
||||||
"description": "Streamed agent execution response.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/AgentTurnResponseEvent\" />"
|
"description": "Streamed agent execution response.\n\n<SchemaDefinition schemaRef=\"#/components/schemas/AgentTurnResponseEvent\" />"
|
||||||
|
@ -8691,6 +8740,7 @@
|
||||||
"AgentCreateResponse",
|
"AgentCreateResponse",
|
||||||
"AgentSessionCreateResponse",
|
"AgentSessionCreateResponse",
|
||||||
"AgentStepResponse",
|
"AgentStepResponse",
|
||||||
|
"AgentTool",
|
||||||
"AgentTurnResponseEvent",
|
"AgentTurnResponseEvent",
|
||||||
"AgentTurnResponseStepCompletePayload",
|
"AgentTurnResponseStepCompletePayload",
|
||||||
"AgentTurnResponseStepProgressPayload",
|
"AgentTurnResponseStepProgressPayload",
|
||||||
|
|
|
@ -38,22 +38,18 @@ components:
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
type: array
|
type: array
|
||||||
preprocessing_tools:
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
type: array
|
|
||||||
sampling_params:
|
sampling_params:
|
||||||
$ref: '#/components/schemas/SamplingParams'
|
$ref: '#/components/schemas/SamplingParams'
|
||||||
tool_choice:
|
tool_choice:
|
||||||
$ref: '#/components/schemas/ToolChoice'
|
$ref: '#/components/schemas/ToolChoice'
|
||||||
default: auto
|
default: auto
|
||||||
tool_names:
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
type: array
|
|
||||||
tool_prompt_format:
|
tool_prompt_format:
|
||||||
$ref: '#/components/schemas/ToolPromptFormat'
|
$ref: '#/components/schemas/ToolPromptFormat'
|
||||||
default: json
|
default: json
|
||||||
|
tools:
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/AgentTool'
|
||||||
|
type: array
|
||||||
required:
|
required:
|
||||||
- max_infer_iters
|
- max_infer_iters
|
||||||
- model
|
- model
|
||||||
|
@ -88,6 +84,27 @@ components:
|
||||||
required:
|
required:
|
||||||
- step
|
- step
|
||||||
type: object
|
type: object
|
||||||
|
AgentTool:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- additionalProperties: false
|
||||||
|
properties:
|
||||||
|
args:
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
type: object
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
required:
|
||||||
|
- name
|
||||||
|
- args
|
||||||
|
type: object
|
||||||
AgentTurnResponseEvent:
|
AgentTurnResponseEvent:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
properties:
|
properties:
|
||||||
|
@ -611,6 +628,10 @@ components:
|
||||||
type: string
|
type: string
|
||||||
stream:
|
stream:
|
||||||
type: boolean
|
type: boolean
|
||||||
|
tools:
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/AgentTool'
|
||||||
|
type: array
|
||||||
required:
|
required:
|
||||||
- agent_id
|
- agent_id
|
||||||
- session_id
|
- session_id
|
||||||
|
@ -4726,6 +4747,8 @@ tags:
|
||||||
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentStepResponse"
|
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentStepResponse"
|
||||||
/>
|
/>
|
||||||
name: AgentStepResponse
|
name: AgentStepResponse
|
||||||
|
- description: <SchemaDefinition schemaRef="#/components/schemas/AgentTool" />
|
||||||
|
name: AgentTool
|
||||||
- description: 'Streamed agent execution response.
|
- description: 'Streamed agent execution response.
|
||||||
|
|
||||||
|
|
||||||
|
@ -5257,6 +5280,7 @@ x-tagGroups:
|
||||||
- AgentCreateResponse
|
- AgentCreateResponse
|
||||||
- AgentSessionCreateResponse
|
- AgentSessionCreateResponse
|
||||||
- AgentStepResponse
|
- AgentStepResponse
|
||||||
|
- AgentTool
|
||||||
- AgentTurnResponseEvent
|
- AgentTurnResponseEvent
|
||||||
- AgentTurnResponseStepCompletePayload
|
- AgentTurnResponseStepCompletePayload
|
||||||
- AgentTurnResponseStepProgressPayload
|
- AgentTurnResponseStepProgressPayload
|
||||||
|
|
|
@ -18,7 +18,7 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
@ -132,14 +132,27 @@ class Session(BaseModel):
|
||||||
memory_bank: Optional[MemoryBank] = None
|
memory_bank: Optional[MemoryBank] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentToolWithArgs(BaseModel):
|
||||||
|
name: str
|
||||||
|
args: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
AgentTool = register_schema(
|
||||||
|
Union[
|
||||||
|
str,
|
||||||
|
AgentToolWithArgs,
|
||||||
|
],
|
||||||
|
name="AgentTool",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigCommon(BaseModel):
|
class AgentConfigCommon(BaseModel):
|
||||||
sampling_params: Optional[SamplingParams] = SamplingParams()
|
sampling_params: Optional[SamplingParams] = SamplingParams()
|
||||||
|
|
||||||
input_shields: Optional[List[str]] = Field(default_factory=list)
|
input_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
output_shields: Optional[List[str]] = Field(default_factory=list)
|
output_shields: Optional[List[str]] = Field(default_factory=list)
|
||||||
tool_names: Optional[List[str]] = Field(default_factory=list)
|
tools: Optional[List[AgentTool]] = Field(default_factory=list)
|
||||||
client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list)
|
client_tools: Optional[List[UserDefinedToolDef]] = Field(default_factory=list)
|
||||||
preprocessing_tools: Optional[List[str]] = Field(default_factory=list)
|
|
||||||
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto)
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
tool_prompt_format: Optional[ToolPromptFormat] = Field(
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
|
@ -295,6 +308,7 @@ class Agents(Protocol):
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
|
tools: Optional[List[AgentTool]] = None,
|
||||||
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
|
||||||
|
|
||||||
@webmethod(route="/agents/turn/get")
|
@webmethod(route="/agents/turn/get")
|
||||||
|
|
|
@ -13,7 +13,7 @@ import secrets
|
||||||
import string
|
import string
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -21,6 +21,8 @@ from llama_models.llama3.api.datatypes import BuiltinTool, ToolCall, ToolParamDe
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
AgentConfig,
|
AgentConfig,
|
||||||
|
AgentTool,
|
||||||
|
AgentToolWithArgs,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
AgentTurnResponseEvent,
|
AgentTurnResponseEvent,
|
||||||
AgentTurnResponseEventType,
|
AgentTurnResponseEventType,
|
||||||
|
@ -188,6 +190,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
sampling_params=self.agent_config.sampling_params,
|
sampling_params=self.agent_config.sampling_params,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
|
tools_for_turn=request.tools,
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
log.info(
|
log.info(
|
||||||
|
@ -237,6 +240,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||||
|
@ -253,7 +257,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
async for res in self._run(
|
async for res in self._run(
|
||||||
session_id, turn_id, input_messages, sampling_params, stream
|
session_id, turn_id, input_messages, sampling_params, stream, tools_for_turn
|
||||||
):
|
):
|
||||||
if isinstance(res, bool):
|
if isinstance(res, bool):
|
||||||
return
|
return
|
||||||
|
@ -348,82 +352,90 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
input_messages: List[Message],
|
input_messages: List[Message],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
tools_for_turn: Optional[List[AgentTool]] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
if self.agent_config.preprocessing_tools:
|
tool_args = {}
|
||||||
with tracing.span("preprocessing_tools") as span:
|
if tools_for_turn:
|
||||||
for tool_name in self.agent_config.preprocessing_tools:
|
for tool in tools_for_turn:
|
||||||
step_id = str(uuid.uuid4())
|
if isinstance(tool, AgentToolWithArgs):
|
||||||
yield AgentTurnResponseStreamChunk(
|
tool_args[tool.name] = tool.args
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
|
||||||
step_type=StepType.tool_execution.value,
|
|
||||||
step_id=step_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
args = dict(
|
|
||||||
session_id=session_id,
|
|
||||||
turn_id=turn_id,
|
|
||||||
input_messages=input_messages,
|
|
||||||
)
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
|
||||||
event=AgentTurnResponseEvent(
|
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
|
||||||
step_type=StepType.tool_execution.value,
|
|
||||||
step_id=step_id,
|
|
||||||
tool_call_delta=ToolCallDelta(
|
|
||||||
parse_status=ToolCallParseStatus.success,
|
|
||||||
content=ToolCall(
|
|
||||||
call_id="", tool_name=tool_name, arguments={}
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
|
||||||
tool_name=tool_name,
|
|
||||||
args=args,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
tool_defs = await self._get_tool_defs(tools_for_turn)
|
||||||
event=AgentTurnResponseEvent(
|
if "memory" in tool_defs and len(input_messages) > 0:
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
with tracing.span("memory_tool") as span:
|
||||||
step_type=StepType.tool_execution.value,
|
step_id = str(uuid.uuid4())
|
||||||
step_id=step_id,
|
yield AgentTurnResponseStreamChunk(
|
||||||
step_details=ToolExecutionStep(
|
event=AgentTurnResponseEvent(
|
||||||
step_id=step_id,
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
turn_id=turn_id,
|
step_type=StepType.tool_execution.value,
|
||||||
tool_calls=[
|
step_id=step_id,
|
||||||
ToolCall(
|
|
||||||
call_id="",
|
|
||||||
tool_name=tool_name,
|
|
||||||
arguments={},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
tool_responses=[
|
|
||||||
ToolResponse(
|
|
||||||
call_id="",
|
|
||||||
tool_name=tool_name,
|
|
||||||
content=result.content,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
span.set_attribute(
|
)
|
||||||
"input", [m.model_dump_json() for m in input_messages]
|
extra_args = tool_args.get("memory", {})
|
||||||
|
args = {
|
||||||
|
# Query memory with the last message's content
|
||||||
|
"query": input_messages[-1],
|
||||||
|
**extra_args,
|
||||||
|
}
|
||||||
|
serialized_args = tracing.serialize_value(args)
|
||||||
|
yield AgentTurnResponseStreamChunk(
|
||||||
|
event=AgentTurnResponseEvent(
|
||||||
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=step_id,
|
||||||
|
tool_call_delta=ToolCallDelta(
|
||||||
|
parse_status=ToolCallParseStatus.success,
|
||||||
|
content=ToolCall(
|
||||||
|
call_id="",
|
||||||
|
tool_name="memory",
|
||||||
|
arguments=serialized_args,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
span.set_attribute("output", result.content)
|
)
|
||||||
span.set_attribute("error_code", result.error_code)
|
result = await self.tool_runtime_api.invoke_tool(
|
||||||
span.set_attribute("error_message", result.error_message)
|
tool_name="memory",
|
||||||
if isinstance(tool_name, BuiltinTool):
|
args=args,
|
||||||
span.set_attribute("tool_name", tool_name.value)
|
)
|
||||||
else:
|
|
||||||
span.set_attribute("tool_name", tool_name)
|
yield AgentTurnResponseStreamChunk(
|
||||||
if result.error_code == 0:
|
event=AgentTurnResponseEvent(
|
||||||
last_message = input_messages[-1]
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
last_message.context = result.content
|
step_type=StepType.tool_execution.value,
|
||||||
|
step_id=step_id,
|
||||||
|
step_details=ToolExecutionStep(
|
||||||
|
step_id=step_id,
|
||||||
|
turn_id=turn_id,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="",
|
||||||
|
tool_name="memory",
|
||||||
|
arguments={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
tool_responses=[
|
||||||
|
ToolResponse(
|
||||||
|
call_id="",
|
||||||
|
tool_name="memory",
|
||||||
|
content=result.content,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
span.set_attribute(
|
||||||
|
"input", [m.model_dump_json() for m in input_messages]
|
||||||
|
)
|
||||||
|
span.set_attribute("output", result.content)
|
||||||
|
span.set_attribute("error_code", result.error_code)
|
||||||
|
span.set_attribute("error_message", result.error_message)
|
||||||
|
span.set_attribute("tool_name", "memory")
|
||||||
|
if result.error_code == 0:
|
||||||
|
last_message = input_messages[-1]
|
||||||
|
last_message.context = result.content
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments = []
|
||||||
|
|
||||||
|
@ -451,7 +463,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
tools=await self._get_tools(),
|
tools=[
|
||||||
|
tool
|
||||||
|
for tool in tool_defs.values()
|
||||||
|
if tool.tool_name != "memory"
|
||||||
|
],
|
||||||
tool_prompt_format=self.agent_config.tool_prompt_format,
|
tool_prompt_format=self.agent_config.tool_prompt_format,
|
||||||
stream=True,
|
stream=True,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
@ -654,44 +670,66 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
n_iter += 1
|
n_iter += 1
|
||||||
|
|
||||||
async def _get_tools(self) -> List[ToolDefinition]:
|
async def _get_tool_defs(
|
||||||
ret = []
|
self, tools_for_turn: Optional[List[AgentTool]]
|
||||||
for tool in self.agent_config.client_tools:
|
) -> Dict[str, ToolDefinition]:
|
||||||
params = {}
|
# Determine which tools to include
|
||||||
for param in tool.parameters:
|
agent_config_tools = set(
|
||||||
params[param.name] = ToolParamDefinition(
|
tool.name if isinstance(tool, AgentToolWithArgs) else tool
|
||||||
param_type=param.parameter_type,
|
for tool in self.agent_config.tools
|
||||||
description=param.description,
|
)
|
||||||
required=param.required,
|
tools_for_turn_set = (
|
||||||
default=param.default,
|
agent_config_tools
|
||||||
)
|
if tools_for_turn is None
|
||||||
ret.append(
|
else {
|
||||||
ToolDefinition(
|
tool.name if isinstance(tool, AgentToolWithArgs) else tool
|
||||||
tool_name=tool.name,
|
for tool in tools_for_turn
|
||||||
description=tool.description,
|
}
|
||||||
parameters=params,
|
)
|
||||||
)
|
|
||||||
|
ret = {}
|
||||||
|
|
||||||
|
for tool_def in self.agent_config.client_tools:
|
||||||
|
ret[tool_def.name] = ToolDefinition(
|
||||||
|
tool_name=tool_def.name,
|
||||||
|
description=tool_def.description,
|
||||||
|
parameters={
|
||||||
|
param.name: ToolParamDefinition(
|
||||||
|
param_type=param.parameter_type,
|
||||||
|
description=param.description,
|
||||||
|
required=param.required,
|
||||||
|
default=param.default,
|
||||||
|
)
|
||||||
|
for param in tool_def.parameters
|
||||||
|
},
|
||||||
)
|
)
|
||||||
for tool_name in self.agent_config.tool_names:
|
|
||||||
tool = await self.tool_groups_api.get_tool(tool_name)
|
for tool_name in agent_config_tools:
|
||||||
if tool.built_in_type:
|
if tool_name not in tools_for_turn_set:
|
||||||
ret.append(ToolDefinition(tool_name=tool.built_in_type))
|
|
||||||
continue
|
continue
|
||||||
params = {}
|
|
||||||
for param in tool.parameters:
|
tool_def = await self.tool_groups_api.get_tool(tool_name)
|
||||||
params[param.name] = ToolParamDefinition(
|
|
||||||
param_type=param.parameter_type,
|
if tool_def.built_in_type:
|
||||||
description=param.description,
|
ret[tool_def.built_in_type] = ToolDefinition(
|
||||||
required=param.required,
|
tool_name=tool_def.built_in_type
|
||||||
default=param.default,
|
|
||||||
)
|
|
||||||
ret.append(
|
|
||||||
ToolDefinition(
|
|
||||||
tool_name=tool.identifier,
|
|
||||||
description=tool.description,
|
|
||||||
parameters=params,
|
|
||||||
)
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
ret[tool_def.identifier] = ToolDefinition(
|
||||||
|
tool_name=tool_def.identifier,
|
||||||
|
description=tool_def.description,
|
||||||
|
parameters={
|
||||||
|
param.name: ToolParamDefinition(
|
||||||
|
param_type=param.parameter_type,
|
||||||
|
description=param.description,
|
||||||
|
required=param.required,
|
||||||
|
default=param.default,
|
||||||
|
)
|
||||||
|
for param in tool_def.parameters
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.apis.agents import (
|
||||||
Agents,
|
Agents,
|
||||||
AgentSessionCreateResponse,
|
AgentSessionCreateResponse,
|
||||||
AgentStepResponse,
|
AgentStepResponse,
|
||||||
|
AgentTool,
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
|
@ -145,6 +146,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
]
|
]
|
||||||
],
|
],
|
||||||
|
tools: Optional[List[AgentTool]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
|
@ -152,6 +154,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
tools=tools,
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
return self._create_agent_turn_streaming(request)
|
return self._create_agent_turn_streaming(request)
|
||||||
|
|
|
@ -54,14 +54,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def _retrieve_context(
|
async def _retrieve_context(
|
||||||
self, messages: List[Message], bank_ids: List[str]
|
self, message: Message, bank_ids: List[str]
|
||||||
) -> Optional[List[InterleavedContent]]:
|
) -> Optional[List[InterleavedContent]]:
|
||||||
if not bank_ids:
|
if not bank_ids:
|
||||||
return None
|
return None
|
||||||
if len(messages) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
message = messages[-1] # only use the last message as input to the query
|
|
||||||
query = await generate_rag_query(
|
query = await generate_rag_query(
|
||||||
self.config.query_generator_config,
|
self.config.query_generator_config,
|
||||||
message,
|
message,
|
||||||
|
@ -113,10 +109,15 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
|
||||||
config = MemoryToolConfig()
|
config = MemoryToolConfig()
|
||||||
if tool.metadata.get("config") is not None:
|
if tool.metadata.get("config") is not None:
|
||||||
config = MemoryToolConfig(**tool.metadata["config"])
|
config = MemoryToolConfig(**tool.metadata["config"])
|
||||||
|
if "memory_bank_id" in args:
|
||||||
|
bank_ids = [args["memory_bank_id"]]
|
||||||
|
else:
|
||||||
|
bank_ids = [
|
||||||
|
bank_config.bank_id for bank_config in config.memory_bank_configs
|
||||||
|
]
|
||||||
context = await self._retrieve_context(
|
context = await self._retrieve_context(
|
||||||
args["input_messages"],
|
args["query"],
|
||||||
[bank_config.bank_id for bank_config in config.memory_bank_configs],
|
bank_ids,
|
||||||
)
|
)
|
||||||
if context is None:
|
if context is None:
|
||||||
context = []
|
context = []
|
||||||
|
|
|
@ -7,11 +7,8 @@
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from together import Together
|
from together import Together
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
@ -53,7 +50,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
|
||||||
MODEL_ALIASES = [
|
MODEL_ALIASES = [
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||||
|
|
|
@ -104,7 +104,7 @@ async def create_agent_turn_with_search_tool(
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
**{
|
**{
|
||||||
**common_params,
|
**common_params,
|
||||||
"tool_names": [tool_name],
|
"tools": [tool_name],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -268,7 +268,7 @@ class TestAgents:
|
||||||
agent_config = AgentConfig(
|
agent_config = AgentConfig(
|
||||||
**{
|
**{
|
||||||
**common_params,
|
**common_params,
|
||||||
"preprocessing_tools": ["memory"],
|
"tools": ["memory"],
|
||||||
"tool_choice": ToolChoice.auto,
|
"tool_choice": ToolChoice.auto,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from typing import Dict, List
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client.lib.agents.agent import Agent
|
from llama_stack_client.lib.agents.agent import Agent, AugmentConfigWithMemoryTool
|
||||||
from llama_stack_client.lib.agents.client_tool import ClientTool
|
from llama_stack_client.lib.agents.client_tool import ClientTool
|
||||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
from llama_stack_client.types import ToolResponseMessage
|
from llama_stack_client.types import ToolResponseMessage
|
||||||
|
@ -151,11 +151,10 @@ def test_agent_simple(llama_stack_client, agent_config):
|
||||||
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"tool_names": [
|
"tools": [
|
||||||
"brave_search",
|
"brave_search",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
print(f"Agent Config: {agent_config}")
|
|
||||||
agent = Agent(llama_stack_client, agent_config)
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
|
@ -181,7 +180,7 @@ def test_builtin_tool_brave_search(llama_stack_client, agent_config):
|
||||||
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
def test_builtin_tool_code_execution(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"tool_names": [
|
"tools": [
|
||||||
"code_interpreter",
|
"code_interpreter",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
@ -209,7 +208,7 @@ def test_custom_tool(llama_stack_client, agent_config):
|
||||||
agent_config = {
|
agent_config = {
|
||||||
**agent_config,
|
**agent_config,
|
||||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
"tool_names": ["brave_search"],
|
"tools": ["brave_search"],
|
||||||
"client_tools": [client_tool.get_tool_definition()],
|
"client_tools": [client_tool.get_tool_definition()],
|
||||||
"tool_prompt_format": "python_list",
|
"tool_prompt_format": "python_list",
|
||||||
}
|
}
|
||||||
|
@ -252,8 +251,12 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
for i, url in enumerate(urls)
|
for i, url in enumerate(urls)
|
||||||
]
|
]
|
||||||
|
|
||||||
agent = Agent.with_memory(llama_stack_client, agent_config)
|
memory_bank_id = AugmentConfigWithMemoryTool(agent_config, llama_stack_client)
|
||||||
[agent.add_document(document) for document in documents]
|
agent = Agent(llama_stack_client, agent_config)
|
||||||
|
llama_stack_client.memory.insert(
|
||||||
|
bank_id=memory_bank_id,
|
||||||
|
documents=documents,
|
||||||
|
)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
|
@ -271,8 +274,16 @@ def test_rag_agent(llama_stack_client, agent_config):
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"name": "memory",
|
||||||
|
"args": {
|
||||||
|
"memory_bank_id": memory_bank_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
logs = [str(log) for log in EventLogger().log(response) if log is not None]
|
||||||
logs_str = "".join(logs)
|
logs_str = "".join(logs)
|
||||||
assert "Tool:memory-tool" in logs_str
|
assert "Tool:memory" in logs_str
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue