feat: implementation for agent/session list and describe (#1606)

Create a new agent:

```
curl --request POST \
  --url http://localhost:8321/v1/agents \
  --header 'Accept: application/json' \
  --header 'Content-Type: application/json' \
  --data '{
  "agent_config": {
    "sampling_params": {
      "strategy": {
        "type": "greedy"
      },
      "max_tokens": 0,
      "repetition_penalty": 1
    },
    "input_shields": [
      "string"
    ],
    "output_shields": [
      "string"
    ],
    "toolgroups": [
      "string"
    ],
    "client_tools": [
      {
        "name": "string",
        "description": "string",
        "parameters": [
          {
            "name": "string",
            "parameter_type": "string",
            "description": "string",
            "required": true,
            "default": null
          }
        ],
        "metadata": {
          "property1": null,
          "property2": null
        }
      }
    ],
    "tool_choice": "auto",
    "tool_prompt_format": "json",
    "tool_config": {
      "tool_choice": "auto",
      "tool_prompt_format": "json",
      "system_message_behavior": "append"
    },
    "max_infer_iters": 10,
    "model": "string",
    "instructions": "string",
    "enable_session_persistence": false,
    "response_format": {
      "type": "json_schema",
      "json_schema": {
        "property1": null,
        "property2": null
      }
    }
  }
}'
```

Get agent:

```
curl http://127.0.0.1:8321/v1/agents/9abad4ab-2c77-45f9-9d16-46b79d2bea1f
{"agent_id":"9abad4ab-2c77-45f9-9d16-46b79d2bea1f","agent_config":{"sampling_params":{"strategy":{"type":"greedy"},"max_tokens":0,"repetition_penalty":1.0},"input_shields":["string"],"output_shields":["string"],"toolgroups":["string"],"client_tools":[{"name":"string","description":"string","parameters":[{"name":"string","parameter_type":"string","description":"string","required":true,"default":null}],"metadata":{"property1":null,"property2":null}}],"tool_choice":"auto","tool_prompt_format":"json","tool_config":{"tool_choice":"auto","tool_prompt_format":"json","system_message_behavior":"append"},"max_infer_iters":10,"model":"string","instructions":"string","enable_session_persistence":false,"response_format":{"type":"json_schema","json_schema":{"property1":null,"property2":null}}},"created_at":"2025-03-12T16:18:28.369144Z"}%
```

List agents:

```
curl http://127.0.0.1:8321/v1/agents|jq
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  1680  100  1680    0     0   498k      0 --:--:-- --:--:-- --:--:--  546k
{
  "data": [
    {
      "agent_id": "9abad4ab-2c77-45f9-9d16-46b79d2bea1f",
      "agent_config": {
        "sampling_params": {
          "strategy": {
            "type": "greedy"
          },
          "max_tokens": 0,
          "repetition_penalty": 1.0
        },
        "input_shields": [
          "string"
        ],
        "output_shields": [
          "string"
        ],
        "toolgroups": [
          "string"
        ],
        "client_tools": [
          {
            "name": "string",
            "description": "string",
            "parameters": [
              {
                "name": "string",
                "parameter_type": "string",
                "description": "string",
                "required": true,
                "default": null
              }
            ],
            "metadata": {
              "property1": null,
              "property2": null
            }
          }
        ],
        "tool_choice": "auto",
        "tool_prompt_format": "json",
        "tool_config": {
          "tool_choice": "auto",
          "tool_prompt_format": "json",
          "system_message_behavior": "append"
        },
        "max_infer_iters": 10,
        "model": "string",
        "instructions": "string",
        "enable_session_persistence": false,
        "response_format": {
          "type": "json_schema",
          "json_schema": {
            "property1": null,
            "property2": null
          }
        }
      },
      "created_at": "2025-03-12T16:18:28.369144Z"
    },
    {
      "agent_id": "a6643aaa-96dd-46db-a405-333dc504b168",
      "agent_config": {
        "sampling_params": {
          "strategy": {
            "type": "greedy"
          },
          "max_tokens": 0,
          "repetition_penalty": 1.0
        },
        "input_shields": [
          "string"
        ],
        "output_shields": [
          "string"
        ],
        "toolgroups": [
          "string"
        ],
        "client_tools": [
          {
            "name": "string",
            "description": "string",
            "parameters": [
              {
                "name": "string",
                "parameter_type": "string",
                "description": "string",
                "required": true,
                "default": null
              }
            ],
            "metadata": {
              "property1": null,
              "property2": null
            }
          }
        ],
        "tool_choice": "auto",
        "tool_prompt_format": "json",
        "tool_config": {
          "tool_choice": "auto",
          "tool_prompt_format": "json",
          "system_message_behavior": "append"
        },
        "max_infer_iters": 10,
        "model": "string",
        "instructions": "string",
        "enable_session_persistence": false,
        "response_format": {
          "type": "json_schema",
          "json_schema": {
            "property1": null,
            "property2": null
          }
        }
      },
      "created_at": "2025-03-12T16:17:12.811273Z"
    }
  ]
}
```

Create sessions:

```
curl --request POST \
  --url http://localhost:8321/v1/agents/{agent_id}/session \
  --header 'Accept: application/json' \
  --header 'Content-Type: application/json' \
  --data '{
  "session_name": "string"
}'
```

List sessions:

```
 curl http://127.0.0.1:8321/v1/agents/9abad4ab-2c77-45f9-9d16-46b79d2bea1f/sessions|jq
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   263  100   263    0     0  90099      0 --:--:-- --:--:-- --:--:--  128k
[
  {
    "session_id": "2b15c4fc-e348-46c1-ae32-f6d424441ac1",
    "session_name": "string",
    "turns": [],
    "started_at": "2025-03-12T17:19:17.784328"
  },
  {
    "session_id": "9432472d-d483-4b73-b682-7b1d35d64111",
    "session_name": "string",
    "turns": [],
    "started_at": "2025-03-12T17:19:19.885834"
  }
]
```

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-05-07 14:49:23 +02:00 committed by GitHub
parent 40e71758d9
commit c91e3552a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 510 additions and 130 deletions

View file

@ -307,11 +307,11 @@
"get": { "get": {
"responses": { "responses": {
"200": { "200": {
"description": "A ListAgentsResponse.", "description": "A PaginatedResponse.",
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/ListAgentsResponse" "$ref": "#/components/schemas/PaginatedResponse"
} }
} }
} }
@ -333,7 +333,26 @@
"Agents" "Agents"
], ],
"description": "List all agents.", "description": "List all agents.",
"parameters": [] "parameters": [
{
"name": "start_index",
"in": "query",
"description": "The index to start the pagination from.",
"required": false,
"schema": {
"type": "integer"
}
},
{
"name": "limit",
"in": "query",
"description": "The number of agents to return.",
"required": false,
"schema": {
"type": "integer"
}
}
]
}, },
"post": { "post": {
"responses": { "responses": {
@ -691,7 +710,7 @@
"tags": [ "tags": [
"Agents" "Agents"
], ],
"description": "Delete an agent by its ID.", "description": "Delete an agent by its ID and its associated sessions and turns.",
"parameters": [ "parameters": [
{ {
"name": "agent_id", "name": "agent_id",
@ -789,7 +808,7 @@
"tags": [ "tags": [
"Agents" "Agents"
], ],
"description": "Delete an agent session by its ID.", "description": "Delete an agent session by its ID and its associated turns.",
"parameters": [ "parameters": [
{ {
"name": "session_id", "name": "session_id",
@ -2410,11 +2429,11 @@
"get": { "get": {
"responses": { "responses": {
"200": { "200": {
"description": "A ListAgentSessionsResponse.", "description": "A PaginatedResponse.",
"content": { "content": {
"application/json": { "application/json": {
"schema": { "schema": {
"$ref": "#/components/schemas/ListAgentSessionsResponse" "$ref": "#/components/schemas/PaginatedResponse"
} }
} }
} }
@ -2445,6 +2464,24 @@
"schema": { "schema": {
"type": "string" "type": "string"
} }
},
{
"name": "start_index",
"in": "query",
"description": "The index to start the pagination from.",
"required": false,
"schema": {
"type": "integer"
}
},
{
"name": "limit",
"in": "query",
"description": "The number of sessions to return.",
"required": false,
"schema": {
"type": "integer"
}
} }
] ]
} }
@ -8996,38 +9033,6 @@
], ],
"title": "Job" "title": "Job"
}, },
"ListAgentSessionsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Session"
}
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListAgentSessionsResponse"
},
"ListAgentsResponse": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/Agent"
}
}
},
"additionalProperties": false,
"required": [
"data"
],
"title": "ListAgentsResponse"
},
"BucketResponse": { "BucketResponse": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -197,11 +197,11 @@ paths:
get: get:
responses: responses:
'200': '200':
description: A ListAgentsResponse. description: A PaginatedResponse.
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/ListAgentsResponse' $ref: '#/components/schemas/PaginatedResponse'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -215,7 +215,19 @@ paths:
tags: tags:
- Agents - Agents
description: List all agents. description: List all agents.
parameters: [] parameters:
- name: start_index
in: query
description: The index to start the pagination from.
required: false
schema:
type: integer
- name: limit
in: query
description: The number of agents to return.
required: false
schema:
type: integer
post: post:
responses: responses:
'200': '200':
@ -465,7 +477,8 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Agents - Agents
description: Delete an agent by its ID. description: >-
Delete an agent by its ID and its associated sessions and turns.
parameters: parameters:
- name: agent_id - name: agent_id
in: path in: path
@ -534,7 +547,8 @@ paths:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Agents - Agents
description: Delete an agent session by its ID. description: >-
Delete an agent session by its ID and its associated turns.
parameters: parameters:
- name: session_id - name: session_id
in: path in: path
@ -1662,11 +1676,11 @@ paths:
get: get:
responses: responses:
'200': '200':
description: A ListAgentSessionsResponse. description: A PaginatedResponse.
content: content:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/ListAgentSessionsResponse' $ref: '#/components/schemas/PaginatedResponse'
'400': '400':
$ref: '#/components/responses/BadRequest400' $ref: '#/components/responses/BadRequest400'
'429': '429':
@ -1688,6 +1702,18 @@ paths:
required: true required: true
schema: schema:
type: string type: string
- name: start_index
in: query
description: The index to start the pagination from.
required: false
schema:
type: integer
- name: limit
in: query
description: The number of sessions to return.
required: false
schema:
type: integer
/v1/eval/benchmarks: /v1/eval/benchmarks:
get: get:
responses: responses:
@ -6217,28 +6243,6 @@ components:
- job_id - job_id
- status - status
title: Job title: Job
ListAgentSessionsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/Session'
additionalProperties: false
required:
- data
title: ListAgentSessionsResponse
ListAgentsResponse:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/Agent'
additionalProperties: false
required:
- data
title: ListAgentsResponse
BucketResponse: BucketResponse:
type: object type: object
properties: properties:

View file

@ -13,6 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
ResponseFormat, ResponseFormat,
@ -241,16 +242,6 @@ class Agent(BaseModel):
created_at: datetime created_at: datetime
@json_schema_type
class ListAgentsResponse(BaseModel):
data: list[Agent]
@json_schema_type
class ListAgentSessionsResponse(BaseModel):
data: list[Session]
class AgentConfigOverridablePerTurn(AgentConfigCommon): class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = None instructions: str | None = None
@ -526,7 +517,7 @@ class Agents(Protocol):
session_id: str, session_id: str,
agent_id: str, agent_id: str,
) -> None: ) -> None:
"""Delete an agent session by its ID. """Delete an agent session by its ID and its associated turns.
:param session_id: The ID of the session to delete. :param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for. :param agent_id: The ID of the agent to delete the session for.
@ -538,17 +529,19 @@ class Agents(Protocol):
self, self,
agent_id: str, agent_id: str,
) -> None: ) -> None:
"""Delete an agent by its ID. """Delete an agent by its ID and its associated sessions and turns.
:param agent_id: The ID of the agent to delete. :param agent_id: The ID of the agent to delete.
""" """
... ...
@webmethod(route="/agents", method="GET") @webmethod(route="/agents", method="GET")
async def list_agents(self) -> ListAgentsResponse: async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents. """List all agents.
:returns: A ListAgentsResponse. :param start_index: The index to start the pagination from.
:param limit: The number of agents to return.
:returns: A PaginatedResponse.
""" """
... ...
@ -565,11 +558,15 @@ class Agents(Protocol):
async def list_agent_sessions( async def list_agent_sessions(
self, self,
agent_id: str, agent_id: str,
) -> ListAgentSessionsResponse: start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""List all session(s) of a given agent. """List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for. :param agent_id: The ID of the agent to list sessions for.
:returns: A ListAgentSessionsResponse. :param start_index: The index to start the pagination from.
:param limit: The number of sessions to return.
:returns: A PaginatedResponse.
""" """
... ...

