mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Refactor KVStore range function
This commit is contained in:
parent
7abab7604b
commit
5ae082df7f
5 changed files with 26 additions and 15 deletions
|
@ -68,10 +68,7 @@ class AgentPersistence:
|
|||
)
|
||||
|
||||
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
||||
values = await self.kvstore.range(
|
||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
values = await self.kvstore.get_match(key_to_match=f"session:{self.agent_id}:{session_id}:")
|
||||
turns = []
|
||||
for value in values:
|
||||
try:
|
||||
|
|
|
@ -18,4 +18,4 @@ class KVStore(Protocol):
|
|||
|
||||
async def delete(self, key: str) -> None: ...
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]: ...
|
||||
async def get_match(self, key_to_match: str) -> List[str]: ...
|
||||
|
|
|
@ -25,11 +25,11 @@ class InmemoryKVStoreImpl(KVStore):
|
|||
async def set(self, key: str, value: str) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
async def get_match(self, key_to_match: str) -> List[str]:
|
||||
return [
|
||||
self._store[key]
|
||||
self._store.get[key]
|
||||
for key in self._store.keys()
|
||||
if key >= start_key and key < end_key
|
||||
if key.startswith(key_to_match)
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -45,8 +45,22 @@ class RedisKVStoreImpl(KVStore):
|
|||
key = self._namespaced_key(key)
|
||||
await self.redis.delete(key)
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
async def get_match(self, key_to_match: str) -> List[str]:
|
||||
key_to_match = self._namespaced_key(key_to_match)
|
||||
|
||||
return await self.redis.zrangebylex(start_key, end_key)
|
||||
cursor = 0
|
||||
keys = set()
|
||||
|
||||
while True:
|
||||
cursor, keys_chunk = await self.redis.scan(cursor=cursor, match=f"{key_to_match}*", count=100)
|
||||
keys.update(key.decode() for key in keys_chunk)
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
if not keys:
|
||||
return []
|
||||
|
||||
values = await self.redis.mget(*keys)
|
||||
values = [value.decode() for value in values if value is not None]
|
||||
|
||||
return sorted(values)
|
||||
|
|
|
@ -60,11 +60,11 @@ class SqliteKVStoreImpl(KVStore):
|
|||
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await db.commit()
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
async def get_match(self, key_to_match: str) -> List[str]:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
(start_key, end_key),
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key LIKE ?",
|
||||
(f"{key_to_match}%",),
|
||||
) as cursor:
|
||||
result = []
|
||||
async for row in cursor:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue