Added support for mongoDB KV store

This commit is contained in:
shrinitgoyal 2024-11-28 13:31:27 +05:30
parent b1a63df8cd
commit 23df7db896
4 changed files with 129 additions and 7 deletions

View file

@ -17,7 +17,8 @@ from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class KVStoreType(Enum):
redis = "redis"
sqlite = "sqlite"
postgres = "postgres"
postgres = "postgres",
mongodb = "mongodb"
class CommonConfig(BaseModel):
@ -55,15 +56,15 @@ class SqliteKVStoreConfig(CommonConfig):
@classmethod
def sample_run_config(
cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"
cls, __distro_dir__: str = "runtime", db_name: str = "kvstore.db"
):
return {
"type": "sqlite",
"namespace": None,
"db_path": "${env.SQLITE_STORE_DIR:~/.llama/"
+ __distro_dir__
+ "}/"
+ db_name,
+ __distro_dir__
+ "}/"
+ db_name,
}
@ -106,7 +107,30 @@ class PostgresKVStoreConfig(CommonConfig):
return v
class MongoDBKVStoreConfig(CommonConfig):
type: Literal[KVStoreType.mongodb.value] = KVStoreType.mongodb.value
host: str = "localhost"
port: int = 5432
db: str = "llamastack"
user: str = None
password: Optional[str] = None
collection_name: str = "llamastack_kvstore"
@classmethod
def sample_run_config(cls, collection_name: str = "llamastack_kvstore"):
return {
"type": "mongodb",
"namespace": None,
"host": "${env.MONGODB_HOST:localhost}",
"port": "${env.MONGODB_PORT:5432}",
"db": "${env.MONGODB_DB}",
"user": "${env.MONGODB_USER}",
"password": "${env.MONGODB_PASSWORD}",
"table_name": "${env.MONGODB_COLLECTION_NAME:" + collection_name + "}",
}
KVStoreConfig = Annotated[
Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig],
Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig, MongoDBKVStoreConfig],
Field(discriminator="type", default=KVStoreType.sqlite.value),
]

View file

@ -9,7 +9,7 @@ from .config import * # noqa: F403
def kvstore_dependencies():
return ["aiosqlite", "psycopg2-binary", "redis"]
return ["aiosqlite", "psycopg2-binary", "redis", "pymongo"]
class InmemoryKVStoreImpl(KVStore):
@ -46,6 +46,10 @@ async def kvstore_impl(config: KVStoreConfig) -> KVStore:
from .postgres import PostgresKVStoreImpl
impl = PostgresKVStoreImpl(config)
elif config.type == KVStoreType.mongodb.value:
from .mongodb import MongoDBKVStoreImpl
impl = MongoDBKVStoreImpl(config)
else:
raise ValueError(f"Unknown kvstore type {config.type}")

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .mongodb import MongoDBKVStoreImpl

View file

@ -0,0 +1,87 @@
import datetime
import logging
from datetime import datetime
from typing import Optional, List
from pymongo import MongoClient
from pymongo.errors import ConfigurationError
from llama_stack.providers.utils.kvstore import KVStore, MongoDBKVStoreConfig
log = logging.getLogger(__name__)
class MongoDBKVStoreImpl(KVStore):
def __init__(self, config: MongoDBKVStoreConfig):
self.config = config
self.conn = None
self.collection = None
async def initialize(self) -> None:
try:
conn_creds = {
"host": self.config.host,
"port": self.config.port,
"username": self.config.user,
"password": self.config.password,
}
conn_creds = {k: v for k, v in conn_creds if v is not None}
try:
self.conn = MongoClient(**conn_creds)
self.collection = self.conn[self.config.db]
except (ConnectionError, ConfigurationError) as e:
raise Exception(f"Failed to connect to MongoDB: {e}")
except Exception as e:
log.exception("Could not connect to MongoDB database server")
raise RuntimeError("Could not connect to MongoDB database server") from e
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
return f"{self.config.namespace}:{key}"
async def set(
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None:
key = self._namespaced_key(key)
update_query = {
"$set": {
"value": value,
"expiration": expiration
}
}
self.collection.update_one(
{"key": key},
update_query,
upsert=True
)
async def get(self, key: str) -> Optional[str]:
key = self._namespaced_key(key)
query = {
"key": key,
"$or": [
{"expiration": {"$exists": False}},
{"expiration": {"$gt": datetime.now(datetime.UTC)}},
],
}
result = self.collection.find_one(query, {"value": 1, "_id": 0})
return result["value"] if result else None
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
self.collection.delete_one({"key": key})
async def range(self, start_key: str, end_key: str) -> List[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
query = {
"key": {"$gte": start_key, "$lt": end_key},
"$or": [
{"expiration": {"$exists": False}},
{"expiration": {"$gt": datetime.now(datetime.UTC)}},
],
}
cursor = self.collection.find(query, {"value": 1, "_id": 0}).sort("key", 1)
return [doc["value"] for doc in cursor]