View file

@ -73,7 +73,7 @@ class DiskDistributionRegistry(DistributionRegistry):
async def get_all(self) -> list[RoutableObjectWithProvider]: async def get_all(self) -> list[RoutableObjectWithProvider]:
start_key, end_key = _get_registry_key_range() start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key) values = await self.kvstore.values_in_range(start_key, end_key)
return _parse_registry_values(values) return _parse_registry_values(values)
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
@ -134,7 +134,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
return return
start_key, end_key = _get_registry_key_range() start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key) values = await self.kvstore.values_in_range(start_key, end_key)
objects = _parse_registry_values(values) objects = _parse_registry_values(values)
async with self._locked_cache() as cache: async with self._locked_cache() as cache:

View file

@ -95,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
vector_io_api: VectorIO, vector_io_api: VectorIO,
persistence_store: KVStore, persistence_store: KVStore,
created_at: str,
): ):
self.agent_id = agent_id self.agent_id = agent_id
self.agent_config = agent_config self.agent_config = agent_config
@ -104,6 +105,7 @@ class ChatAgent(ShieldRunnerMixin):
self.storage = AgentPersistence(agent_id, persistence_store) self.storage = AgentPersistence(agent_id, persistence_store)
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
self.created_at = created_at
ShieldRunnerMixin.__init__( ShieldRunnerMixin.__init__(
self, self,

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import logging import logging
import uuid import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
Agent, Agent,
@ -20,14 +20,13 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest, AgentTurnCreateRequest,
AgentTurnResumeRequest, AgentTurnResumeRequest,
Document, Document,
ListAgentSessionsResponse,
ListAgentsResponse,
OpenAIResponseInputMessage, OpenAIResponseInputMessage,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseObject, OpenAIResponseObject,
Session, Session,
Turn, Turn,
) )
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
ToolConfig, ToolConfig,
@ -38,14 +37,15 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.datasetio.pagination import paginate_records
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl from .openai_responses import OpenAIResponsesImpl
from .persistence import AgentInfo
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.INFO)
class MetaReferenceAgentsImpl(Agents): class MetaReferenceAgentsImpl(Agents):
@ -82,43 +82,47 @@ class MetaReferenceAgentsImpl(Agents):
agent_config: AgentConfig, agent_config: AgentConfig,
) -> AgentCreateResponse: ) -> AgentCreateResponse:
agent_id = str(uuid.uuid4()) agent_id = str(uuid.uuid4())
created_at = datetime.now(timezone.utc)
agent_info = AgentInfo(
**agent_config.model_dump(),
created_at=created_at,
)
# Store the agent info
await self.persistence_store.set( await self.persistence_store.set(
key=f"agent:{agent_id}", key=f"agent:{agent_id}",
value=agent_config.model_dump_json(), value=agent_info.model_dump_json(),
) )
return AgentCreateResponse( return AgentCreateResponse(
agent_id=agent_id, agent_id=agent_id,
) )
async def _get_agent_impl(self, agent_id: str) -> ChatAgent: async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get( agent_info_json = await self.persistence_store.get(
key=f"agent:{agent_id}", key=f"agent:{agent_id}",
) )
if not agent_config: if not agent_info_json:
raise ValueError(f"Could not find agent config for {agent_id}") raise ValueError(f"Could not find agent info for {agent_id}")
try: try:
agent_config = json.loads(agent_config) agent_info = AgentInfo.model_validate_json(agent_info_json)
except json.JSONDecodeError as e:
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
try:
agent_config = AgentConfig(**agent_config)
except Exception as e: except Exception as e:
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e raise ValueError(f"Could not validate agent info for {agent_id}") from e
return ChatAgent( return ChatAgent(
agent_id=agent_id, agent_id=agent_id,
agent_config=agent_config, agent_config=agent_info,
inference_api=self.inference_api, inference_api=self.inference_api,
safety_api=self.safety_api, safety_api=self.safety_api,
vector_io_api=self.vector_io_api, vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api, tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api, tool_groups_api=self.tool_groups_api,
persistence_store=( persistence_store=(
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
), ),
created_at=agent_info.created_at,
) )
async def create_agent_session( async def create_agent_session(
@ -212,6 +216,7 @@ class MetaReferenceAgentsImpl(Agents):
turn_ids: list[str] | None = None, turn_ids: list[str] | None = None,
) -> Session: ) -> Session:
agent = await self._get_agent_impl(agent_id) agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id) session_info = await agent.storage.get_session_info(session_id)
if session_info is None: if session_info is None:
raise ValueError(f"Session {session_id} not found") raise ValueError(f"Session {session_id} not found")
@ -226,24 +231,75 @@ class MetaReferenceAgentsImpl(Agents):
) )
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
await self.persistence_store.delete(f"session:{agent_id}:{session_id}") agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
# Delete turns first, then the session
await agent.storage.delete_session_turns(session_id)
await agent.storage.delete_session(session_id)
async def delete_agent(self, agent_id: str) -> None: async def delete_agent(self, agent_id: str) -> None:
# First get all sessions for this agent
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Delete all sessions
for session in sessions:
await self.delete_agents_session(agent_id, session.session_id)
# Finally delete the agent itself
await self.persistence_store.delete(f"agent:{agent_id}") await self.persistence_store.delete(f"agent:{agent_id}")
async def shutdown(self) -> None: async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
pass agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
agent_list: list[Agent] = []
for agent_key in agent_keys:
agent_id = agent_key.split(":")[1]
async def list_agents(self) -> ListAgentsResponse: # Get the agent info using the key
pass agent_info_json = await self.persistence_store.get(agent_key)
if not agent_info_json:
logger.error(f"Could not find agent info for key {agent_key}")
continue
try:
agent_info = AgentInfo.model_validate_json(agent_info_json)
agent_list.append(
Agent(
agent_id=agent_id,
agent_config=agent_info,
created_at=agent_info.created_at,
)
)
except Exception as e:
logger.error(f"Error parsing agent info for {agent_id}: {e}")
continue
# Convert Agent objects to dictionaries
agent_dicts = [agent.model_dump() for agent in agent_list]
return paginate_records(agent_dicts, start_index, limit)
async def get_agent(self, agent_id: str) -> Agent: async def get_agent(self, agent_id: str) -> Agent:
pass chat_agent = await self._get_agent_impl(agent_id)
agent = Agent(
agent_id=agent_id,
agent_config=chat_agent.agent_config,
created_at=chat_agent.created_at,
)
return agent
async def list_agent_sessions( async def list_agent_sessions(
self, self, agent_id: str, start_index: int | None = None, limit: int | None = None
agent_id: str, ) -> PaginatedResponse:
) -> ListAgentSessionsResponse: agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Convert Session objects to dictionaries
session_dicts = [session.model_dump() for session in sessions]
return paginate_records(session_dicts, start_index, limit)
async def shutdown(self) -> None:
pass pass
# OpenAI responses # OpenAI responses

View file

@ -9,9 +9,7 @@ import logging
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from pydantic import BaseModel from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes from llama_stack.distribution.request_headers import get_auth_attributes
@ -20,15 +18,17 @@ from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel): class AgentSessionInfo(Session):
session_id: str
session_name: str
# TODO: is this used anywhere? # TODO: is this used anywhere?
vector_db_id: str | None = None vector_db_id: str | None = None
started_at: datetime started_at: datetime
access_attributes: AccessAttributes | None = None access_attributes: AccessAttributes | None = None
class AgentInfo(AgentConfig):
created_at: datetime
class AgentPersistence: class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore): def __init__(self, agent_id: str, kvstore: KVStore):
self.agent_id = agent_id self.agent_id = agent_id
@ -46,6 +46,7 @@ class AgentPersistence:
session_name=name, session_name=name,
started_at=datetime.now(timezone.utc), started_at=datetime.now(timezone.utc),
access_attributes=access_attributes, access_attributes=access_attributes,
turns=[],
) )
await self.kvstore.set( await self.kvstore.set(
@ -109,7 +110,7 @@ class AgentPersistence:
if not await self.get_session_if_accessible(session_id): if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied") raise ValueError(f"Session {session_id} not found or access denied")
values = await self.kvstore.range( values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:{session_id}:", start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff", end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
) )
@ -121,7 +122,6 @@ class AgentPersistence:
except Exception as e: except Exception as e:
log.error(f"Error parsing turn: {e}") log.error(f"Error parsing turn: {e}")
continue continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None: async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
@ -170,3 +170,43 @@ class AgentPersistence:
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
) )
return int(value) if value else None return int(value) if value else None
async def list_sessions(self) -> list[Session]:
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:",
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
)
sessions = []
for value in values:
try:
session_info = Session(**json.loads(value))
sessions.append(session_info)
except Exception as e:
log.error(f"Error parsing session info: {e}")
continue
return sessions
async def delete_session_turns(self, session_id: str) -> None:
"""Delete all turns and their associated data for a session.
Args:
session_id: The ID of the session whose turns should be deleted.
"""
turns = await self.get_session_turns(session_id)
for turn in turns:
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
async def delete_session(self, session_id: str) -> None:
"""Delete a session and all its associated turns.
Args:
session_id: The ID of the session to delete.
Raises:
ValueError: If the session does not exist.
"""
session_info = await self.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")

View file

@ -64,7 +64,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
# Load existing datasets from kvstore # Load existing datasets from kvstore
start_key = DATASETS_PREFIX start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff" end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key) stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets: for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset) dataset = Dataset.model_validate_json(dataset)

View file

@ -58,7 +58,7 @@ class MetaReferenceEvalImpl(
# Load existing benchmarks from kvstore # Load existing benchmarks from kvstore
start_key = EVAL_TASKS_PREFIX start_key = EVAL_TASKS_PREFIX
end_key = f"{EVAL_TASKS_PREFIX}\xff" end_key = f"{EVAL_TASKS_PREFIX}\xff"
stored_benchmarks = await self.kvstore.range(start_key, end_key) stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key)
for benchmark in stored_benchmarks: for benchmark in stored_benchmarks:
benchmark = Benchmark.model_validate_json(benchmark) benchmark = Benchmark.model_validate_json(benchmark)

View file

@ -125,7 +125,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
# Load existing banks from kvstore # Load existing banks from kvstore
start_key = VECTOR_DBS_PREFIX start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff" end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.range(start_key, end_key) stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs: for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data) vector_db = VectorDB.model_validate_json(vector_db_data)

View file

@ -42,7 +42,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
# Load existing datasets from kvstore # Load existing datasets from kvstore
start_key = DATASETS_PREFIX start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff" end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key) stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets: for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset) dataset = Dataset.model_validate_json(dataset)

