From 59f1fe5af80bd4a13995eb45f156e19a606bc57d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sat, 21 Sep 2024 22:43:03 -0700 Subject: [PATCH] Agent persistence works --- llama_stack/apis/agents/client.py | 1 + llama_stack/distribution/datatypes.py | 14 ++- llama_stack/distribution/utils/config_dirs.py | 2 + .../adapters/memory/chroma/chroma.py | 3 - .../adapters/memory/pgvector/pgvector.py | 1 - .../impls/meta_reference/agents/__init__.py | 8 +- .../meta_reference/agents/agent_instance.py | 89 ++++++++++++------- .../impls/meta_reference/agents/agents.py | 13 ++- .../impls/meta_reference/agents/config.py | 2 + .../agents/tests/test_chat_agent.py | 1 + .../impls/meta_reference/memory/faiss.py | 1 - llama_stack/providers/registry/agents.py | 8 +- llama_stack/providers/utils/kvstore/api.py | 13 +-- llama_stack/providers/utils/kvstore/config.py | 13 +-- .../providers/utils/kvstore/kvstore.py | 27 +++++- .../providers/utils/kvstore/redis/redis.py | 13 ++- .../providers/utils/kvstore/sqlite/sqlite.py | 17 ++-- 17 files changed, 136 insertions(+), 90 deletions(-) diff --git a/llama_stack/apis/agents/client.py b/llama_stack/apis/agents/client.py index c5cba3541..8f6d61228 100644 --- a/llama_stack/apis/agents/client.py +++ b/llama_stack/apis/agents/client.py @@ -102,6 +102,7 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None): tools=tool_definitions, tool_choice=ToolChoice.auto, tool_prompt_format=ToolPromptFormat.function_tag, + enable_session_persistence=False, ) create_response = await api.create_agent(agent_config) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 52522886e..a3ff86cdf 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -10,7 +10,7 @@ from typing import Any, Dict, List, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field @json_schema_type @@ -168,14 +168,12 @@ Fully-qualified name of the module to import. The module is expected to have: class RemoteProviderConfig(BaseModel): - url: str = Field(..., description="The URL for the provider") + host: str = "localhost" + port: int - @validator("url") - @classmethod - def validate_url(cls, url: str) -> str: - if not url.startswith("http"): - raise ValueError(f"URL must start with http: {url}") - return url.rstrip("/") + @property + def url(self) -> str: + return f"http://{self.host}:{self.port}" def remote_provider_id(adapter_id: str) -> str: diff --git a/llama_stack/distribution/utils/config_dirs.py b/llama_stack/distribution/utils/config_dirs.py index adf3876a3..3785f4507 100644 --- a/llama_stack/distribution/utils/config_dirs.py +++ b/llama_stack/distribution/utils/config_dirs.py @@ -15,3 +15,5 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions" DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints" BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds" + +RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime" diff --git a/llama_stack/providers/adapters/memory/chroma/chroma.py b/llama_stack/providers/adapters/memory/chroma/chroma.py index 15f5810a9..0a5f5bcd6 100644 --- a/llama_stack/providers/adapters/memory/chroma/chroma.py +++ b/llama_stack/providers/adapters/memory/chroma/chroma.py @@ -31,9 +31,6 @@ class ChromaIndex(EmbeddingIndex): embeddings ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" - for i, chunk in enumerate(chunks): - print(f"Adding chunk #{i} tokens={chunk.token_count}") - await self.collection.add( documents=[chunk.json() for chunk in chunks], embeddings=embeddings, diff --git a/llama_stack/providers/adapters/memory/pgvector/pgvector.py b/llama_stack/providers/adapters/memory/pgvector/pgvector.py index a5c84a1b2..9cf0771ab 100644 --- a/llama_stack/providers/adapters/memory/pgvector/pgvector.py +++ b/llama_stack/providers/adapters/memory/pgvector/pgvector.py @@ -80,7 +80,6 @@ class PGVectorIndex(EmbeddingIndex): values = [] for i, chunk in enumerate(chunks): - print(f"Adding chunk #{i} tokens={chunk.token_count}") values.append( ( f"{chunk.document_id}:chunk-{i}", diff --git a/llama_stack/providers/impls/meta_reference/agents/__init__.py b/llama_stack/providers/impls/meta_reference/agents/__init__.py index b6f3e6456..c0844be3b 100644 --- a/llama_stack/providers/impls/meta_reference/agents/__init__.py +++ b/llama_stack/providers/impls/meta_reference/agents/__init__.py @@ -8,18 +8,14 @@ from typing import Dict from llama_stack.distribution.datatypes import Api, ProviderSpec -from .config import MetaReferenceImplConfig +from .config import MetaReferenceAgentsImplConfig async def get_provider_impl( - config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec] + config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec] ): from .agents import MetaReferenceAgentsImpl - assert isinstance( - config, MetaReferenceImplConfig - ), f"Unexpected config type: {type(config)}" - impl = MetaReferenceAgentsImpl( config, deps[Api.inference], diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index b7f56917e..4199b2e96 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -6,6 +6,7 @@ import asyncio import copy +import json import os import secrets import shutil @@ -25,7 +26,7 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.providers.utils.api import KVStore +from llama_stack.providers.utils.kvstore import KVStore from .rag.context_retriever import generate_rag_query from .safety import SafetyException, ShieldRunnerMixin @@ -46,6 +47,13 @@ def make_random_string(length: int = 8): ) +class AgentSessionInfo(BaseModel): + session_id: str + session_name: str + memory_bank_id: Optional[str] = None + started_at: datetime + + class ChatAgent(ShieldRunnerMixin): def __init__( self, @@ -136,21 +144,41 @@ class ChatAgent(ShieldRunnerMixin): async def create_session(self, name: str) -> str: session_id = str(uuid.uuid4()) - session = dict( + session_info = AgentSessionInfo( session_id=session_id, session_name=name, - started_at=str(datetime.now()), + started_at=datetime.now(), ) await self.persistence_store.set( key=f"session:{self.agent_id}:{session_id}", - value=json.dumps(session), + value=session_info.json(), ) return session_id + async def _get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]: + value = await self.persistence_store.get( + key=f"session:{self.agent_id}:{session_id}", + ) + if not value: + return None + + return AgentSessionInfo(**json.loads(value)) + + async def _add_memory_bank_to_session(self, session_id: str, bank_id: str): + session_info = await self._get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") + + session_info.memory_bank_id = bank_id + await self.persistence_store.set( + key=f"session:{self.agent_id}:{session_id}", + value=session_info.json(), + ) + async def _add_turn_to_session(self, session_id: str, turn: Turn): await self.persistence_store.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", - value=json.dumps(turn.dict()), + value=turn.json(), ) async def _get_session_turns(self, session_id: str) -> List[Turn]: @@ -169,15 +197,6 @@ class ChatAgent(ShieldRunnerMixin): return turns - async def _get_session_info(self, session_id: str) -> Optional[Dict[str, str]]: - value = await self.persistence_store.get( - key=f"session:{self.agent_id}:{session_id}", - ) - if not value: - return None - - return json.loads(value) - async def create_and_execute_turn( self, request: AgentTurnCreateRequest ) -> AsyncGenerator: @@ -209,7 +228,7 @@ class ChatAgent(ShieldRunnerMixin): steps = [] output_message = None async for chunk in self.run( - session=session, + session_id=request.session_id, turn_id=turn_id, input_messages=messages, attachments=request.attachments or [], @@ -261,7 +280,7 @@ class ChatAgent(ShieldRunnerMixin): async def run( self, - session: Session, + session_id: str, turn_id: str, input_messages: List[Message], attachments: List[Attachment], @@ -282,7 +301,7 @@ class ChatAgent(ShieldRunnerMixin): yield res async for res in self._run( - session, turn_id, input_messages, attachments, sampling_params, stream + session_id, turn_id, input_messages, attachments, sampling_params, stream ): if isinstance(res, bool): return @@ -364,7 +383,7 @@ class ChatAgent(ShieldRunnerMixin): async def _run( self, - session: Session, + session_id: str, turn_id: str, input_messages: List[Message], attachments: List[Attachment], @@ -389,7 +408,7 @@ class ChatAgent(ShieldRunnerMixin): # TODO: find older context from the session and either replace it # or append with a sliding window. this is really a very simplistic implementation rag_context, bank_ids = await self._retrieve_context( - session, input_messages, attachments + session_id, input_messages, attachments ) step_id = str(uuid.uuid4()) @@ -645,17 +664,25 @@ class ChatAgent(ShieldRunnerMixin): n_iter += 1 - async def _ensure_memory_bank(self, session: Session) -> MemoryBank: - if session.memory_bank is None: - session.memory_bank = await self.memory_api.create_memory_bank( - name=f"memory_bank_{session.session_id}", + async def _ensure_memory_bank(self, session_id: str) -> str: + session_info = await self._get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") + + if session_info.memory_bank_id is None: + memory_bank = await self.memory_api.create_memory_bank( + name=f"memory_bank_{session_id}", config=VectorMemoryBankConfig( embedding_model="sentence-transformer/all-MiniLM-L6-v2", chunk_size_in_tokens=512, ), ) + bank_id = memory_bank.bank_id + await self._add_memory_bank_to_session(session_id, bank_id) + else: + bank_id = session_info.memory_bank_id - return session.memory_bank + return bank_id async def _should_retrieve_context( self, messages: List[Message], attachments: List[Attachment] @@ -680,7 +707,7 @@ class ChatAgent(ShieldRunnerMixin): return None async def _retrieve_context( - self, session: Session, messages: List[Message], attachments: List[Attachment] + self, session_id: str, messages: List[Message], attachments: List[Attachment] ) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids) bank_ids = [] @@ -689,8 +716,8 @@ class ChatAgent(ShieldRunnerMixin): bank_ids.extend(c.bank_id for c in memory.memory_bank_configs) if attachments: - bank = await self._ensure_memory_bank(session) - bank_ids.append(bank.bank_id) + bank_id = await self._ensure_memory_bank(session_id) + bank_ids.append(bank_id) documents = [ MemoryBankDocument( @@ -701,9 +728,11 @@ class ChatAgent(ShieldRunnerMixin): ) for a in attachments ] - await self.memory_api.insert_documents(bank.bank_id, documents) - elif session.memory_bank: - bank_ids.append(session.memory_bank.bank_id) + await self.memory_api.insert_documents(bank_id, documents) + else: + session_info = await self._get_session_info(session_id) + if session_info.memory_bank_id: + bank_ids.append(session_info.memory_bank_id) if not bank_ids: # this can happen if the per-session memory bank is not yet populated diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index a7ab548d2..0673cd16f 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import json import logging import uuid from typing import AsyncGenerator @@ -14,7 +14,7 @@ from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.providers.utils.kvstore import kvstore_impl +from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from .agent_instance import ChatAgent from .config import MetaReferenceAgentsImplConfig @@ -35,6 +35,7 @@ class MetaReferenceAgentsImpl(Agents): self.inference_api = inference_api self.memory_api = memory_api self.safety_api = safety_api + self.in_memory_store = InmemoryKVStoreImpl() async def initialize(self) -> None: self.persistence_store = await kvstore_impl(self.config.persistence_store) @@ -47,7 +48,7 @@ class MetaReferenceAgentsImpl(Agents): await self.persistence_store.set( key=f"agent:{agent_id}", - value=json.dumps(agent_config.dict()), + value=agent_config.json(), ) return AgentCreateResponse( agent_id=agent_id, @@ -80,7 +81,11 @@ class MetaReferenceAgentsImpl(Agents): inference_api=self.inference_api, safety_api=self.safety_api, memory_api=self.memory_api, - persistence_store=self.persistence_store, + persistence_store=( + self.persistence_store + if agent_config.enable_session_persistence + else self.in_memory_store + ), ) async def create_agent_session( diff --git a/llama_stack/providers/impls/meta_reference/agents/config.py b/llama_stack/providers/impls/meta_reference/agents/config.py index d4f15968e..0146cb436 100644 --- a/llama_stack/providers/impls/meta_reference/agents/config.py +++ b/llama_stack/providers/impls/meta_reference/agents/config.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from pydantic import BaseModel + from llama_stack.providers.utils.kvstore import KVStoreConfig diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py index 6e5505b6e..9d941edc9 100644 --- a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py +++ b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py @@ -184,6 +184,7 @@ async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api): # ), ], tool_choice=ToolChoice.auto, + enable_session_persistence=False, input_shields=[], output_shields=[], ) diff --git a/llama_stack/providers/impls/meta_reference/memory/faiss.py b/llama_stack/providers/impls/meta_reference/memory/faiss.py index ee716430e..30b7245e6 100644 --- a/llama_stack/providers/impls/meta_reference/memory/faiss.py +++ b/llama_stack/providers/impls/meta_reference/memory/faiss.py @@ -42,7 +42,6 @@ class FaissIndex(EmbeddingIndex): indexlen = len(self.id_by_index) for i, chunk in enumerate(chunks): self.chunk_by_index[indexlen + i] = chunk - logger.info(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}") self.id_by_index[indexlen + i] = chunk.document_id self.index.add(np.array(embeddings).astype(np.float32)) diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 3195c92da..5d4684d82 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -7,6 +7,7 @@ from typing import List from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec +from llama_stack.providers.utils.kvstore import kvstore_dependencies def available_providers() -> List[ProviderSpec]: @@ -19,11 +20,10 @@ def available_providers() -> List[ProviderSpec]: "pillow", "pandas", "scikit-learn", - "torch", - "transformers", - ], + ] + + kvstore_dependencies(), module="llama_stack.providers.impls.meta_reference.agents", - config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceImplConfig", + config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceAgentsImplConfig", api_dependencies=[ Api.inference, Api.safety, diff --git a/llama_stack/providers/utils/kvstore/api.py b/llama_stack/providers/utils/kvstore/api.py index 5ddc8db5b..ba5b206c0 100644 --- a/llama_stack/providers/utils/kvstore/api.py +++ b/llama_stack/providers/utils/kvstore/api.py @@ -7,22 +7,15 @@ from datetime import datetime from typing import List, Optional, Protocol -from pydantic import BaseModel - - -class KVStoreValue(BaseModel): - key: str - value: str - expiration: Optional[datetime] = None - class KVStore(Protocol): + # TODO: make the value type bytes instead of str async def set( self, key: str, value: str, expiration: Optional[datetime] = None ) -> None: ... - async def get(self, key: str) -> Optional[KVStoreValue]: ... + async def get(self, key: str) -> Optional[str]: ... async def delete(self, key: str) -> None: ... - async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]: ... + async def range(self, start_key: str, end_key: str) -> List[str]: ... diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index 0a99c78cd..5893e4c4a 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -7,9 +7,11 @@ from enum import Enum from typing import Literal, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import Annotated +from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR + class KVStoreType(Enum): redis = "redis" @@ -24,20 +26,21 @@ class CommonConfig(BaseModel): ) -class RedisKVStoreImplConfig(CommonConfig): +class RedisKVStoreConfig(CommonConfig): type: Literal[KVStoreType.redis.value] = KVStoreType.redis.value host: str = "localhost" port: int = 6379 -class SqliteKVStoreImplConfig(CommonConfig): +class SqliteKVStoreConfig(CommonConfig): type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value db_path: str = Field( + default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(), description="File path for the sqlite database", ) -class PostgresKVStoreImplConfig(CommonConfig): +class PostgresKVStoreConfig(CommonConfig): type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value host: str = "localhost" port: int = 5432 @@ -47,6 +50,6 @@ class PostgresKVStoreImplConfig(CommonConfig): KVStoreConfig = Annotated[ - Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PostgresKVStoreImplConfig], + Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig], Field(discriminator="type", default=KVStoreType.sqlite.value), ] diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py index a7fa0af75..a3cabc206 100644 --- a/llama_stack/providers/utils/kvstore/kvstore.py +++ b/llama_stack/providers/utils/kvstore/kvstore.py @@ -12,16 +12,37 @@ def kvstore_dependencies(): return ["aiosqlite", "psycopg2-binary", "redis"] +class InmemoryKVStoreImpl(KVStore): + def __init__(self): + self._store = {} + + async def initialize(self) -> None: + pass + + async def get(self, key: str) -> Optional[str]: + return self._store.get(key) + + async def set(self, key: str, value: str) -> None: + self._store[key] = value + + async def range(self, start_key: str, end_key: str) -> List[str]: + return [ + self._store[key] + for key in self._store.keys() + if key >= start_key and key < end_key + ] + + async def kvstore_impl(config: KVStoreConfig) -> KVStore: - if config.type == KVStoreType.redis: + if config.type == KVStoreType.redis.value: from .redis import RedisKVStoreImpl impl = RedisKVStoreImpl(config) - elif config.type == KVStoreType.sqlite: + elif config.type == KVStoreType.sqlite.value: from .sqlite import SqliteKVStoreImpl impl = SqliteKVStoreImpl(config) - elif config.type == KVStoreType.pgvector: + elif config.type == KVStoreType.postgres.value: raise NotImplementedError() else: raise ValueError(f"Unknown kvstore type {config.type}") diff --git a/llama_stack/providers/utils/kvstore/redis/redis.py b/llama_stack/providers/utils/kvstore/redis/redis.py index 977da06d3..fb264b15c 100644 --- a/llama_stack/providers/utils/kvstore/redis/redis.py +++ b/llama_stack/providers/utils/kvstore/redis/redis.py @@ -4,17 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from datetime import datetime, timedelta +from datetime import datetime from typing import List, Optional from redis.asyncio import Redis from ..api import * # noqa: F403 -from ..config import RedisKVStoreImplConfig +from ..config import RedisKVStoreConfig class RedisKVStoreImpl(KVStore): - def __init__(self, config: RedisKVStoreImplConfig): + def __init__(self, config: RedisKVStoreConfig): self.config = config async def initialize(self) -> None: @@ -33,20 +33,19 @@ class RedisKVStoreImpl(KVStore): if expiration: await self.redis.expireat(key, expiration) - async def get(self, key: str) -> Optional[KVStoreValue]: + async def get(self, key: str) -> Optional[str]: key = self._namespaced_key(key) value = await self.redis.get(key) if value is None: return None ttl = await self.redis.ttl(key) - expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None - return KVStoreValue(key=key, value=value, expiration=expiration) + return value async def delete(self, key: str) -> None: key = self._namespaced_key(key) await self.redis.delete(key) - async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]: + async def range(self, start_key: str, end_key: str) -> List[str]: start_key = self._namespaced_key(start_key) end_key = self._namespaced_key(end_key) diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index 9d7fd4d86..1c5311d10 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os + from datetime import datetime from typing import List, Optional @@ -16,9 +18,10 @@ from ..config import SqliteKVStoreConfig class SqliteKVStoreImpl(KVStore): def __init__(self, config: SqliteKVStoreConfig): self.db_path = config.db_path - self.table_name = config.table_name + self.table_name = "kvstore" async def initialize(self): + os.makedirs(os.path.dirname(self.db_path), exist_ok=True) async with aiosqlite.connect(self.db_path) as db: await db.execute( f""" @@ -41,7 +44,7 @@ class SqliteKVStoreImpl(KVStore): ) await db.commit() - async def get(self, key: str) -> Optional[KVStoreValue]: + async def get(self, key: str) -> Optional[str]: async with aiosqlite.connect(self.db_path) as db: async with db.execute( f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) @@ -50,14 +53,14 @@ class SqliteKVStoreImpl(KVStore): if row is None: return None value, expiration = row - return KVStoreValue(key=key, value=value, expiration=expiration) + return value async def delete(self, key: str) -> None: async with aiosqlite.connect(self.db_path) as db: await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) await db.commit() - async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]: + async def range(self, start_key: str, end_key: str) -> List[str]: async with aiosqlite.connect(self.db_path) as db: async with db.execute( f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", @@ -65,8 +68,6 @@ class SqliteKVStoreImpl(KVStore): ) as cursor: result = [] async for row in cursor: - key, value, expiration = row - result.append( - KVStoreValue(key=key, value=value, expiration=expiration) - ) + _, value, _ = row + result.append(value) return result