Agent persistence works

This commit is contained in:
Ashwin Bharambe 2024-09-21 22:43:03 -07:00
parent 4eb0f30891
commit 59f1fe5af8
17 changed files with 136 additions and 90 deletions

View file

@ -102,6 +102,7 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
tools=tool_definitions, tools=tool_definitions,
tool_choice=ToolChoice.auto, tool_choice=ToolChoice.auto,
tool_prompt_format=ToolPromptFormat.function_tag, tool_prompt_format=ToolPromptFormat.function_tag,
enable_session_persistence=False,
) )
create_response = await api.create_agent(agent_config) create_response = await api.create_agent(agent_config)

View file

@ -10,7 +10,7 @@ from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field
@json_schema_type @json_schema_type
@ -168,14 +168,12 @@ Fully-qualified name of the module to import. The module is expected to have:
class RemoteProviderConfig(BaseModel): class RemoteProviderConfig(BaseModel):
url: str = Field(..., description="The URL for the provider") host: str = "localhost"
port: int
@validator("url") @property
@classmethod def url(self) -> str:
def validate_url(cls, url: str) -> str: return f"http://{self.host}:{self.port}"
if not url.startswith("http"):
raise ValueError(f"URL must start with http: {url}")
return url.rstrip("/")
def remote_provider_id(adapter_id: str) -> str: def remote_provider_id(adapter_id: str) -> str:

View file

@ -15,3 +15,5 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints" DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds" BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"

View file