View file

@ -16,4 +16,6 @@ class KVStore(Protocol):
async def delete(self, key: str) -> None: ... async def delete(self, key: str) -> None: ...
async def range(self, start_key: str, end_key: str) -> list[str]: ... async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...

View file

@ -26,9 +26,16 @@ class InmemoryKVStoreImpl(KVStore):
async def set(self, key: str, value: str) -> None: async def set(self, key: str, value: str) -> None:
self._store[key] = value self._store[key] = value
async def range(self, start_key: str, end_key: str) -> list[str]: async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key] return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
return [key for key in self._store.keys() if key >= start_key and key < end_key]
async def delete(self, key: str) -> None:
del self._store[key]
async def kvstore_impl(config: KVStoreConfig) -> KVStore: async def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis.value: if config.type == KVStoreType.redis.value:

View file

@ -57,7 +57,7 @@ class MongoDBKVStoreImpl(KVStore):
key = self._namespaced_key(key) key = self._namespaced_key(key)
await self.collection.delete_one({"key": key}) await self.collection.delete_one({"key": key})
async def range(self, start_key: str, end_key: str) -> list[str]: async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key) start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key) end_key = self._namespaced_key(end_key)
query = { query = {
@ -68,3 +68,10 @@ class MongoDBKVStoreImpl(KVStore):
async for doc in cursor: async for doc in cursor:
result.append(doc["value"]) result.append(doc["value"])
return result return result
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {"key": {"$gte": start_key, "$lt": end_key}}
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
return [doc["key"] for doc in cursor]

View file

@ -85,7 +85,7 @@ class PostgresKVStoreImpl(KVStore):
(key,), (key,),
) )
async def range(self, start_key: str, end_key: str) -> list[str]: async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key) start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key) end_key = self._namespaced_key(end_key)
@ -99,3 +99,13 @@ class PostgresKVStoreImpl(KVStore):
(start_key, end_key), (start_key, end_key),
) )
return [row[0] for row in self.cursor.fetchall()] return [row[0] for row in self.cursor.fetchall()]
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]

