From b74f25035cf5b93eaf32d0ae3c9b7a6a0b6d131a Mon Sep 17 00:00:00 2001 From: Shrinit Goyal <64660979+shrinitg@users.noreply.github.com> Date: Thu, 20 Feb 2025 12:00:50 +0530 Subject: [PATCH] Added support for mongoDB KV store (#543) Added the support for mongoDB as KV store validated in mongodb, it is able to store agent data, session data and turn data image this is how run.yaml would look: ``` config: persistence_store: type: mongodb namespace: null host: localhost port: 27017 db: llamastack user: "" password: "" collection_name: llamastack_kvstore ``` --------- Co-authored-by: shrinitgoyal --- llama_stack/providers/utils/kvstore/config.py | 26 ++++++- .../providers/utils/kvstore/kvstore.py | 6 +- .../utils/kvstore/mongodb/__init__.py | 7 ++ .../utils/kvstore/mongodb/mongodb.py | 69 +++++++++++++++++++ 4 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 llama_stack/providers/utils/kvstore/mongodb/__init__.py create mode 100644 llama_stack/providers/utils/kvstore/mongodb/mongodb.py diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index 85327c131..b9403df32 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -18,6 +18,7 @@ class KVStoreType(Enum): redis = "redis" sqlite = "sqlite" postgres = "postgres" + mongodb = "mongodb" class CommonConfig(BaseModel): @@ -101,7 +102,30 @@ class PostgresKVStoreConfig(CommonConfig): return v +class MongoDBKVStoreConfig(CommonConfig): + type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value + host: str = "localhost" + port: int = 27017 + db: str = "llamastack" + user: str = None + password: Optional[str] = None + collection_name: str = "llamastack_kvstore" + + @classmethod + def sample_run_config(cls, collection_name: str = "llamastack_kvstore"): + return { + "type": "mongodb", + "namespace": None, + "host": "${env.MONGODB_HOST:localhost}", + "port": "${env.MONGODB_PORT:5432}", + "db": "${env.MONGODB_DB}", + "user": "${env.MONGODB_USER}", + "password": "${env.MONGODB_PASSWORD}", + "collection_name": "${env.MONGODB_COLLECTION_NAME:" + collection_name + "}", + } + + KVStoreConfig = Annotated[ - Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig], + Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig, MongoDBKVStoreConfig], Field(discriminator="type", default=KVStoreType.sqlite.value), ] diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py index 32b4e40dd..6bc175260 100644 --- a/llama_stack/providers/utils/kvstore/kvstore.py +++ b/llama_stack/providers/utils/kvstore/kvstore.py @@ -11,7 +11,7 @@ from .config import KVStoreConfig, KVStoreType def kvstore_dependencies(): - return ["aiosqlite", "psycopg2-binary", "redis"] + return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"] class InmemoryKVStoreImpl(KVStore): @@ -44,6 +44,10 @@ async def kvstore_impl(config: KVStoreConfig) -> KVStore: from .postgres import PostgresKVStoreImpl impl = PostgresKVStoreImpl(config) + elif config.type == KVStoreType.mongodb.value: + from .mongodb import MongoDBKVStoreImpl + + impl = MongoDBKVStoreImpl(config) else: raise ValueError(f"Unknown kvstore type {config.type}") diff --git a/llama_stack/providers/utils/kvstore/mongodb/__init__.py b/llama_stack/providers/utils/kvstore/mongodb/__init__.py new file mode 100644 index 000000000..4f7fe46e7 --- /dev/null +++ b/llama_stack/providers/utils/kvstore/mongodb/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .mongodb import MongoDBKVStoreImpl diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py new file mode 100644 index 000000000..625fce929 --- /dev/null +++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +from datetime import datetime +from typing import List, Optional + +from pymongo import MongoClient + +from llama_stack.providers.utils.kvstore import KVStore, MongoDBKVStoreConfig + +log = logging.getLogger(__name__) + + +class MongoDBKVStoreImpl(KVStore): + def __init__(self, config: MongoDBKVStoreConfig): + self.config = config + self.conn = None + self.collection = None + + async def initialize(self) -> None: + try: + conn_creds = { + "host": self.config.host, + "port": self.config.port, + "username": self.config.user, + "password": self.config.password, + } + conn_creds = {k: v for k, v in conn_creds.items() if v is not None} + self.conn = MongoClient(**conn_creds) + self.collection = self.conn[self.config.db][self.config.collection_name] + except Exception as e: + log.exception("Could not connect to MongoDB database server") + raise RuntimeError("Could not connect to MongoDB database server") from e + + def _namespaced_key(self, key: str) -> str: + if not self.config.namespace: + return key + return f"{self.config.namespace}:{key}" + + async def set( + self, key: str, value: str, expiration: Optional[datetime] = None + ) -> None: + + key = self._namespaced_key(key) + update_query = {"$set": {"value": value, "expiration": expiration}} + self.collection.update_one({"key": key}, update_query, upsert=True) + + async def get(self, key: str) -> Optional[str]: + key = self._namespaced_key(key) + query = {"key": key} + result = self.collection.find_one(query, {"value": 1, "_id": 0}) + return result["value"] if result else None + + async def delete(self, key: str) -> None: + key = self._namespaced_key(key) + self.collection.delete_one({"key": key}) + + async def range(self, start_key: str, end_key: str) -> List[str]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + query = { + "key": {"$gte": start_key, "$lt": end_key}, + } + cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1) + return [doc["value"] for doc in cursor]