From b74f25035cf5b93eaf32d0ae3c9b7a6a0b6d131a Mon Sep 17 00:00:00 2001
From: Shrinit Goyal <64660979+shrinitg@users.noreply.github.com>
Date: Thu, 20 Feb 2025 12:00:50 +0530
Subject: [PATCH] Added support for mongoDB KV store (#543)
Added the support for mongoDB as KV store
validated in mongodb, it is able to store agent data, session data and
turn data
this is how run.yaml would look:
```
config:
persistence_store:
type: mongodb
namespace: null
host: localhost
port: 27017
db: llamastack
user: ""
password: ""
collection_name: llamastack_kvstore
```
---------
Co-authored-by: shrinitgoyal
---
llama_stack/providers/utils/kvstore/config.py | 26 ++++++-
.../providers/utils/kvstore/kvstore.py | 6 +-
.../utils/kvstore/mongodb/__init__.py | 7 ++
.../utils/kvstore/mongodb/mongodb.py | 69 +++++++++++++++++++
4 files changed, 106 insertions(+), 2 deletions(-)
create mode 100644 llama_stack/providers/utils/kvstore/mongodb/__init__.py
create mode 100644 llama_stack/providers/utils/kvstore/mongodb/mongodb.py
diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py
index 85327c131..b9403df32 100644
--- a/llama_stack/providers/utils/kvstore/config.py
+++ b/llama_stack/providers/utils/kvstore/config.py
@@ -18,6 +18,7 @@ class KVStoreType(Enum):
redis = "redis"
sqlite = "sqlite"
postgres = "postgres"
+ mongodb = "mongodb"
class CommonConfig(BaseModel):
@@ -101,7 +102,30 @@ class PostgresKVStoreConfig(CommonConfig):
return v
+class MongoDBKVStoreConfig(CommonConfig):
+ type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value
+ host: str = "localhost"
+ port: int = 27017
+ db: str = "llamastack"
+ user: str = None
+ password: Optional[str] = None
+ collection_name: str = "llamastack_kvstore"
+
+ @classmethod
+ def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
+ return {
+ "type": "mongodb",
+ "namespace": None,
+ "host": "${env.MONGODB_HOST:localhost}",
+ "port": "${env.MONGODB_PORT:5432}",
+ "db": "${env.MONGODB_DB}",
+ "user": "${env.MONGODB_USER}",
+ "password": "${env.MONGODB_PASSWORD}",
+ "collection_name": "${env.MONGODB_COLLECTION_NAME:" + collection_name + "}",
+ }
+
+
KVStoreConfig = Annotated[
- Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig],
+ Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig, MongoDBKVStoreConfig],
Field(discriminator="type", default=KVStoreType.sqlite.value),
]
diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py
index 32b4e40dd..6bc175260 100644
--- a/llama_stack/providers/utils/kvstore/kvstore.py
+++ b/llama_stack/providers/utils/kvstore/kvstore.py
@@ -11,7 +11,7 @@ from .config import KVStoreConfig, KVStoreType
def kvstore_dependencies():
- return ["aiosqlite", "psycopg2-binary", "redis"]
+ return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
class InmemoryKVStoreImpl(KVStore):
@@ -44,6 +44,10 @@ async def kvstore_impl(config: KVStoreConfig) -> KVStore:
from .postgres import PostgresKVStoreImpl
impl = PostgresKVStoreImpl(config)
+ elif config.type == KVStoreType.mongodb.value:
+ from .mongodb import MongoDBKVStoreImpl
+
+ impl = MongoDBKVStoreImpl(config)
else:
raise ValueError(f"Unknown kvstore type {config.type}")
diff --git a/llama_stack/providers/utils/kvstore/mongodb/__init__.py b/llama_stack/providers/utils/kvstore/mongodb/__init__.py
new file mode 100644
index 000000000..4f7fe46e7
--- /dev/null
+++ b/llama_stack/providers/utils/kvstore/mongodb/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from .mongodb import MongoDBKVStoreImpl
diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py
new file mode 100644
index 000000000..625fce929
--- /dev/null
+++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py
@@ -0,0 +1,69 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import logging
+from datetime import datetime
+from typing import List, Optional
+
+from pymongo import MongoClient
+
+from llama_stack.providers.utils.kvstore import KVStore, MongoDBKVStoreConfig
+
+log = logging.getLogger(__name__)
+
+
+class MongoDBKVStoreImpl(KVStore):
+ def __init__(self, config: MongoDBKVStoreConfig):
+ self.config = config
+ self.conn = None
+ self.collection = None
+
+ async def initialize(self) -> None:
+ try:
+ conn_creds = {
+ "host": self.config.host,
+ "port": self.config.port,
+ "username": self.config.user,
+ "password": self.config.password,
+ }
+ conn_creds = {k: v for k, v in conn_creds.items() if v is not None}
+ self.conn = MongoClient(**conn_creds)
+ self.collection = self.conn[self.config.db][self.config.collection_name]
+ except Exception as e:
+ log.exception("Could not connect to MongoDB database server")
+ raise RuntimeError("Could not connect to MongoDB database server") from e
+
+ def _namespaced_key(self, key: str) -> str:
+ if not self.config.namespace:
+ return key
+ return f"{self.config.namespace}:{key}"
+
+ async def set(
+ self, key: str, value: str, expiration: Optional[datetime] = None
+ ) -> None:
+
+ key = self._namespaced_key(key)
+ update_query = {"$set": {"value": value, "expiration": expiration}}
+ self.collection.update_one({"key": key}, update_query, upsert=True)
+
+ async def get(self, key: str) -> Optional[str]:
+ key = self._namespaced_key(key)
+ query = {"key": key}
+ result = self.collection.find_one(query, {"value": 1, "_id": 0})
+ return result["value"] if result else None
+
+ async def delete(self, key: str) -> None:
+ key = self._namespaced_key(key)
+ self.collection.delete_one({"key": key})
+
+ async def range(self, start_key: str, end_key: str) -> List[str]:
+ start_key = self._namespaced_key(start_key)
+ end_key = self._namespaced_key(end_key)
+ query = {
+ "key": {"$gte": start_key, "$lt": end_key},
+ }
+ cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
+ return [doc["value"] for doc in cursor]