diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 611590a29..5d083b8db 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Optional, Protocol +from typing import Dict, List, Literal, Optional, Protocol from llama_models.llama3.api.datatypes import URL @@ -32,6 +32,7 @@ class DatasetDef(BaseModel): @json_schema_type class DatasetDefWithProvider(DatasetDef): + type: Literal["dataset"] = "dataset" provider_id: str = Field( description="ID of the provider which serves this dataset", ) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 09a0e3d9e..5c75a1ca0 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Optional, Protocol, runtime_checkable +from typing import Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field @@ -25,6 +25,7 @@ class ModelDef(BaseModel): @json_schema_type class ModelDefWithProvider(ModelDef): + type: Literal["model"] = "model" provider_id: str = Field( description="The provider ID for this model", ) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index af0b0c1ee..0a4ff6316 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -4,12 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Dict, List, Optional, Protocol, runtime_checkable +from typing import Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod -from pydantic import BaseModel, Field from llama_stack.apis.common.type_system import JsonType, ParamType +from pydantic import BaseModel, Field @json_schema_type @@ -53,6 +53,7 @@ class ScoringFnDef(BaseModel): @json_schema_type class ScoringFnDefWithProvider(ScoringFnDef): + type: Literal["scoring_fn"] = "scoring_fn" provider_id: str = Field( description="ID of the provider which serves this dataset", ) diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index e59f90f22..826e7c30e 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from enum import Enum -from typing import Dict, List, Optional, Protocol, runtime_checkable +from typing import Dict, List, Literal, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field @@ -34,6 +34,7 @@ class ShieldDef(BaseModel): @json_schema_type class ShieldDefWithProvider(ShieldDef): + type: Literal["shield"] = "shield" provider_id: str = Field( description="The provider ID for this shield type", ) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 9ad82cd79..9d2a10adb 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -37,12 +37,16 @@ RoutableObject = Union[ ScoringFnDef, ] -RoutableObjectWithProvider = Union[ - ModelDefWithProvider, - ShieldDefWithProvider, - MemoryBankDefWithProvider, - DatasetDefWithProvider, - ScoringFnDefWithProvider, + +RoutableObjectWithProvider = Annotated[ + Union[ + ModelDefWithProvider, + ShieldDefWithProvider, + MemoryBankDefWithProvider, + DatasetDefWithProvider, + ScoringFnDefWithProvider, + ], + Field(discriminator="type"), ] RoutedProtocol = Union[ diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 5e9ed381f..35ccd5178 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -2,9 +2,7 @@ import json from typing import Protocol -from docs.openapi_generator.strong_typing.deserializer import create_deserializer - -from docs.openapi_generator.strong_typing.serialization import object_to_json +import pydantic from llama_stack.distribution.datatypes import RoutableObjectWithProvider @@ -17,7 +15,6 @@ class Registry(Protocol): KEY_FORMAT = "distributions:registry:{}" -DESERIALIZER = create_deserializer(RoutableObjectWithProvider) class DiskRegistry(Registry): @@ -33,20 +30,21 @@ class DiskRegistry(Registry): # Parse JSON string into list of objects objects_data = json.loads(json_str) - return [DESERIALIZER.parse(obj_str) for obj_str in objects_data] + return [ + pydantic.parse_obj_as( + RoutableObjectWithProvider, + obj_str, + ) + for obj_str in objects_data + ] # TODO: make it thread safe using CAS async def register(self, obj: RoutableObjectWithProvider) -> None: - # Get existing objects for this identifier existing_objects = await self.get(obj.identifier) - # Add new object to list existing_objects.append(obj) - # Convert all objects to JSON strings and store as JSON array - objects_json = [ - object_to_json(existing_object) for existing_object in existing_objects - ] + objects_json = [obj.model_dump_json() for existing_object in existing_objects] await self.kvstore.set( KEY_FORMAT.format(obj.identifier), json.dumps(objects_json) ) diff --git a/llama_stack/distribution/store/tests/test_registry.py b/llama_stack/distribution/store/tests/test_registry.py new file mode 100644 index 000000000..cd9b07a79 --- /dev/null +++ b/llama_stack/distribution/store/tests/test_registry.py @@ -0,0 +1,28 @@ +import pytest +import pytest_asyncio +from llama_stack.distribution.store import * +from llama_stack.apis.memory_banks import GraphMemoryBankDef, VectorMemoryBankDef +from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig +from llama_stack.distribution.datatypes import * # noqa: F403 + + +@pytest.mark.asyncio +async def test_registry(): + registry = DiskRegistry(await kvstore_impl(SqliteKVStoreConfig())) + bank = VectorMemoryBankDef( + identifier="test_bank", + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + provider_id="bar", + ) + + await registry.register(bank) + result_bank = await registry.get("test_bank") + # assert result_bank == bank + assert result_bank.identifier == bank.identifier + assert result_bank.embedding_model == bank.embedding_model + assert result_bank.chunk_size_in_tokens == bank.chunk_size_in_tokens + assert result_bank.overlap_size_in_tokens == bank.overlap_size_in_tokens + assert result_bank.provider_id == bank.provider_id