initial cut at using kvstores for agent persistence

This commit is contained in:
Ashwin Bharambe 2024-09-21 21:16:26 -07:00
parent 61974e337f
commit 4eb0f30891
10 changed files with 153 additions and 120 deletions

View file

@ -3,3 +3,5 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .kvstore import * # noqa: F401, F403

View file

@ -5,20 +5,20 @@
# the root directory of this source tree.
from datetime import datetime
from typing import Any, List, Optional, Protocol
from typing import List, Optional, Protocol
from pydantic import BaseModel
class KVStoreValue(BaseModel):
key: str
value: Any
value: str
expiration: Optional[datetime] = None
class KVStore(Protocol):
async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None: ...
async def get(self, key: str) -> Optional[KVStoreValue]: ...

View file

@ -14,7 +14,7 @@ from typing_extensions import Annotated
class KVStoreType(Enum):
redis = "redis"
sqlite = "sqlite"
pgvector = "pgvector"
postgres = "postgres"
class CommonConfig(BaseModel):
@ -37,8 +37,8 @@ class SqliteKVStoreImplConfig(CommonConfig):
)
class PGVectorKVStoreImplConfig(CommonConfig):
type: Literal[KVStoreType.pgvector.value] = KVStoreType.pgvector.value
class PostgresKVStoreImplConfig(CommonConfig):
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
host: str = "localhost"
port: int = 5432
db: str = "llamastack"
@ -47,6 +47,6 @@ class PGVectorKVStoreImplConfig(CommonConfig):
KVStoreConfig = Annotated[
Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PGVectorKVStoreImplConfig],
Field(discriminator="type"),
Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PostgresKVStoreImplConfig],
Field(discriminator="type", default=KVStoreType.sqlite.value),
]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from datetime import datetime, timedelta
from typing import Any, List, Optional
from typing import List, Optional
from redis.asyncio import Redis
@ -26,7 +26,7 @@ class RedisKVStoreImpl(KVStore):
return f"{self.config.namespace}:{key}"
async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None:
key = self._namespaced_key(key)
await self.redis.set(key, value)
@ -50,11 +50,4 @@ class RedisKVStoreImpl(KVStore):
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
keys = await self.redis.keys(f"{start_key}*")
result = []
for key in keys:
if key <= end_key:
value = await self.get(key)
if value:
result.append(value)
return result
return await self.redis.zrangebylex(start_key, end_key)

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from typing import Any, List, Optional
from typing import List, Optional
import aiosqlite
@ -33,12 +32,12 @@ class SqliteKVStoreImpl(KVStore):
await db.commit()
async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
(key, json.dumps(value), expiration),
(key, value, expiration),
)
await db.commit()
@ -51,9 +50,7 @@ class SqliteKVStoreImpl(KVStore):
if row is None:
return None
value, expiration = row
return KVStoreValue(
key=key, value=json.loads(value), expiration=expiration
)
return KVStoreValue(key=key, value=value, expiration=expiration)
async def delete(self, key: str) -> None:
async with aiosqlite.connect(self.db_path) as db:
@ -70,8 +67,6 @@ class SqliteKVStoreImpl(KVStore):
async for row in cursor:
key, value, expiration = row
result.append(
KVStoreValue(
key=key, value=json.loads(value), expiration=expiration
)
KVStoreValue(key=key, value=value, expiration=expiration)
)
return result