feat: implementation for agent/session list and describe (#1606)

Create a new agent:

```
curl --request POST \
  --url http://localhost:8321/v1/agents \
  --header 'Accept: application/json' \
  --header 'Content-Type: application/json' \
  --data '{
  "agent_config": {
    "sampling_params": {
      "strategy": {
        "type": "greedy"
      },
      "max_tokens": 0,
      "repetition_penalty": 1
    },
    "input_shields": [
      "string"
    ],
    "output_shields": [
      "string"
    ],
    "toolgroups": [
      "string"
    ],
    "client_tools": [
      {
        "name": "string",
        "description": "string",
        "parameters": [
          {
            "name": "string",
            "parameter_type": "string",
            "description": "string",
            "required": true,
            "default": null
          }
        ],
        "metadata": {
          "property1": null,
          "property2": null
        }
      }
    ],
    "tool_choice": "auto",
    "tool_prompt_format": "json",
    "tool_config": {
      "tool_choice": "auto",
      "tool_prompt_format": "json",
      "system_message_behavior": "append"
    },
    "max_infer_iters": 10,
    "model": "string",
    "instructions": "string",
    "enable_session_persistence": false,
    "response_format": {
      "type": "json_schema",
      "json_schema": {
        "property1": null,
        "property2": null
      }
    }
  }
}'
```

Get agent:

```
curl http://127.0.0.1:8321/v1/agents/9abad4ab-2c77-45f9-9d16-46b79d2bea1f
{"agent_id":"9abad4ab-2c77-45f9-9d16-46b79d2bea1f","agent_config":{"sampling_params":{"strategy":{"type":"greedy"},"max_tokens":0,"repetition_penalty":1.0},"input_shields":["string"],"output_shields":["string"],"toolgroups":["string"],"client_tools":[{"name":"string","description":"string","parameters":[{"name":"string","parameter_type":"string","description":"string","required":true,"default":null}],"metadata":{"property1":null,"property2":null}}],"tool_choice":"auto","tool_prompt_format":"json","tool_config":{"tool_choice":"auto","tool_prompt_format":"json","system_message_behavior":"append"},"max_infer_iters":10,"model":"string","instructions":"string","enable_session_persistence":false,"response_format":{"type":"json_schema","json_schema":{"property1":null,"property2":null}}},"created_at":"2025-03-12T16:18:28.369144Z"}%
```

List agents:

```
curl http://127.0.0.1:8321/v1/agents|jq
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1680  100  1680    0     0   498k      0 --:--:-- --:--:-- --:--:--  546k
{
  "data": [
    {
      "agent_id": "9abad4ab-2c77-45f9-9d16-46b79d2bea1f",
      "agent_config": {
        "sampling_params": {
          "strategy": {
            "type": "greedy"
          },
          "max_tokens": 0,
          "repetition_penalty": 1.0
        },
        "input_shields": [
          "string"
        ],
        "output_shields": [
          "string"
        ],
        "toolgroups": [
          "string"
        ],
        "client_tools": [
          {
            "name": "string",
            "description": "string",
            "parameters": [
              {
                "name": "string",
                "parameter_type": "string",
                "description": "string",
                "required": true,
                "default": null
              }
            ],
            "metadata": {
              "property1": null,
              "property2": null
            }
          }
        ],
        "tool_choice": "auto",
        "tool_prompt_format": "json",
        "tool_config": {
          "tool_choice": "auto",
          "tool_prompt_format": "json",
          "system_message_behavior": "append"
        },
        "max_infer_iters": 10,
        "model": "string",
        "instructions": "string",
        "enable_session_persistence": false,
        "response_format": {
          "type": "json_schema",
          "json_schema": {
            "property1": null,
            "property2": null
          }
        }
      },
      "created_at": "2025-03-12T16:18:28.369144Z"
    },
    {
      "agent_id": "a6643aaa-96dd-46db-a405-333dc504b168",
      "agent_config": {
        "sampling_params": {
          "strategy": {
            "type": "greedy"
          },
          "max_tokens": 0,
          "repetition_penalty": 1.0
        },
        "input_shields": [
          "string"
        ],
        "output_shields": [
          "string"
        ],
        "toolgroups": [
          "string"
        ],
        "client_tools": [
          {
            "name": "string",
            "description": "string",
            "parameters": [
              {
                "name": "string",
                "parameter_type": "string",
                "description": "string",
                "required": true,
                "default": null
              }
            ],
            "metadata": {
              "property1": null,
              "property2": null
            }
          }
        ],
        "tool_choice": "auto",
        "tool_prompt_format": "json",
        "tool_config": {
          "tool_choice": "auto",
          "tool_prompt_format": "json",
          "system_message_behavior": "append"
        },
        "max_infer_iters": 10,
        "model": "string",
        "instructions": "string",
        "enable_session_persistence": false,
        "response_format": {
          "type": "json_schema",
          "json_schema": {
            "property1": null,
            "property2": null
          }
        }
      },
      "created_at": "2025-03-12T16:17:12.811273Z"
    }
  ]
}
```

Create sessions:

```
curl --request POST \
  --url http://localhost:8321/v1/agents/{agent_id}/session \
  --header 'Accept: application/json' \
  --header 'Content-Type: application/json' \
  --data '{
  "session_name": "string"
}'
```

List sessions:

```
 curl http://127.0.0.1:8321/v1/agents/9abad4ab-2c77-45f9-9d16-46b79d2bea1f/sessions|jq
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   263  100   263    0     0  90099      0 --:--:-- --:--:-- --:--:--  128k
[
  {
    "session_id": "2b15c4fc-e348-46c1-ae32-f6d424441ac1",
    "session_name": "string",
    "turns": [],
    "started_at": "2025-03-12T17:19:17.784328"
  },
  {
    "session_id": "9432472d-d483-4b73-b682-7b1d35d64111",
    "session_name": "string",
    "turns": [],
    "started_at": "2025-03-12T17:19:19.885834"
  }
]
```

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-07 14:49:23 +02:00 committed by GitHub
parent 40e71758d9
commit c91e3552a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 510 additions and 130 deletions

View file

@ -307,11 +307,11 @@
"get": {
"responses": {
"200": {
"description": "A ListAgentsResponse.",
"description": "A PaginatedResponse.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListAgentsResponse"
"$ref": "#/components/schemas/PaginatedResponse"
}
}
}
@ -333,7 +333,26 @@
"Agents"
],
"description": "List all agents.",
"parameters": []
"parameters": [
{
"name": "start_index",
"in": "query",
"description": "The index to start the pagination from.",
"required": false,
"schema": {
"type": "integer"
}
},
{
"name": "limit",
"in": "query",
"description": "The number of agents to return.",
"required": false,
"schema": {
"type": "integer"
}
}
]
},
"post": {
"responses": {
@ -691,7 +710,7 @@
"tags": [
"Agents"
],
"description": "Delete an agent by its ID.",
"description": "Delete an agent by its ID and its associated sessions and turns.",
"parameters": [
{
"name": "agent_id",
@ -789,7 +808,7 @@
"tags": [
"Agents"
],
"description": "Delete an agent session by its ID.",
"description": "Delete an agent session by its ID and its associated turns.",
"parameters": [
{
"name": "session_id",
@ -2410,11 +2429,11 @@
"get": {
"responses": {
"200": {
"description": "A ListAgentSessionsResponse.",
"description": "A PaginatedResponse.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListAgentSessionsResponse"
"$ref": "#/components/schemas/PaginatedResponse"
}
}
}
@ -2445,6 +2464,24 @@
"schema": {
"type": "string"
}
},
{
"name": "start_index",
"in": "query",
"description": "The index to start the pagination from.",
"required": false,
"schema": {
"type": "integer"
}
},
{
"name": "limit",
"in": "query",
"description": "The number of sessions to return.",
"required": false,
"schema": {
"type": "integer"
}
}
]
}
@ -8996,38 +9033,6 @@
],
"title": "Job"
},
"ListAgentSessionsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Session"
}
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListAgentSessionsResponse"
},
"ListAgentsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Agent"
}
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListAgentsResponse"
},
"BucketResponse": {
"type": "object",
"properties": {

View file

@ -197,11 +197,11 @@ paths:
get:
responses:
'200':
description: A ListAgentsResponse.
description: A PaginatedResponse.
content:
application/json:
schema:
$ref: '#/components/schemas/ListAgentsResponse'
$ref: '#/components/schemas/PaginatedResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -215,7 +215,19 @@ paths:
tags:
- Agents
description: List all agents.
parameters: []
parameters:
- name: start_index
in: query
description: The index to start the pagination from.
required: false
schema:
type: integer
- name: limit
in: query
description: The number of agents to return.
required: false
schema:
type: integer
post:
responses:
'200':
@ -465,7 +477,8 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
description: Delete an agent by its ID.
description: >-
Delete an agent by its ID and its associated sessions and turns.
parameters:
- name: agent_id
in: path
@ -534,7 +547,8 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
description: Delete an agent session by its ID.
description: >-
Delete an agent session by its ID and its associated turns.
parameters:
- name: session_id
in: path
@ -1662,11 +1676,11 @@ paths:
get:
responses:
'200':
description: A ListAgentSessionsResponse.
description: A PaginatedResponse.
content:
application/json:
schema:
$ref: '#/components/schemas/ListAgentSessionsResponse'
$ref: '#/components/schemas/PaginatedResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@ -1688,6 +1702,18 @@ paths:
required: true
schema:
type: string
- name: start_index
in: query
description: The index to start the pagination from.
required: false
schema:
type: integer
- name: limit
in: query
description: The number of sessions to return.
required: false
schema:
type: integer
/v1/eval/benchmarks:
get:
responses:
@ -6217,28 +6243,6 @@ components:
- job_id
- status
title: Job
ListAgentSessionsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/Session'
additionalProperties: false
required:
- data
title: ListAgentSessionsResponse
ListAgentsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/Agent'
additionalProperties: false
required:
- data
title: ListAgentsResponse
BucketResponse:
type: object
properties:

View file

@ -13,6 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
@ -241,16 +242,6 @@ class Agent(BaseModel):
created_at: datetime
@json_schema_type
class ListAgentsResponse(BaseModel):
data: list[Agent]
@json_schema_type
class ListAgentSessionsResponse(BaseModel):
data: list[Session]
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = None
@ -526,7 +517,7 @@ class Agents(Protocol):
session_id: str,
agent_id: str,
) -> None:
"""Delete an agent session by its ID.
"""Delete an agent session by its ID and its associated turns.
:param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for.
@ -538,17 +529,19 @@ class Agents(Protocol):
self,
agent_id: str,
) -> None:
"""Delete an agent by its ID.
"""Delete an agent by its ID and its associated sessions and turns.
:param agent_id: The ID of the agent to delete.
"""
...
@webmethod(route="/agents", method="GET")
async def list_agents(self) -> ListAgentsResponse:
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents.
:returns: A ListAgentsResponse.
:param start_index: The index to start the pagination from.
:param limit: The number of agents to return.
:returns: A PaginatedResponse.
"""
...
@ -565,11 +558,15 @@ class Agents(Protocol):
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
:returns: A ListAgentSessionsResponse.
:param start_index: The index to start the pagination from.
:param limit: The number of sessions to return.
:returns: A PaginatedResponse.
"""
...

View file

@ -73,7 +73,7 @@ class DiskDistributionRegistry(DistributionRegistry):
async def get_all(self) -> list[RoutableObjectWithProvider]:
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
values = await self.kvstore.values_in_range(start_key, end_key)
return _parse_registry_values(values)
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
@ -134,7 +134,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
return
start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key)
values = await self.kvstore.values_in_range(start_key, end_key)
objects = _parse_registry_values(values)
async with self._locked_cache() as cache:

View file

@ -95,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_groups_api: ToolGroups,
vector_io_api: VectorIO,
persistence_store: KVStore,
created_at: str,
):
self.agent_id = agent_id
self.agent_config = agent_config
@ -104,6 +105,7 @@ class ChatAgent(ShieldRunnerMixin):
self.storage = AgentPersistence(agent_id, persistence_store)
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
self.created_at = created_at
ShieldRunnerMixin.__init__(
self,

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from llama_stack.apis.agents import (
Agent,
@ -20,14 +20,13 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
ListAgentSessionsResponse,
ListAgentsResponse,
OpenAIResponseInputMessage,
OpenAIResponseInputTool,
OpenAIResponseObject,
Session,
Turn,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import (
Inference,
ToolConfig,
@ -38,14 +37,15 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.datasetio.pagination import paginate_records
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl
from .persistence import AgentInfo
logger = logging.getLogger()
logger.setLevel(logging.INFO)
class MetaReferenceAgentsImpl(Agents):
@ -82,43 +82,47 @@ class MetaReferenceAgentsImpl(Agents):
agent_config: AgentConfig,
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
created_at = datetime.now(timezone.utc)
agent_info = AgentInfo(
**agent_config.model_dump(),
created_at=created_at,
)
# Store the agent info
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_config.model_dump_json(),
value=agent_info.model_dump_json(),
)
return AgentCreateResponse(
agent_id=agent_id,
)
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get(
agent_info_json = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
if not agent_config:
raise ValueError(f"Could not find agent config for {agent_id}")
if not agent_info_json:
raise ValueError(f"Could not find agent info for {agent_id}")
try:
agent_config = json.loads(agent_config)
except json.JSONDecodeError as e:
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
try:
agent_config = AgentConfig(**agent_config)
agent_info = AgentInfo.model_validate_json(agent_info_json)
except Exception as e:
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
raise ValueError(f"Could not validate agent info for {agent_id}") from e
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config,
agent_config=agent_info,
inference_api=self.inference_api,
safety_api=self.safety_api,
vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
created_at=agent_info.created_at,
)
async def create_agent_session(
@ -212,6 +216,7 @@ class MetaReferenceAgentsImpl(Agents):
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
@ -226,24 +231,75 @@ class MetaReferenceAgentsImpl(Agents):
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
# Delete turns first, then the session
await agent.storage.delete_session_turns(session_id)
await agent.storage.delete_session(session_id)
async def delete_agent(self, agent_id: str) -> None:
# First get all sessions for this agent
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Delete all sessions
for session in sessions:
await self.delete_agents_session(agent_id, session.session_id)
# Finally delete the agent itself
await self.persistence_store.delete(f"agent:{agent_id}")
async def shutdown(self) -> None:
pass
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
agent_list: list[Agent] = []
for agent_key in agent_keys:
agent_id = agent_key.split(":")[1]
async def list_agents(self) -> ListAgentsResponse:
pass
# Get the agent info using the key
agent_info_json = await self.persistence_store.get(agent_key)
if not agent_info_json:
logger.error(f"Could not find agent info for key {agent_key}")
continue
try:
agent_info = AgentInfo.model_validate_json(agent_info_json)
agent_list.append(
Agent(
agent_id=agent_id,
agent_config=agent_info,
created_at=agent_info.created_at,
)
)
except Exception as e:
logger.error(f"Error parsing agent info for {agent_id}: {e}")
continue
# Convert Agent objects to dictionaries
agent_dicts = [agent.model_dump() for agent in agent_list]
return paginate_records(agent_dicts, start_index, limit)
async def get_agent(self, agent_id: str) -> Agent:
pass
chat_agent = await self._get_agent_impl(agent_id)
agent = Agent(
agent_id=agent_id,
agent_config=chat_agent.agent_config,
created_at=chat_agent.created_at,
)
return agent
async def list_agent_sessions(
self,
agent_id: str,
) -> ListAgentSessionsResponse:
self, agent_id: str, start_index: int | None = None, limit: int | None = None
) -> PaginatedResponse:
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Convert Session objects to dictionaries
session_dicts = [session.model_dump() for session in sessions]
return paginate_records(session_dicts, start_index, limit)
async def shutdown(self) -> None:
pass
# OpenAI responses

View file

@ -9,9 +9,7 @@ import logging
import uuid
from datetime import datetime, timezone
from pydantic import BaseModel
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes
@ -20,15 +18,17 @@ from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
class AgentSessionInfo(Session):
# TODO: is this used anywhere?
vector_db_id: str | None = None
started_at: datetime
access_attributes: AccessAttributes | None = None
class AgentInfo(AgentConfig):
created_at: datetime
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore):
self.agent_id = agent_id
@ -46,6 +46,7 @@ class AgentPersistence:
session_name=name,
started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
turns=[],
)
await self.kvstore.set(
@ -109,7 +110,7 @@ class AgentPersistence:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
values = await self.kvstore.range(
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
@ -121,7 +122,6 @@ class AgentPersistence:
except Exception as e:
log.error(f"Error parsing turn: {e}")
continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
@ -170,3 +170,43 @@ class AgentPersistence:
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)
return int(value) if value else None
async def list_sessions(self) -> list[Session]:
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:",
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
)
sessions = []
for value in values:
try:
session_info = Session(**json.loads(value))
sessions.append(session_info)
except Exception as e:
log.error(f"Error parsing session info: {e}")
continue
return sessions
async def delete_session_turns(self, session_id: str) -> None:
"""Delete all turns and their associated data for a session.
Args:
session_id: The ID of the session whose turns should be deleted.
"""
turns = await self.get_session_turns(session_id)
for turn in turns:
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
async def delete_session(self, session_id: str) -> None:
"""Delete a session and all its associated turns.
Args:
session_id: The ID of the session to delete.
Raises:
ValueError: If the session does not exist.
"""
session_info = await self.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")

View file

@ -64,7 +64,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key)
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)

View file

@ -58,7 +58,7 @@ class MetaReferenceEvalImpl(
# Load existing benchmarks from kvstore
start_key = EVAL_TASKS_PREFIX
end_key = f"{EVAL_TASKS_PREFIX}\xff"
stored_benchmarks = await self.kvstore.range(start_key, end_key)
stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key)
for benchmark in stored_benchmarks:
benchmark = Benchmark.model_validate_json(benchmark)

View file

@ -125,7 +125,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
# Load existing banks from kvstore
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.range(start_key, end_key)
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)

View file

@ -42,7 +42,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key)
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)

View file

@ -16,4 +16,6 @@ class KVStore(Protocol):
async def delete(self, key: str) -> None: ...
async def range(self, start_key: str, end_key: str) -> list[str]: ...
async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...

View file

@ -26,9 +26,16 @@ class InmemoryKVStoreImpl(KVStore):
async def set(self, key: str, value: str) -> None:
self._store[key] = value
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
return [key for key in self._store.keys() if key >= start_key and key < end_key]
async def delete(self, key: str) -> None:
del self._store[key]
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis.value:

View file

@ -57,7 +57,7 @@ class MongoDBKVStoreImpl(KVStore):
key = self._namespaced_key(key)
await self.collection.delete_one({"key": key})
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {
@ -68,3 +68,10 @@ class MongoDBKVStoreImpl(KVStore):
async for doc in cursor:
result.append(doc["value"])
return result
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {"key": {"$gte": start_key, "$lt": end_key}}
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
return [doc["key"] for doc in cursor]

View file

@ -85,7 +85,7 @@ class PostgresKVStoreImpl(KVStore):
(key,),
)
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
@ -99,3 +99,13 @@ class PostgresKVStoreImpl(KVStore):
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]

View file

@ -42,7 +42,7 @@ class RedisKVStoreImpl(KVStore):
key = self._namespaced_key(key)
await self.redis.delete(key)
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
cursor = 0
@ -67,3 +67,10 @@ class RedisKVStoreImpl(KVStore):
]
return []
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
if not matching_keys:
return []
return [k.decode("utf-8") for k in matching_keys]

View file

@ -54,7 +54,7 @@ class SqliteKVStoreImpl(KVStore):
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit()
async def range(self, start_key: str, end_key: str) -> list[str]:
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
@ -65,3 +65,13 @@ class SqliteKVStoreImpl(KVStore):
_, value, _ = row
result.append(value)
return result
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
)
rows = await cursor.fetchall()
return [row[0] for row in rows]

View file

@ -0,0 +1,230 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from unittest.mock import AsyncMock
import pytest
import pytest_asyncio
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import Inference
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
@pytest.fixture
def mock_apis():
return {
"inference_api": AsyncMock(spec=Inference),
"vector_io_api": AsyncMock(spec=VectorIO),
"safety_api": AsyncMock(spec=Safety),
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
"tool_groups_api": AsyncMock(spec=ToolGroups),
}
@pytest.fixture
def config(tmp_path):
return MetaReferenceAgentsImplConfig(
persistence_store={
"type": "sqlite",
"db_path": str(tmp_path / "test.db"),
},
)
@pytest_asyncio.fixture
async def agents_impl(config, mock_apis):
impl = MetaReferenceAgentsImpl(
config,
mock_apis["inference_api"],
mock_apis["vector_io_api"],
mock_apis["safety_api"],
mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"],
)
await impl.initialize()
yield impl
await impl.shutdown()
@pytest.fixture
def sample_agent_config():
return AgentConfig(
sampling_params={
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 1.0,
},
input_shields=["string"],
output_shields=["string"],
toolgroups=["string"],
client_tools=[
{
"name": "string",
"description": "string",
"parameters": [
{
"name": "string",
"parameter_type": "string",
"description": "string",
"required": True,
"default": None,
}
],
"metadata": {
"property1": None,
"property2": None,
},
}
],
tool_choice="auto",
tool_prompt_format="json",
tool_config={
"tool_choice": "auto",
"tool_prompt_format": "json",
"system_message_behavior": "append",
},
max_infer_iters=10,
model="string",
instructions="string",
enable_session_persistence=False,
response_format={
"type": "json_schema",
"json_schema": {
"property1": None,
"property2": None,
},
},
)
@pytest.mark.asyncio
async def test_create_agent(agents_impl, sample_agent_config):
response = await agents_impl.create_agent(sample_agent_config)
assert isinstance(response, AgentCreateResponse)
assert response.agent_id is not None
stored_agent = await agents_impl.persistence_store.get(f"agent:{response.agent_id}")
assert stored_agent is not None
agent_info = AgentInfo.model_validate_json(stored_agent)
assert agent_info.model == sample_agent_config.model
assert agent_info.created_at is not None
assert isinstance(agent_info.created_at, datetime)
@pytest.mark.asyncio
async def test_get_agent(agents_impl, sample_agent_config):
create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id
agent = await agents_impl.get_agent(agent_id)
assert isinstance(agent, Agent)
assert agent.agent_id == agent_id
assert agent.agent_config.model == sample_agent_config.model
assert agent.created_at is not None
assert isinstance(agent.created_at, datetime)
@pytest.mark.asyncio
async def test_list_agents(agents_impl, sample_agent_config):
agent1_response = await agents_impl.create_agent(sample_agent_config)
agent2_response = await agents_impl.create_agent(sample_agent_config)
response = await agents_impl.list_agents()
assert isinstance(response, PaginatedResponse)
assert len(response.data) == 2
agent_ids = {agent["agent_id"] for agent in response.data}
assert agent1_response.agent_id in agent_ids
assert agent2_response.agent_id in agent_ids
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
config = sample_agent_config.model_copy()
config.enable_session_persistence = enable_session_persistence
response = await agents_impl.create_agent(config)
agent_id = response.agent_id
# Create a session
session_response = await agents_impl.create_agent_session(agent_id, "test_session")
assert session_response.session_id is not None
# Verify the session was stored
session = await agents_impl.get_agents_session(agent_id, session_response.session_id)
assert session.session_name == "test_session"
assert session.session_id == session_response.session_id
assert session.started_at is not None
assert session.turns == []
# Delete the session
await agents_impl.delete_agents_session(agent_id, session_response.session_id)
# Verify the session was deleted
with pytest.raises(ValueError):
await agents_impl.get_agents_session(agent_id, session_response.session_id)
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
config = sample_agent_config.model_copy()
config.enable_session_persistence = enable_session_persistence
response = await agents_impl.create_agent(config)
agent_id = response.agent_id
# Create multiple sessions
session1 = await agents_impl.create_agent_session(agent_id, "session1")
session2 = await agents_impl.create_agent_session(agent_id, "session2")
# List sessions
sessions = await agents_impl.list_agent_sessions(agent_id)
assert len(sessions.data) == 2
session_ids = {s["session_id"] for s in sessions.data}
assert session1.session_id in session_ids
assert session2.session_id in session_ids
# Delete one session
await agents_impl.delete_agents_session(agent_id, session1.session_id)
# Verify the session was deleted
with pytest.raises(ValueError):
await agents_impl.get_agents_session(agent_id, session1.session_id)
# List sessions again
sessions = await agents_impl.list_agent_sessions(agent_id)
assert len(sessions.data) == 1
assert session2.session_id in {s["session_id"] for s in sessions.data}
@pytest.mark.asyncio
async def test_delete_agent(agents_impl, sample_agent_config):
# Create an agent
response = await agents_impl.create_agent(sample_agent_config)
agent_id = response.agent_id
# Delete the agent
await agents_impl.delete_agent(agent_id)
# Verify the agent was deleted
with pytest.raises(ValueError):
await agents_impl.get_agent(agent_id)

View file

@ -54,6 +54,7 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
turns=[],
)
await agent_persistence.kvstore.set(
@ -85,6 +86,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
turns=[],
)
await agent_persistence.kvstore.set(
@ -137,6 +139,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
turns=[],
)
await agent_persistence.kvstore.set(