diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index c3141f807..3482acb31 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -22,6 +22,27 @@ from llama_stack.schema_utils import json_schema_type class ModelsProtocolPrivate(Protocol): + """ + Protocol for model management. + + This allows users to register their preferred model identifiers. + + Model registration requires - + - a provider, used to route the registration request + - a model identifier, user's intended name for the model during inference + - a provider model identifier, a model identifier supported by the provider + + Providers will only accept registration for provider model ids they support. + + Example, + register: provider x my-model-id x provider-model-id + -> Error if provider does not support provider-model-id + -> Error if my-model-id is already registered + -> Success if provider supports provider-model-id + inference: my-model-id x ... + -> Provider uses provider-model-id for inference + """ + async def register_model(self, model: Model) -> Model: ... async def unregister_model(self, model_id: str) -> None: ... diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index e915b3098..37c187181 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -8,7 +8,7 @@ from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union import httpx -from ollama import AsyncClient +from ollama import AsyncClient # type: ignore[attr-defined] from openai import AsyncOpenAI from llama_stack.apis.common.content_types import ( @@ -333,7 +333,10 @@ class OllamaInferenceAdapter( return EmbeddingsResponse(embeddings=embeddings) async def register_model(self, model: Model) -> Model: - model = await self.register_helper.register_model(model) + try: + model = await self.register_helper.register_model(model) + except ValueError: + pass # Ignore statically unknown model, will check live listing if model.model_type == ModelType.embedding: logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") await self.client.pull(model.provider_resource_id) @@ -342,9 +345,12 @@ class OllamaInferenceAdapter( # - models not currently running are run by the ollama server as needed response = await self.client.list() available_models = [m["model"] for m in response["models"]] - if model.provider_resource_id not in available_models: + provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id) + if provider_resource_id is None: + provider_resource_id = model.provider_resource_id + if provider_resource_id not in available_models: available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]] - if model.provider_resource_id in available_models_latest: + if provider_resource_id in available_models_latest: logger.warning( f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'" ) @@ -352,6 +358,7 @@ class OllamaInferenceAdapter( raise ValueError( f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}" ) + model.provider_resource_id = provider_resource_id return model diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 8cfef2ee0..ac268c86c 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -372,7 +372,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): # self.client should only be created after the initialization is complete to avoid asyncio cross-context errors. # Changing this may lead to unpredictable behavior. client = self._create_client() if self.client is None else self.client - model = await self.register_helper.register_model(model) + try: + model = await self.register_helper.register_model(model) + except ValueError: + pass # Ignore statically unknown model, will check live listing res = await client.models.list() available_models = [m.id async for m in res] if model.provider_resource_id not in available_models: diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 4d7063953..c5199b0a8 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -75,40 +75,50 @@ class ModelRegistryHelper(ModelsProtocolPrivate): def get_provider_model_id(self, identifier: str) -> Optional[str]: return self.alias_to_provider_id_map.get(identifier, None) + # TODO: why keep a separate llama model mapping? def get_llama_model(self, provider_model_id: str) -> Optional[str]: return self.provider_id_to_llama_model_map.get(provider_model_id, None) async def register_model(self, model: Model) -> Model: + if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)): + raise ValueError( + f"Model '{model.provider_resource_id}' is not supported. Supported models are: {', '.join(self.alias_to_provider_id_map.keys())}" + ) + provider_resource_id = self.get_provider_model_id(model.model_id) if model.model_type == ModelType.embedding: # embedding models are always registered by their provider model id and does not need to be mapped to a llama model provider_resource_id = model.provider_resource_id - else: - provider_resource_id = self.get_provider_model_id(model.provider_resource_id) - if provider_resource_id: - model.provider_resource_id = provider_resource_id + if provider_resource_id != supported_model_id: # be idemopotent, only reject differences + raise ValueError( + f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first." + ) else: llama_model = model.metadata.get("llama_model") - if llama_model is None: - return model + if llama_model: + existing_llama_model = self.get_llama_model(model.provider_resource_id) + if existing_llama_model: + if existing_llama_model != llama_model: + raise ValueError( + f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" + ) + else: + if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: + raise ValueError( + f"Invalid llama_model '{llama_model}' specified in metadata. " + f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}" + ) + self.provider_id_to_llama_model_map[model.provider_resource_id] = ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] + ) - existing_llama_model = self.get_llama_model(model.provider_resource_id) - if existing_llama_model: - if existing_llama_model != llama_model: - raise ValueError( - f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" - ) - else: - if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: - raise ValueError( - f"Invalid llama_model '{llama_model}' specified in metadata. " - f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}" - ) - self.provider_id_to_llama_model_map[model.provider_resource_id] = ( - ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] - ) + self.alias_to_provider_id_map[model.model_id] = supported_model_id return model async def unregister_model(self, model_id: str) -> None: - pass + # TODO: should we block unregistering base supported provider model IDs? + if model_id not in self.alias_to_provider_id_map: + raise ValueError(f"Model id '{model_id}' is not registered.") + + del self.alias_to_provider_id_map[model_id] diff --git a/tests/unit/providers/utils/test_model_registry.py b/tests/unit/providers/utils/test_model_registry.py new file mode 100644 index 000000000..67f8a138f --- /dev/null +++ b/tests/unit/providers/utils/test_model_registry.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# +# ModelRegistryHelper provides mixin functionality for registering and +# unregistering models. It maintains a mapping of model ID / aliases to +# provider model IDs. +# +# Test cases - +# - Looking up an alias that does not exist should return None. +# - Registering a model + provider ID should add the model to the registry. If +# provider ID is known or an alias for a provider ID. +# - Registering an existing model should return an error. Unless it's a +# dulicate entry. +# - Unregistering a model should remove it from the registry. +# - Unregistering a model that does not exist should return an error. +# - Supported model ID and their aliases are registered during initialization. +# Only aliases are added afterwards. +# +# Questions - +# - Should we be allowed to register models w/o provider model IDs? No. +# According to POST /v1/models, required params are +# - identifier +# - provider_resource_id +# - provider_id +# - type +# - metadata +# - model_type +# +# TODO: llama_model functionality +# + +import pytest + +from llama_stack.apis.models.models import Model +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry + + +@pytest.fixture +def known_model() -> Model: + return Model( + provider_id="provider", + identifier="known-model", + provider_resource_id="known-provider-id", + ) + + +@pytest.fixture +def known_model2() -> Model: + return Model( + provider_id="provider", + identifier="known-model2", + provider_resource_id="known-provider-id2", + ) + + +@pytest.fixture +def known_provider_model(known_model: Model) -> ProviderModelEntry: + return ProviderModelEntry( + provider_model_id=known_model.provider_resource_id, + aliases=[known_model.model_id], + ) + + +@pytest.fixture +def known_provider_model2(known_model2: Model) -> ProviderModelEntry: + return ProviderModelEntry( + provider_model_id=known_model2.provider_resource_id, + # aliases=[], + ) + + +@pytest.fixture +def unknown_model() -> Model: + return Model( + provider_id="provider", + identifier="unknown-model", + provider_resource_id="unknown-provider-id", + ) + + +@pytest.fixture +def helper(known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry) -> ModelRegistryHelper: + return ModelRegistryHelper([known_provider_model, known_provider_model2]) + + +@pytest.mark.asyncio +async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: + assert helper.get_provider_model_id(unknown_model.model_id) is None + + +@pytest.mark.asyncio +async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: + with pytest.raises(ValueError): + await helper.register_model(unknown_model) + + +@pytest.mark.asyncio +async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None: + model = Model( + provider_id=known_model.provider_id, + identifier="new-model", + provider_resource_id=known_model.provider_resource_id, + ) + assert helper.get_provider_model_id(model.model_id) is None + await helper.register_model(model) + assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id + + +@pytest.mark.asyncio +async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None: + model = Model( + provider_id=known_model.provider_id, + identifier="new-model", + provider_resource_id=known_model.model_id, # use known model's id as an alias for the supported model id + ) + assert helper.get_provider_model_id(model.model_id) is None + await helper.register_model(model) + assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id + + +@pytest.mark.asyncio +async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None: + await helper.register_model(known_model) + assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id + + +@pytest.mark.asyncio +async def test_register_model_existing_different( + helper: ModelRegistryHelper, known_model: Model, known_model2: Model +) -> None: + known_model.provider_resource_id = known_model2.provider_resource_id + with pytest.raises(ValueError): + await helper.register_model(known_model) + + +@pytest.mark.asyncio +async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None: + await helper.register_model(known_model) # duplicate entry + assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id + await helper.unregister_model(known_model.model_id) + assert helper.get_provider_model_id(known_model.model_id) is None + + +@pytest.mark.asyncio +async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: + with pytest.raises(ValueError): + await helper.unregister_model(unknown_model.model_id) + + +@pytest.mark.asyncio +async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None: + assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id + + +@pytest.mark.asyncio +async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None: + assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id + await helper.unregister_model(known_model.provider_resource_id) + assert helper.get_provider_model_id(known_model.provider_resource_id) is None