mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: implementation for agent/session list and describe (#1606)
Create a new agent: ``` curl --request POST \ --url http://localhost:8321/v1/agents \ --header 'Accept: application/json' \ --header 'Content-Type: application/json' \ --data '{ "agent_config": { "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 0, "repetition_penalty": 1 }, "input_shields": [ "string" ], "output_shields": [ "string" ], "toolgroups": [ "string" ], "client_tools": [ { "name": "string", "description": "string", "parameters": [ { "name": "string", "parameter_type": "string", "description": "string", "required": true, "default": null } ], "metadata": { "property1": null, "property2": null } } ], "tool_choice": "auto", "tool_prompt_format": "json", "tool_config": { "tool_choice": "auto", "tool_prompt_format": "json", "system_message_behavior": "append" }, "max_infer_iters": 10, "model": "string", "instructions": "string", "enable_session_persistence": false, "response_format": { "type": "json_schema", "json_schema": { "property1": null, "property2": null } } } }' ``` Get agent: ``` curl http://127.0.0.1:8321/v1/agents/9abad4ab-2c77-45f9-9d16-46b79d2bea1f {"agent_id":"9abad4ab-2c77-45f9-9d16-46b79d2bea1f","agent_config":{"sampling_params":{"strategy":{"type":"greedy"},"max_tokens":0,"repetition_penalty":1.0},"input_shields":["string"],"output_shields":["string"],"toolgroups":["string"],"client_tools":[{"name":"string","description":"string","parameters":[{"name":"string","parameter_type":"string","description":"string","required":true,"default":null}],"metadata":{"property1":null,"property2":null}}],"tool_choice":"auto","tool_prompt_format":"json","tool_config":{"tool_choice":"auto","tool_prompt_format":"json","system_message_behavior":"append"},"max_infer_iters":10,"model":"string","instructions":"string","enable_session_persistence":false,"response_format":{"type":"json_schema","json_schema":{"property1":null,"property2":null}}},"created_at":"2025-03-12T16:18:28.369144Z"}% ``` List agents: ``` curl http://127.0.0.1:8321/v1/agents|jq % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 1680 100 1680 0 0 498k 0 --:--:-- --:--:-- --:--:-- 546k { "data": [ { "agent_id": "9abad4ab-2c77-45f9-9d16-46b79d2bea1f", "agent_config": { "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 0, "repetition_penalty": 1.0 }, "input_shields": [ "string" ], "output_shields": [ "string" ], "toolgroups": [ "string" ], "client_tools": [ { "name": "string", "description": "string", "parameters": [ { "name": "string", "parameter_type": "string", "description": "string", "required": true, "default": null } ], "metadata": { "property1": null, "property2": null } } ], "tool_choice": "auto", "tool_prompt_format": "json", "tool_config": { "tool_choice": "auto", "tool_prompt_format": "json", "system_message_behavior": "append" }, "max_infer_iters": 10, "model": "string", "instructions": "string", "enable_session_persistence": false, "response_format": { "type": "json_schema", "json_schema": { "property1": null, "property2": null } } }, "created_at": "2025-03-12T16:18:28.369144Z" }, { "agent_id": "a6643aaa-96dd-46db-a405-333dc504b168", "agent_config": { "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 0, "repetition_penalty": 1.0 }, "input_shields": [ "string" ], "output_shields": [ "string" ], "toolgroups": [ "string" ], "client_tools": [ { "name": "string", "description": "string", "parameters": [ { "name": "string", "parameter_type": "string", "description": "string", "required": true, "default": null } ], "metadata": { "property1": null, "property2": null } } ], "tool_choice": "auto", "tool_prompt_format": "json", "tool_config": { "tool_choice": "auto", "tool_prompt_format": "json", "system_message_behavior": "append" }, "max_infer_iters": 10, "model": "string", "instructions": "string", "enable_session_persistence": false, "response_format": { "type": "json_schema", "json_schema": { "property1": null, "property2": null } } }, "created_at": "2025-03-12T16:17:12.811273Z" } ] } ``` Create sessions: ``` curl --request POST \ --url http://localhost:8321/v1/agents/{agent_id}/session \ --header 'Accept: application/json' \ --header 'Content-Type: application/json' \ --data '{ "session_name": "string" }' ``` List sessions: ``` curl http://127.0.0.1:8321/v1/agents/9abad4ab-2c77-45f9-9d16-46b79d2bea1f/sessions|jq % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 263 100 263 0 0 90099 0 --:--:-- --:--:-- --:--:-- 128k [ { "session_id": "2b15c4fc-e348-46c1-ae32-f6d424441ac1", "session_name": "string", "turns": [], "started_at": "2025-03-12T17:19:17.784328" }, { "session_id": "9432472d-d483-4b73-b682-7b1d35d64111", "session_name": "string", "turns": [], "started_at": "2025-03-12T17:19:19.885834" } ] ``` Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
40e71758d9
commit
c91e3552a3
19 changed files with 510 additions and 130 deletions
|
@ -13,6 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
ResponseFormat,
|
||||
|
@ -241,16 +242,6 @@ class Agent(BaseModel):
|
|||
created_at: datetime
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentsResponse(BaseModel):
|
||||
data: list[Agent]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListAgentSessionsResponse(BaseModel):
|
||||
data: list[Session]
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: str | None = None
|
||||
|
||||
|
@ -526,7 +517,7 @@ class Agents(Protocol):
|
|||
session_id: str,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Delete an agent session by its ID.
|
||||
"""Delete an agent session by its ID and its associated turns.
|
||||
|
||||
:param session_id: The ID of the session to delete.
|
||||
:param agent_id: The ID of the agent to delete the session for.
|
||||
|
@ -538,17 +529,19 @@ class Agents(Protocol):
|
|||
self,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Delete an agent by its ID.
|
||||
"""Delete an agent by its ID and its associated sessions and turns.
|
||||
|
||||
:param agent_id: The ID of the agent to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET")
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||
"""List all agents.
|
||||
|
||||
:returns: A ListAgentsResponse.
|
||||
:param start_index: The index to start the pagination from.
|
||||
:param limit: The number of agents to return.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
@ -565,11 +558,15 @@ class Agents(Protocol):
|
|||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""List all session(s) of a given agent.
|
||||
|
||||
:param agent_id: The ID of the agent to list sessions for.
|
||||
:returns: A ListAgentSessionsResponse.
|
||||
:param start_index: The index to start the pagination from.
|
||||
:param limit: The number of sessions to return.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ class DiskDistributionRegistry(DistributionRegistry):
|
|||
|
||||
async def get_all(self) -> list[RoutableObjectWithProvider]:
|
||||
start_key, end_key = _get_registry_key_range()
|
||||
values = await self.kvstore.range(start_key, end_key)
|
||||
values = await self.kvstore.values_in_range(start_key, end_key)
|
||||
return _parse_registry_values(values)
|
||||
|
||||
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
|
||||
|
@ -134,7 +134,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
|||
return
|
||||
|
||||
start_key, end_key = _get_registry_key_range()
|
||||
values = await self.kvstore.range(start_key, end_key)
|
||||
values = await self.kvstore.values_in_range(start_key, end_key)
|
||||
objects = _parse_registry_values(values)
|
||||
|
||||
async with self._locked_cache() as cache:
|
||||
|
|
|
@ -95,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_groups_api: ToolGroups,
|
||||
vector_io_api: VectorIO,
|
||||
persistence_store: KVStore,
|
||||
created_at: str,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
|
@ -104,6 +105,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||
self.tool_runtime_api = tool_runtime_api
|
||||
self.tool_groups_api = tool_groups_api
|
||||
self.created_at = created_at
|
||||
|
||||
ShieldRunnerMixin.__init__(
|
||||
self,
|
||||
|
|
|
@ -4,10 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
|
@ -20,14 +20,13 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListAgentSessionsResponse,
|
||||
ListAgentsResponse,
|
||||
OpenAIResponseInputMessage,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
|
@ -38,14 +37,15 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.providers.utils.datasetio.pagination import paginate_records
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
from .openai_responses import OpenAIResponsesImpl
|
||||
from .persistence import AgentInfo
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class MetaReferenceAgentsImpl(Agents):
|
||||
|
@ -82,43 +82,47 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse:
|
||||
agent_id = str(uuid.uuid4())
|
||||
created_at = datetime.now(timezone.utc)
|
||||
|
||||
agent_info = AgentInfo(
|
||||
**agent_config.model_dump(),
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
# Store the agent info
|
||||
await self.persistence_store.set(
|
||||
key=f"agent:{agent_id}",
|
||||
value=agent_config.model_dump_json(),
|
||||
value=agent_info.model_dump_json(),
|
||||
)
|
||||
|
||||
return AgentCreateResponse(
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||
agent_config = await self.persistence_store.get(
|
||||
agent_info_json = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
if not agent_config:
|
||||
raise ValueError(f"Could not find agent config for {agent_id}")
|
||||
if not agent_info_json:
|
||||
raise ValueError(f"Could not find agent info for {agent_id}")
|
||||
|
||||
try:
|
||||
agent_config = json.loads(agent_config)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
|
||||
|
||||
try:
|
||||
agent_config = AgentConfig(**agent_config)
|
||||
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e
|
||||
raise ValueError(f"Could not validate agent info for {agent_id}") from e
|
||||
|
||||
return ChatAgent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_config,
|
||||
agent_config=agent_info,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
vector_io_api=self.vector_io_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
persistence_store=(
|
||||
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store
|
||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
created_at=agent_info.created_at,
|
||||
)
|
||||
|
||||
async def create_agent_session(
|
||||
|
@ -212,6 +216,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
@ -226,24 +231,75 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
)
|
||||
|
||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||
await self.persistence_store.delete(f"session:{agent_id}:{session_id}")
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
# Delete turns first, then the session
|
||||
await agent.storage.delete_session_turns(session_id)
|
||||
await agent.storage.delete_session(session_id)
|
||||
|
||||
async def delete_agent(self, agent_id: str) -> None:
|
||||
# First get all sessions for this agent
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
sessions = await agent.storage.list_sessions()
|
||||
|
||||
# Delete all sessions
|
||||
for session in sessions:
|
||||
await self.delete_agents_session(agent_id, session.session_id)
|
||||
|
||||
# Finally delete the agent itself
|
||||
await self.persistence_store.delete(f"agent:{agent_id}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
|
||||
agent_list: list[Agent] = []
|
||||
for agent_key in agent_keys:
|
||||
agent_id = agent_key.split(":")[1]
|
||||
|
||||
async def list_agents(self) -> ListAgentsResponse:
|
||||
pass
|
||||
# Get the agent info using the key
|
||||
agent_info_json = await self.persistence_store.get(agent_key)
|
||||
if not agent_info_json:
|
||||
logger.error(f"Could not find agent info for key {agent_key}")
|
||||
continue
|
||||
|
||||
try:
|
||||
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||
agent_list.append(
|
||||
Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_info,
|
||||
created_at=agent_info.created_at,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing agent info for {agent_id}: {e}")
|
||||
continue
|
||||
|
||||
# Convert Agent objects to dictionaries
|
||||
agent_dicts = [agent.model_dump() for agent in agent_list]
|
||||
return paginate_records(agent_dicts, start_index, limit)
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
pass
|
||||
chat_agent = await self._get_agent_impl(agent_id)
|
||||
agent = Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=chat_agent.agent_config,
|
||||
created_at=chat_agent.created_at,
|
||||
)
|
||||
return agent
|
||||
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> ListAgentSessionsResponse:
|
||||
self, agent_id: str, start_index: int | None = None, limit: int | None = None
|
||||
) -> PaginatedResponse:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
sessions = await agent.storage.list_sessions()
|
||||
# Convert Session objects to dictionaries
|
||||
session_dicts = [session.model_dump() for session in sessions]
|
||||
return paginate_records(session_dicts, start_index, limit)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
# OpenAI responses
|
||||
|
|
|
@ -9,9 +9,7 @@ import logging
|
|||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.agents import ToolExecutionStep, Turn
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.distribution.access_control import check_access
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.distribution.request_headers import get_auth_attributes
|
||||
|
@ -20,15 +18,17 @@ from llama_stack.providers.utils.kvstore import KVStore
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentSessionInfo(BaseModel):
|
||||
session_id: str
|
||||
session_name: str
|
||||
class AgentSessionInfo(Session):
|
||||
# TODO: is this used anywhere?
|
||||
vector_db_id: str | None = None
|
||||
started_at: datetime
|
||||
access_attributes: AccessAttributes | None = None
|
||||
|
||||
|
||||
class AgentInfo(AgentConfig):
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
def __init__(self, agent_id: str, kvstore: KVStore):
|
||||
self.agent_id = agent_id
|
||||
|
@ -46,6 +46,7 @@ class AgentPersistence:
|
|||
session_name=name,
|
||||
started_at=datetime.now(timezone.utc),
|
||||
access_attributes=access_attributes,
|
||||
turns=[],
|
||||
)
|
||||
|
||||
await self.kvstore.set(
|
||||
|
@ -109,7 +110,7 @@ class AgentPersistence:
|
|||
if not await self.get_session_if_accessible(session_id):
|
||||
raise ValueError(f"Session {session_id} not found or access denied")
|
||||
|
||||
values = await self.kvstore.range(
|
||||
values = await self.kvstore.values_in_range(
|
||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
|
@ -121,7 +122,6 @@ class AgentPersistence:
|
|||
except Exception as e:
|
||||
log.error(f"Error parsing turn: {e}")
|
||||
continue
|
||||
turns.sort(key=lambda x: (x.completed_at or datetime.min))
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
|
@ -170,3 +170,43 @@ class AgentPersistence:
|
|||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return int(value) if value else None
|
||||
|
||||
async def list_sessions(self) -> list[Session]:
|
||||
values = await self.kvstore.values_in_range(
|
||||
start_key=f"session:{self.agent_id}:",
|
||||
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
sessions = []
|
||||
for value in values:
|
||||
try:
|
||||
session_info = Session(**json.loads(value))
|
||||
sessions.append(session_info)
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing session info: {e}")
|
||||
continue
|
||||
return sessions
|
||||
|
||||
async def delete_session_turns(self, session_id: str) -> None:
|
||||
"""Delete all turns and their associated data for a session.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session whose turns should be deleted.
|
||||
"""
|
||||
turns = await self.get_session_turns(session_id)
|
||||
for turn in turns:
|
||||
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
|
||||
|
||||
async def delete_session(self, session_id: str) -> None:
|
||||
"""Delete a session and all its associated turns.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session does not exist.
|
||||
"""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")
|
||||
|
|
|
@ -64,7 +64,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
# Load existing datasets from kvstore
|
||||
start_key = DATASETS_PREFIX
|
||||
end_key = f"{DATASETS_PREFIX}\xff"
|
||||
stored_datasets = await self.kvstore.range(start_key, end_key)
|
||||
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for dataset in stored_datasets:
|
||||
dataset = Dataset.model_validate_json(dataset)
|
||||
|
|
|
@ -58,7 +58,7 @@ class MetaReferenceEvalImpl(
|
|||
# Load existing benchmarks from kvstore
|
||||
start_key = EVAL_TASKS_PREFIX
|
||||
end_key = f"{EVAL_TASKS_PREFIX}\xff"
|
||||
stored_benchmarks = await self.kvstore.range(start_key, end_key)
|
||||
stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for benchmark in stored_benchmarks:
|
||||
benchmark = Benchmark.model_validate_json(benchmark)
|
||||
|
|
|
@ -125,7 +125,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
# Load existing banks from kvstore
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.range(start_key, end_key)
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
|
|
|
@ -42,7 +42,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
# Load existing datasets from kvstore
|
||||
start_key = DATASETS_PREFIX
|
||||
end_key = f"{DATASETS_PREFIX}\xff"
|
||||
stored_datasets = await self.kvstore.range(start_key, end_key)
|
||||
stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for dataset in stored_datasets:
|
||||
dataset = Dataset.model_validate_json(dataset)
|
||||
|
|
|
@ -16,4 +16,6 @@ class KVStore(Protocol):
|
|||
|
||||
async def delete(self, key: str) -> None: ...
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> list[str]: ...
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ...
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ...
|
||||
|
|
|
@ -26,9 +26,16 @@ class InmemoryKVStoreImpl(KVStore):
|
|||
async def set(self, key: str, value: str) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key]
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
"""Get all keys in the given range."""
|
||||
return [key for key in self._store.keys() if key >= start_key and key < end_key]
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
del self._store[key]
|
||||
|
||||
|
||||
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
|
||||
if config.type == KVStoreType.redis.value:
|
||||
|
|
|
@ -57,7 +57,7 @@ class MongoDBKVStoreImpl(KVStore):
|
|||
key = self._namespaced_key(key)
|
||||
await self.collection.delete_one({"key": key})
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
query = {
|
||||
|
@ -68,3 +68,10 @@ class MongoDBKVStoreImpl(KVStore):
|
|||
async for doc in cursor:
|
||||
result.append(doc["value"])
|
||||
return result
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
query = {"key": {"$gte": start_key, "$lt": end_key}}
|
||||
cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1)
|
||||
return [doc["key"] for doc in cursor]
|
||||
|
|
|
@ -85,7 +85,7 @@ class PostgresKVStoreImpl(KVStore):
|
|||
(key,),
|
||||
)
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
|
@ -99,3 +99,13 @@ class PostgresKVStoreImpl(KVStore):
|
|||
(start_key, end_key),
|
||||
)
|
||||
return [row[0] for row in self.cursor.fetchall()]
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
self.cursor.execute(
|
||||
f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s",
|
||||
(start_key, end_key),
|
||||
)
|
||||
return [row[0] for row in self.cursor.fetchall()]
|
||||
|
|
|
@ -42,7 +42,7 @@ class RedisKVStoreImpl(KVStore):
|
|||
key = self._namespaced_key(key)
|
||||
await self.redis.delete(key)
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
cursor = 0
|
||||
|
@ -67,3 +67,10 @@ class RedisKVStoreImpl(KVStore):
|
|||
]
|
||||
|
||||
return []
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
"""Get all keys in the given range."""
|
||||
matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}")
|
||||
if not matching_keys:
|
||||
return []
|
||||
return [k.decode("utf-8") for k in matching_keys]
|
||||
|
|
|
@ -54,7 +54,7 @@ class SqliteKVStoreImpl(KVStore):
|
|||
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await db.commit()
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> list[str]:
|
||||
async def values_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
|
@ -65,3 +65,13 @@ class SqliteKVStoreImpl(KVStore):
|
|||
_, value, _ = row
|
||||
result.append(value)
|
||||
return result
|
||||
|
||||
async def keys_in_range(self, start_key: str, end_key: str) -> list[str]:
|
||||
"""Get all keys in the given range."""
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
cursor = await db.execute(
|
||||
f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue