mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Agent persistence works
This commit is contained in:
parent
4eb0f30891
commit
59f1fe5af8
17 changed files with 136 additions and 90 deletions
|
@ -102,6 +102,7 @@ async def _run_agent(api, tool_definitions, user_prompts, attachments=None):
|
|||
tools=tool_definitions,
|
||||
tool_choice=ToolChoice.auto,
|
||||
tool_prompt_format=ToolPromptFormat.function_tag,
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
create_response = await api.create_agent(agent_config)
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import Any, Dict, List, Optional, Protocol, Union
|
|||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -168,14 +168,12 @@ Fully-qualified name of the module to import. The module is expected to have:
|
|||
|
||||
|
||||
class RemoteProviderConfig(BaseModel):
|
||||
url: str = Field(..., description="The URL for the provider")
|
||||
host: str = "localhost"
|
||||
port: int
|
||||
|
||||
@validator("url")
|
||||
@classmethod
|
||||
def validate_url(cls, url: str) -> str:
|
||||
if not url.startswith("http"):
|
||||
raise ValueError(f"URL must start with http: {url}")
|
||||
return url.rstrip("/")
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
|
||||
def remote_provider_id(adapter_id: str) -> str:
|
||||
|
|
|
@ -15,3 +15,5 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
|
|||
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
|
||||
|
||||
BUILDS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "builds"
|
||||
|
||||
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
|
||||
|
|
|
@ -31,9 +31,6 @@ class ChromaIndex(EmbeddingIndex):
|
|||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
print(f"Adding chunk #{i} tokens={chunk.token_count}")
|
||||
|
||||
await self.collection.add(
|
||||
documents=[chunk.json() for chunk in chunks],
|
||||
embeddings=embeddings,
|
||||
|
|
|
@ -80,7 +80,6 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
values = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
print(f"Adding chunk #{i} tokens={chunk.token_count}")
|
||||
values.append(
|
||||
(
|
||||
f"{chunk.document_id}:chunk-{i}",
|
||||
|
|
|
@ -8,18 +8,14 @@ from typing import Dict
|
|||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MetaReferenceImplConfig
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceImplConfig, deps: Dict[Api, ProviderSpec]
|
||||
config: MetaReferenceAgentsImplConfig, deps: Dict[Api, ProviderSpec]
|
||||
):
|
||||
from .agents import MetaReferenceAgentsImpl
|
||||
|
||||
assert isinstance(
|
||||
config, MetaReferenceImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = MetaReferenceAgentsImpl(
|
||||
config,
|
||||
deps[Api.inference],
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import shutil
|
||||
|
@ -25,7 +26,7 @@ from llama_stack.apis.inference import * # noqa: F403
|
|||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.api import KVStore
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from .rag.context_retriever import generate_rag_query
|
||||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
@ -46,6 +47,13 @@ def make_random_string(length: int = 8):
|
|||
)
|
||||
|
||||
|
||||
class AgentSessionInfo(BaseModel):
|
||||
session_id: str
|
||||
session_name: str
|
||||
memory_bank_id: Optional[str] = None
|
||||
started_at: datetime
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -136,21 +144,41 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def create_session(self, name: str) -> str:
|
||||
session_id = str(uuid.uuid4())
|
||||
session = dict(
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=str(datetime.now()),
|
||||
started_at=datetime.now(),
|
||||
)
|
||||
await self.persistence_store.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=json.dumps(session),
|
||||
value=session_info.json(),
|
||||
)
|
||||
return session_id
|
||||
|
||||
async def _get_session_info(self, session_id: str) -> Optional[AgentSessionInfo]:
|
||||
value = await self.persistence_store.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
if not value:
|
||||
return None
|
||||
|
||||
return AgentSessionInfo(**json.loads(value))
|
||||
|
||||
async def _add_memory_bank_to_session(self, session_id: str, bank_id: str):
|
||||
session_info = await self._get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
session_info.memory_bank_id = bank_id
|
||||
await self.persistence_store.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.json(),
|
||||
)
|
||||
|
||||
async def _add_turn_to_session(self, session_id: str, turn: Turn):
|
||||
await self.persistence_store.set(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||
value=json.dumps(turn.dict()),
|
||||
value=turn.json(),
|
||||
)
|
||||
|
||||
async def _get_session_turns(self, session_id: str) -> List[Turn]:
|
||||
|
@ -169,15 +197,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
return turns
|
||||
|
||||
async def _get_session_info(self, session_id: str) -> Optional[Dict[str, str]]:
|
||||
value = await self.persistence_store.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
if not value:
|
||||
return None
|
||||
|
||||
return json.loads(value)
|
||||
|
||||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
|
@ -209,7 +228,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session=session,
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
|
@ -261,7 +280,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def run(
|
||||
self,
|
||||
session: Session,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
attachments: List[Attachment],
|
||||
|
@ -282,7 +301,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield res
|
||||
|
||||
async for res in self._run(
|
||||
session, turn_id, input_messages, attachments, sampling_params, stream
|
||||
session_id, turn_id, input_messages, attachments, sampling_params, stream
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
@ -364,7 +383,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
async def _run(
|
||||
self,
|
||||
session: Session,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: List[Message],
|
||||
attachments: List[Attachment],
|
||||
|
@ -389,7 +408,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# TODO: find older context from the session and either replace it
|
||||
# or append with a sliding window. this is really a very simplistic implementation
|
||||
rag_context, bank_ids = await self._retrieve_context(
|
||||
session, input_messages, attachments
|
||||
session_id, input_messages, attachments
|
||||
)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
|
@ -645,17 +664,25 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
n_iter += 1
|
||||
|
||||
async def _ensure_memory_bank(self, session: Session) -> MemoryBank:
|
||||
if session.memory_bank is None:
|
||||
session.memory_bank = await self.memory_api.create_memory_bank(
|
||||
name=f"memory_bank_{session.session_id}",
|
||||
async def _ensure_memory_bank(self, session_id: str) -> str:
|
||||
session_info = await self._get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
if session_info.memory_bank_id is None:
|
||||
memory_bank = await self.memory_api.create_memory_bank(
|
||||
name=f"memory_bank_{session_id}",
|
||||
config=VectorMemoryBankConfig(
|
||||
embedding_model="sentence-transformer/all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
),
|
||||
)
|
||||
bank_id = memory_bank.bank_id
|
||||
await self._add_memory_bank_to_session(session_id, bank_id)
|
||||
else:
|
||||
bank_id = session_info.memory_bank_id
|
||||
|
||||
return session.memory_bank
|
||||
return bank_id
|
||||
|
||||
async def _should_retrieve_context(
|
||||
self, messages: List[Message], attachments: List[Attachment]
|
||||
|
@ -680,7 +707,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return None
|
||||
|
||||
async def _retrieve_context(
|
||||
self, session: Session, messages: List[Message], attachments: List[Attachment]
|
||||
self, session_id: str, messages: List[Message], attachments: List[Attachment]
|
||||
) -> Tuple[List[str], List[int]]: # (rag_context, bank_ids)
|
||||
bank_ids = []
|
||||
|
||||
|
@ -689,8 +716,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
bank_ids.extend(c.bank_id for c in memory.memory_bank_configs)
|
||||
|
||||
if attachments:
|
||||
bank = await self._ensure_memory_bank(session)
|
||||
bank_ids.append(bank.bank_id)
|
||||
bank_id = await self._ensure_memory_bank(session_id)
|
||||
bank_ids.append(bank_id)
|
||||
|
||||
documents = [
|
||||
MemoryBankDocument(
|
||||
|
@ -701,9 +728,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
for a in attachments
|
||||
]
|
||||
await self.memory_api.insert_documents(bank.bank_id, documents)
|
||||
elif session.memory_bank:
|
||||
bank_ids.append(session.memory_bank.bank_id)
|
||||
await self.memory_api.insert_documents(bank_id, documents)
|
||||
else:
|
||||
session_info = await self._get_session_info(session_id)
|
||||
if session_info.memory_bank_id:
|
||||
bank_ids.append(session_info.memory_bank_id)
|
||||
|
||||
if not bank_ids:
|
||||
# this can happen if the per-session memory bank is not yet populated
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
@ -14,7 +14,7 @@ from llama_stack.apis.memory import Memory
|
|||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.agents import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
|
@ -35,6 +35,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.safety_api = safety_api
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
|
@ -47,7 +48,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
await self.persistence_store.set(
|
||||
key=f"agent:{agent_id}",
|
||||
value=json.dumps(agent_config.dict()),
|
||||
value=agent_config.json(),
|
||||
)
|
||||
return AgentCreateResponse(
|
||||
agent_id=agent_id,
|
||||
|
@ -80,7 +81,11 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
memory_api=self.memory_api,
|
||||
persistence_store=self.persistence_store,
|
||||
persistence_store=(
|
||||
self.persistence_store
|
||||
if agent_config.enable_session_persistence
|
||||
else self.in_memory_store
|
||||
),
|
||||
)
|
||||
|
||||
async def create_agent_session(
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||
|
||||
|
||||
|
|
|
@ -184,6 +184,7 @@ async def chat_agent(mock_inference_api, mock_safety_api, mock_memory_api):
|
|||
# ),
|
||||
],
|
||||
tool_choice=ToolChoice.auto,
|
||||
enable_session_persistence=False,
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
)
|
||||
|
|
|
@ -42,7 +42,6 @@ class FaissIndex(EmbeddingIndex):
|
|||
indexlen = len(self.id_by_index)
|
||||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
logger.info(f"Adding chunk #{indexlen + i} tokens={chunk.token_count}")
|
||||
self.id_by_index[indexlen + i] = chunk.document_id
|
||||
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
from llama_stack.providers.utils.kvstore import kvstore_dependencies
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
|
@ -19,11 +20,10 @@ def available_providers() -> List[ProviderSpec]:
|
|||
"pillow",
|
||||
"pandas",
|
||||
"scikit-learn",
|
||||
"torch",
|
||||
"transformers",
|
||||
],
|
||||
]
|
||||
+ kvstore_dependencies(),
|
||||
module="llama_stack.providers.impls.meta_reference.agents",
|
||||
config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceImplConfig",
|
||||
config_class="llama_stack.providers.impls.meta_reference.agents.MetaReferenceAgentsImplConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
Api.safety,
|
||||
|
|
|
@ -7,22 +7,15 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class KVStoreValue(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
expiration: Optional[datetime] = None
|
||||
|
||||
|
||||
class KVStore(Protocol):
|
||||
# TODO: make the value type bytes instead of str
|
||||
async def set(
|
||||
self, key: str, value: str, expiration: Optional[datetime] = None
|
||||
) -> None: ...
|
||||
|
||||
async def get(self, key: str) -> Optional[KVStoreValue]: ...
|
||||
async def get(self, key: str) -> Optional[str]: ...
|
||||
|
||||
async def delete(self, key: str) -> None: ...
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]: ...
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]: ...
|
||||
|
|
|
@ -7,9 +7,11 @@
|
|||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
|
||||
class KVStoreType(Enum):
|
||||
redis = "redis"
|
||||
|
@ -24,20 +26,21 @@ class CommonConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class RedisKVStoreImplConfig(CommonConfig):
|
||||
class RedisKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.redis.value] = KVStoreType.redis.value
|
||||
host: str = "localhost"
|
||||
port: int = 6379
|
||||
|
||||
|
||||
class SqliteKVStoreImplConfig(CommonConfig):
|
||||
class SqliteKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
|
||||
db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(),
|
||||
description="File path for the sqlite database",
|
||||
)
|
||||
|
||||
|
||||
class PostgresKVStoreImplConfig(CommonConfig):
|
||||
class PostgresKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
|
@ -47,6 +50,6 @@ class PostgresKVStoreImplConfig(CommonConfig):
|
|||
|
||||
|
||||
KVStoreConfig = Annotated[
|
||||
Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PostgresKVStoreImplConfig],
|
||||
Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig],
|
||||
Field(discriminator="type", default=KVStoreType.sqlite.value),
|
||||
]
|
||||
|
|
|
@ -12,16 +12,37 @@ def kvstore_dependencies():
|
|||
return ["aiosqlite", "psycopg2-binary", "redis"]
|
||||
|
||||
|
||||
class InmemoryKVStoreImpl(KVStore):
|
||||
def __init__(self):
|
||||
self._store = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
return self._store.get(key)
|
||||
|
||||
async def set(self, key: str, value: str) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
return [
|
||||
self._store[key]
|
||||
for key in self._store.keys()
|
||||
if key >= start_key and key < end_key
|
||||
]
|
||||
|
||||
|
||||
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
|
||||
if config.type == KVStoreType.redis:
|
||||
if config.type == KVStoreType.redis.value:
|
||||
from .redis import RedisKVStoreImpl
|
||||
|
||||
impl = RedisKVStoreImpl(config)
|
||||
elif config.type == KVStoreType.sqlite:
|
||||
elif config.type == KVStoreType.sqlite.value:
|
||||
from .sqlite import SqliteKVStoreImpl
|
||||
|
||||
impl = SqliteKVStoreImpl(config)
|
||||
elif config.type == KVStoreType.pgvector:
|
||||
elif config.type == KVStoreType.postgres.value:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||
|
|
|
@ -4,17 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from ..api import * # noqa: F403
|
||||
from ..config import RedisKVStoreImplConfig
|
||||
from ..config import RedisKVStoreConfig
|
||||
|
||||
|
||||
class RedisKVStoreImpl(KVStore):
|
||||
def __init__(self, config: RedisKVStoreImplConfig):
|
||||
def __init__(self, config: RedisKVStoreConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
@ -33,20 +33,19 @@ class RedisKVStoreImpl(KVStore):
|
|||
if expiration:
|
||||
await self.redis.expireat(key, expiration)
|
||||
|
||||
async def get(self, key: str) -> Optional[KVStoreValue]:
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
key = self._namespaced_key(key)
|
||||
value = await self.redis.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
ttl = await self.redis.ttl(key)
|
||||
expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None
|
||||
return KVStoreValue(key=key, value=value, expiration=expiration)
|
||||
return value
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.delete(key)
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]:
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
|
@ -16,9 +18,10 @@ from ..config import SqliteKVStoreConfig
|
|||
class SqliteKVStoreImpl(KVStore):
|
||||
def __init__(self, config: SqliteKVStoreConfig):
|
||||
self.db_path = config.db_path
|
||||
self.table_name = config.table_name
|
||||
self.table_name = "kvstore"
|
||||
|
||||
async def initialize(self):
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"""
|
||||
|
@ -41,7 +44,7 @@ class SqliteKVStoreImpl(KVStore):
|
|||
)
|
||||
await db.commit()
|
||||
|
||||
async def get(self, key: str) -> Optional[KVStoreValue]:
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||
|
@ -50,14 +53,14 @@ class SqliteKVStoreImpl(KVStore):
|
|||
if row is None:
|
||||
return None
|
||||
value, expiration = row
|
||||
return KVStoreValue(key=key, value=value, expiration=expiration)
|
||||
return value
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await db.commit()
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]:
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
|
@ -65,8 +68,6 @@ class SqliteKVStoreImpl(KVStore):
|
|||
) as cursor:
|
||||
result = []
|
||||
async for row in cursor:
|
||||
key, value, expiration = row
|
||||
result.append(
|
||||
KVStoreValue(key=key, value=value, expiration=expiration)
|
||||
)
|
||||
_, value, _ = row
|
||||
result.append(value)
|
||||
return result
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue