diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 84d4cf646..163d05087 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -307,11 +307,11 @@
"get": {
"responses": {
"200": {
- "description": "A ListAgentsResponse.",
+ "description": "A PaginatedResponse.",
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/ListAgentsResponse"
+ "$ref": "#/components/schemas/PaginatedResponse"
}
}
}
@@ -333,7 +333,26 @@
"Agents"
],
"description": "List all agents.",
- "parameters": []
+ "parameters": [
+ {
+ "name": "start_index",
+ "in": "query",
+ "description": "The index to start the pagination from.",
+ "required": false,
+ "schema": {
+ "type": "integer"
+ }
+ },
+ {
+ "name": "limit",
+ "in": "query",
+ "description": "The number of agents to return.",
+ "required": false,
+ "schema": {
+ "type": "integer"
+ }
+ }
+ ]
},
"post": {
"responses": {
@@ -691,7 +710,7 @@
"tags": [
"Agents"
],
- "description": "Delete an agent by its ID.",
+ "description": "Delete an agent by its ID and its associated sessions and turns.",
"parameters": [
{
"name": "agent_id",
@@ -789,7 +808,7 @@
"tags": [
"Agents"
],
- "description": "Delete an agent session by its ID.",
+ "description": "Delete an agent session by its ID and its associated turns.",
"parameters": [
{
"name": "session_id",
@@ -2410,11 +2429,11 @@
"get": {
"responses": {
"200": {
- "description": "A ListAgentSessionsResponse.",
+ "description": "A PaginatedResponse.",
"content": {
"application/json": {
"schema": {
- "$ref": "#/components/schemas/ListAgentSessionsResponse"
+ "$ref": "#/components/schemas/PaginatedResponse"
}
}
}
@@ -2445,6 +2464,24 @@
"schema": {
"type": "string"
}
+ },
+ {
+ "name": "start_index",
+ "in": "query",
+ "description": "The index to start the pagination from.",
+ "required": false,
+ "schema": {
+ "type": "integer"
+ }
+ },
+ {
+ "name": "limit",
+ "in": "query",
+ "description": "The number of sessions to return.",
+ "required": false,
+ "schema": {
+ "type": "integer"
+ }
}
]
}
@@ -8996,38 +9033,6 @@
],
"title": "Job"
},
- "ListAgentSessionsResponse": {
- "type": "object",
- "properties": {
- "data": {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/Session"
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "data"
- ],
- "title": "ListAgentSessionsResponse"
- },
- "ListAgentsResponse": {
- "type": "object",
- "properties": {
- "data": {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/Agent"
- }
- }
- },
- "additionalProperties": false,
- "required": [
- "data"
- ],
- "title": "ListAgentsResponse"
- },
"BucketResponse": {
"type": "object",
"properties": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 259cb3007..67c1d3dbd 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -197,11 +197,11 @@ paths:
get:
responses:
'200':
- description: A ListAgentsResponse.
+ description: A PaginatedResponse.
content:
application/json:
schema:
- $ref: '#/components/schemas/ListAgentsResponse'
+ $ref: '#/components/schemas/PaginatedResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -215,7 +215,19 @@ paths:
tags:
- Agents
description: List all agents.
- parameters: []
+ parameters:
+ - name: start_index
+ in: query
+ description: The index to start the pagination from.
+ required: false
+ schema:
+ type: integer
+ - name: limit
+ in: query
+ description: The number of agents to return.
+ required: false
+ schema:
+ type: integer
post:
responses:
'200':
@@ -465,7 +477,8 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
- description: Delete an agent by its ID.
+ description: >-
+ Delete an agent by its ID and its associated sessions and turns.
parameters:
- name: agent_id
in: path
@@ -534,7 +547,8 @@ paths:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
- description: Delete an agent session by its ID.
+ description: >-
+ Delete an agent session by its ID and its associated turns.
parameters:
- name: session_id
in: path
@@ -1662,11 +1676,11 @@ paths:
get:
responses:
'200':
- description: A ListAgentSessionsResponse.
+ description: A PaginatedResponse.
content:
application/json:
schema:
- $ref: '#/components/schemas/ListAgentSessionsResponse'
+ $ref: '#/components/schemas/PaginatedResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
@@ -1688,6 +1702,18 @@ paths:
required: true
schema:
type: string
+ - name: start_index
+ in: query
+ description: The index to start the pagination from.
+ required: false
+ schema:
+ type: integer
+ - name: limit
+ in: query
+ description: The number of sessions to return.
+ required: false
+ schema:
+ type: integer
/v1/eval/benchmarks:
get:
responses:
@@ -6217,28 +6243,6 @@ components:
- job_id
- status
title: Job
- ListAgentSessionsResponse:
- type: object
- properties:
- data:
- type: array
- items:
- $ref: '#/components/schemas/Session'
- additionalProperties: false
- required:
- - data
- title: ListAgentSessionsResponse
- ListAgentsResponse:
- type: object
- properties:
- data:
- type: array
- items:
- $ref: '#/components/schemas/Agent'
- additionalProperties: false
- required:
- - data
- title: ListAgentsResponse
BucketResponse:
type: object
properties:
diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py
index 84e3a9057..91e57dbbe 100644
--- a/llama_stack/apis/agents/agents.py
+++ b/llama_stack/apis/agents/agents.py
@@ -13,6 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
+from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
@@ -241,16 +242,6 @@ class Agent(BaseModel):
created_at: datetime
-@json_schema_type
-class ListAgentsResponse(BaseModel):
- data: list[Agent]
-
-
-@json_schema_type
-class ListAgentSessionsResponse(BaseModel):
- data: list[Session]
-
-
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = None
@@ -526,7 +517,7 @@ class Agents(Protocol):
session_id: str,
agent_id: str,
) -> None:
- """Delete an agent session by its ID.
+ """Delete an agent session by its ID and its associated turns.
:param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for.
@@ -538,17 +529,19 @@ class Agents(Protocol):
self,
agent_id: str,
) -> None:
- """Delete an agent by its ID.
+ """Delete an agent by its ID and its associated sessions and turns.
:param agent_id: The ID of the agent to delete.
"""
...
@webmethod(route="/agents", method="GET")
- async def list_agents(self) -> ListAgentsResponse:
+ async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents.
- :returns: A ListAgentsResponse.
+ :param start_index: The index to start the pagination from.
+ :param limit: The number of agents to return.
+ :returns: A PaginatedResponse.
"""
...
@@ -565,11 +558,15 @@ class Agents(Protocol):
async def list_agent_sessions(
self,
agent_id: str,
- ) -> ListAgentSessionsResponse:
+ start_index: int | None = None,
+ limit: int | None = None,
+ ) -> PaginatedResponse:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
- :returns: A ListAgentSessionsResponse.
+ :param start_index: The index to start the pagination from.
+ :param limit: The number of sessions to return.
+ :returns: A PaginatedResponse.
"""
...
diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py
index ae97f600c..a6b400136 100644
--- a/llama_stack/distribution/store/registry.py
+++ b/llama_stack/distribution/store/registry.py
@@ -73,7 +73,7 @@ class DiskDistributionRegistry(DistributionRegistry):
async def get_all(self) -> list[RoutableObjectWithProvider]:
start_key, end_key = _get_registry_key_range()
- values = await self.kvstore.range(start_key, end_key)
+ values = await self.kvstore.values_in_range(start_key, end_key)
return _parse_registry_values(values)
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
@@ -134,7 +134,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
return
start_key, end_key = _get_registry_key_range()
- values = await self.kvstore.range(start_key, end_key)
+ values = await self.kvstore.values_in_range(start_key, end_key)
objects = _parse_registry_values(values)
async with self._locked_cache() as cache:
diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
index 3807ddb86..2e387e7e8 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py
@@ -95,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_groups_api: ToolGroups,
vector_io_api: VectorIO,
persistence_store: KVStore,
+ created_at: str,
):
self.agent_id = agent_id
self.agent_config = agent_config
@@ -104,6 +105,7 @@ class ChatAgent(ShieldRunnerMixin):
self.storage = AgentPersistence(agent_id, persistence_store)
self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api
+ self.created_at = created_at
ShieldRunnerMixin.__init__(
self,
diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py
index a87739925..19d60c816 100644
--- a/llama_stack/providers/inline/agents/meta_reference/agents.py
+++ b/llama_stack/providers/inline/agents/meta_reference/agents.py
@@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-import json
import logging
import uuid
from collections.abc import AsyncGenerator
+from datetime import datetime, timezone
from llama_stack.apis.agents import (
Agent,
@@ -20,14 +20,13 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
- ListAgentSessionsResponse,
- ListAgentsResponse,
OpenAIResponseInputMessage,
OpenAIResponseInputTool,
OpenAIResponseObject,
Session,
Turn,
)
+from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import (
Inference,
ToolConfig,
@@ -38,14 +37,15 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
+from llama_stack.providers.utils.datasetio.pagination import paginate_records
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl
+from .persistence import AgentInfo
logger = logging.getLogger()
-logger.setLevel(logging.INFO)
class MetaReferenceAgentsImpl(Agents):
@@ -82,43 +82,47 @@ class MetaReferenceAgentsImpl(Agents):
agent_config: AgentConfig,
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
+ created_at = datetime.now(timezone.utc)
+ agent_info = AgentInfo(
+ **agent_config.model_dump(),
+ created_at=created_at,
+ )
+
+ # Store the agent info
await self.persistence_store.set(
key=f"agent:{agent_id}",
- value=agent_config.model_dump_json(),
+ value=agent_info.model_dump_json(),
)
+
return AgentCreateResponse(
agent_id=agent_id,
)
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
- agent_config = await self.persistence_store.get(
+ agent_info_json = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
- if not agent_config:
- raise ValueError(f"Could not find agent config for {agent_id}")
+ if not agent_info_json:
+ raise ValueError(f"Could not find agent info for {agent_id}")
try:
- agent_config = json.loads(agent_config)
- except json.JSONDecodeError as e:
- raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
-
- try:
- agent_config = AgentConfig(**agent_config)
+ agent_info = AgentInfo.model_validate_json(agent_info_json)
except Exception as e:
- raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
+ raise ValueError(f"Could not validate agent info for {agent_id}") from e
return ChatAgent(
agent_id=agent_id,
- agent_config=agent_config,
+ agent_config=agent_info,
inference_api=self.inference_api,
safety_api=self.safety_api,
vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
- self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
+ self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
+ created_at=agent_info.created_at,
)
async def create_agent_session(
@@ -212,6 +216,7 @@ class MetaReferenceAgentsImpl(Agents):
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
+
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
@@ -226,24 +231,75 @@ class MetaReferenceAgentsImpl(Agents):
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
- await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
+ agent = await self._get_agent_impl(agent_id)
+ session_info = await agent.storage.get_session_info(session_id)
+ if session_info is None:
+ raise ValueError(f"Session {session_id} not found")
+
+ # Delete turns first, then the session
+ await agent.storage.delete_session_turns(session_id)
+ await agent.storage.delete_session(session_id)
async def delete_agent(self, agent_id: str) -> None:
+ # First get all sessions for this agent
+ agent = await self._get_agent_impl(agent_id)
+ sessions = await agent.storage.list_sessions()
+
+ # Delete all sessions
+ for session in sessions:
+ await self.delete_agents_session(agent_id, session.session_id)
+
+ # Finally delete the agent itself
await self.persistence_store.delete(f"agent:{agent_id}")
- async def shutdown(self) -> None:
- pass
+ async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
+ agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
+ agent_list: list[Agent] = []
+ for agent_key in agent_keys:
+ agent_id = agent_key.split(":")[1]
- async def list_agents(self) -> ListAgentsResponse:
- pass
+ # Get the agent info using the key
+ agent_info_json = await self.persistence_store.get(agent_key)
+ if not agent_info_json:
+ logger.error(f"Could not find agent info for key {agent_key}")
+ continue
+
+ try:
+ agent_info = AgentInfo.model_validate_json(agent_info_json)
+ agent_list.append(
+ Agent(
+ agent_id=agent_id,
+ agent_config=agent_info,
+ created_at=agent_info.created_at,
+ )
+ )
+ except Exception as e:
+ logger.error(f"Error parsing agent info for {agent_id}: {e}")
+ continue
+
+ # Convert Agent objects to dictionaries
+ agent_dicts = [agent.model_dump() for agent in agent_list]
+ return paginate_records(agent_dicts, start_index, limit)
async def get_agent(self, agent_id: str) -> Agent:
- pass
+ chat_agent = await self._get_agent_impl(agent_id)
+ agent = Agent(
+ agent_id=agent_id,
+ agent_config=chat_agent.agent_config,
+ created_at=chat_agent.created_at,
+ )
+ return agent
async def list_agent_sessions(
- self,
- agent_id: str,
- ) -> ListAgentSessionsResponse:
+ self, agent_id: str, start_index: int | None = None, limit: int | None = None
+ ) -> PaginatedResponse:
+ agent = await self._get_agent_impl(agent_id)
+ sessions = await agent.storage.list_sessions()
+ # Convert Session objects to dictionaries
+ session_dicts = [session.model_dump() for session in sessions]
+ return paginate_records(session_dicts, start_index, limit)
+
+ async def shutdown(self) -> None:
pass
# OpenAI responses
diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py
index 60ce91395..5031a4a90 100644
--- a/llama_stack/providers/inline/agents/meta_reference/persistence.py
+++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py
@@ -9,9 +9,7 @@ import logging
import uuid
from datetime import datetime, timezone
-from pydantic import BaseModel
-
-from llama_stack.apis.agents import ToolExecutionStep, Turn
+from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes
@@ -20,15 +18,17 @@ from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)
-class AgentSessionInfo(BaseModel):
- session_id: str
- session_name: str
+class AgentSessionInfo(Session):
# TODO: is this used anywhere?
vector_db_id: str | None = None
started_at: datetime
access_attributes: AccessAttributes | None = None
+class AgentInfo(AgentConfig):
+ created_at: datetime
+
+
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore):
self.agent_id = agent_id
@@ -46,6 +46,7 @@ class AgentPersistence:
session_name=name,
started_at=datetime.now(timezone.utc),
access_attributes=access_attributes,
+ turns=[],
)
await self.kvstore.set(
@@ -109,7 +110,7 @@ class AgentPersistence:
if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied")
- values = await self.kvstore.range(
+ values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
@@ -121,7 +122,6 @@ class AgentPersistence:
except Exception as e:
log.error(f"Error parsing turn: {e}")
continue
- turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
@@ -170,3 +170,43 @@ class AgentPersistence:
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)
return int(value) if value else None
+
+ async def list_sessions(self) -> list[Session]:
+ values = await self.kvstore.values_in_range(
+ start_key=f"session:{self.agent_id}:",
+ end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
+ )
+ sessions = []
+ for value in values:
+ try:
+ session_info = Session(**json.loads(value))
+ sessions.append(session_info)
+ except Exception as e:
+ log.error(f"Error parsing session info: {e}")
+ continue
+ return sessions
+
+ async def delete_session_turns(self, session_id: str) -> None:
+ """Delete all turns and their associated data for a session.
+
+ Args:
+ session_id: The ID of the session whose turns should be deleted.
+ """
+ turns = await self.get_session_turns(session_id)
+ for turn in turns:
+ await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
+
+ async def delete_session(self, session_id: str) -> None:
+ """Delete a session and all its associated turns.
+
+ Args:
+ session_id: The ID of the session to delete.
+
+ Raises:
+ ValueError: If the session does not exist.
+ """
+ session_info = await self.get_session_info(session_id)
+ if session_info is None:
+ raise ValueError(f"Session {session_id} not found")
+
+ await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")
diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py
index 260a640bd..5640df6ab 100644
--- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py
+++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py
@@ -64,7 +64,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
- stored_datasets = await self.kvstore.range(start_key, end_key)
+ stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)
diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py
index 12dde66b5..bc0898dc5 100644
--- a/llama_stack/providers/inline/eval/meta_reference/eval.py
+++ b/llama_stack/providers/inline/eval/meta_reference/eval.py
@@ -58,7 +58,7 @@ class MetaReferenceEvalImpl(
# Load existing benchmarks from kvstore
start_key = EVAL_TASKS_PREFIX
end_key = f"{EVAL_TASKS_PREFIX}\xff"
- stored_benchmarks = await self.kvstore.range(start_key, end_key)
+ stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key)
for benchmark in stored_benchmarks:
benchmark = Benchmark.model_validate_json(benchmark)
diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py
index 5d5b1da35..d3dc7e694 100644
--- a/llama_stack/providers/inline/vector_io/faiss/faiss.py
+++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py
@@ -125,7 +125,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
# Load existing banks from kvstore
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
- stored_vector_dbs = await self.kvstore.range(start_key, end_key)
+ stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py
index baecf45e9..e329c88a7 100644
--- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py
+++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py
@@ -42,7 +42,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
- stored_datasets = await self.kvstore.range(start_key, end_key)
+ stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)
diff --git a/llama_stack/providers/utils/kvstore/api.py b/llama_stack/providers/utils/kvstore/api.py
index 8efde4ea9..d17dc66e1 100644
--- a/llama_stack/providers/utils/kvstore/api.py
+++ b/llama_stack/providers/utils/kvstore/api.py
@@ -16,4 +16,6 @@ class KVStore(Protocol):
async def delete(self, key: str) -> None: ...
- async def range(self, start_key: str, end_key: str) -> list[str]: ...
+ async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
+
+ async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...
diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py
index 0eb969b65..3a1ee8a26 100644
--- a/llama_stack/providers/utils/kvstore/kvstore.py
+++ b/llama_stack/providers/utils/kvstore/kvstore.py
@@ -26,9 +26,16 @@ class InmemoryKVStoreImpl(KVStore):
async def set(self, key: str, value: str) -> None:
self._store[key] = value
- async def range(self, start_key: str, end_key: str) -> list[str]:
+ async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
+ async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
+ """Get all keys in the given range."""
+ return [key for key in self._store.keys() if key >= start_key and key < end_key]
+
+ async def delete(self, key: str) -> None:
+ del self._store[key]
+
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis.value:
diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py
index 330a079bf..3842773d9 100644
--- a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py
+++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py
@@ -57,7 +57,7 @@ class MongoDBKVStoreImpl(KVStore):
key = self._namespaced_key(key)
await self.collection.delete_one({"key": key})
- async def range(self, start_key: str, end_key: str) -> list[str]:
+ async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {
@@ -68,3 +68,10 @@ class MongoDBKVStoreImpl(KVStore):
async for doc in cursor:
result.append(doc["value"])
return result
+
+ async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
+ start_key = self._namespaced_key(start_key)
+ end_key = self._namespaced_key(end_key)
+ query = {"key": {"$gte": start_key, "$lt": end_key}}
+ cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
+ return [doc["key"] for doc in cursor]
diff --git a/llama_stack/providers/utils/kvstore/postgres/postgres.py b/llama_stack/providers/utils/kvstore/postgres/postgres.py
index 6bfbc4f81..bd35decfc 100644
--- a/llama_stack/providers/utils/kvstore/postgres/postgres.py
+++ b/llama_stack/providers/utils/kvstore/postgres/postgres.py
@@ -85,7 +85,7 @@ class PostgresKVStoreImpl(KVStore):
(key,),
)
- async def range(self, start_key: str, end_key: str) -> list[str]:
+ async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
@@ -99,3 +99,13 @@ class PostgresKVStoreImpl(KVStore):
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]
+
+ async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
+ start_key = self._namespaced_key(start_key)
+ end_key = self._namespaced_key(end_key)
+
+ self.cursor.execute(
+ f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
+ (start_key, end_key),
+ )
+ return [row[0] for row in self.cursor.fetchall()]
diff --git a/llama_stack/providers/utils/kvstore/redis/redis.py b/llama_stack/providers/utils/kvstore/redis/redis.py
index d95de0291..3d2d956c3 100644
--- a/llama_stack/providers/utils/kvstore/redis/redis.py
+++ b/llama_stack/providers/utils/kvstore/redis/redis.py
@@ -42,7 +42,7 @@ class RedisKVStoreImpl(KVStore):
key = self._namespaced_key(key)
await self.redis.delete(key)
- async def range(self, start_key: str, end_key: str) -> list[str]:
+ async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
cursor = 0
@@ -67,3 +67,10 @@ class RedisKVStoreImpl(KVStore):
]
return []
+
+ async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
+ """Get all keys in the given range."""
+ matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
+ if not matching_keys:
+ return []
+ return [k.decode("utf-8") for k in matching_keys]
diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py
index faca77887..4e49e4d8c 100644
--- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py
+++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py
@@ -54,7 +54,7 @@ class SqliteKVStoreImpl(KVStore):
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit()
- async def range(self, start_key: str, end_key: str) -> list[str]:
+ async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
@@ -65,3 +65,13 @@ class SqliteKVStoreImpl(KVStore):
_, value, _ = row
result.append(value)
return result
+
+ async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
+ """Get all keys in the given range."""
+ async with aiosqlite.connect(self.db_path) as db:
+ cursor = await db.execute(
+ f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
+ (start_key, end_key),
+ )
+ rows = await cursor.fetchall()
+ return [row[0] for row in rows]
diff --git a/tests/unit/providers/agent/test_meta_reference_agent.py b/tests/unit/providers/agent/test_meta_reference_agent.py
new file mode 100644
index 000000000..bef24e123
--- /dev/null
+++ b/tests/unit/providers/agent/test_meta_reference_agent.py
@@ -0,0 +1,230 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from datetime import datetime
+from unittest.mock import AsyncMock
+
+import pytest
+import pytest_asyncio
+
+from llama_stack.apis.agents import (
+ Agent,
+ AgentConfig,
+ AgentCreateResponse,
+)
+from llama_stack.apis.common.responses import PaginatedResponse
+from llama_stack.apis.inference import Inference
+from llama_stack.apis.safety import Safety
+from llama_stack.apis.tools import ToolGroups, ToolRuntime
+from llama_stack.apis.vector_io import VectorIO
+from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
+from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
+from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
+
+
+@pytest.fixture
+def mock_apis():
+ return {
+ "inference_api": AsyncMock(spec=Inference),
+ "vector_io_api": AsyncMock(spec=VectorIO),
+ "safety_api": AsyncMock(spec=Safety),
+ "tool_runtime_api": AsyncMock(spec=ToolRuntime),
+ "tool_groups_api": AsyncMock(spec=ToolGroups),
+ }
+
+
+@pytest.fixture
+def config(tmp_path):
+ return MetaReferenceAgentsImplConfig(
+ persistence_store={
+ "type": "sqlite",
+ "db_path": str(tmp_path / "test.db"),
+ },
+ )
+
+
+@pytest_asyncio.fixture
+async def agents_impl(config, mock_apis):
+ impl = MetaReferenceAgentsImpl(
+ config,
+ mock_apis["inference_api"],
+ mock_apis["vector_io_api"],
+ mock_apis["safety_api"],
+ mock_apis["tool_runtime_api"],
+ mock_apis["tool_groups_api"],
+ )
+ await impl.initialize()
+ yield impl
+ await impl.shutdown()
+
+
+@pytest.fixture
+def sample_agent_config():
+ return AgentConfig(
+ sampling_params={
+ "strategy": {"type": "greedy"},
+ "max_tokens": 0,
+ "repetition_penalty": 1.0,
+ },
+ input_shields=["string"],
+ output_shields=["string"],
+ toolgroups=["string"],
+ client_tools=[
+ {
+ "name": "string",
+ "description": "string",
+ "parameters": [
+ {
+ "name": "string",
+ "parameter_type": "string",
+ "description": "string",
+ "required": True,
+ "default": None,
+ }
+ ],
+ "metadata": {
+ "property1": None,
+ "property2": None,
+ },
+ }
+ ],
+ tool_choice="auto",
+ tool_prompt_format="json",
+ tool_config={
+ "tool_choice": "auto",
+ "tool_prompt_format": "json",
+ "system_message_behavior": "append",
+ },
+ max_infer_iters=10,
+ model="string",
+ instructions="string",
+ enable_session_persistence=False,
+ response_format={
+ "type": "json_schema",
+ "json_schema": {
+ "property1": None,
+ "property2": None,
+ },
+ },
+ )
+
+
+@pytest.mark.asyncio
+async def test_create_agent(agents_impl, sample_agent_config):
+ response = await agents_impl.create_agent(sample_agent_config)
+
+ assert isinstance(response, AgentCreateResponse)
+ assert response.agent_id is not None
+
+ stored_agent = await agents_impl.persistence_store.get(f"agent:{response.agent_id}")
+ assert stored_agent is not None
+ agent_info = AgentInfo.model_validate_json(stored_agent)
+ assert agent_info.model == sample_agent_config.model
+ assert agent_info.created_at is not None
+ assert isinstance(agent_info.created_at, datetime)
+
+
+@pytest.mark.asyncio
+async def test_get_agent(agents_impl, sample_agent_config):
+ create_response = await agents_impl.create_agent(sample_agent_config)
+ agent_id = create_response.agent_id
+
+ agent = await agents_impl.get_agent(agent_id)
+
+ assert isinstance(agent, Agent)
+ assert agent.agent_id == agent_id
+ assert agent.agent_config.model == sample_agent_config.model
+ assert agent.created_at is not None
+ assert isinstance(agent.created_at, datetime)
+
+
+@pytest.mark.asyncio
+async def test_list_agents(agents_impl, sample_agent_config):
+ agent1_response = await agents_impl.create_agent(sample_agent_config)
+ agent2_response = await agents_impl.create_agent(sample_agent_config)
+
+ response = await agents_impl.list_agents()
+
+ assert isinstance(response, PaginatedResponse)
+ assert len(response.data) == 2
+ agent_ids = {agent["agent_id"] for agent in response.data}
+ assert agent1_response.agent_id in agent_ids
+ assert agent2_response.agent_id in agent_ids
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("enable_session_persistence", [True, False])
+async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
+ # Create an agent with specified persistence setting
+ config = sample_agent_config.model_copy()
+ config.enable_session_persistence = enable_session_persistence
+ response = await agents_impl.create_agent(config)
+ agent_id = response.agent_id
+
+ # Create a session
+ session_response = await agents_impl.create_agent_session(agent_id, "test_session")
+ assert session_response.session_id is not None
+
+ # Verify the session was stored
+ session = await agents_impl.get_agents_session(agent_id, session_response.session_id)
+ assert session.session_name == "test_session"
+ assert session.session_id == session_response.session_id
+ assert session.started_at is not None
+ assert session.turns == []
+
+ # Delete the session
+ await agents_impl.delete_agents_session(agent_id, session_response.session_id)
+
+ # Verify the session was deleted
+ with pytest.raises(ValueError):
+ await agents_impl.get_agents_session(agent_id, session_response.session_id)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("enable_session_persistence", [True, False])
+async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
+ # Create an agent with specified persistence setting
+ config = sample_agent_config.model_copy()
+ config.enable_session_persistence = enable_session_persistence
+ response = await agents_impl.create_agent(config)
+ agent_id = response.agent_id
+
+ # Create multiple sessions
+ session1 = await agents_impl.create_agent_session(agent_id, "session1")
+ session2 = await agents_impl.create_agent_session(agent_id, "session2")
+
+ # List sessions
+ sessions = await agents_impl.list_agent_sessions(agent_id)
+ assert len(sessions.data) == 2
+ session_ids = {s["session_id"] for s in sessions.data}
+ assert session1.session_id in session_ids
+ assert session2.session_id in session_ids
+
+ # Delete one session
+ await agents_impl.delete_agents_session(agent_id, session1.session_id)
+
+ # Verify the session was deleted
+ with pytest.raises(ValueError):
+ await agents_impl.get_agents_session(agent_id, session1.session_id)
+
+ # List sessions again
+ sessions = await agents_impl.list_agent_sessions(agent_id)
+ assert len(sessions.data) == 1
+ assert session2.session_id in {s["session_id"] for s in sessions.data}
+
+
+@pytest.mark.asyncio
+async def test_delete_agent(agents_impl, sample_agent_config):
+ # Create an agent
+ response = await agents_impl.create_agent(sample_agent_config)
+ agent_id = response.agent_id
+
+ # Delete the agent
+ await agents_impl.delete_agent(agent_id)
+
+ # Verify the agent was deleted
+ with pytest.raises(ValueError):
+ await agents_impl.get_agent(agent_id)
diff --git a/tests/unit/providers/agents/test_persistence_access_control.py b/tests/unit/providers/agents/test_persistence_access_control.py
index 06e1a778a..48fa647a8 100644
--- a/tests/unit/providers/agents/test_persistence_access_control.py
+++ b/tests/unit/providers/agents/test_persistence_access_control.py
@@ -54,6 +54,7 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
+ turns=[],
)
await agent_persistence.kvstore.set(
@@ -85,6 +86,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
+ turns=[],
)
await agent_persistence.kvstore.set(
@@ -137,6 +139,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
+ turns=[],
)
await agent_persistence.kvstore.set(