mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
initial cut at using kvstores for agent persistence
This commit is contained in:
parent
61974e337f
commit
4eb0f30891
10 changed files with 153 additions and 120 deletions
|
@ -276,6 +276,8 @@ class AgentConfigCommon(BaseModel):
|
||||||
default=ToolPromptFormat.json
|
default=ToolPromptFormat.json
|
||||||
)
|
)
|
||||||
|
|
||||||
|
max_infer_iters: int = 10
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentConfig(AgentConfigCommon):
|
class AgentConfig(AgentConfigCommon):
|
||||||
|
|
|
@ -83,10 +83,12 @@ def prompt_for_discriminated_union(
|
||||||
if isinstance(typ, FieldInfo):
|
if isinstance(typ, FieldInfo):
|
||||||
inner_type = typ.annotation
|
inner_type = typ.annotation
|
||||||
discriminator = typ.discriminator
|
discriminator = typ.discriminator
|
||||||
|
default_value = typ.default
|
||||||
else:
|
else:
|
||||||
args = get_args(typ)
|
args = get_args(typ)
|
||||||
inner_type = args[0]
|
inner_type = args[0]
|
||||||
discriminator = args[1].discriminator
|
discriminator = args[1].discriminator
|
||||||
|
default_value = args[1].default
|
||||||
|
|
||||||
union_types = get_args(inner_type)
|
union_types = get_args(inner_type)
|
||||||
# Find the discriminator field in each union type
|
# Find the discriminator field in each union type
|
||||||
|
@ -99,9 +101,14 @@ def prompt_for_discriminated_union(
|
||||||
type_map[value] = t
|
type_map[value] = t
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
discriminator_value = input(
|
prompt = f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())})"
|
||||||
f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())}): "
|
if default_value is not None:
|
||||||
)
|
prompt += f" (default: {default_value})"
|
||||||
|
|
||||||
|
discriminator_value = input(f"{prompt}: ")
|
||||||
|
if discriminator_value == "" and default_value is not None:
|
||||||
|
discriminator_value = default_value
|
||||||
|
|
||||||
if discriminator_value in type_map:
|
if discriminator_value in type_map:
|
||||||
chosen_type = type_map[discriminator_value]
|
chosen_type = type_map[discriminator_value]
|
||||||
print(f"\nConfiguring {chosen_type.__name__}:")
|
print(f"\nConfiguring {chosen_type.__name__}:")
|
||||||
|
|
|
@ -25,10 +25,19 @@ from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.api import KVStore
|
||||||
|
|
||||||
from .rag.context_retriever import generate_rag_query
|
from .rag.context_retriever import generate_rag_query
|
||||||
from .safety import SafetyException, ShieldRunnerMixin
|
from .safety import SafetyException, ShieldRunnerMixin
|
||||||
from .tools.base import BaseTool
|
from .tools.base import BaseTool
|
||||||
from .tools.builtin import interpret_content_as_attachment, SingleMessageBuiltinTool
|
from .tools.builtin import (
|
||||||
|
CodeInterpreterTool,
|
||||||
|
interpret_content_as_attachment,
|
||||||
|
PhotogenTool,
|
||||||
|
SearchTool,
|
||||||
|
WolframAlphaTool,
|
||||||
|
)
|
||||||
|
from .tools.safety import with_safety
|
||||||
|
|
||||||
|
|
||||||
def make_random_string(length: int = 8):
|
def make_random_string(length: int = 8):
|
||||||
|
@ -40,23 +49,44 @@ def make_random_string(length: int = 8):
|
||||||
class ChatAgent(ShieldRunnerMixin):
|
class ChatAgent(ShieldRunnerMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
agent_id: str,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
memory_api: Memory,
|
memory_api: Memory,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
builtin_tools: List[SingleMessageBuiltinTool],
|
persistence_store: KVStore,
|
||||||
max_infer_iters: int = 10,
|
|
||||||
):
|
):
|
||||||
|
self.agent_id = agent_id
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
self.memory_api = memory_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
|
self.persistence_store = persistence_store
|
||||||
self.max_infer_iters = max_infer_iters
|
|
||||||
self.tools_dict = {t.get_name(): t for t in builtin_tools}
|
|
||||||
|
|
||||||
self.tempdir = tempfile.mkdtemp()
|
self.tempdir = tempfile.mkdtemp()
|
||||||
self.sessions = {}
|
|
||||||
|
builtin_tools = []
|
||||||
|
for tool_defn in agent_config.tools:
|
||||||
|
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
||||||
|
tool = WolframAlphaTool(tool_defn.api_key)
|
||||||
|
elif isinstance(tool_defn, SearchToolDefinition):
|
||||||
|
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
|
||||||
|
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
||||||
|
tool = CodeInterpreterTool()
|
||||||
|
elif isinstance(tool_defn, PhotogenToolDefinition):
|
||||||
|
tool = PhotogenTool(dump_dir=self.tempdir)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
builtin_tools.append(
|
||||||
|
with_safety(
|
||||||
|
tool,
|
||||||
|
safety_api,
|
||||||
|
tool_defn.input_shields,
|
||||||
|
tool_defn.output_shields,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.tools_dict = {t.get_name(): t for t in builtin_tools}
|
||||||
|
|
||||||
ShieldRunnerMixin.__init__(
|
ShieldRunnerMixin.__init__(
|
||||||
self,
|
self,
|
||||||
|
@ -80,7 +110,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
msg.context = None
|
msg.context = None
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
||||||
# messages.extend(turn.input_messages)
|
|
||||||
for step in turn.steps:
|
for step in turn.steps:
|
||||||
if step.step_type == StepType.inference.value:
|
if step.step_type == StepType.inference.value:
|
||||||
messages.append(step.model_response)
|
messages.append(step.model_response)
|
||||||
|
@ -105,31 +134,64 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# print_dialog(messages)
|
# print_dialog(messages)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def create_session(self, name: str) -> Session:
|
async def create_session(self, name: str) -> str:
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
session = Session(
|
session = dict(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
session_name=name,
|
session_name=name,
|
||||||
turns=[],
|
started_at=str(datetime.now()),
|
||||||
started_at=datetime.now(),
|
|
||||||
)
|
)
|
||||||
self.sessions[session_id] = session
|
await self.persistence_store.set(
|
||||||
return session
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
|
value=json.dumps(session),
|
||||||
|
)
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
async def _add_turn_to_session(self, session_id: str, turn: Turn):
|
||||||
|
await self.persistence_store.set(
|
||||||
|
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||||
|
value=json.dumps(turn.dict()),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_session_turns(self, session_id: str) -> List[Turn]:
|
||||||
|
values = await self.persistence_store.range(
|
||||||
|
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||||
|
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||||
|
)
|
||||||
|
turns = []
|
||||||
|
for value in values:
|
||||||
|
try:
|
||||||
|
turn = Turn(**json.loads(value))
|
||||||
|
turns.append(turn)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing turn: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return turns
|
||||||
|
|
||||||
|
async def _get_session_info(self, session_id: str) -> Optional[Dict[str, str]]:
|
||||||
|
value = await self.persistence_store.get(
|
||||||
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
|
)
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return json.loads(value)
|
||||||
|
|
||||||
async def create_and_execute_turn(
|
async def create_and_execute_turn(
|
||||||
self, request: AgentTurnCreateRequest
|
self, request: AgentTurnCreateRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
assert (
|
session_info = await self._get_session_info(request.session_id)
|
||||||
request.session_id in self.sessions
|
if session_info is None:
|
||||||
), f"Session {request.session_id} not found"
|
raise ValueError(f"Session {request.session_id} not found")
|
||||||
|
|
||||||
session = self.sessions[request.session_id]
|
turns = await self._get_session_turns(request.session_id)
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
if len(session.turns) == 0 and self.agent_config.instructions != "":
|
if len(turns) == 0 and self.agent_config.instructions != "":
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
for i, turn in enumerate(session.turns):
|
for i, turn in enumerate(turns):
|
||||||
messages.extend(self.turn_to_messages(turn))
|
messages.extend(self.turn_to_messages(turn))
|
||||||
|
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
|
@ -186,7 +248,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
completed_at=datetime.now(),
|
completed_at=datetime.now(),
|
||||||
steps=steps,
|
steps=steps,
|
||||||
)
|
)
|
||||||
session.turns.append(turn)
|
await self._add_turn_to_session(request.session_id, turn)
|
||||||
|
|
||||||
chunk = AgentTurnResponseStreamChunk(
|
chunk = AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
|
@ -455,7 +517,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if n_iter >= self.max_infer_iters:
|
if n_iter >= self.agent_config.max_infer_iters:
|
||||||
cprint("Done with MAX iterations, exiting.")
|
cprint("Done with MAX iterations, exiting.")
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import tempfile
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
@ -15,39 +14,15 @@ from llama_stack.apis.memory import Memory
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from .agent_instance import ChatAgent
|
from .agent_instance import ChatAgent
|
||||||
from .config import MetaReferenceAgentsImplConfig
|
from .config import MetaReferenceAgentsImplConfig
|
||||||
from .tools.builtin import (
|
|
||||||
CodeInterpreterTool,
|
|
||||||
PhotogenTool,
|
|
||||||
SearchTool,
|
|
||||||
WolframAlphaTool,
|
|
||||||
)
|
|
||||||
from .tools.safety import with_safety
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
AGENT_INSTANCES_BY_ID = {}
|
|
||||||
|
|
||||||
|
|
||||||
class KVStore(Protocol):
|
|
||||||
def get(self, key: str) -> str:
|
|
||||||
...
|
|
||||||
|
|
||||||
def set(self, key: str, value: str) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
def kvstore_impl(config: KVStoreConfig) -> KVStore:
|
|
||||||
if config.type == KVStoreType.redis:
|
|
||||||
from .kvstore_impls.redis import RedisKVStoreImpl
|
|
||||||
return RedisKVStoreImpl(config)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImpl(Agents):
|
class MetaReferenceAgentsImpl(Agents):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -60,10 +35,9 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.memory_api = memory_api
|
self.memory_api = memory_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.kvstore = kvstore_impl(config.kvstore)
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||||
|
|
||||||
async def create_agent(
|
async def create_agent(
|
||||||
self,
|
self,
|
||||||
|
@ -71,38 +45,42 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
) -> AgentCreateResponse:
|
) -> AgentCreateResponse:
|
||||||
agent_id = str(uuid.uuid4())
|
agent_id = str(uuid.uuid4())
|
||||||
|
|
||||||
builtin_tools = []
|
await self.persistence_store.set(
|
||||||
for tool_defn in agent_config.tools:
|
key=f"agent:{agent_id}",
|
||||||
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
value=json.dumps(agent_config.dict()),
|
||||||
tool = WolframAlphaTool(tool_defn.api_key)
|
)
|
||||||
elif isinstance(tool_defn, SearchToolDefinition):
|
return AgentCreateResponse(
|
||||||
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
|
agent_id=agent_id,
|
||||||
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
|
)
|
||||||
tool = CodeInterpreterTool()
|
|
||||||
elif isinstance(tool_defn, PhotogenToolDefinition):
|
|
||||||
tool = PhotogenTool(dump_dir=tempfile.mkdtemp())
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
builtin_tools.append(
|
async def get_agent(self, agent_id: str) -> ChatAgent:
|
||||||
with_safety(
|
agent_config = await self.persistence_store.get(
|
||||||
tool,
|
key=f"agent:{agent_id}",
|
||||||
self.safety_api,
|
)
|
||||||
tool_defn.input_shields,
|
if not agent_config:
|
||||||
tool_defn.output_shields,
|
raise ValueError(f"Could not find agent config for {agent_id}")
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent(
|
try:
|
||||||
|
agent_config = json.loads(agent_config)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not JSON decode agent config for {agent_id}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent_config = AgentConfig(**agent_config)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not validate(?) agent config for {agent_id}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return ChatAgent(
|
||||||
|
agent_id=agent_id,
|
||||||
agent_config=agent_config,
|
agent_config=agent_config,
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
memory_api=self.memory_api,
|
memory_api=self.memory_api,
|
||||||
builtin_tools=builtin_tools,
|
persistence_store=self.persistence_store,
|
||||||
)
|
|
||||||
|
|
||||||
return AgentCreateResponse(
|
|
||||||
agent_id=agent_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent_session(
|
async def create_agent_session(
|
||||||
|
@ -110,12 +88,11 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_name: str,
|
session_name: str,
|
||||||
) -> AgentSessionCreateResponse:
|
) -> AgentSessionCreateResponse:
|
||||||
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
agent = await self.get_agent(agent_id)
|
||||||
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
|
||||||
|
|
||||||
session = agent.create_session(session_name)
|
session_id = await agent.create_session(session_name)
|
||||||
return AgentSessionCreateResponse(
|
return AgentSessionCreateResponse(
|
||||||
session_id=session.session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def create_agent_turn(
|
async def create_agent_turn(
|
||||||
|
@ -131,6 +108,8 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
attachments: Optional[List[Attachment]] = None,
|
attachments: Optional[List[Attachment]] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
agent = await self.get_agent(agent_id)
|
||||||
|
|
||||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
|
@ -140,12 +119,5 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_id = request.agent_id
|
|
||||||
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
|
|
||||||
agent = AGENT_INSTANCES_BY_ID[agent_id]
|
|
||||||
|
|
||||||
assert (
|
|
||||||
request.session_id in agent.sessions
|
|
||||||
), f"Session {request.session_id} not found"
|
|
||||||
async for event in agent.create_and_execute_turn(request):
|
async for event in agent.create_and_execute_turn(request):
|
||||||
yield event
|
yield event
|
||||||
|
|
|
@ -8,4 +8,4 @@ from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
class MetaReferenceAgentsImplConfig(BaseModel):
|
||||||
kv_store: KVStoreConfig
|
persistence_store: KVStoreConfig
|
||||||
|
|
|
@ -3,3 +3,5 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .kvstore import * # noqa: F401, F403
|
||||||
|
|
|
@ -5,20 +5,20 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, List, Optional, Protocol
|
from typing import List, Optional, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class KVStoreValue(BaseModel):
|
class KVStoreValue(BaseModel):
|
||||||
key: str
|
key: str
|
||||||
value: Any
|
value: str
|
||||||
expiration: Optional[datetime] = None
|
expiration: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
class KVStore(Protocol):
|
class KVStore(Protocol):
|
||||||
async def set(
|
async def set(
|
||||||
self, key: str, value: Any, expiration: Optional[datetime] = None
|
self, key: str, value: str, expiration: Optional[datetime] = None
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
async def get(self, key: str) -> Optional[KVStoreValue]: ...
|
async def get(self, key: str) -> Optional[KVStoreValue]: ...
|
||||||
|
|
|
@ -14,7 +14,7 @@ from typing_extensions import Annotated
|
||||||
class KVStoreType(Enum):
|
class KVStoreType(Enum):
|
||||||
redis = "redis"
|
redis = "redis"
|
||||||
sqlite = "sqlite"
|
sqlite = "sqlite"
|
||||||
pgvector = "pgvector"
|
postgres = "postgres"
|
||||||
|
|
||||||
|
|
||||||
class CommonConfig(BaseModel):
|
class CommonConfig(BaseModel):
|
||||||
|
@ -37,8 +37,8 @@ class SqliteKVStoreImplConfig(CommonConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PGVectorKVStoreImplConfig(CommonConfig):
|
class PostgresKVStoreImplConfig(CommonConfig):
|
||||||
type: Literal[KVStoreType.pgvector.value] = KVStoreType.pgvector.value
|
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
||||||
host: str = "localhost"
|
host: str = "localhost"
|
||||||
port: int = 5432
|
port: int = 5432
|
||||||
db: str = "llamastack"
|
db: str = "llamastack"
|
||||||
|
@ -47,6 +47,6 @@ class PGVectorKVStoreImplConfig(CommonConfig):
|
||||||
|
|
||||||
|
|
||||||
KVStoreConfig = Annotated[
|
KVStoreConfig = Annotated[
|
||||||
Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PGVectorKVStoreImplConfig],
|
Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PostgresKVStoreImplConfig],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type", default=KVStoreType.sqlite.value),
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from redis.asyncio import Redis
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class RedisKVStoreImpl(KVStore):
|
||||||
return f"{self.config.namespace}:{key}"
|
return f"{self.config.namespace}:{key}"
|
||||||
|
|
||||||
async def set(
|
async def set(
|
||||||
self, key: str, value: Any, expiration: Optional[datetime] = None
|
self, key: str, value: str, expiration: Optional[datetime] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
key = self._namespaced_key(key)
|
key = self._namespaced_key(key)
|
||||||
await self.redis.set(key, value)
|
await self.redis.set(key, value)
|
||||||
|
@ -50,11 +50,4 @@ class RedisKVStoreImpl(KVStore):
|
||||||
start_key = self._namespaced_key(start_key)
|
start_key = self._namespaced_key(start_key)
|
||||||
end_key = self._namespaced_key(end_key)
|
end_key = self._namespaced_key(end_key)
|
||||||
|
|
||||||
keys = await self.redis.keys(f"{start_key}*")
|
return await self.redis.zrangebylex(start_key, end_key)
|
||||||
result = []
|
|
||||||
for key in keys:
|
|
||||||
if key <= end_key:
|
|
||||||
value = await self.get(key)
|
|
||||||
if value:
|
|
||||||
result.append(value)
|
|
||||||
return result
|
|
||||||
|
|
|
@ -4,9 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
|
@ -33,12 +32,12 @@ class SqliteKVStoreImpl(KVStore):
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
async def set(
|
async def set(
|
||||||
self, key: str, value: Any, expiration: Optional[datetime] = None
|
self, key: str, value: str, expiration: Optional[datetime] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
await db.execute(
|
await db.execute(
|
||||||
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
|
||||||
(key, json.dumps(value), expiration),
|
(key, value, expiration),
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
@ -51,9 +50,7 @@ class SqliteKVStoreImpl(KVStore):
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
value, expiration = row
|
value, expiration = row
|
||||||
return KVStoreValue(
|
return KVStoreValue(key=key, value=value, expiration=expiration)
|
||||||
key=key, value=json.loads(value), expiration=expiration
|
|
||||||
)
|
|
||||||
|
|
||||||
async def delete(self, key: str) -> None:
|
async def delete(self, key: str) -> None:
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
@ -70,8 +67,6 @@ class SqliteKVStoreImpl(KVStore):
|
||||||
async for row in cursor:
|
async for row in cursor:
|
||||||
key, value, expiration = row
|
key, value, expiration = row
|
||||||
result.append(
|
result.append(
|
||||||
KVStoreValue(
|
KVStoreValue(key=key, value=value, expiration=expiration)
|
||||||
key=key, value=json.loads(value), expiration=expiration
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue