forked from phoenix-oss/llama-stack-mirror
feat: implementation for agent/session list and describe (#1606)
Create a new agent: ``` curl --request POST \ --url http://localhost:8321/v1/agents \ --header 'Accept: application/json' \ --header 'Content-Type: application/json' \ --data '{ "agent_config": { "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 0, "repetition_penalty": 1 }, "input_shields": [ "string" ], "output_shields": [ "string" ], "toolgroups": [ "string" ], "client_tools": [ { "name": "string", "description": "string", "parameters": [ { "name": "string", "parameter_type": "string", "description": "string", "required": true, "default": null } ], "metadata": { "property1": null, "property2": null } } ], "tool_choice": "auto", "tool_prompt_format": "json", "tool_config": { "tool_choice": "auto", "tool_prompt_format": "json", "system_message_behavior": "append" }, "max_infer_iters": 10, "model": "string", "instructions": "string", "enable_session_persistence": false, "response_format": { "type": "json_schema", "json_schema": { "property1": null, "property2": null } } } }' ``` Get agent: ``` curl http://127.0.0.1:8321/v1/agents/9abad4ab-2c77-45f9-9d16-46b79d2bea1f {"agent_id":"9abad4ab-2c77-45f9-9d16-46b79d2bea1f","agent_config":{"sampling_params":{"strategy":{"type":"greedy"},"max_tokens":0,"repetition_penalty":1.0},"input_shields":["string"],"output_shields":["string"],"toolgroups":["string"],"client_tools":[{"name":"string","description":"string","parameters":[{"name":"string","parameter_type":"string","description":"string","required":true,"default":null}],"metadata":{"property1":null,"property2":null}}],"tool_choice":"auto","tool_prompt_format":"json","tool_config":{"tool_choice":"auto","tool_prompt_format":"json","system_message_behavior":"append"},"max_infer_iters":10,"model":"string","instructions":"string","enable_session_persistence":false,"response_format":{"type":"json_schema","json_schema":{"property1":null,"property2":null}}},"created_at":"2025-03-12T16:18:28.369144Z"}% ``` List agents: ``` curl http://127.0.0.1:8321/v1/agents|jq % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 1680 100 1680 0 0 498k 0 --:--:-- --:--:-- --:--:-- 546k { "data": [ { "agent_id": "9abad4ab-2c77-45f9-9d16-46b79d2bea1f", "agent_config": { "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 0, "repetition_penalty": 1.0 }, "input_shields": [ "string" ], "output_shields": [ "string" ], "toolgroups": [ "string" ], "client_tools": [ { "name": "string", "description": "string", "parameters": [ { "name": "string", "parameter_type": "string", "description": "string", "required": true, "default": null } ], "metadata": { "property1": null, "property2": null } } ], "tool_choice": "auto", "tool_prompt_format": "json", "tool_config": { "tool_choice": "auto", "tool_prompt_format": "json", "system_message_behavior": "append" }, "max_infer_iters": 10, "model": "string", "instructions": "string", "enable_session_persistence": false, "response_format": { "type": "json_schema", "json_schema": { "property1": null, "property2": null } } }, "created_at": "2025-03-12T16:18:28.369144Z" }, { "agent_id": "a6643aaa-96dd-46db-a405-333dc504b168", "agent_config": { "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 0, "repetition_penalty": 1.0 }, "input_shields": [ "string" ], "output_shields": [ "string" ], "toolgroups": [ "string" ], "client_tools": [ { "name": "string", "description": "string", "parameters": [ { "name": "string", "parameter_type": "string", "description": "string", "required": true, "default": null } ], "metadata": { "property1": null, "property2": null } } ], "tool_choice": "auto", "tool_prompt_format": "json", "tool_config": { "tool_choice": "auto", "tool_prompt_format": "json", "system_message_behavior": "append" }, "max_infer_iters": 10, "model": "string", "instructions": "string", "enable_session_persistence": false, "response_format": { "type": "json_schema", "json_schema": { "property1": null, "property2": null } } }, "created_at": "2025-03-12T16:17:12.811273Z" } ] } ``` Create sessions: ``` curl --request POST \ --url http://localhost:8321/v1/agents/{agent_id}/session \ --header 'Accept: application/json' \ --header 'Content-Type: application/json' \ --data '{ "session_name": "string" }' ``` List sessions: ``` curl http://127.0.0.1:8321/v1/agents/9abad4ab-2c77-45f9-9d16-46b79d2bea1f/sessions|jq % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 263 100 263 0 0 90099 0 --:--:-- --:--:-- --:--:-- 128k [ { "session_id": "2b15c4fc-e348-46c1-ae32-f6d424441ac1", "session_name": "string", "turns": [], "started_at": "2025-03-12T17:19:17.784328" }, { "session_id": "9432472d-d483-4b73-b682-7b1d35d64111", "session_name": "string", "turns": [], "started_at": "2025-03-12T17:19:19.885834" } ] ``` Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
40e71758d9
commit
c91e3552a3
19 changed files with 510 additions and 130 deletions
83
docs/_static/llama-stack-spec.html
vendored
83
docs/_static/llama-stack-spec.html
vendored
|
@ -307,11 +307,11 @@
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
"description": "A ListAgentsResponse.",
|
"description": "A PaginatedResponse.",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/components/schemas/ListAgentsResponse"
|
"$ref": "#/components/schemas/PaginatedResponse"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -333,7 +333,26 @@
|
||||||
"Agents"
|
"Agents"
|
||||||
],
|
],
|
||||||
"description": "List all agents.",
|
"description": "List all agents.",
|
||||||
"parameters": []
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "start_index",
|
||||||
|
"in": "query",
|
||||||
|
"description": "The index to start the pagination from.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "integer"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "limit",
|
||||||
|
"in": "query",
|
||||||
|
"description": "The number of agents to return.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "integer"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
@ -691,7 +710,7 @@
|
||||||
"tags": [
|
"tags": [
|
||||||
"Agents"
|
"Agents"
|
||||||
],
|
],
|
||||||
"description": "Delete an agent by its ID.",
|
"description": "Delete an agent by its ID and its associated sessions and turns.",
|
||||||
"parameters": [
|
"parameters": [
|
||||||
{
|
{
|
||||||
"name": "agent_id",
|
"name": "agent_id",
|
||||||
|
@ -789,7 +808,7 @@
|
||||||
"tags": [
|
"tags": [
|
||||||
"Agents"
|
"Agents"
|
||||||
],
|
],
|
||||||
"description": "Delete an agent session by its ID.",
|
"description": "Delete an agent session by its ID and its associated turns.",
|
||||||
"parameters": [
|
"parameters": [
|
||||||
{
|
{
|
||||||
"name": "session_id",
|
"name": "session_id",
|
||||||
|
@ -2410,11 +2429,11 @@
|
||||||
"get": {
|
"get": {
|
||||||
"responses": {
|
"responses": {
|
||||||
"200": {
|
"200": {
|
||||||
"description": "A ListAgentSessionsResponse.",
|
"description": "A PaginatedResponse.",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/components/schemas/ListAgentSessionsResponse"
|
"$ref": "#/components/schemas/PaginatedResponse"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2445,6 +2464,24 @@
|
||||||
"schema": {
|
"schema": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "start_index",
|
||||||
|
"in": "query",
|
||||||
|
"description": "The index to start the pagination from.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "integer"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "limit",
|
||||||
|
"in": "query",
|
||||||
|
"description": "The number of sessions to return.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "integer"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
@ -8996,38 +9033,6 @@
|
||||||
],
|
],
|
||||||
"title": "Job"
|
"title": "Job"
|
||||||
},
|
},
|
||||||
"ListAgentSessionsResponse": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"data": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/Session"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": [
|
|
||||||
"data"
|
|
||||||
],
|
|
||||||
"title": "ListAgentSessionsResponse"
|
|
||||||
},
|
|
||||||
"ListAgentsResponse": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"data": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"$ref": "#/components/schemas/Agent"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"additionalProperties": false,
|
|
||||||
"required": [
|
|
||||||
"data"
|
|
||||||
],
|
|
||||||
"title": "ListAgentsResponse"
|
|
||||||
},
|
|
||||||
"BucketResponse": {
|
"BucketResponse": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
62
docs/_static/llama-stack-spec.yaml
vendored
62
docs/_static/llama-stack-spec.yaml
vendored
|
@ -197,11 +197,11 @@ paths:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A ListAgentsResponse.
|
description: A PaginatedResponse.
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/ListAgentsResponse'
|
$ref: '#/components/schemas/PaginatedResponse'
|
||||||
'400':
|
'400':
|
||||||
$ref: '#/components/responses/BadRequest400'
|
$ref: '#/components/responses/BadRequest400'
|
||||||
'429':
|
'429':
|
||||||
|
@ -215,7 +215,19 @@ paths:
|
||||||
tags:
|
tags:
|
||||||
- Agents
|
- Agents
|
||||||
description: List all agents.
|
description: List all agents.
|
||||||
parameters: []
|
parameters:
|
||||||
|
- name: start_index
|
||||||
|
in: query
|
||||||
|
description: The index to start the pagination from.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
- name: limit
|
||||||
|
in: query
|
||||||
|
description: The number of agents to return.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
|
@ -465,7 +477,8 @@ paths:
|
||||||
$ref: '#/components/responses/DefaultError'
|
$ref: '#/components/responses/DefaultError'
|
||||||
tags:
|
tags:
|
||||||
- Agents
|
- Agents
|
||||||
description: Delete an agent by its ID.
|
description: >-
|
||||||
|
Delete an agent by its ID and its associated sessions and turns.
|
||||||
parameters:
|
parameters:
|
||||||
- name: agent_id
|
- name: agent_id
|
||||||
in: path
|
in: path
|
||||||
|
@ -534,7 +547,8 @@ paths:
|
||||||
$ref: '#/components/responses/DefaultError'
|
$ref: '#/components/responses/DefaultError'
|
||||||
tags:
|
tags:
|
||||||
- Agents
|
- Agents
|
||||||
description: Delete an agent session by its ID.
|
description: >-
|
||||||
|
Delete an agent session by its ID and its associated turns.
|
||||||
parameters:
|
parameters:
|
||||||
- name: session_id
|
- name: session_id
|
||||||
in: path
|
in: path
|
||||||
|
@ -1662,11 +1676,11 @@ paths:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A ListAgentSessionsResponse.
|
description: A PaginatedResponse.
|
||||||
content:
|
content:
|
||||||
application/json:
|
application/json:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/ListAgentSessionsResponse'
|
$ref: '#/components/schemas/PaginatedResponse'
|
||||||
'400':
|
'400':
|
||||||
$ref: '#/components/responses/BadRequest400'
|
$ref: '#/components/responses/BadRequest400'
|
||||||
'429':
|
'429':
|
||||||
|
@ -1688,6 +1702,18 @@ paths:
|
||||||
required: true
|
required: true
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
|
- name: start_index
|
||||||
|
in: query
|
||||||
|
description: The index to start the pagination from.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
- name: limit
|
||||||
|
in: query
|
||||||
|
description: The number of sessions to return.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
/v1/eval/benchmarks:
|
/v1/eval/benchmarks:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
@ -6217,28 +6243,6 @@ components:
|
||||||
- job_id
|
- job_id
|
||||||
- status
|
- status
|
||||||
title: Job
|
title: Job
|
||||||
ListAgentSessionsResponse:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
data:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/Session'
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- data
|
|
||||||
title: ListAgentSessionsResponse
|
|
||||||
ListAgentsResponse:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
data:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/Agent'
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- data
|
|
||||||
title: ListAgentsResponse
|
|
||||||
BucketResponse:
|
BucketResponse:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
@ -13,6 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||||
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
CompletionMessage,
|
CompletionMessage,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
@ -241,16 +242,6 @@ class Agent(BaseModel):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListAgentsResponse(BaseModel):
|
|
||||||
data: list[Agent]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ListAgentSessionsResponse(BaseModel):
|
|
||||||
data: list[Session]
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
|
|
||||||
|
@ -526,7 +517,7 @@ class Agents(Protocol):
|
||||||
session_id: str,
|
session_id: str,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Delete an agent session by its ID.
|
"""Delete an agent session by its ID and its associated turns.
|
||||||
|
|
||||||
:param session_id: The ID of the session to delete.
|
:param session_id: The ID of the session to delete.
|
||||||
:param agent_id: The ID of the agent to delete the session for.
|
:param agent_id: The ID of the agent to delete the session for.
|
||||||
|
@ -538,17 +529,19 @@ class Agents(Protocol):
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Delete an agent by its ID.
|
"""Delete an agent by its ID and its associated sessions and turns.
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to delete.
|
:param agent_id: The ID of the agent to delete.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/agents", method="GET")
|
@webmethod(route="/agents", method="GET")
|
||||||
async def list_agents(self) -> ListAgentsResponse:
|
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||||
"""List all agents.
|
"""List all agents.
|
||||||
|
|
||||||
:returns: A ListAgentsResponse.
|
:param start_index: The index to start the pagination from.
|
||||||
|
:param limit: The number of agents to return.
|
||||||
|
:returns: A PaginatedResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@ -565,11 +558,15 @@ class Agents(Protocol):
|
||||||
async def list_agent_sessions(
|
async def list_agent_sessions(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
) -> ListAgentSessionsResponse:
|
start_index: int | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> PaginatedResponse:
|
||||||
"""List all session(s) of a given agent.
|
"""List all session(s) of a given agent.
|
||||||
|
|
||||||
:param agent_id: The ID of the agent to list sessions for.
|
:param agent_id: The ID of the agent to list sessions for.
|
||||||
:returns: A ListAgentSessionsResponse.
|
:param start_index: The index to start the pagination from.
|
||||||
|
:param limit: The number of sessions to return.
|
||||||
|
:returns: A PaginatedResponse.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,7 @@ class DiskDistributionRegistry(DistributionRegistry):
|
||||||
|
|
||||||
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
||||||
start_key, end_key = _get_registry_key_range()
|
start_key, end_key = _get_registry_key_range()
|
||||||
values = await self.kvstore.range(start_key, end_key)
|
values = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
return _parse_registry_values(values)
|
return _parse_registry_values(values)
|
||||||
|
|
||||||
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||||
|
@ -134,7 +134,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
return
|
return
|
||||||
|
|
||||||
start_key, end_key = _get_registry_key_range()
|
start_key, end_key = _get_registry_key_range()
|
||||||
values = await self.kvstore.range(start_key, end_key)
|
values = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
objects = _parse_registry_values(values)
|
objects = _parse_registry_values(values)
|
||||||
|
|
||||||
async with self._locked_cache() as cache:
|
async with self._locked_cache() as cache:
|
||||||
|
|
|
@ -95,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
tool_groups_api: ToolGroups,
|
tool_groups_api: ToolGroups,
|
||||||
vector_io_api: VectorIO,
|
vector_io_api: VectorIO,
|
||||||
persistence_store: KVStore,
|
persistence_store: KVStore,
|
||||||
|
created_at: str,
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
|
@ -104,6 +105,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||||
self.tool_runtime_api = tool_runtime_api
|
self.tool_runtime_api = tool_runtime_api
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
|
self.created_at = created_at
|
||||||
|
|
||||||
ShieldRunnerMixin.__init__(
|
ShieldRunnerMixin.__init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -4,10 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
Agent,
|
Agent,
|
||||||
|
@ -20,14 +20,13 @@ from llama_stack.apis.agents import (
|
||||||
AgentTurnCreateRequest,
|
AgentTurnCreateRequest,
|
||||||
AgentTurnResumeRequest,
|
AgentTurnResumeRequest,
|
||||||
Document,
|
Document,
|
||||||
ListAgentSessionsResponse,
|
|
||||||
ListAgentsResponse,
|
|
||||||
OpenAIResponseInputMessage,
|
OpenAIResponseInputMessage,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
Session,
|
Session,
|
||||||
Turn,
|
Turn,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
@ -38,14 +37,15 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.providers.utils.datasetio.pagination import paginate_records
|
||||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||||
|
|
||||||
from .agent_instance import ChatAgent
|
from .agent_instance import ChatAgent
|
||||||
from .config import MetaReferenceAgentsImplConfig
|
from .config import MetaReferenceAgentsImplConfig
|
||||||
from .openai_responses import OpenAIResponsesImpl
|
from .openai_responses import OpenAIResponsesImpl
|
||||||
|
from .persistence import AgentInfo
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImpl(Agents):
|
class MetaReferenceAgentsImpl(Agents):
|
||||||
|
@ -82,43 +82,47 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
) -> AgentCreateResponse:
|
) -> AgentCreateResponse:
|
||||||
agent_id = str(uuid.uuid4())
|
agent_id = str(uuid.uuid4())
|
||||||
|
created_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
agent_info = AgentInfo(
|
||||||
|
**agent_config.model_dump(),
|
||||||
|
created_at=created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the agent info
|
||||||
await self.persistence_store.set(
|
await self.persistence_store.set(
|
||||||
key=f"agent:{agent_id}",
|
key=f"agent:{agent_id}",
|
||||||
value=agent_config.model_dump_json(),
|
value=agent_info.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentCreateResponse(
|
return AgentCreateResponse(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||||
agent_config = await self.persistence_store.get(
|
agent_info_json = await self.persistence_store.get(
|
||||||
key=f"agent:{agent_id}",
|
key=f"agent:{agent_id}",
|
||||||
)
|
)
|
||||||
if not agent_config:
|
if not agent_info_json:
|
||||||
raise ValueError(f"Could not find agent config for {agent_id}")
|
raise ValueError(f"Could not find agent info for {agent_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent_config = json.loads(agent_config)
|
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent_config = AgentConfig(**agent_config)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
|
raise ValueError(f"Could not validate agent info for {agent_id}") from e
|
||||||
|
|
||||||
return ChatAgent(
|
return ChatAgent(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
agent_config=agent_config,
|
agent_config=agent_info,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
vector_io_api=self.vector_io_api,
|
vector_io_api=self.vector_io_api,
|
||||||
tool_runtime_api=self.tool_runtime_api,
|
tool_runtime_api=self.tool_runtime_api,
|
||||||
tool_groups_api=self.tool_groups_api,
|
tool_groups_api=self.tool_groups_api,
|
||||||
persistence_store=(
|
persistence_store=(
|
||||||
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
|
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||||
),
|
),
|
||||||
|
created_at=agent_info.created_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent_session(
|
async def create_agent_session(
|
||||||
|
@ -212,6 +216,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
turn_ids: list[str] | None = None,
|
turn_ids: list[str] | None = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
agent = await self._get_agent_impl(agent_id)
|
agent = await self._get_agent_impl(agent_id)
|
||||||
|
|
||||||
session_info = await agent.storage.get_session_info(session_id)
|
session_info = await agent.storage.get_session_info(session_id)
|
||||||
if session_info is None:
|
if session_info is None:
|
||||||
raise ValueError(f"Session {session_id} not found")
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
@ -226,24 +231,75 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||||
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
|
agent = await self._get_agent_impl(agent_id)
|
||||||
|
session_info = await agent.storage.get_session_info(session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
# Delete turns first, then the session
|
||||||
|
await agent.storage.delete_session_turns(session_id)
|
||||||
|
await agent.storage.delete_session(session_id)
|
||||||
|
|
||||||
async def delete_agent(self, agent_id: str) -> None:
|
async def delete_agent(self, agent_id: str) -> None:
|
||||||
|
# First get all sessions for this agent
|
||||||
|
agent = await self._get_agent_impl(agent_id)
|
||||||
|
sessions = await agent.storage.list_sessions()
|
||||||
|
|
||||||
|
# Delete all sessions
|
||||||
|
for session in sessions:
|
||||||
|
await self.delete_agents_session(agent_id, session.session_id)
|
||||||
|
|
||||||
|
# Finally delete the agent itself
|
||||||
await self.persistence_store.delete(f"agent:{agent_id}")
|
await self.persistence_store.delete(f"agent:{agent_id}")
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||||
pass
|
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
|
||||||
|
agent_list: list[Agent] = []
|
||||||
|
for agent_key in agent_keys:
|
||||||
|
agent_id = agent_key.split(":")[1]
|
||||||
|
|
||||||
async def list_agents(self) -> ListAgentsResponse:
|
# Get the agent info using the key
|
||||||
pass
|
agent_info_json = await self.persistence_store.get(agent_key)
|
||||||
|
if not agent_info_json:
|
||||||
|
logger.error(f"Could not find agent info for key {agent_key}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||||
|
agent_list.append(
|
||||||
|
Agent(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_config=agent_info,
|
||||||
|
created_at=agent_info.created_at,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error parsing agent info for {agent_id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert Agent objects to dictionaries
|
||||||
|
agent_dicts = [agent.model_dump() for agent in agent_list]
|
||||||
|
return paginate_records(agent_dicts, start_index, limit)
|
||||||
|
|
||||||
async def get_agent(self, agent_id: str) -> Agent:
|
async def get_agent(self, agent_id: str) -> Agent:
|
||||||
pass
|
chat_agent = await self._get_agent_impl(agent_id)
|
||||||
|
agent = Agent(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_config=chat_agent.agent_config,
|
||||||
|
created_at=chat_agent.created_at,
|
||||||
|
)
|
||||||
|
return agent
|
||||||
|
|
||||||
async def list_agent_sessions(
|
async def list_agent_sessions(
|
||||||
self,
|
self, agent_id: str, start_index: int | None = None, limit: int | None = None
|
||||||
agent_id: str,
|
) -> PaginatedResponse:
|
||||||
) -> ListAgentSessionsResponse:
|
agent = await self._get_agent_impl(agent_id)
|
||||||
|
sessions = await agent.storage.list_sessions()
|
||||||
|
# Convert Session objects to dictionaries
|
||||||
|
session_dicts = [session.model_dump() for session in sessions]
|
||||||
|
return paginate_records(session_dicts, start_index, limit)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# OpenAI responses
|
# OpenAI responses
|
||||||
|
|
|
@ -9,9 +9,7 @@ import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||||
|
|
||||||
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
|
||||||
from llama_stack.distribution.access_control import check_access
|
from llama_stack.distribution.access_control import check_access
|
||||||
from llama_stack.distribution.datatypes import AccessAttributes
|
from llama_stack.distribution.datatypes import AccessAttributes
|
||||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||||
|
@ -20,15 +18,17 @@ from llama_stack.providers.utils.kvstore import KVStore
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AgentSessionInfo(BaseModel):
|
class AgentSessionInfo(Session):
|
||||||
session_id: str
|
|
||||||
session_name: str
|
|
||||||
# TODO: is this used anywhere?
|
# TODO: is this used anywhere?
|
||||||
vector_db_id: str | None = None
|
vector_db_id: str | None = None
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
access_attributes: AccessAttributes | None = None
|
access_attributes: AccessAttributes | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentInfo(AgentConfig):
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
class AgentPersistence:
|
class AgentPersistence:
|
||||||
def __init__(self, agent_id: str, kvstore: KVStore):
|
def __init__(self, agent_id: str, kvstore: KVStore):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
|
@ -46,6 +46,7 @@ class AgentPersistence:
|
||||||
session_name=name,
|
session_name=name,
|
||||||
started_at=datetime.now(timezone.utc),
|
started_at=datetime.now(timezone.utc),
|
||||||
access_attributes=access_attributes,
|
access_attributes=access_attributes,
|
||||||
|
turns=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
|
@ -109,7 +110,7 @@ class AgentPersistence:
|
||||||
if not await self.get_session_if_accessible(session_id):
|
if not await self.get_session_if_accessible(session_id):
|
||||||
raise ValueError(f"Session {session_id} not found or access denied")
|
raise ValueError(f"Session {session_id} not found or access denied")
|
||||||
|
|
||||||
values = await self.kvstore.range(
|
values = await self.kvstore.values_in_range(
|
||||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||||
)
|
)
|
||||||
|
@ -121,7 +122,6 @@ class AgentPersistence:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error parsing turn: {e}")
|
log.error(f"Error parsing turn: {e}")
|
||||||
continue
|
continue
|
||||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||||
|
@ -170,3 +170,43 @@ class AgentPersistence:
|
||||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||||
)
|
)
|
||||||
return int(value) if value else None
|
return int(value) if value else None
|
||||||
|
|
||||||
|
async def list_sessions(self) -> list[Session]:
|
||||||
|
values = await self.kvstore.values_in_range(
|
||||||
|
start_key=f"session:{self.agent_id}:",
|
||||||
|
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
|
||||||
|
)
|
||||||
|
sessions = []
|
||||||
|
for value in values:
|
||||||
|
try:
|
||||||
|
session_info = Session(**json.loads(value))
|
||||||
|
sessions.append(session_info)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Error parsing session info: {e}")
|
||||||
|
continue
|
||||||
|
return sessions
|
||||||
|
|
||||||
|
async def delete_session_turns(self, session_id: str) -> None:
|
||||||
|
"""Delete all turns and their associated data for a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The ID of the session whose turns should be deleted.
|
||||||
|
"""
|
||||||
|
turns = await self.get_session_turns(session_id)
|
||||||
|
for turn in turns:
|
||||||
|
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
|
||||||
|
|
||||||
|
async def delete_session(self, session_id: str) -> None:
|
||||||
|
"""Delete a session and all its associated turns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The ID of the session to delete.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the session does not exist.
|
||||||
|
"""
|
||||||
|
session_info = await self.get_session_info(session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
|
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")
|
||||||
|
|
|
@ -64,7 +64,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
# Load existing datasets from kvstore
|
# Load existing datasets from kvstore
|
||||||
start_key = DATASETS_PREFIX
|
start_key = DATASETS_PREFIX
|
||||||
end_key = f"{DATASETS_PREFIX}\xff"
|
end_key = f"{DATASETS_PREFIX}\xff"
|
||||||
stored_datasets = await self.kvstore.range(start_key, end_key)
|
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
||||||
for dataset in stored_datasets:
|
for dataset in stored_datasets:
|
||||||
dataset = Dataset.model_validate_json(dataset)
|
dataset = Dataset.model_validate_json(dataset)
|
||||||
|
|
|
@ -58,7 +58,7 @@ class MetaReferenceEvalImpl(
|
||||||
# Load existing benchmarks from kvstore
|
# Load existing benchmarks from kvstore
|
||||||
start_key = EVAL_TASKS_PREFIX
|
start_key = EVAL_TASKS_PREFIX
|
||||||
end_key = f"{EVAL_TASKS_PREFIX}\xff"
|
end_key = f"{EVAL_TASKS_PREFIX}\xff"
|
||||||
stored_benchmarks = await self.kvstore.range(start_key, end_key)
|
stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
||||||
for benchmark in stored_benchmarks:
|
for benchmark in stored_benchmarks:
|
||||||
benchmark = Benchmark.model_validate_json(benchmark)
|
benchmark = Benchmark.model_validate_json(benchmark)
|
||||||
|
|
|
@ -125,7 +125,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
# Load existing banks from kvstore
|
# Load existing banks from kvstore
|
||||||
start_key = VECTOR_DBS_PREFIX
|
start_key = VECTOR_DBS_PREFIX
|
||||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||||
stored_vector_dbs = await self.kvstore.range(start_key, end_key)
|
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
||||||
for vector_db_data in stored_vector_dbs:
|
for vector_db_data in stored_vector_dbs:
|
||||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||||
|
|
|
@ -42,7 +42,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||||
# Load existing datasets from kvstore
|
# Load existing datasets from kvstore
|
||||||
start_key = DATASETS_PREFIX
|
start_key = DATASETS_PREFIX
|
||||||
end_key = f"{DATASETS_PREFIX}\xff"
|
end_key = f"{DATASETS_PREFIX}\xff"
|
||||||
stored_datasets = await self.kvstore.range(start_key, end_key)
|
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
||||||
for dataset in stored_datasets:
|
for dataset in stored_datasets:
|
||||||
dataset = Dataset.model_validate_json(dataset)
|
dataset = Dataset.model_validate_json(dataset)
|
||||||
|
|
|
@ -16,4 +16,6 @@ class KVStore(Protocol):
|
||||||
|
|
||||||
async def delete(self, key: str) -> None: ...
|
async def delete(self, key: str) -> None: ...
|
||||||
|
|
||||||
async def range(self, start_key: str, end_key: str) -> list[str]: ...
|
async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
|
||||||
|
|
||||||
|
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...
|
||||||
|
|
|
@ -26,9 +26,16 @@ class InmemoryKVStoreImpl(KVStore):
|
||||||
async def set(self, key: str, value: str) -> None:
|
async def set(self, key: str, value: str) -> None:
|
||||||
self._store[key] = value
|
self._store[key] = value
|
||||||
|
|
||||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
|
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
|
||||||
|
|
||||||
|
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
"""Get all keys in the given range."""
|
||||||
|
return [key for key in self._store.keys() if key >= start_key and key < end_key]
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> None:
|
||||||
|
del self._store[key]
|
||||||
|
|
||||||
|
|
||||||
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
|
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
|
||||||
if config.type == KVStoreType.redis.value:
|
if config.type == KVStoreType.redis.value:
|
||||||
|
|
|
@ -57,7 +57,7 @@ class MongoDBKVStoreImpl(KVStore):
|
||||||
key = self._namespaced_key(key)
|
key = self._namespaced_key(key)
|
||||||
await self.collection.delete_one({"key": key})
|
await self.collection.delete_one({"key": key})
|
||||||
|
|
||||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
start_key = self._namespaced_key(start_key)
|
start_key = self._namespaced_key(start_key)
|
||||||
end_key = self._namespaced_key(end_key)
|
end_key = self._namespaced_key(end_key)
|
||||||
query = {
|
query = {
|
||||||
|
@ -68,3 +68,10 @@ class MongoDBKVStoreImpl(KVStore):
|
||||||
async for doc in cursor:
|
async for doc in cursor:
|
||||||
result.append(doc["value"])
|
result.append(doc["value"])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
start_key = self._namespaced_key(start_key)
|
||||||
|
end_key = self._namespaced_key(end_key)
|
||||||
|
query = {"key": {"$gte": start_key, "$lt": end_key}}
|
||||||
|
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
|
||||||
|
return [doc["key"] for doc in cursor]
|
||||||
|
|
|
@ -85,7 +85,7 @@ class PostgresKVStoreImpl(KVStore):
|
||||||
(key,),
|
(key,),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
start_key = self._namespaced_key(start_key)
|
start_key = self._namespaced_key(start_key)
|
||||||
end_key = self._namespaced_key(end_key)
|
end_key = self._namespaced_key(end_key)
|
||||||
|
|
||||||
|
@ -99,3 +99,13 @@ class PostgresKVStoreImpl(KVStore):
|
||||||
(start_key, end_key),
|
(start_key, end_key),
|
||||||
)
|
)
|
||||||
return [row[0] for row in self.cursor.fetchall()]
|
return [row[0] for row in self.cursor.fetchall()]
|
||||||
|
|
||||||
|
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
start_key = self._namespaced_key(start_key)
|
||||||
|
end_key = self._namespaced_key(end_key)
|
||||||
|
|
||||||
|
self.cursor.execute(
|
||||||
|
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
|
||||||
|
(start_key, end_key),
|
||||||
|
)
|
||||||
|
return [row[0] for row in self.cursor.fetchall()]
|
||||||
|
|
|
@ -42,7 +42,7 @@ class RedisKVStoreImpl(KVStore):
|
||||||
key = self._namespaced_key(key)
|
key = self._namespaced_key(key)
|
||||||
await self.redis.delete(key)
|
await self.redis.delete(key)
|
||||||
|
|
||||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
start_key = self._namespaced_key(start_key)
|
start_key = self._namespaced_key(start_key)
|
||||||
end_key = self._namespaced_key(end_key)
|
end_key = self._namespaced_key(end_key)
|
||||||
cursor = 0
|
cursor = 0
|
||||||
|
@ -67,3 +67,10 @@ class RedisKVStoreImpl(KVStore):
|
||||||
]
|
]
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
"""Get all keys in the given range."""
|
||||||
|
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
|
||||||
|
if not matching_keys:
|
||||||
|
return []
|
||||||
|
return [k.decode("utf-8") for k in matching_keys]
|
||||||
|
|
|
@ -54,7 +54,7 @@ class SqliteKVStoreImpl(KVStore):
|
||||||
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
async with db.execute(
|
async with db.execute(
|
||||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||||
|
@ -65,3 +65,13 @@ class SqliteKVStoreImpl(KVStore):
|
||||||
_, value, _ = row
|
_, value, _ = row
|
||||||
result.append(value)
|
result.append(value)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||||
|
"""Get all keys in the given range."""
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
cursor = await db.execute(
|
||||||
|
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||||
|
(start_key, end_key),
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [row[0] for row in rows]
|
||||||
|
|
230
tests/unit/providers/agent/test_meta_reference_agent.py
Normal file
230
tests/unit/providers/agent/test_meta_reference_agent.py
Normal file
|
@ -0,0 +1,230 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import (
|
||||||
|
Agent,
|
||||||
|
AgentConfig,
|
||||||
|
AgentCreateResponse,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
from llama_stack.apis.safety import Safety
|
||||||
|
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||||
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
|
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
|
||||||
|
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
|
||||||
|
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_apis():
|
||||||
|
return {
|
||||||
|
"inference_api": AsyncMock(spec=Inference),
|
||||||
|
"vector_io_api": AsyncMock(spec=VectorIO),
|
||||||
|
"safety_api": AsyncMock(spec=Safety),
|
||||||
|
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
|
||||||
|
"tool_groups_api": AsyncMock(spec=ToolGroups),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config(tmp_path):
|
||||||
|
return MetaReferenceAgentsImplConfig(
|
||||||
|
persistence_store={
|
||||||
|
"type": "sqlite",
|
||||||
|
"db_path": str(tmp_path / "test.db"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def agents_impl(config, mock_apis):
|
||||||
|
impl = MetaReferenceAgentsImpl(
|
||||||
|
config,
|
||||||
|
mock_apis["inference_api"],
|
||||||
|
mock_apis["vector_io_api"],
|
||||||
|
mock_apis["safety_api"],
|
||||||
|
mock_apis["tool_runtime_api"],
|
||||||
|
mock_apis["tool_groups_api"],
|
||||||
|
)
|
||||||
|
await impl.initialize()
|
||||||
|
yield impl
|
||||||
|
await impl.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_agent_config():
|
||||||
|
return AgentConfig(
|
||||||
|
sampling_params={
|
||||||
|
"strategy": {"type": "greedy"},
|
||||||
|
"max_tokens": 0,
|
||||||
|
"repetition_penalty": 1.0,
|
||||||
|
},
|
||||||
|
input_shields=["string"],
|
||||||
|
output_shields=["string"],
|
||||||
|
toolgroups=["string"],
|
||||||
|
client_tools=[
|
||||||
|
{
|
||||||
|
"name": "string",
|
||||||
|
"description": "string",
|
||||||
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "string",
|
||||||
|
"parameter_type": "string",
|
||||||
|
"description": "string",
|
||||||
|
"required": True,
|
||||||
|
"default": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"property1": None,
|
||||||
|
"property2": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
tool_choice="auto",
|
||||||
|
tool_prompt_format="json",
|
||||||
|
tool_config={
|
||||||
|
"tool_choice": "auto",
|
||||||
|
"tool_prompt_format": "json",
|
||||||
|
"system_message_behavior": "append",
|
||||||
|
},
|
||||||
|
max_infer_iters=10,
|
||||||
|
model="string",
|
||||||
|
instructions="string",
|
||||||
|
enable_session_persistence=False,
|
||||||
|
response_format={
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"property1": None,
|
||||||
|
"property2": None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_agent(agents_impl, sample_agent_config):
|
||||||
|
response = await agents_impl.create_agent(sample_agent_config)
|
||||||
|
|
||||||
|
assert isinstance(response, AgentCreateResponse)
|
||||||
|
assert response.agent_id is not None
|
||||||
|
|
||||||
|
stored_agent = await agents_impl.persistence_store.get(f"agent:{response.agent_id}")
|
||||||
|
assert stored_agent is not None
|
||||||
|
agent_info = AgentInfo.model_validate_json(stored_agent)
|
||||||
|
assert agent_info.model == sample_agent_config.model
|
||||||
|
assert agent_info.created_at is not None
|
||||||
|
assert isinstance(agent_info.created_at, datetime)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_agent(agents_impl, sample_agent_config):
|
||||||
|
create_response = await agents_impl.create_agent(sample_agent_config)
|
||||||
|
agent_id = create_response.agent_id
|
||||||
|
|
||||||
|
agent = await agents_impl.get_agent(agent_id)
|
||||||
|
|
||||||
|
assert isinstance(agent, Agent)
|
||||||
|
assert agent.agent_id == agent_id
|
||||||
|
assert agent.agent_config.model == sample_agent_config.model
|
||||||
|
assert agent.created_at is not None
|
||||||
|
assert isinstance(agent.created_at, datetime)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_agents(agents_impl, sample_agent_config):
|
||||||
|
agent1_response = await agents_impl.create_agent(sample_agent_config)
|
||||||
|
agent2_response = await agents_impl.create_agent(sample_agent_config)
|
||||||
|
|
||||||
|
response = await agents_impl.list_agents()
|
||||||
|
|
||||||
|
assert isinstance(response, PaginatedResponse)
|
||||||
|
assert len(response.data) == 2
|
||||||
|
agent_ids = {agent["agent_id"] for agent in response.data}
|
||||||
|
assert agent1_response.agent_id in agent_ids
|
||||||
|
assert agent2_response.agent_id in agent_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
||||||
|
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
|
||||||
|
# Create an agent with specified persistence setting
|
||||||
|
config = sample_agent_config.model_copy()
|
||||||
|
config.enable_session_persistence = enable_session_persistence
|
||||||
|
response = await agents_impl.create_agent(config)
|
||||||
|
agent_id = response.agent_id
|
||||||
|
|
||||||
|
# Create a session
|
||||||
|
session_response = await agents_impl.create_agent_session(agent_id, "test_session")
|
||||||
|
assert session_response.session_id is not None
|
||||||
|
|
||||||
|
# Verify the session was stored
|
||||||
|
session = await agents_impl.get_agents_session(agent_id, session_response.session_id)
|
||||||
|
assert session.session_name == "test_session"
|
||||||
|
assert session.session_id == session_response.session_id
|
||||||
|
assert session.started_at is not None
|
||||||
|
assert session.turns == []
|
||||||
|
|
||||||
|
# Delete the session
|
||||||
|
await agents_impl.delete_agents_session(agent_id, session_response.session_id)
|
||||||
|
|
||||||
|
# Verify the session was deleted
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await agents_impl.get_agents_session(agent_id, session_response.session_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
||||||
|
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
|
||||||
|
# Create an agent with specified persistence setting
|
||||||
|
config = sample_agent_config.model_copy()
|
||||||
|
config.enable_session_persistence = enable_session_persistence
|
||||||
|
response = await agents_impl.create_agent(config)
|
||||||
|
agent_id = response.agent_id
|
||||||
|
|
||||||
|
# Create multiple sessions
|
||||||
|
session1 = await agents_impl.create_agent_session(agent_id, "session1")
|
||||||
|
session2 = await agents_impl.create_agent_session(agent_id, "session2")
|
||||||
|
|
||||||
|
# List sessions
|
||||||
|
sessions = await agents_impl.list_agent_sessions(agent_id)
|
||||||
|
assert len(sessions.data) == 2
|
||||||
|
session_ids = {s["session_id"] for s in sessions.data}
|
||||||
|
assert session1.session_id in session_ids
|
||||||
|
assert session2.session_id in session_ids
|
||||||
|
|
||||||
|
# Delete one session
|
||||||
|
await agents_impl.delete_agents_session(agent_id, session1.session_id)
|
||||||
|
|
||||||
|
# Verify the session was deleted
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await agents_impl.get_agents_session(agent_id, session1.session_id)
|
||||||
|
|
||||||
|
# List sessions again
|
||||||
|
sessions = await agents_impl.list_agent_sessions(agent_id)
|
||||||
|
assert len(sessions.data) == 1
|
||||||
|
assert session2.session_id in {s["session_id"] for s in sessions.data}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_agent(agents_impl, sample_agent_config):
|
||||||
|
# Create an agent
|
||||||
|
response = await agents_impl.create_agent(sample_agent_config)
|
||||||
|
agent_id = response.agent_id
|
||||||
|
|
||||||
|
# Delete the agent
|
||||||
|
await agents_impl.delete_agent(agent_id)
|
||||||
|
|
||||||
|
# Verify the agent was deleted
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await agents_impl.get_agent(agent_id)
|
|
@ -54,6 +54,7 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
|
||||||
session_name="Restricted Session",
|
session_name="Restricted Session",
|
||||||
started_at=datetime.now(),
|
started_at=datetime.now(),
|
||||||
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
|
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
|
||||||
|
turns=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
await agent_persistence.kvstore.set(
|
await agent_persistence.kvstore.set(
|
||||||
|
@ -85,6 +86,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
|
||||||
session_name="Restricted Session",
|
session_name="Restricted Session",
|
||||||
started_at=datetime.now(),
|
started_at=datetime.now(),
|
||||||
access_attributes=AccessAttributes(roles=["admin"]),
|
access_attributes=AccessAttributes(roles=["admin"]),
|
||||||
|
turns=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
await agent_persistence.kvstore.set(
|
await agent_persistence.kvstore.set(
|
||||||
|
@ -137,6 +139,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
|
||||||
session_name="Restricted Session",
|
session_name="Restricted Session",
|
||||||
started_at=datetime.now(),
|
started_at=datetime.now(),
|
||||||
access_attributes=AccessAttributes(roles=["admin"]),
|
access_attributes=AccessAttributes(roles=["admin"]),
|
||||||
|
turns=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
await agent_persistence.kvstore.set(
|
await agent_persistence.kvstore.set(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue