mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-29 04:32:01 +00:00
fix: allow lookup of models registered at runtime
adds tests for ModelRegistryHelper to stabilize behavior
This commit is contained in:
parent
e4d001c4e4
commit
9982aa64f0
3 changed files with 165 additions and 1 deletions
|
|
@ -22,6 +22,27 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
|
||||
class ModelsProtocolPrivate(Protocol):
|
||||
"""
|
||||
Protocol for model management.
|
||||
|
||||
This allows users to register their preferred model identifiers.
|
||||
|
||||
Model registration requires -
|
||||
- a provider, used to route the registration request
|
||||
- a model identifier, user's intended name for the model during inference
|
||||
- a provider model identifier, a model identifier supported by the provider
|
||||
|
||||
Providers will only accept registration for provider model ids they support.
|
||||
|
||||
Example,
|
||||
register: provider x my-model-id x provider-model-id
|
||||
-> Error if provider does not support provider-model-id
|
||||
-> Error if my-model-id is already registered
|
||||
-> Success if provider supports provider-model-id
|
||||
inference: my-model-id x ...
|
||||
-> Provider uses provider-model-id for inference
|
||||
"""
|
||||
|
||||
async def register_model(self, model: Model) -> Model: ...
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
|
|
|||
|
|
@ -59,6 +59,8 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
|
|||
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
def __init__(self, model_entries: List[ProviderModelEntry]):
|
||||
self.supported_model_ids = {entry.provider_model_id for entry in model_entries}
|
||||
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
for entry in model_entries:
|
||||
|
|
@ -79,6 +81,16 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if model.provider_resource_id not in self.supported_model_ids:
|
||||
raise ValueError(
|
||||
f"Model id '{model.provider_resource_id}' is not supported. Supported ids are: {', '.join(self.supported_model_ids)}"
|
||||
)
|
||||
if model.model_id in self.alias_to_provider_id_map:
|
||||
# be idemopotent
|
||||
if model.provider_resource_id != self.alias_to_provider_id_map[model.model_id]:
|
||||
raise ValueError(
|
||||
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
|
||||
)
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
|
|
@ -108,7 +120,12 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
|
||||
self.alias_to_provider_id_map[model.model_id] = model.provider_resource_id
|
||||
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
if model_id not in self.alias_to_provider_id_map:
|
||||
raise ValueError(f"Model id '{model_id}' is not registered.")
|
||||
|
||||
del self.alias_to_provider_id_map[model_id]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue