Refactor KVStore range function

This commit is contained in:
Minutis 2024-10-06 13:36:57 +03:00
parent 7abab7604b
commit 5ae082df7f
5 changed files with 26 additions and 15 deletions

View file

@ -68,10 +68,7 @@ class AgentPersistence:
) )
async def get_session_turns(self, session_id: str) -> List[Turn]: async def get_session_turns(self, session_id: str) -> List[Turn]:
values = await self.kvstore.range( values = await self.kvstore.get_match(key_to_match=f"session:{self.agent_id}:{session_id}:")
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
turns = [] turns = []
for value in values: for value in values:
try: try:

View file

@ -18,4 +18,4 @@ class KVStore(Protocol):
async def delete(self, key: str) -> None: ... async def delete(self, key: str) -> None: ...
async def range(self, start_key: str, end_key: str) -> List[str]: ... async def get_match(self, key_to_match: str) -> List[str]: ...

View file

@ -25,11 +25,11 @@ class InmemoryKVStoreImpl(KVStore):
async def set(self, key: str, value: str) -> None: async def set(self, key: str, value: str) -> None:
self._store[key] = value self._store[key] = value
async def range(self, start_key: str, end_key: str) -> List[str]: async def get_match(self, key_to_match: str) -> List[str]:
return [ return [
self._store[key] self._store.get[key]
for key in self._store.keys() for key in self._store.keys()
if key >= start_key and key < end_key if key.startswith(key_to_match)
] ]

View file

@ -45,8 +45,22 @@ class RedisKVStoreImpl(KVStore):
key = self._namespaced_key(key) key = self._namespaced_key(key)
await self.redis.delete(key) await self.redis.delete(key)
async def range(self, start_key: str, end_key: str) -> List[str]: async def get_match(self, key_to_match: str) -> List[str]:
start_key = self._namespaced_key(start_key) key_to_match = self._namespaced_key(key_to_match)
end_key = self._namespaced_key(end_key)
return await self.redis.zrangebylex(start_key, end_key) cursor = 0
keys = set()
while True:
cursor, keys_chunk = await self.redis.scan(cursor=cursor, match=f"{key_to_match}*", count=100)
keys.update(key.decode() for key in keys_chunk)
if cursor == 0:
break
if not keys:
return []
values = await self.redis.mget(*keys)
values = [value.decode() for value in values if value is not None]
return sorted(values)

View file

@ -60,11 +60,11 @@ class SqliteKVStoreImpl(KVStore):
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit() await db.commit()
async def range(self, start_key: str, end_key: str) -> List[str]: async def get_match(self, key_to_match: str) -> List[str]:
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
async with db.execute( async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", f"SELECT key, value, expiration FROM {self.table_name} WHERE key LIKE ?",
(start_key, end_key), (f"{key_to_match}%",),
) as cursor: ) as cursor:
result = [] result = []
async for row in cursor: async for row in cursor: