forked from phoenix-oss/llama-stack-mirror
fix: allow use of models registered at runtime (#1980)
# What does this PR do? fix a bug where models registered at runtime could not be used. ``` $ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct $ curl http://localhost:8321/v1/openai/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "model": "test-model", "messages": [{"role": "user", "content": "What is the weather like in Boston today?"}] }' =(client)=> {"detail":"Internal server error: An unexpected error occurred."} =(server)=> TypeError: Missing required arguments; Expected either ('messages' and 'model') or ('messages', 'model' and 'stream') arguments to be given ``` *root cause:* test-model is not added to ModelRegistryHelper's alias_to_provider_id_map. as part of the fix, this adds tests for ModelRegistryHelper and defines its expected behavior. user visible behavior changes - | action | existing behavior | new behavior | | -- | -- | -- | | double register | success (but no change) | error | | register unknown | success (fail when used) | error | existing behavior for register unknown model and double register - ``` $ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct-unknown Successfully registered model test-model $ llama-stack-client models list | grep test-model │ llm │ test-model │ meta/llama-3.1-70b-instruct-unknown │ │ nv… │ $ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct Successfully registered model test-model $ llama-stack-client models list | grep test-model │ llm │ test-model │ meta/llama-3.1-70b-instruct-unknown │ │ nv… │ ``` new behavior for register unknown - ``` $ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct-unknown ╭──────────────────────────────────────────────────────────────────────────────────────────────────╮ │ Failed to register model │ │ │ │ Error Type: BadRequestError │ │ Details: Error code: 400 - {'detail': "Invalid value: Model id │ │ 'meta/llama-3.1-70b-instruct-unknown' is not supported. Supported ids are: │ │ meta/llama-3.1-70b-instruct, snowflake/arctic-embed-l, meta/llama-3.2-1b-instruct, │ │ nvidia/nv-embedqa-mistral-7b-v2, meta/llama-3.2-90b-vision-instruct, meta/llama-3.2-3b-instruct, │ │ meta/llama-3.2-11b-vision-instruct, meta/llama-3.1-405b-instruct, meta/llama3-8b-instruct, │ │ meta/llama3-70b-instruct, nvidia/llama-3.2-nv-embedqa-1b-v2, meta/llama-3.1-8b-instruct, │ │ nvidia/nv-embedqa-e5-v5"} │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` new behavior for double register - ``` $ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct Successfully registered model test-model $ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.2-1b-instruct ╭──────────────────────────────────────────────────────────────────────────────────────────────────╮ │ Failed to register model │ │ │ │ Error Type: BadRequestError │ │ Details: Error code: 400 - {'detail': "Invalid value: Model id 'test-model' is already │ │ registered. Please use a different id or unregister it first."} │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ ``` ## Test Plan ``` uv run pytest -v tests/unit/providers/utils/test_model_registry.py ```
This commit is contained in:
parent
64829947d0
commit
88a796ca5a
5 changed files with 231 additions and 27 deletions
|
@ -22,6 +22,27 @@ from llama_stack.schema_utils import json_schema_type
|
|||
|
||||
|
||||
class ModelsProtocolPrivate(Protocol):
|
||||
"""
|
||||
Protocol for model management.
|
||||
|
||||
This allows users to register their preferred model identifiers.
|
||||
|
||||
Model registration requires -
|
||||
- a provider, used to route the registration request
|
||||
- a model identifier, user's intended name for the model during inference
|
||||
- a provider model identifier, a model identifier supported by the provider
|
||||
|
||||
Providers will only accept registration for provider model ids they support.
|
||||
|
||||
Example,
|
||||
register: provider x my-model-id x provider-model-id
|
||||
-> Error if provider does not support provider-model-id
|
||||
-> Error if my-model-id is already registered
|
||||
-> Success if provider supports provider-model-id
|
||||
inference: my-model-id x ...
|
||||
-> Provider uses provider-model-id for inference
|
||||
"""
|
||||
|
||||
async def register_model(self, model: Model) -> Model: ...
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
from ollama import AsyncClient
|
||||
from ollama import AsyncClient # type: ignore[attr-defined]
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
@ -333,7 +333,10 @@ class OllamaInferenceAdapter(
|
|||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.register_helper.register_model(model)
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
if model.model_type == ModelType.embedding:
|
||||
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
|
||||
await self.client.pull(model.provider_resource_id)
|
||||
|
@ -342,9 +345,12 @@ class OllamaInferenceAdapter(
|
|||
# - models not currently running are run by the ollama server as needed
|
||||
response = await self.client.list()
|
||||
available_models = [m["model"] for m in response["models"]]
|
||||
if model.provider_resource_id not in available_models:
|
||||
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
||||
if provider_resource_id is None:
|
||||
provider_resource_id = model.provider_resource_id
|
||||
if provider_resource_id not in available_models:
|
||||
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
|
||||
if model.provider_resource_id in available_models_latest:
|
||||
if provider_resource_id in available_models_latest:
|
||||
logger.warning(
|
||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
|
@ -352,6 +358,7 @@ class OllamaInferenceAdapter(
|
|||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
|
||||
)
|
||||
model.provider_resource_id = provider_resource_id
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
@ -372,7 +372,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
|
||||
# Changing this may lead to unpredictable behavior.
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
model = await self.register_helper.register_model(model)
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
res = await client.models.list()
|
||||
available_models = [m.id async for m in res]
|
||||
if model.provider_resource_id not in available_models:
|
||||
|
|
|
@ -75,40 +75,50 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
def get_provider_model_id(self, identifier: str) -> Optional[str]:
|
||||
return self.alias_to_provider_id_map.get(identifier, None)
|
||||
|
||||
# TODO: why keep a separate llama model mapping?
|
||||
def get_llama_model(self, provider_model_id: str) -> Optional[str]:
|
||||
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)):
|
||||
raise ValueError(
|
||||
f"Model '{model.provider_resource_id}' is not supported. Supported models are: {', '.join(self.alias_to_provider_id_map.keys())}"
|
||||
)
|
||||
provider_resource_id = self.get_provider_model_id(model.model_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||
provider_resource_id = model.provider_resource_id
|
||||
else:
|
||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||
|
||||
if provider_resource_id:
|
||||
model.provider_resource_id = provider_resource_id
|
||||
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences
|
||||
raise ValueError(
|
||||
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
|
||||
)
|
||||
else:
|
||||
llama_model = model.metadata.get("llama_model")
|
||||
if llama_model is None:
|
||||
return model
|
||||
if llama_model:
|
||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||
if existing_llama_model:
|
||||
if existing_llama_model != llama_model:
|
||||
raise ValueError(
|
||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
raise ValueError(
|
||||
f"Invalid llama_model '{llama_model}' specified in metadata. "
|
||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||
)
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
|
||||
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||
if existing_llama_model:
|
||||
if existing_llama_model != llama_model:
|
||||
raise ValueError(
|
||||
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||
)
|
||||
else:
|
||||
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||
raise ValueError(
|
||||
f"Invalid llama_model '{llama_model}' specified in metadata. "
|
||||
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
|
||||
)
|
||||
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||
)
|
||||
self.alias_to_provider_id_map[model.model_id] = supported_model_id
|
||||
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
# TODO: should we block unregistering base supported provider model IDs?
|
||||
if model_id not in self.alias_to_provider_id_map:
|
||||
raise ValueError(f"Model id '{model_id}' is not registered.")
|
||||
|
||||
del self.alias_to_provider_id_map[model_id]
|
||||
|
|
163
tests/unit/providers/utils/test_model_registry.py
Normal file
163
tests/unit/providers/utils/test_model_registry.py
Normal file
|
@ -0,0 +1,163 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
#
|
||||
# ModelRegistryHelper provides mixin functionality for registering and
|
||||
# unregistering models. It maintains a mapping of model ID / aliases to
|
||||
# provider model IDs.
|
||||
#
|
||||
# Test cases -
|
||||
# - Looking up an alias that does not exist should return None.
|
||||
# - Registering a model + provider ID should add the model to the registry. If
|
||||
# provider ID is known or an alias for a provider ID.
|
||||
# - Registering an existing model should return an error. Unless it's a
|
||||
# dulicate entry.
|
||||
# - Unregistering a model should remove it from the registry.
|
||||
# - Unregistering a model that does not exist should return an error.
|
||||
# - Supported model ID and their aliases are registered during initialization.
|
||||
# Only aliases are added afterwards.
|
||||
#
|
||||
# Questions -
|
||||
# - Should we be allowed to register models w/o provider model IDs? No.
|
||||
# According to POST /v1/models, required params are
|
||||
# - identifier
|
||||
# - provider_resource_id
|
||||
# - provider_id
|
||||
# - type
|
||||
# - metadata
|
||||
# - model_type
|
||||
#
|
||||
# TODO: llama_model functionality
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models.models import Model
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def known_model() -> Model:
|
||||
return Model(
|
||||
provider_id="provider",
|
||||
identifier="known-model",
|
||||
provider_resource_id="known-provider-id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def known_model2() -> Model:
|
||||
return Model(
|
||||
provider_id="provider",
|
||||
identifier="known-model2",
|
||||
provider_resource_id="known-provider-id2",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def known_provider_model(known_model: Model) -> ProviderModelEntry:
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=known_model.provider_resource_id,
|
||||
aliases=[known_model.model_id],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def known_provider_model2(known_model2: Model) -> ProviderModelEntry:
|
||||
return ProviderModelEntry(
|
||||
provider_model_id=known_model2.provider_resource_id,
|
||||
# aliases=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unknown_model() -> Model:
|
||||
return Model(
|
||||
provider_id="provider",
|
||||
identifier="unknown-model",
|
||||
provider_resource_id="unknown-provider-id",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def helper(known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry) -> ModelRegistryHelper:
|
||||
return ModelRegistryHelper([known_provider_model, known_provider_model2])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(unknown_model.model_id) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
await helper.register_model(unknown_model)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
model = Model(
|
||||
provider_id=known_model.provider_id,
|
||||
identifier="new-model",
|
||||
provider_resource_id=known_model.provider_resource_id,
|
||||
)
|
||||
assert helper.get_provider_model_id(model.model_id) is None
|
||||
await helper.register_model(model)
|
||||
assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
model = Model(
|
||||
provider_id=known_model.provider_id,
|
||||
identifier="new-model",
|
||||
provider_resource_id=known_model.model_id, # use known model's id as an alias for the supported model id
|
||||
)
|
||||
assert helper.get_provider_model_id(model.model_id) is None
|
||||
await helper.register_model(model)
|
||||
assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
await helper.register_model(known_model)
|
||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_existing_different(
|
||||
helper: ModelRegistryHelper, known_model: Model, known_model2: Model
|
||||
) -> None:
|
||||
known_model.provider_resource_id = known_model2.provider_resource_id
|
||||
with pytest.raises(ValueError):
|
||||
await helper.register_model(known_model)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
await helper.register_model(known_model) # duplicate entry
|
||||
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
|
||||
await helper.unregister_model(known_model.model_id)
|
||||
assert helper.get_provider_model_id(known_model.model_id) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
await helper.unregister_model(unknown_model.model_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
|
||||
await helper.unregister_model(known_model.provider_resource_id)
|
||||
assert helper.get_provider_model_id(known_model.provider_resource_id) is None
|
Loading…
Add table
Add a link
Reference in a new issue