View file

@ -42,7 +42,7 @@ class RedisKVStoreImpl(KVStore):
key = self._namespaced_key(key) key = self._namespaced_key(key)
await self.redis.delete(key) await self.redis.delete(key)
async def range(self, start_key: str, end_key: str) -> list[str]: async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
start_key = self._namespaced_key(start_key) start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key) end_key = self._namespaced_key(end_key)
cursor = 0 cursor = 0
@ -67,3 +67,10 @@ class RedisKVStoreImpl(KVStore):
] ]
return [] return []
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
if not matching_keys:
return []
return [k.decode("utf-8") for k in matching_keys]

View file

@ -54,7 +54,7 @@ class SqliteKVStoreImpl(KVStore):
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit() await db.commit()
async def range(self, start_key: str, end_key: str) -> list[str]: async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
async with db.execute( async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
@ -65,3 +65,13 @@ class SqliteKVStoreImpl(KVStore):
_, value, _ = row _, value, _ = row
result.append(value) result.append(value)
return result return result
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
"""Get all keys in the given range."""
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
(start_key, end_key),
)
rows = await cursor.fetchall()
return [row[0] for row in rows]

View file

@ -0,0 +1,230 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from unittest.mock import AsyncMock
import pytest
import pytest_asyncio
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import Inference
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl
from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
@pytest.fixture
def mock_apis():
return {
"inference_api": AsyncMock(spec=Inference),
"vector_io_api": AsyncMock(spec=VectorIO),
"safety_api": AsyncMock(spec=Safety),
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
"tool_groups_api": AsyncMock(spec=ToolGroups),
}
@pytest.fixture
def config(tmp_path):
return MetaReferenceAgentsImplConfig(
persistence_store={
"type": "sqlite",
"db_path": str(tmp_path / "test.db"),
},
)
@pytest_asyncio.fixture
async def agents_impl(config, mock_apis):
impl = MetaReferenceAgentsImpl(
config,
mock_apis["inference_api"],
mock_apis["vector_io_api"],
mock_apis["safety_api"],
mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"],
)
await impl.initialize()
yield impl
await impl.shutdown()
@pytest.fixture
def sample_agent_config():
return AgentConfig(
sampling_params={
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 1.0,
},
input_shields=["string"],
output_shields=["string"],
toolgroups=["string"],
client_tools=[
{
"name": "string",
"description": "string",
"parameters": [
{
"name": "string",
"parameter_type": "string",
"description": "string",
"required": True,
"default": None,
}
],
"metadata": {
"property1": None,
"property2": None,
},
}
],
tool_choice="auto",
tool_prompt_format="json",
tool_config={
"tool_choice": "auto",
"tool_prompt_format": "json",
"system_message_behavior": "append",
},
max_infer_iters=10,
model="string",
instructions="string",
enable_session_persistence=False,
response_format={
"type": "json_schema",
"json_schema": {
"property1": None,
"property2": None,
},
},
)
@pytest.mark.asyncio
async def test_create_agent(agents_impl, sample_agent_config):
response = await agents_impl.create_agent(sample_agent_config)
assert isinstance(response, AgentCreateResponse)
assert response.agent_id is not None
stored_agent = await agents_impl.persistence_store.get(f"agent:{response.agent_id}")
assert stored_agent is not None
agent_info = AgentInfo.model_validate_json(stored_agent)
assert agent_info.model == sample_agent_config.model
assert agent_info.created_at is not None
assert isinstance(agent_info.created_at, datetime)
@pytest.mark.asyncio
async def test_get_agent(agents_impl, sample_agent_config):
create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id
agent = await agents_impl.get_agent(agent_id)
assert isinstance(agent, Agent)
assert agent.agent_id == agent_id
assert agent.agent_config.model == sample_agent_config.model
assert agent.created_at is not None
assert isinstance(agent.created_at, datetime)
@pytest.mark.asyncio
async def test_list_agents(agents_impl, sample_agent_config):
agent1_response = await agents_impl.create_agent(sample_agent_config)
agent2_response = await agents_impl.create_agent(sample_agent_config)
response = await agents_impl.list_agents()
assert isinstance(response, PaginatedResponse)
assert len(response.data) == 2
agent_ids = {agent["agent_id"] for agent in response.data}
assert agent1_response.agent_id in agent_ids
assert agent2_response.agent_id in agent_ids
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
config = sample_agent_config.model_copy()
config.enable_session_persistence = enable_session_persistence
response = await agents_impl.create_agent(config)
agent_id = response.agent_id
# Create a session
session_response = await agents_impl.create_agent_session(agent_id, "test_session")
assert session_response.session_id is not None
# Verify the session was stored
session = await agents_impl.get_agents_session(agent_id, session_response.session_id)
assert session.session_name == "test_session"
assert session.session_id == session_response.session_id
assert session.started_at is not None
assert session.turns == []
# Delete the session
await agents_impl.delete_agents_session(agent_id, session_response.session_id)
# Verify the session was deleted
with pytest.raises(ValueError):
await agents_impl.get_agents_session(agent_id, session_response.session_id)
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting
config = sample_agent_config.model_copy()
config.enable_session_persistence = enable_session_persistence
response = await agents_impl.create_agent(config)
agent_id = response.agent_id
# Create multiple sessions
session1 = await agents_impl.create_agent_session(agent_id, "session1")
session2 = await agents_impl.create_agent_session(agent_id, "session2")
# List sessions
sessions = await agents_impl.list_agent_sessions(agent_id)
assert len(sessions.data) == 2
session_ids = {s["session_id"] for s in sessions.data}
assert session1.session_id in session_ids
assert session2.session_id in session_ids
# Delete one session
await agents_impl.delete_agents_session(agent_id, session1.session_id)
# Verify the session was deleted
with pytest.raises(ValueError):
await agents_impl.get_agents_session(agent_id, session1.session_id)
# List sessions again
sessions = await agents_impl.list_agent_sessions(agent_id)
assert len(sessions.data) == 1
assert session2.session_id in {s["session_id"] for s in sessions.data}
@pytest.mark.asyncio
async def test_delete_agent(agents_impl, sample_agent_config):
# Create an agent
response = await agents_impl.create_agent(sample_agent_config)
agent_id = response.agent_id
# Delete the agent
await agents_impl.delete_agent(agent_id)
# Verify the agent was deleted
with pytest.raises(ValueError):
await agents_impl.get_agent(agent_id)

View file

@ -54,6 +54,7 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
session_name="Restricted Session", session_name="Restricted Session",
started_at=datetime.now(), started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]), access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
turns=[],
) )
await agent_persistence.kvstore.set( await agent_persistence.kvstore.set(
@ -85,6 +86,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
session_name="Restricted Session", session_name="Restricted Session",
started_at=datetime.now(), started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]), access_attributes=AccessAttributes(roles=["admin"]),
turns=[],
) )
await agent_persistence.kvstore.set( await agent_persistence.kvstore.set(
@ -137,6 +139,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
session_name="Restricted Session", session_name="Restricted Session",
started_at=datetime.now(), started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]), access_attributes=AccessAttributes(roles=["admin"]),
turns=[],
) )
await agent_persistence.kvstore.set( await agent_persistence.kvstore.set(