From 8fe22230b8e8dccf139424dfa478a7a805b5eed0 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Fri, 1 Nov 2024 08:31:17 -0700 Subject: [PATCH] persist registered objects with distribution --- llama_stack/apis/datasets/datasets.py | 6 +-- llama_stack/apis/models/models.py | 2 +- .../scoring_functions/scoring_functions.py | 8 +-- llama_stack/apis/shields/shields.py | 2 +- .../distribution/routers/routing_tables.py | 4 +- llama_stack/distribution/server/server.py | 23 +++++++-- llama_stack/distribution/store/__init__.py | 3 ++ llama_stack/distribution/store/registry.py | 51 +++++++++++++++++++ .../adapters/inference/ollama/ollama.py | 2 +- 9 files changed, 85 insertions(+), 16 deletions(-) create mode 100644 llama_stack/distribution/store/__init__.py create mode 100644 llama_stack/distribution/store/registry.py diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 7a56049bf..9cfe1ca36 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -10,10 +10,10 @@ from llama_models.llama3.api.datatypes import URL from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel, Field - from llama_stack.apis.common.type_system import ParamType +from pydantic import BaseModel, Field + @json_schema_type class DatasetDef(BaseModel): @@ -24,7 +24,7 @@ class DatasetDef(BaseModel): description="The schema definition for this dataset", ) url: URL - metadata: Dict[str, Any] = Field( + metadata: Dict[str, str] = Field( default_factory=dict, description="Any additional metadata for this dataset", ) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 994c8e995..2cb72cb3b 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -17,7 +17,7 @@ class ModelDef(BaseModel): llama_model: str = Field( description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.", ) - metadata: Dict[str, Any] = Field( + metadata: Dict[str, str] = Field( default_factory=dict, description="Any additional metadata for this model", ) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 2e5bf0aef..5311f05f0 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from typing import Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel, Field -from llama_stack.apis.common.type_system import ParamType +from llama_stack.apis.common.type_system import JsonType, ParamType +from pydantic import BaseModel, Field @json_schema_type @@ -36,7 +36,7 @@ class LLMAsJudgeContext(BaseModel): class ScoringFnDef(BaseModel): identifier: str description: Optional[str] = None - metadata: Dict[str, Any] = Field( + metadata: Dict[str, str] = Field( default_factory=dict, description="Any additional metadata for this definition", ) diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 7f003faa2..3fb229e17 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -26,7 +26,7 @@ class ShieldDef(BaseModel): type: str = Field( description="The type of shield this is; the value is one of the ShieldType enum" ) - params: Dict[str, Any] = Field( + params: Dict[str, str] = Field( default_factory=dict, description="Any additional parameters needed for this shield", ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 4e462c54b..06fb49092 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -14,6 +14,7 @@ from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +import llama_stack.distribution.store as distribution_store def get_impl_api(p: Any) -> Api: @@ -170,8 +171,7 @@ class CommonRoutingTableImpl(RoutingTable): if obj.identifier not in self.registry: self.registry[obj.identifier] = [] self.registry[obj.identifier].append(obj) - - # TODO: persist this to a store + await distribution_store.REGISTRY.register(obj) class ModelsRoutingTable(CommonRoutingTableImpl, Models): diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index b8fe4734e..ef1c5ec61 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -22,25 +22,28 @@ import yaml from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse -from pydantic import BaseModel, ValidationError -from termcolor import cprint -from typing_extensions import Annotated from llama_stack.distribution.distribution import ( builtin_automatically_routed_apis, get_provider_registry, ) +from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR + from llama_stack.providers.utils.telemetry.tracing import ( end_trace, setup_logger, SpanStatus, start_trace, ) +from pydantic import BaseModel, ValidationError +from termcolor import cprint +from typing_extensions import Annotated from llama_stack.distribution.datatypes import * # noqa: F403 - +import llama_stack.distribution.store as distribution_store from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls +from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from .endpoints import get_all_api_endpoints @@ -278,6 +281,18 @@ def main( config = StackRunConfig(**yaml.safe_load(fp)) app = FastAPI() + # instantiate kvstore for storing and retrieving distribution metadata + dist_kvstore = asyncio.run( + kvstore_impl( + SqliteKVStoreConfig( + db_path=( + DISTRIBS_BASE_DIR / config.image_name / "kvstore.db" + ).as_posix() + ) + ) + ) + + distribution_store.REGISTRY = distribution_store.DiskRegistry(dist_kvstore) impls = asyncio.run(resolve_impls(config, get_provider_registry())) if Api.telemetry in impls: diff --git a/llama_stack/distribution/store/__init__.py b/llama_stack/distribution/store/__init__.py new file mode 100644 index 000000000..75813f424 --- /dev/null +++ b/llama_stack/distribution/store/__init__.py @@ -0,0 +1,3 @@ +from .registry import DiskRegistry, Registry + +REGISTRY: Registry = None diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py new file mode 100644 index 000000000..444090a4e --- /dev/null +++ b/llama_stack/distribution/store/registry.py @@ -0,0 +1,51 @@ +import json + +from typing import Protocol + +from docs.openapi_generator.strong_typing.deserializer import create_deserializer + +from docs.openapi_generator.strong_typing.serialization import object_to_json + +from llama_stack.distribution.datatypes import RoutableObjectWithProvider + +from llama_stack.providers.utils.kvstore import KVStore + + +class Registry(Protocol): + async def get(self, identifier: str) -> [RoutableObjectWithProvider]: ... + async def register(self, obj: RoutableObjectWithProvider) -> None: ... + + +KEY_FORMAT = "distributions:registry:{}" +DESERIALIZER = create_deserializer(RoutableObjectWithProvider) + + +class DiskRegistry(Registry): + def __init__(self, kvstore: KVStore): + self.kvstore = kvstore + + async def get(self, identifier: str) -> [RoutableObjectWithProvider]: + # Get JSON string from kvstore + json_str = await self.kvstore.get(KEY_FORMAT.format(identifier)) + if not json_str: + return [] + + # Parse JSON string into list of objects + objects_data = json.loads(json_str) + + return [DESERIALIZER.parse(obj_str) for obj_str in objects_data] + + async def register(self, obj: RoutableObjectWithProvider) -> None: + # Get existing objects for this identifier + existing_objects = await self.get(obj.identifier) + + # Add new object to list + existing_objects.append(obj) + + # Convert all objects to JSON strings and store as JSON array + objects_json = [ + object_to_json(existing_object) for existing_object in existing_objects + ] + await self.kvstore.set( + KEY_FORMAT.format(obj.identifier), json.dumps(objects_json) + ) diff --git a/llama_stack/providers/adapters/inference/ollama/ollama.py b/llama_stack/providers/adapters/inference/ollama/ollama.py index 916241a7c..d1b51b9b1 100644 --- a/llama_stack/providers/adapters/inference/ollama/ollama.py +++ b/llama_stack/providers/adapters/inference/ollama/ollama.py @@ -81,7 +81,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): identifier=llama_model, llama_model=llama_model, metadata={ - "ollama_model": r["model"], + "ollama_model": str(r["model"]), }, ) )