@ -31,9 +31,6 @@ class ChromaIndex(EmbeddingIndex):
embeddings embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
for i, chunk in enumerate(chunks):
print(f"Adding chunk #{i} tokens={chunk.token_count}")
await self.collection.add( await self.collection.add(
documents=[chunk.json() for chunk in chunks], documents=[chunk.json() for chunk in chunks],
embeddings=embeddings, embeddings=embeddings,

View file

@ -80,7 +80,6 @@ class PGVectorIndex(EmbeddingIndex):
values = [] values = []
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
print(f"Adding chunk #{i} tokens={chunk.token_count}")
values.append( values.append(
( (
f"{chunk.document_id}:chunk-{i}", f"{chunk.document_id}:chunk-{i}",

View file

@ -8,18 +8,14 @@ from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceImplConfig from .config import MetaReferenceAgentsImplConfig
async def get_provider_impl( async def get_provider_impl(
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec] config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]
): ):
from .agents import MetaReferenceAgentsImpl from .agents import MetaReferenceAgentsImpl
assert isinstance(
config, MetaReferenceImplConfig
), f"Unexpected config type: {type(config)}"
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(
config, config,
deps[Api.inference], deps[Api.inference],

View file

@ -6,6 +6,7 @@
import asyncio import asyncio
import copy import copy
import json
import os import os
import secrets import secrets
import shutil import shutil
@ -25,7 +26,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.api import KVStore from llama_stack.providers.utils.kvstore import KVStore
from .rag.context_retriever import generate_rag_query from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin from .safety import SafetyException, ShieldRunnerMixin
@ -46,6 +47,13 @@ def make_random_string(length: int = 8):
) )
class AgentSessionInfo(BaseModel):
session_id: str
session_name: str
memory_bank_id: Optional[str] = None
started_at: datetime
class ChatAgent(ShieldRunnerMixin): class ChatAgent(ShieldRunnerMixin):
def __init__( def __init__(
self, self,
@ -136,21 +144,41 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
session = dict( session_info = AgentSessionInfo(
session_id=session_id, session_id=session_id,
session_name=name, session_name=name,
started_at=str(datetime.now()), started_at=datetime.now(),
) )
await self.persistence_store.set( await self.persistence_store.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
value=json.dumps(session), value=session_info.json(),
) )
return session_id return session_id
async def _get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
value = await self.persistence_store.get(
key=f"session:{self.agent_id}:{session_id}",
)
if not value:
return None
return AgentSessionInfo(**json.loads(value))
async def _add_memory_bank_to_session(self, session_id: str, bank_id: str):
session_info = await self._get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
session_info.memory_bank_id = bank_id
await self.persistence_store.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.json(),
)
async def _add_turn_to_session(self, session_id: str, turn: Turn): async def _add_turn_to_session(self, session_id: str, turn: Turn):
await self.persistence_store.set( await self.persistence_store.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=json.dumps(turn.dict()), value=turn.json(),
) )
async def _get_session_turns(self, session_id: str) -> List[Turn]: async def _get_session_turns(self, session_id: str) -> List[Turn]:
@ -169,15 +197,6 @@ class ChatAgent(ShieldRunnerMixin):
return turns return turns
async def _get_session_info(self, session_id: str) -> Optional[Dict[str, str]]:
value = await self.persistence_store.get(
key=f"session:{self.agent_id}:{session_id}",
)
if not value:
return None
return json.loads(value)
async def create_and_execute_turn( async def create_and_execute_turn(
self, request: AgentTurnCreateRequest self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -209,7 +228,7 @@ class ChatAgent(ShieldRunnerMixin):
steps = [] steps = []
output_message = None output_message = None
async for chunk in self.run( async for chunk in self.run(
session=session, session_id=request.session_id,
turn_id=turn_id, turn_id=turn_id,
input_messages=messages, input_messages=messages,
attachments=request.attachments or [], attachments=request.attachments or [],
@ -261,7 +280,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run( async def run(
self, self,
session: Session, session_id: str,
turn_id: str, turn_id: str,
input_messages: List[Message], input_messages: List[Message],
attachments: List[Attachment], attachments: List[Attachment],
@ -282,7 +301,7 @@ class ChatAgent(ShieldRunnerMixin):
yield res yield res
async for res in self._run( async for res in self._run(
session, turn_id, input_messages, attachments, sampling_params, stream session_id, turn_id, input_messages, attachments, sampling_params, stream
): ):
if isinstance(res, bool): if isinstance(res, bool):
return return
@ -364,7 +383,7 @@ class ChatAgent(ShieldRunnerMixin):
async def _run( async def _run(
self, self,
session: Session, session_id: str,
turn_id: str, turn_id: str,
input_messages: List[Message], input_messages: List[Message],
attachments: List[Attachment], attachments: List[Attachment],
@ -389,7 +408,7 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it # TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation # or append with a sliding window. this is really a very simplistic implementation
rag_context, bank_ids = await self._retrieve_context( rag_context, bank_ids = await self._retrieve_context(
session, input_messages, attachments session_id, input_messages, attachments
) )
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
@ -645,17 +664,25 @@ class ChatAgent(ShieldRunnerMixin):
n_iter += 1 n_iter += 1
async def _ensure_memory_bank(self, session: Session) -> MemoryBank: async def _ensure_memory_bank(self, session_id: str) -> str:
if session.memory_bank is None: session_info = await self._get_session_info(session_id)
session.memory_bank = await self.memory_api.create_memory_bank( if session_info is None:
name=f"memory_bank_{session.session_id}", raise ValueError(f"Session {session_id} not found")
if session_info.memory_bank_id is None:
memory_bank = await self.memory_api.create_memory_bank(
name=f"memory_bank_{session_id}",
config=VectorMemoryBankConfig( config=VectorMemoryBankConfig(
embedding_model="sentence-transformer/all-MiniLM-L6-v2", embedding_model="sentence-transformer/all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
), ),
) )
bank_id = memory_bank.bank_id
await self._add_memory_bank_to_session(session_id, bank_id)
else:
bank_id = session_info.memory_bank_id
return session.memory_bank return bank_id
async def _should_retrieve_context( async def _should_retrieve_context(
self, messages: List[Message], attachments: List[Attachment] self, messages: List[Message], attachments: List[Attachment]
@ -680,7 +707,7 @@ class ChatAgent(ShieldRunnerMixin):
return None return None
async def _retrieve_context( async def _retrieve_context(
self, session: Session, messages: List[Message], attachments: List[Attachment] self, session_id: str, messages: List[Message], attachments: List[Attachment]
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids) ) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
bank_ids = [] bank_ids = []
@ -689,8 +716,8 @@ class ChatAgent(ShieldRunnerMixin):
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs) bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
if attachments: if attachments:
bank = await self._ensure_memory_bank(session) bank_id = await self._ensure_memory_bank(session_id)
bank_ids.append(bank.bank_id) bank_ids.append(bank_id)
documents = [ documents = [
MemoryBankDocument( MemoryBankDocument(
@ -701,9 +728,11 @@ class ChatAgent(ShieldRunnerMixin):
) )
for a in attachments for a in attachments
] ]
await self.memory_api.insert_documents(bank.bank_id, documents) await self.memory_api.insert_documents(bank_id, documents)
elif session.memory_bank: else:
bank_ids.append(session.memory_bank.bank_id) session_info = await self._get_session_info(session_id)
if session_info.memory_bank_id:
bank_ids.append(session_info.memory_bank_id)
if not bank_ids: if not bank_ids:
# this can happen if the per-session memory bank is not yet populated # this can happen if the per-session memory bank is not yet populated

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import logging import logging
import uuid import uuid
from typing import AsyncGenerator from typing import AsyncGenerator
@ -14,7 +14,7 @@ from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from .agent_instance import ChatAgent from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
@ -35,6 +35,7 @@ class MetaReferenceAgentsImpl(Agents):
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api self.memory_api = memory_api
self.safety_api = safety_api self.safety_api = safety_api
self.in_memory_store = InmemoryKVStoreImpl()
async def initialize(self) -> None: async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store) self.persistence_store = await kvstore_impl(self.config.persistence_store)
@ -47,7 +48,7 @@ class MetaReferenceAgentsImpl(Agents):
await self.persistence_store.set( await self.persistence_store.set(
key=f"agent:{agent_id}", key=f"agent:{agent_id}",
value=json.dumps(agent_config.dict()), value=agent_config.json(),
) )
return AgentCreateResponse( return AgentCreateResponse(
agent_id=agent_id, agent_id=agent_id,
@ -80,7 +81,11 @@ class MetaReferenceAgentsImpl(Agents):
inference_api=self.inference_api, inference_api=self.inference_api,
safety_api=self.safety_api, safety_api=self.safety_api,
memory_api=self.memory_api, memory_api=self.memory_api,
persistence_store=self.persistence_store, persistence_store=(
self.persistence_store
if agent_config.enable_session_persistence
else self.in_memory_store
),
) )
async def create_agent_session( async def create_agent_session(

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStoreConfig from llama_stack.providers.utils.kvstore import KVStoreConfig

View file

@ -184,6 +184,7 @@ async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
# ), # ),
], ],
tool_choice=ToolChoice.auto, tool_choice=ToolChoice.auto,
enable_session_persistence=False,
input_shields=[], input_shields=[],
output_shields=[], output_shields=[],
) )

View file

@ -42,7 +42,6 @@ class FaissIndex(EmbeddingIndex):
indexlen = len(self.id_by_index) indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk self.chunk_by_index[indexlen + i] = chunk
logger.info(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
self.id_by_index[indexlen + i] = chunk.document_id self.id_by_index[indexlen + i] = chunk.document_id
self.index.add(np.array(embeddings).astype(np.float32)) self.index.add(np.array(embeddings).astype(np.float32))

View file

@ -7,6 +7,7 @@
from typing import List from typing import List
from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_stack.providers.utils.kvstore import kvstore_dependencies
def available_providers() -> List[ProviderSpec]: def available_providers() -> List[ProviderSpec]:
@ -19,11 +20,10 @@ def available_providers() -> List[ProviderSpec]:
"pillow", "pillow",
"pandas", "pandas",
"scikit-learn", "scikit-learn",
"torch", ]
"transformers", + kvstore_dependencies(),
],
module="llama_stack.providers.impls.meta_reference.agents", module="llama_stack.providers.impls.meta_reference.agents",
config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceImplConfig", config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceAgentsImplConfig",
api_dependencies=[ api_dependencies=[
Api.inference, Api.inference,
Api.safety, Api.safety,

View file

@ -7,22 +7,15 @@
from datetime import datetime from datetime import datetime
from typing import List, Optional, Protocol from typing import List, Optional, Protocol
from pydantic import BaseModel
class KVStoreValue(BaseModel):
key: str
value: str
expiration: Optional[datetime] = None
class KVStore(Protocol): class KVStore(Protocol):
# TODO: make the value type bytes instead of str
async def set( async def set(
self, key: str, value: str, expiration: Optional[datetime] = None self, key: str, value: str, expiration: Optional[datetime] = None
) -> None: ... ) -> None: ...
async def get(self, key: str) -> Optional[KVStoreValue]: ... async def get(self, key: str) -> Optional[str]: ...
async def delete(self, key: str) -> None: ... async def delete(self, key: str) -> None: ...
async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]: ... async def range(self, start_key: str, end_key: str) -> List[str]: ...

View file

@ -7,9 +7,11 @@
from enum import Enum from enum import Enum
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class KVStoreType(Enum): class KVStoreType(Enum):
redis = "redis" redis = "redis"
@ -24,20 +26,21 @@ class CommonConfig(BaseModel):
) )
class RedisKVStoreImplConfig(CommonConfig): class RedisKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.redis.value] = KVStoreType.redis.value type: Literal[KVStoreType.redis.value] = KVStoreType.redis.value
host: str = "localhost" host: str = "localhost"
port: int = 6379 port: int = 6379
class SqliteKVStoreImplConfig(CommonConfig): class SqliteKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
db_path: str = Field( db_path: str = Field(
default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(),
description="File path for the sqlite database", description="File path for the sqlite database",
) )
class PostgresKVStoreImplConfig(CommonConfig): class PostgresKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
host: str = "localhost" host: str = "localhost"
port: int = 5432 port: int = 5432
@ -47,6 +50,6 @@ class PostgresKVStoreImplConfig(CommonConfig):
KVStoreConfig = Annotated[ KVStoreConfig = Annotated[
Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PostgresKVStoreImplConfig], Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig],
Field(discriminator="type", default=KVStoreType.sqlite.value), Field(discriminator="type", default=KVStoreType.sqlite.value),
] ]

View file

@ -12,16 +12,37 @@ def kvstore_dependencies():
return ["aiosqlite", "psycopg2-binary", "redis"] return ["aiosqlite", "psycopg2-binary", "redis"]
class InmemoryKVStoreImpl(KVStore):
def __init__(self):
self._store = {}
async def initialize(self) -> None:
pass
async def get(self, key: str) -> Optional[str]:
return self._store.get(key)
async def set(self, key: str, value: str) -> None:
self._store[key] = value
async def range(self, start_key: str, end_key: str) -> List[str]:
return [
self._store[key]
for key in self._store.keys()
if key >= start_key and key < end_key
]
async def kvstore_impl(config: KVStoreConfig) -> KVStore: async def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis: if config.type == KVStoreType.redis.value:
from .redis import RedisKVStoreImpl from .redis import RedisKVStoreImpl
impl = RedisKVStoreImpl(config) impl = RedisKVStoreImpl(config)
elif config.type == KVStoreType.sqlite: elif config.type == KVStoreType.sqlite.value:
from .sqlite import SqliteKVStoreImpl from .sqlite import SqliteKVStoreImpl
impl = SqliteKVStoreImpl(config) impl = SqliteKVStoreImpl(config)
elif config.type == KVStoreType.pgvector: elif config.type == KVStoreType.postgres.value:
raise NotImplementedError() raise NotImplementedError()
else: else:
raise ValueError(f"Unknown kvstore type {config.type}") raise ValueError(f"Unknown kvstore type {config.type}")

View file

@ -4,17 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime, timedelta from datetime import datetime
from typing import List, Optional from typing import List, Optional
from redis.asyncio import Redis from redis.asyncio import Redis
from ..api import * # noqa: F403 from ..api import * # noqa: F403
from ..config import RedisKVStoreImplConfig from ..config import RedisKVStoreConfig
class RedisKVStoreImpl(KVStore): class RedisKVStoreImpl(KVStore):
def __init__(self, config: RedisKVStoreImplConfig): def __init__(self, config: RedisKVStoreConfig):
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:
@ -33,20 +33,19 @@ class RedisKVStoreImpl(KVStore):
if expiration: if expiration:
await self.redis.expireat(key, expiration) await self.redis.expireat(key, expiration)
async def get(self, key: str) -> Optional[KVStoreValue]: async def get(self, key: str) -> Optional[str]:
key = self._namespaced_key(key) key = self._namespaced_key(key)
value = await self.redis.get(key) value = await self.redis.get(key)
if value is None: if value is None:
return None return None
ttl = await self.redis.ttl(key) ttl = await self.redis.ttl(key)
expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None return value
return KVStoreValue(key=key, value=value, expiration=expiration)
async def delete(self, key: str) -> None: async def delete(self, key: str) -> None:
key = self._namespaced_key(key) key = self._namespaced_key(key)
await self.redis.delete(key) await self.redis.delete(key)
async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]: async def range(self, start_key: str, end_key: str) -> List[str]:
start_key = self._namespaced_key(start_key) start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key) end_key = self._namespaced_key(end_key)

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
from datetime import datetime from datetime import datetime
from typing import List, Optional from typing import List, Optional
@ -16,9 +18,10 @@ from ..config import SqliteKVStoreConfig
class SqliteKVStoreImpl(KVStore): class SqliteKVStoreImpl(KVStore):
def __init__(self, config: SqliteKVStoreConfig): def __init__(self, config: SqliteKVStoreConfig):
self.db_path = config.db_path self.db_path = config.db_path
self.table_name = config.table_name self.table_name = "kvstore"
async def initialize(self): async def initialize(self):
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
await db.execute( await db.execute(
f""" f"""
@ -41,7 +44,7 @@ class SqliteKVStoreImpl(KVStore):
) )
await db.commit() await db.commit()
async def get(self, key: str) -> Optional[KVStoreValue]: async def get(self, key: str) -> Optional[str]:
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
async with db.execute( async with db.execute(
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
@ -50,14 +53,14 @@ class SqliteKVStoreImpl(KVStore):
if row is None: if row is None:
return None return None
value, expiration = row value, expiration = row
return KVStoreValue(key=key, value=value, expiration=expiration) return value
async def delete(self, key: str) -> None: async def delete(self, key: str) -> None:
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit() await db.commit()
async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]: async def range(self, start_key: str, end_key: str) -> List[str]:
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
async with db.execute( async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
@ -65,8 +68,6 @@ class SqliteKVStoreImpl(KVStore):
) as cursor: ) as cursor:
result = [] result = []
async for row in cursor: async for row in cursor:
key, value, expiration = row _, value, _ = row
result.append( result.append(value)
KVStoreValue(key=key, value=value, expiration=expiration)
)
return result return result