mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
fix: allow lookup of models registered at runtime
adds tests for ModelRegistryHelper to stabilize behavior
This commit is contained in:
parent
e4d001c4e4
commit
9982aa64f0
3 changed files with 165 additions and 1 deletions
|
@ -22,6 +22,27 @@ from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
class ModelsProtocolPrivate(Protocol):
|
class ModelsProtocolPrivate(Protocol):
|
||||||
|
"""
|
||||||
|
Protocol for model management.
|
||||||
|
|
||||||
|
This allows users to register their preferred model identifiers.
|
||||||
|
|
||||||
|
Model registration requires -
|
||||||
|
- a provider, used to route the registration request
|
||||||
|
- a model identifier, user's intended name for the model during inference
|
||||||
|
- a provider model identifier, a model identifier supported by the provider
|
||||||
|
|
||||||
|
Providers will only accept registration for provider model ids they support.
|
||||||
|
|
||||||
|
Example,
|
||||||
|
register: provider x my-model-id x provider-model-id
|
||||||
|
-> Error if provider does not support provider-model-id
|
||||||
|
-> Error if my-model-id is already registered
|
||||||
|
-> Success if provider supports provider-model-id
|
||||||
|
inference: my-model-id x ...
|
||||||
|
-> Provider uses provider-model-id for inference
|
||||||
|
"""
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model: ...
|
async def register_model(self, model: Model) -> Model: ...
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None: ...
|
async def unregister_model(self, model_id: str) -> None: ...
|
||||||
|
|
|
@ -59,6 +59,8 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
|
||||||
|
|
||||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
def __init__(self, model_entries: List[ProviderModelEntry]):
|
def __init__(self, model_entries: List[ProviderModelEntry]):
|
||||||
|
self.supported_model_ids = {entry.provider_model_id for entry in model_entries}
|
||||||
|
|
||||||
self.alias_to_provider_id_map = {}
|
self.alias_to_provider_id_map = {}
|
||||||
self.provider_id_to_llama_model_map = {}
|
self.provider_id_to_llama_model_map = {}
|
||||||
for entry in model_entries:
|
for entry in model_entries:
|
||||||
|
@ -79,6 +81,16 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
if model.provider_resource_id not in self.supported_model_ids:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model id '{model.provider_resource_id}' is not supported. Supported ids are: {', '.join(self.supported_model_ids)}"
|
||||||
|
)
|
||||||
|
if model.model_id in self.alias_to_provider_id_map:
|
||||||
|
# be idemopotent
|
||||||
|
if model.provider_resource_id != self.alias_to_provider_id_map[model.model_id]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
|
||||||
|
)
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||||
provider_resource_id = model.provider_resource_id
|
provider_resource_id = model.provider_resource_id
|
||||||
|
@ -108,7 +120,12 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.alias_to_provider_id_map[model.model_id] = model.provider_resource_id
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
if model_id not in self.alias_to_provider_id_map:
|
||||||
|
raise ValueError(f"Model id '{model_id}' is not registered.")
|
||||||
|
|
||||||
|
del self.alias_to_provider_id_map[model_id]
|
||||||
|
|
126
tests/unit/providers/utils/test_model_registry.py
Normal file
126
tests/unit/providers/utils/test_model_registry.py
Normal file
|
@ -0,0 +1,126 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
#
|
||||||
|
# ModelRegistryHelper provides mixin functionality for registering and
|
||||||
|
# unregistering models. It maintains a mapping of model ID / aliases to
|
||||||
|
# provider model IDs.
|
||||||
|
#
|
||||||
|
# Test cases -
|
||||||
|
# - Looking up an alias that does not exist should return None.
|
||||||
|
# - Registering a model + provider ID should add the model to the registry. If provider ID is known.
|
||||||
|
# - Registering an existing model should return an error.
|
||||||
|
# - Unregistering a model should remove it from the registry.
|
||||||
|
# - Unregistering a model that does not exist should return an error.
|
||||||
|
# - Models can be registered during initialization or via register_model.
|
||||||
|
#
|
||||||
|
# Questions -
|
||||||
|
# - Should we be allowed to register models w/o provider model IDs? No.
|
||||||
|
# According to POST /v1/models, required params are
|
||||||
|
# - identifier
|
||||||
|
# - provider_resource_id
|
||||||
|
# - provider_id
|
||||||
|
# - type
|
||||||
|
# - metadata
|
||||||
|
# - model_type
|
||||||
|
#
|
||||||
|
# TODO: llama_model functionality
|
||||||
|
#
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import Model
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def known_model() -> Model:
|
||||||
|
return Model(
|
||||||
|
provider_id="provider",
|
||||||
|
identifier="known-model",
|
||||||
|
provider_resource_id="known-provider-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def known_provider_model(known_model: Model) -> ProviderModelEntry:
|
||||||
|
return ProviderModelEntry(
|
||||||
|
provider_model_id=known_model.provider_resource_id,
|
||||||
|
# aliases=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unknown_model() -> Model:
|
||||||
|
return Model(
|
||||||
|
provider_id="provider",
|
||||||
|
identifier="unknown-model",
|
||||||
|
provider_resource_id="unknown-provider-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def helper(known_provider_model: ProviderModelEntry) -> ModelRegistryHelper:
|
||||||
|
return ModelRegistryHelper([known_provider_model])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||||
|
assert helper.get_provider_model_id(unknown_model.model_id) is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await helper.register_model(unknown_model)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
|
assert helper.get_provider_model_id(known_model.model_id) is None
|
||||||
|
await helper.register_model(known_model)
|
||||||
|
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
|
await helper.register_model(known_model)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await helper.register_model(known_model)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
|
await helper.register_model(known_model)
|
||||||
|
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
||||||
|
await helper.unregister_model(known_model.model_id)
|
||||||
|
assert helper.get_provider_model_id(known_model.model_id) is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await helper.unregister_model(unknown_model.model_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
|
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_register_model_existing_from_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
|
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
known_model.identifier = known_model.provider_resource_id
|
||||||
|
await helper.register_model(known_model)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||||
|
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||||
|
await helper.unregister_model(known_model.provider_resource_id)
|
||||||
|
assert helper.get_provider_model_id(known_model.provider_resource_id) is None
|
Loading…
Add table
Add a link
Reference in a new issue