diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index ed400efae..867ae3f98 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -17,7 +17,8 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR class KVStoreType(Enum): redis = "redis" sqlite = "sqlite" - postgres = "postgres" + postgres = "postgres", + mongodb = "mongodb" class CommonConfig(BaseModel): @@ -55,15 +56,15 @@ class SqliteKVStoreConfig(CommonConfig): @classmethod def sample_run_config( - cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db" + cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db" ): return { "type": "sqlite", "namespace": None, "db_path": "${env.SQLITE_STORE_DIR:~/.llama/" - + __distro_dir__ - + "}/" - + db_name, + + __distro_dir__ + + "}/" + + db_name, } @@ -106,7 +107,30 @@ class PostgresKVStoreConfig(CommonConfig): return v +class MongoDBKVStoreConfig(CommonConfig): + type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value + host: str = "localhost" + port: int = 5432 + db: str = "llamastack" + user: str = None + password: Optional[str] = None + collection_name: str = "llamastack_kvstore" + + @classmethod + def sample_run_config(cls, collection_name: str = "llamastack_kvstore"): + return { + "type": "mongodb", + "namespace": None, + "host": "${env.MONGODB_HOST:localhost}", + "port": "${env.MONGODB_PORT:5432}", + "db": "${env.MONGODB_DB}", + "user": "${env.MONGODB_USER}", + "password": "${env.MONGODB_PASSWORD}", + "table_name": "${env.MONGODB_COLLECTION_NAME:" + collection_name + "}", + } + + KVStoreConfig = Annotated[ - Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig], + Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig, MongoDBKVStoreConfig], Field(discriminator="type", default=KVStoreType.sqlite.value), ] diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py index 469f400d0..deab90602 100644 --- a/llama_stack/providers/utils/kvstore/kvstore.py +++ b/llama_stack/providers/utils/kvstore/kvstore.py @@ -9,7 +9,7 @@ from .config import * # noqa: F403 def kvstore_dependencies(): - return ["aiosqlite", "psycopg2-binary", "redis"] + return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"] class InmemoryKVStoreImpl(KVStore): @@ -46,6 +46,10 @@ async def kvstore_impl(config: KVStoreConfig) -> KVStore: from .postgres import PostgresKVStoreImpl impl = PostgresKVStoreImpl(config) + elif config.type == KVStoreType.mongodb.value: + from .mongodb import MongoDBKVStoreImpl + + impl = MongoDBKVStoreImpl(config) else: raise ValueError(f"Unknown kvstore type {config.type}") diff --git a/llama_stack/providers/utils/kvstore/mongodb/__init__.py b/llama_stack/providers/utils/kvstore/mongodb/__init__.py new file mode 100644 index 000000000..4f7fe46e7 --- /dev/null +++ b/llama_stack/providers/utils/kvstore/mongodb/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .mongodb import MongoDBKVStoreImpl diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py new file mode 100644 index 000000000..f749f26d0 --- /dev/null +++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -0,0 +1,87 @@ +import datetime +import logging +from datetime import datetime +from typing import Optional, List + +from pymongo import MongoClient +from pymongo.errors import ConfigurationError + +from llama_stack.providers.utils.kvstore import KVStore, MongoDBKVStoreConfig + +log = logging.getLogger(__name__) + + +class MongoDBKVStoreImpl(KVStore): + def __init__(self, config: MongoDBKVStoreConfig): + self.config = config + self.conn = None + self.collection = None + + async def initialize(self) -> None: + try: + conn_creds = { + "host": self.config.host, + "port": self.config.port, + "username": self.config.user, + "password": self.config.password, + } + conn_creds = {k: v for k, v in conn_creds if v is not None} + + try: + self.conn = MongoClient(**conn_creds) + self.collection = self.conn[self.config.db] + except (ConnectionError, ConfigurationError) as e: + raise Exception(f"Failed to connect to MongoDB: {e}") + except Exception as e: + log.exception("Could not connect to MongoDB database server") + raise RuntimeError("Could not connect to MongoDB database server") from e + + def _namespaced_key(self, key: str) -> str: + if not self.config.namespace: + return key + return f"{self.config.namespace}:{key}" + + async def set( + self, key: str, value: str, expiration: Optional[datetime] = None + ) -> None: + key = self._namespaced_key(key) + update_query = { + "$set": { + "value": value, + "expiration": expiration + } + } + self.collection.update_one( + {"key": key}, + update_query, + upsert=True + ) + + async def get(self, key: str) -> Optional[str]: + key = self._namespaced_key(key) + query = { + "key": key, + "$or": [ + {"expiration": {"$exists": False}}, + {"expiration": {"$gt": datetime.now(datetime.UTC)}}, + ], + } + result = self.collection.find_one(query, {"value": 1, "_id": 0}) + return result["value"] if result else None + + async def delete(self, key: str) -> None: + key = self._namespaced_key(key) + self.collection.delete_one({"key": key}) + + async def range(self, start_key: str, end_key: str) -> List[str]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + query = { + "key": {"$gte": start_key, "$lt": end_key}, + "$or": [ + {"expiration": {"$exists": False}}, + {"expiration": {"$gt": datetime.now(datetime.UTC)}}, + ], + } + cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1) + return [doc["value"] for doc in cursor]