mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 19:04:19 +00:00
persist registered objects with distribution (#354)
* persist registered objects with distribution * linter fixes * comment * use annotate and field discriminator * workign tests * donot use global state * precommit failures fixed * add back Any * fix imports * remove unnecessary changes in ollama * precommit failures fixed * make kvstore configurable for dist and rename registry * add comment about registry list return * fix linter errors * use registry to hydrate * remove debug print * linter fixes * remove kvstore.db * rename distribution_registry_store --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
c9bf1d7d0b
commit
663883cc29
12 changed files with 401 additions and 46 deletions
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from typing import Any, Dict, List, Literal, Optional, Protocol
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
from llama_models.llama3.api.datatypes import URL
|
||||||
|
|
||||||
|
@ -32,6 +32,7 @@ class DatasetDef(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DatasetDefWithProvider(DatasetDef):
|
class DatasetDefWithProvider(DatasetDef):
|
||||||
|
type: Literal["dataset"] = "dataset"
|
||||||
provider_id: str = Field(
|
provider_id: str = Field(
|
||||||
description="ID of the provider which serves this dataset",
|
description="ID of the provider which serves this dataset",
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -25,6 +25,7 @@ class ModelDef(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ModelDefWithProvider(ModelDef):
|
class ModelDefWithProvider(ModelDef):
|
||||||
|
type: Literal["model"] = "model"
|
||||||
provider_id: str = Field(
|
provider_id: str = Field(
|
||||||
description="The provider ID for this model",
|
description="The provider ID for this model",
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -53,6 +53,7 @@ class ScoringFnDef(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ScoringFnDefWithProvider(ScoringFnDef):
|
class ScoringFnDefWithProvider(ScoringFnDef):
|
||||||
|
type: Literal["scoring_fn"] = "scoring_fn"
|
||||||
provider_id: str = Field(
|
provider_id: str = Field(
|
||||||
description="ID of the provider which serves this dataset",
|
description="ID of the provider which serves this dataset",
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -34,6 +34,7 @@ class ShieldDef(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ShieldDefWithProvider(ShieldDef):
|
class ShieldDefWithProvider(ShieldDef):
|
||||||
|
type: Literal["shield"] = "shield"
|
||||||
provider_id: str = Field(
|
provider_id: str = Field(
|
||||||
description="The provider ID for this shield type",
|
description="The provider ID for this shield type",
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
||||||
|
|
||||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||||
|
@ -37,12 +38,16 @@ RoutableObject = Union[
|
||||||
ScoringFnDef,
|
ScoringFnDef,
|
||||||
]
|
]
|
||||||
|
|
||||||
RoutableObjectWithProvider = Union[
|
|
||||||
|
RoutableObjectWithProvider = Annotated[
|
||||||
|
Union[
|
||||||
ModelDefWithProvider,
|
ModelDefWithProvider,
|
||||||
ShieldDefWithProvider,
|
ShieldDefWithProvider,
|
||||||
MemoryBankDefWithProvider,
|
MemoryBankDefWithProvider,
|
||||||
DatasetDefWithProvider,
|
DatasetDefWithProvider,
|
||||||
ScoringFnDefWithProvider,
|
ScoringFnDefWithProvider,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
]
|
]
|
||||||
|
|
||||||
RoutedProtocol = Union[
|
RoutedProtocol = Union[
|
||||||
|
@ -134,6 +139,12 @@ One or more providers to use for each API. The same provider_type (e.g., meta-re
|
||||||
can be instantiated multiple times (with different configs) if necessary.
|
can be instantiated multiple times (with different configs) if necessary.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
metadata_store: Optional[KVStoreConfig] = Field(
|
||||||
|
default=None,
|
||||||
|
description="""
|
||||||
|
Configuration for the persistence store used by the distribution registry. If not specified,
|
||||||
|
a default SQLite store will be used.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
|
|
|
@ -26,6 +26,7 @@ from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.telemetry import Telemetry
|
from llama_stack.apis.telemetry import Telemetry
|
||||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||||
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,7 +66,9 @@ class ProviderWithSpec(Provider):
|
||||||
|
|
||||||
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
# TODO: this code is not very straightforward to follow and needs one more round of refactoring
|
||||||
async def resolve_impls(
|
async def resolve_impls(
|
||||||
run_config: StackRunConfig, provider_registry: Dict[Api, Dict[str, ProviderSpec]]
|
run_config: StackRunConfig,
|
||||||
|
provider_registry: Dict[Api, Dict[str, ProviderSpec]],
|
||||||
|
dist_registry: DistributionRegistry,
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
"""
|
"""
|
||||||
Does two things:
|
Does two things:
|
||||||
|
@ -189,6 +192,7 @@ async def resolve_impls(
|
||||||
provider,
|
provider,
|
||||||
deps,
|
deps,
|
||||||
inner_impls,
|
inner_impls,
|
||||||
|
dist_registry,
|
||||||
)
|
)
|
||||||
# TODO: ugh slightly redesign this shady looking code
|
# TODO: ugh slightly redesign this shady looking code
|
||||||
if "inner-" in api_str:
|
if "inner-" in api_str:
|
||||||
|
@ -237,6 +241,7 @@ async def instantiate_provider(
|
||||||
provider: ProviderWithSpec,
|
provider: ProviderWithSpec,
|
||||||
deps: Dict[str, Any],
|
deps: Dict[str, Any],
|
||||||
inner_impls: Dict[str, Any],
|
inner_impls: Dict[str, Any],
|
||||||
|
dist_registry: DistributionRegistry,
|
||||||
):
|
):
|
||||||
protocols = api_protocol_map()
|
protocols = api_protocol_map()
|
||||||
additional_protocols = additional_protocols_map()
|
additional_protocols = additional_protocols_map()
|
||||||
|
@ -270,7 +275,7 @@ async def instantiate_provider(
|
||||||
method = "get_routing_table_impl"
|
method = "get_routing_table_impl"
|
||||||
|
|
||||||
config = None
|
config = None
|
||||||
args = [provider_spec.api, inner_impls, deps]
|
args = [provider_spec.api, inner_impls, deps, dist_registry]
|
||||||
else:
|
else:
|
||||||
method = "get_provider_impl"
|
method = "get_provider_impl"
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,9 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
|
|
||||||
from .routing_tables import (
|
from .routing_tables import (
|
||||||
DatasetsRoutingTable,
|
DatasetsRoutingTable,
|
||||||
MemoryBanksRoutingTable,
|
MemoryBanksRoutingTable,
|
||||||
|
@ -20,6 +23,7 @@ async def get_routing_table_impl(
|
||||||
api: Api,
|
api: Api,
|
||||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||||
_deps,
|
_deps,
|
||||||
|
dist_registry: DistributionRegistry,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
api_to_tables = {
|
api_to_tables = {
|
||||||
"memory_banks": MemoryBanksRoutingTable,
|
"memory_banks": MemoryBanksRoutingTable,
|
||||||
|
@ -32,7 +36,7 @@ async def get_routing_table_impl(
|
||||||
if api.value not in api_to_tables:
|
if api.value not in api_to_tables:
|
||||||
raise ValueError(f"API {api.value} not found in router map")
|
raise ValueError(f"API {api.value} not found in router map")
|
||||||
|
|
||||||
impl = api_to_tables[api.value](impls_by_provider_id)
|
impl = api_to_tables[api.value](impls_by_provider_id, dist_registry)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from llama_stack.apis.shields import * # noqa: F403
|
||||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.distribution.store import DistributionRegistry
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,25 +47,23 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
||||||
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
Registry = Dict[str, List[RoutableObjectWithProvider]]
|
||||||
|
|
||||||
|
|
||||||
# TODO: this routing table maintains state in memory purely. We need to
|
|
||||||
# add persistence to it when we add dynamic registration of objects.
|
|
||||||
class CommonRoutingTableImpl(RoutingTable):
|
class CommonRoutingTableImpl(RoutingTable):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
impls_by_provider_id: Dict[str, RoutedProtocol],
|
impls_by_provider_id: Dict[str, RoutedProtocol],
|
||||||
|
dist_registry: DistributionRegistry,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.impls_by_provider_id = impls_by_provider_id
|
self.impls_by_provider_id = impls_by_provider_id
|
||||||
|
self.dist_registry = dist_registry
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.registry: Registry = {}
|
# Initialize the registry if not already done
|
||||||
|
await self.dist_registry.initialize()
|
||||||
|
|
||||||
def add_objects(
|
async def add_objects(
|
||||||
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
objs: List[RoutableObjectWithProvider], provider_id: str, cls
|
||||||
) -> None:
|
) -> None:
|
||||||
for obj in objs:
|
for obj in objs:
|
||||||
if obj.identifier not in self.registry:
|
|
||||||
self.registry[obj.identifier] = []
|
|
||||||
|
|
||||||
if cls is None:
|
if cls is None:
|
||||||
obj.provider_id = provider_id
|
obj.provider_id = provider_id
|
||||||
else:
|
else:
|
||||||
|
@ -74,34 +73,35 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
obj.provider_id = provider_id
|
obj.provider_id = provider_id
|
||||||
else:
|
else:
|
||||||
obj = cls(**obj.model_dump(), provider_id=provider_id)
|
obj = cls(**obj.model_dump(), provider_id=provider_id)
|
||||||
self.registry[obj.identifier].append(obj)
|
await self.dist_registry.register(obj)
|
||||||
|
|
||||||
|
# Register all objects from providers
|
||||||
for pid, p in self.impls_by_provider_id.items():
|
for pid, p in self.impls_by_provider_id.items():
|
||||||
api = get_impl_api(p)
|
api = get_impl_api(p)
|
||||||
if api == Api.inference:
|
if api == Api.inference:
|
||||||
p.model_store = self
|
p.model_store = self
|
||||||
models = await p.list_models()
|
models = await p.list_models()
|
||||||
add_objects(models, pid, ModelDefWithProvider)
|
await add_objects(models, pid, ModelDefWithProvider)
|
||||||
|
|
||||||
elif api == Api.safety:
|
elif api == Api.safety:
|
||||||
p.shield_store = self
|
p.shield_store = self
|
||||||
shields = await p.list_shields()
|
shields = await p.list_shields()
|
||||||
add_objects(shields, pid, ShieldDefWithProvider)
|
await add_objects(shields, pid, ShieldDefWithProvider)
|
||||||
|
|
||||||
elif api == Api.memory:
|
elif api == Api.memory:
|
||||||
p.memory_bank_store = self
|
p.memory_bank_store = self
|
||||||
memory_banks = await p.list_memory_banks()
|
memory_banks = await p.list_memory_banks()
|
||||||
add_objects(memory_banks, pid, None)
|
await add_objects(memory_banks, pid, None)
|
||||||
|
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
p.dataset_store = self
|
p.dataset_store = self
|
||||||
datasets = await p.list_datasets()
|
datasets = await p.list_datasets()
|
||||||
add_objects(datasets, pid, DatasetDefWithProvider)
|
await add_objects(datasets, pid, DatasetDefWithProvider)
|
||||||
|
|
||||||
elif api == Api.scoring:
|
elif api == Api.scoring:
|
||||||
p.scoring_function_store = self
|
p.scoring_function_store = self
|
||||||
scoring_functions = await p.list_scoring_functions()
|
scoring_functions = await p.list_scoring_functions()
|
||||||
add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
|
await add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for p in self.impls_by_provider_id.values():
|
for p in self.impls_by_provider_id.values():
|
||||||
|
@ -124,39 +124,44 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown routing table type")
|
raise ValueError("Unknown routing table type")
|
||||||
|
|
||||||
if routing_key not in self.registry:
|
# Get objects from disk registry
|
||||||
|
objects = self.dist_registry.get_cached(routing_key)
|
||||||
|
if not objects:
|
||||||
apiname, objname = apiname_object()
|
apiname, objname = apiname_object()
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`{routing_key}` not registered. Make sure there is an {apiname} provider serving this {objname}."
|
f"`{routing_key}` not registered. Make sure there is an {apiname} provider serving this {objname}."
|
||||||
)
|
)
|
||||||
|
|
||||||
objs = self.registry[routing_key]
|
for obj in objects:
|
||||||
for obj in objs:
|
|
||||||
if not provider_id or provider_id == obj.provider_id:
|
if not provider_id or provider_id == obj.provider_id:
|
||||||
return self.impls_by_provider_id[obj.provider_id]
|
return self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
raise ValueError(f"Provider not found for `{routing_key}`")
|
raise ValueError(f"Provider not found for `{routing_key}`")
|
||||||
|
|
||||||
def get_object_by_identifier(
|
async def get_object_by_identifier(
|
||||||
self, identifier: str
|
self, identifier: str
|
||||||
) -> Optional[RoutableObjectWithProvider]:
|
) -> Optional[RoutableObjectWithProvider]:
|
||||||
objs = self.registry.get(identifier, [])
|
# Get from disk registry
|
||||||
if not objs:
|
objects = await self.dist_registry.get(identifier)
|
||||||
|
if not objects:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# kind of ill-defined behavior here, but we'll just return the first one
|
# kind of ill-defined behavior here, but we'll just return the first one
|
||||||
return objs[0]
|
return objects[0]
|
||||||
|
|
||||||
async def register_object(self, obj: RoutableObjectWithProvider):
|
async def register_object(self, obj: RoutableObjectWithProvider):
|
||||||
entries = self.registry.get(obj.identifier, [])
|
# Get existing objects from registry
|
||||||
for entry in entries:
|
existing_objects = await self.dist_registry.get(obj.identifier)
|
||||||
if entry.provider_id == obj.provider_id or not obj.provider_id:
|
|
||||||
|
# Check for existing registration
|
||||||
|
for existing_obj in existing_objects:
|
||||||
|
if existing_obj.provider_id == obj.provider_id or not obj.provider_id:
|
||||||
print(
|
print(
|
||||||
f"`{obj.identifier}` already registered with `{entry.provider_id}`"
|
f"`{obj.identifier}` already registered with `{existing_obj.provider_id}`"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# if provider_id is not specified, we'll pick an arbitrary one from existing entries
|
# if provider_id is not specified, pick an arbitrary one from existing entries
|
||||||
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
|
||||||
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
|
||||||
|
@ -166,12 +171,7 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
p = self.impls_by_provider_id[obj.provider_id]
|
p = self.impls_by_provider_id[obj.provider_id]
|
||||||
|
|
||||||
await register_object_with_provider(obj, p)
|
await register_object_with_provider(obj, p)
|
||||||
|
await self.dist_registry.register(obj)
|
||||||
if obj.identifier not in self.registry:
|
|
||||||
self.registry[obj.identifier] = []
|
|
||||||
self.registry[obj.identifier].append(obj)
|
|
||||||
|
|
||||||
# TODO: persist this to a store
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
|
@ -31,6 +31,8 @@ from llama_stack.distribution.distribution import (
|
||||||
get_provider_registry,
|
get_provider_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
|
@ -38,9 +40,10 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls
|
from llama_stack.distribution.resolver import resolve_impls
|
||||||
|
from llama_stack.distribution.store import CachedDiskDistributionRegistry
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||||
|
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
|
@ -278,8 +281,23 @@ def main(
|
||||||
config = StackRunConfig(**yaml.safe_load(fp))
|
config = StackRunConfig(**yaml.safe_load(fp))
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
# instantiate kvstore for storing and retrieving distribution metadata
|
||||||
|
if config.metadata_store:
|
||||||
|
dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store))
|
||||||
|
else:
|
||||||
|
dist_kvstore = asyncio.run(
|
||||||
|
kvstore_impl(
|
||||||
|
SqliteKVStoreConfig(
|
||||||
|
db_path=(
|
||||||
|
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
|
||||||
|
).as_posix()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
impls = asyncio.run(resolve_impls(config, get_provider_registry()))
|
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
||||||
|
|
||||||
|
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
|
|
||||||
|
|
7
llama_stack/distribution/store/__init__.py
Normal file
7
llama_stack/distribution/store/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .registry import * # noqa: F401 F403
|
135
llama_stack/distribution/store/registry.py
Normal file
135
llama_stack/distribution/store/registry.py
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Protocol
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionRegistry(Protocol):
|
||||||
|
async def get_all(self) -> List[RoutableObjectWithProvider]: ...
|
||||||
|
|
||||||
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
async def get(self, identifier: str) -> List[RoutableObjectWithProvider]: ...
|
||||||
|
|
||||||
|
def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]: ...
|
||||||
|
|
||||||
|
# The current data structure allows multiple objects with the same identifier but different providers.
|
||||||
|
# This is not ideal - we should have a single object that can be served by multiple providers,
|
||||||
|
# suggesting a data structure like (obj: Obj, providers: List[str]) rather than List[RoutableObjectWithProvider].
|
||||||
|
# The current approach could lead to inconsistencies if the same logical object has different data across providers.
|
||||||
|
async def register(self, obj: RoutableObjectWithProvider) -> bool: ...
|
||||||
|
|
||||||
|
|
||||||
|
KEY_FORMAT = "distributions:registry:{}"
|
||||||
|
|
||||||
|
|
||||||
|
class DiskDistributionRegistry(DistributionRegistry):
|
||||||
|
def __init__(self, kvstore: KVStore):
|
||||||
|
self.kvstore = kvstore
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]:
|
||||||
|
# Disk registry does not have a cache
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||||
|
start_key = KEY_FORMAT.format("")
|
||||||
|
end_key = KEY_FORMAT.format("\xff")
|
||||||
|
keys = await self.kvstore.range(start_key, end_key)
|
||||||
|
return [await self.get(key.split(":")[-1]) for key in keys]
|
||||||
|
|
||||||
|
async def get(self, identifier: str) -> List[RoutableObjectWithProvider]:
|
||||||
|
json_str = await self.kvstore.get(KEY_FORMAT.format(identifier))
|
||||||
|
if not json_str:
|
||||||
|
return []
|
||||||
|
|
||||||
|
objects_data = json.loads(json_str)
|
||||||
|
return [
|
||||||
|
pydantic.parse_obj_as(
|
||||||
|
RoutableObjectWithProvider,
|
||||||
|
json.loads(obj_str),
|
||||||
|
)
|
||||||
|
for obj_str in objects_data
|
||||||
|
]
|
||||||
|
|
||||||
|
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||||
|
existing_objects = await self.get(obj.identifier)
|
||||||
|
# dont register if the object's providerid already exists
|
||||||
|
for eobj in existing_objects:
|
||||||
|
if eobj.provider_id == obj.provider_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
existing_objects.append(obj)
|
||||||
|
|
||||||
|
objects_json = [
|
||||||
|
obj.model_dump_json() for obj in existing_objects
|
||||||
|
] # Fixed variable name
|
||||||
|
await self.kvstore.set(
|
||||||
|
KEY_FORMAT.format(obj.identifier), json.dumps(objects_json)
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
|
def __init__(self, kvstore: KVStore):
|
||||||
|
super().__init__(kvstore)
|
||||||
|
self.cache: Dict[str, List[RoutableObjectWithProvider]] = {}
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
start_key = KEY_FORMAT.format("")
|
||||||
|
end_key = KEY_FORMAT.format("\xff")
|
||||||
|
|
||||||
|
keys = await self.kvstore.range(start_key, end_key)
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
identifier = key.split(":")[-1]
|
||||||
|
objects = await super().get(identifier)
|
||||||
|
if objects:
|
||||||
|
self.cache[identifier] = objects
|
||||||
|
|
||||||
|
def get_cached(self, identifier: str) -> List[RoutableObjectWithProvider]:
|
||||||
|
return self.cache.get(identifier, [])
|
||||||
|
|
||||||
|
async def get_all(self) -> List[RoutableObjectWithProvider]:
|
||||||
|
return [item for sublist in self.cache.values() for item in sublist]
|
||||||
|
|
||||||
|
async def get(self, identifier: str) -> List[RoutableObjectWithProvider]:
|
||||||
|
if identifier in self.cache:
|
||||||
|
return self.cache[identifier]
|
||||||
|
|
||||||
|
objects = await super().get(identifier)
|
||||||
|
if objects:
|
||||||
|
self.cache[identifier] = objects
|
||||||
|
|
||||||
|
return objects
|
||||||
|
|
||||||
|
async def register(self, obj: RoutableObjectWithProvider) -> bool:
|
||||||
|
# First update disk
|
||||||
|
success = await super().register(obj)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
# Then update cache
|
||||||
|
if obj.identifier not in self.cache:
|
||||||
|
self.cache[obj.identifier] = []
|
||||||
|
|
||||||
|
# Check if provider already exists in cache
|
||||||
|
for cached_obj in self.cache[obj.identifier]:
|
||||||
|
if cached_obj.provider_id == obj.provider_id:
|
||||||
|
return success
|
||||||
|
|
||||||
|
# If not, update cache
|
||||||
|
self.cache[obj.identifier].append(obj)
|
||||||
|
|
||||||
|
return success
|
171
llama_stack/distribution/store/tests/test_registry.py
Normal file
171
llama_stack/distribution/store/tests/test_registry.py
Normal file
|
@ -0,0 +1,171 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from llama_stack.distribution.store import * # noqa F403
|
||||||
|
from llama_stack.apis.inference import ModelDefWithProvider
|
||||||
|
from llama_stack.apis.memory_banks import VectorMemoryBankDef
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa F403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config():
|
||||||
|
config = SqliteKVStoreConfig(db_path="/tmp/test_registry.db")
|
||||||
|
if os.path.exists(config.db_path):
|
||||||
|
os.remove(config.db_path)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def registry(config):
|
||||||
|
registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||||
|
await registry.initialize()
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def cached_registry(config):
|
||||||
|
registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||||
|
await registry.initialize()
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_bank():
|
||||||
|
return VectorMemoryBankDef(
|
||||||
|
identifier="test_bank",
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
overlap_size_in_tokens=64,
|
||||||
|
provider_id="test-provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_model():
|
||||||
|
return ModelDefWithProvider(
|
||||||
|
identifier="test_model",
|
||||||
|
llama_model="Llama3.2-3B-Instruct",
|
||||||
|
provider_id="test-provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_registry_initialization(registry):
|
||||||
|
# Test empty registry
|
||||||
|
results = await registry.get("nonexistent")
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_registration(registry, sample_bank, sample_model):
|
||||||
|
print(f"Registering {sample_bank}")
|
||||||
|
await registry.register(sample_bank)
|
||||||
|
print(f"Registering {sample_model}")
|
||||||
|
await registry.register(sample_model)
|
||||||
|
print("Getting bank")
|
||||||
|
results = await registry.get("test_bank")
|
||||||
|
assert len(results) == 1
|
||||||
|
result_bank = results[0]
|
||||||
|
assert result_bank.identifier == sample_bank.identifier
|
||||||
|
assert result_bank.embedding_model == sample_bank.embedding_model
|
||||||
|
assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens
|
||||||
|
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
|
||||||
|
assert result_bank.provider_id == sample_bank.provider_id
|
||||||
|
|
||||||
|
results = await registry.get("test_model")
|
||||||
|
assert len(results) == 1
|
||||||
|
result_model = results[0]
|
||||||
|
assert result_model.identifier == sample_model.identifier
|
||||||
|
assert result_model.llama_model == sample_model.llama_model
|
||||||
|
assert result_model.provider_id == sample_model.provider_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cached_registry_initialization(config, sample_bank, sample_model):
|
||||||
|
# First populate the disk registry
|
||||||
|
disk_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||||
|
await disk_registry.initialize()
|
||||||
|
await disk_registry.register(sample_bank)
|
||||||
|
await disk_registry.register(sample_model)
|
||||||
|
|
||||||
|
# Test cached version loads from disk
|
||||||
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||||
|
await cached_registry.initialize()
|
||||||
|
|
||||||
|
results = await cached_registry.get("test_bank")
|
||||||
|
assert len(results) == 1
|
||||||
|
result_bank = results[0]
|
||||||
|
assert result_bank.identifier == sample_bank.identifier
|
||||||
|
assert result_bank.embedding_model == sample_bank.embedding_model
|
||||||
|
assert result_bank.chunk_size_in_tokens == sample_bank.chunk_size_in_tokens
|
||||||
|
assert result_bank.overlap_size_in_tokens == sample_bank.overlap_size_in_tokens
|
||||||
|
assert result_bank.provider_id == sample_bank.provider_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cached_registry_updates(config):
|
||||||
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||||
|
await cached_registry.initialize()
|
||||||
|
|
||||||
|
new_bank = VectorMemoryBankDef(
|
||||||
|
identifier="test_bank_2",
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=256,
|
||||||
|
overlap_size_in_tokens=32,
|
||||||
|
provider_id="baz",
|
||||||
|
)
|
||||||
|
await cached_registry.register(new_bank)
|
||||||
|
|
||||||
|
# Verify in cache
|
||||||
|
results = await cached_registry.get("test_bank_2")
|
||||||
|
assert len(results) == 1
|
||||||
|
result_bank = results[0]
|
||||||
|
assert result_bank.identifier == new_bank.identifier
|
||||||
|
assert result_bank.provider_id == new_bank.provider_id
|
||||||
|
|
||||||
|
# Verify persisted to disk
|
||||||
|
new_registry = DiskDistributionRegistry(await kvstore_impl(config))
|
||||||
|
await new_registry.initialize()
|
||||||
|
results = await new_registry.get("test_bank_2")
|
||||||
|
assert len(results) == 1
|
||||||
|
result_bank = results[0]
|
||||||
|
assert result_bank.identifier == new_bank.identifier
|
||||||
|
assert result_bank.provider_id == new_bank.provider_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_duplicate_provider_registration(config):
|
||||||
|
cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config))
|
||||||
|
await cached_registry.initialize()
|
||||||
|
|
||||||
|
original_bank = VectorMemoryBankDef(
|
||||||
|
identifier="test_bank_2",
|
||||||
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
chunk_size_in_tokens=256,
|
||||||
|
overlap_size_in_tokens=32,
|
||||||
|
provider_id="baz",
|
||||||
|
)
|
||||||
|
await cached_registry.register(original_bank)
|
||||||
|
|
||||||
|
duplicate_bank = VectorMemoryBankDef(
|
||||||
|
identifier="test_bank_2",
|
||||||
|
embedding_model="different-model",
|
||||||
|
chunk_size_in_tokens=128,
|
||||||
|
overlap_size_in_tokens=16,
|
||||||
|
provider_id="baz", # Same provider_id
|
||||||
|
)
|
||||||
|
await cached_registry.register(duplicate_bank)
|
||||||
|
|
||||||
|
results = await cached_registry.get("test_bank_2")
|
||||||
|
assert len(results) == 1 # Still only one result
|
||||||
|
assert (
|
||||||
|
results[0].embedding_model == original_bank.embedding_model
|
||||||
|
) # Original values preserved
|
Loading…
Add table
Add a link
Reference in a new issue