fix: allow use of models registered at runtime (#1980)

# What does this PR do?

fix a bug where models registered at runtime could not be used.

```
$ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct

$ curl http://localhost:8321/v1/openai/v1/chat/completions \                                                        
-H "Content-Type: application/json" \
-d '{
  "model": "test-model",
  "messages": [{"role": "user", "content": "What is the weather like in Boston today?"}]
}'

=(client)=> {"detail":"Internal server error: An unexpected error occurred."}
=(server)=> TypeError: Missing required arguments; Expected either ('messages' and 'model') or ('messages', 'model' and 'stream') arguments to be given
```

*root cause:* test-model is not added to ModelRegistryHelper's
alias_to_provider_id_map.

as part of the fix, this adds tests for ModelRegistryHelper and defines
its expected behavior.

user visible behavior changes -

| action | existing behavior | new behavior |
| -- | -- | -- |
| double register | success (but no change) | error |
| register unknown | success (fail when used) | error |

existing behavior for register unknown model and double register -
```
$ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct-unknown
Successfully registered model test-model

$ llama-stack-client models list | grep test-model
│ llm │ test-model                               │ meta/llama-3.1-70b-instruct-unknown │     │ nv… │

$ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct       
Successfully registered model test-model

$ llama-stack-client models list | grep test-model
│ llm │ test-model                               │ meta/llama-3.1-70b-instruct-unknown │     │ nv… │
```

new behavior for register unknown -
```
$ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct-unknown
╭──────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Failed to register model                                                                         │
│                                                                                                  │
│ Error Type: BadRequestError                                                                      │
│ Details: Error code: 400 - {'detail': "Invalid value: Model id                                   │
│ 'meta/llama-3.1-70b-instruct-unknown' is not supported. Supported ids are:                       │
│ meta/llama-3.1-70b-instruct, snowflake/arctic-embed-l, meta/llama-3.2-1b-instruct,               │
│ nvidia/nv-embedqa-mistral-7b-v2, meta/llama-3.2-90b-vision-instruct, meta/llama-3.2-3b-instruct, │
│ meta/llama-3.2-11b-vision-instruct, meta/llama-3.1-405b-instruct, meta/llama3-8b-instruct,       │
│ meta/llama3-70b-instruct, nvidia/llama-3.2-nv-embedqa-1b-v2, meta/llama-3.1-8b-instruct,         │
│ nvidia/nv-embedqa-e5-v5"}                                                                        │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
```

new behavior for double register -
```
$ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct
Successfully registered model test-model

$ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.2-1b-instruct 
╭──────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Failed to register model                                                                         │
│                                                                                                  │
│ Error Type: BadRequestError                                                                      │
│ Details: Error code: 400 - {'detail': "Invalid value: Model id 'test-model' is already           │
│ registered. Please use a different id or unregister it first."}                                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
```


## Test Plan

```
uv run pytest -v tests/unit/providers/utils/test_model_registry.py
```
This commit is contained in:
Matthew Farrellee 2025-05-01 15:00:58 -04:00 committed by GitHub
parent 64829947d0
commit 88a796ca5a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 231 additions and 27 deletions

View file

@ -22,6 +22,27 @@ from llama_stack.schema_utils import json_schema_type
class ModelsProtocolPrivate(Protocol): class ModelsProtocolPrivate(Protocol):
"""
Protocol for model management.
This allows users to register their preferred model identifiers.
Model registration requires -
- a provider, used to route the registration request
- a model identifier, user's intended name for the model during inference
- a provider model identifier, a model identifier supported by the provider
Providers will only accept registration for provider model ids they support.
Example,
register: provider x my-model-id x provider-model-id
-> Error if provider does not support provider-model-id
-> Error if my-model-id is already registered
-> Success if provider supports provider-model-id
inference: my-model-id x ...
-> Provider uses provider-model-id for inference
"""
async def register_model(self, model: Model) -> Model: ... async def register_model(self, model: Model) -> Model: ...
async def unregister_model(self, model_id: str) -> None: ... async def unregister_model(self, model_id: str) -> None: ...

View file

@ -8,7 +8,7 @@
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx import httpx
from ollama import AsyncClient from ollama import AsyncClient # type: ignore[attr-defined]
from openai import AsyncOpenAI from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -333,7 +333,10 @@ class OllamaInferenceAdapter(
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model) try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id) await self.client.pull(model.provider_resource_id)
@ -342,9 +345,12 @@ class OllamaInferenceAdapter(
# - models not currently running are run by the ollama server as needed # - models not currently running are run by the ollama server as needed
response = await self.client.list() response = await self.client.list()
available_models = [m["model"] for m in response["models"]] available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models: provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None:
provider_resource_id = model.provider_resource_id
if provider_resource_id not in available_models:
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]] available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
if model.provider_resource_id in available_models_latest: if provider_resource_id in available_models_latest:
logger.warning( logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'" f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
) )
@ -352,6 +358,7 @@ class OllamaInferenceAdapter(
raise ValueError( raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}" f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
) )
model.provider_resource_id = provider_resource_id
return model return model

View file

@ -372,7 +372,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors. # self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
# Changing this may lead to unpredictable behavior. # Changing this may lead to unpredictable behavior.
client = self._create_client() if self.client is None else self.client client = self._create_client() if self.client is None else self.client
model = await self.register_helper.register_model(model) try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
res = await client.models.list() res = await client.models.list()
available_models = [m.id async for m in res] available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models: if model.provider_resource_id not in available_models:

View file

@ -75,40 +75,50 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
def get_provider_model_id(self, identifier: str) -> Optional[str]: def get_provider_model_id(self, identifier: str) -> Optional[str]:
return self.alias_to_provider_id_map.get(identifier, None) return self.alias_to_provider_id_map.get(identifier, None)
# TODO: why keep a separate llama model mapping?
def get_llama_model(self, provider_model_id: str) -> Optional[str]: def get_llama_model(self, provider_model_id: str) -> Optional[str]:
return self.provider_id_to_llama_model_map.get(provider_model_id, None) return self.provider_id_to_llama_model_map.get(provider_model_id, None)
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)):
raise ValueError(
f"Model '{model.provider_resource_id}' is not supported. Supported models are: {', '.join(self.alias_to_provider_id_map.keys())}"
)
provider_resource_id = self.get_provider_model_id(model.model_id)
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model # embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id: if provider_resource_id:
model.provider_resource_id = provider_resource_id if provider_resource_id != supported_model_id: # be idemopotent, only reject differences
raise ValueError(
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
)
else: else:
llama_model = model.metadata.get("llama_model") llama_model = model.metadata.get("llama_model")
if llama_model is None: if llama_model:
return model existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError(
f"Invalid llama_model '{llama_model}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
)
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
existing_llama_model = self.get_llama_model(model.provider_resource_id) self.alias_to_provider_id_map[model.model_id] = supported_model_id
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError(
f"Invalid llama_model '{llama_model}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
)
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
return model return model
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass # TODO: should we block unregistering base supported provider model IDs?
if model_id not in self.alias_to_provider_id_map:
raise ValueError(f"Model id '{model_id}' is not registered.")
del self.alias_to_provider_id_map[model_id]

View file

@ -0,0 +1,163 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
#
# ModelRegistryHelper provides mixin functionality for registering and
# unregistering models. It maintains a mapping of model ID / aliases to
# provider model IDs.
#
# Test cases -
# - Looking up an alias that does not exist should return None.
# - Registering a model + provider ID should add the model to the registry. If
# provider ID is known or an alias for a provider ID.
# - Registering an existing model should return an error. Unless it's a
# dulicate entry.
# - Unregistering a model should remove it from the registry.
# - Unregistering a model that does not exist should return an error.
# - Supported model ID and their aliases are registered during initialization.
# Only aliases are added afterwards.
#
# Questions -
# - Should we be allowed to register models w/o provider model IDs? No.
# According to POST /v1/models, required params are
# - identifier
# - provider_resource_id
# - provider_id
# - type
# - metadata
# - model_type
#
# TODO: llama_model functionality
#
import pytest
from llama_stack.apis.models.models import Model
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry
@pytest.fixture
def known_model() -> Model:
return Model(
provider_id="provider",
identifier="known-model",
provider_resource_id="known-provider-id",
)
@pytest.fixture
def known_model2() -> Model:
return Model(
provider_id="provider",
identifier="known-model2",
provider_resource_id="known-provider-id2",
)
@pytest.fixture
def known_provider_model(known_model: Model) -> ProviderModelEntry:
return ProviderModelEntry(
provider_model_id=known_model.provider_resource_id,
aliases=[known_model.model_id],
)
@pytest.fixture
def known_provider_model2(known_model2: Model) -> ProviderModelEntry:
return ProviderModelEntry(
provider_model_id=known_model2.provider_resource_id,
# aliases=[],
)
@pytest.fixture
def unknown_model() -> Model:
return Model(
provider_id="provider",
identifier="unknown-model",
provider_resource_id="unknown-provider-id",
)
@pytest.fixture
def helper(known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry) -> ModelRegistryHelper:
return ModelRegistryHelper([known_provider_model, known_provider_model2])
@pytest.mark.asyncio
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
assert helper.get_provider_model_id(unknown_model.model_id) is None
@pytest.mark.asyncio
async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
with pytest.raises(ValueError):
await helper.register_model(unknown_model)
@pytest.mark.asyncio
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
model = Model(
provider_id=known_model.provider_id,
identifier="new-model",
provider_resource_id=known_model.provider_resource_id,
)
assert helper.get_provider_model_id(model.model_id) is None
await helper.register_model(model)
assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None:
model = Model(
provider_id=known_model.provider_id,
identifier="new-model",
provider_resource_id=known_model.model_id, # use known model's id as an alias for the supported model id
)
assert helper.get_provider_model_id(model.model_id) is None
await helper.register_model(model)
assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
await helper.register_model(known_model)
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_existing_different(
helper: ModelRegistryHelper, known_model: Model, known_model2: Model
) -> None:
known_model.provider_resource_id = known_model2.provider_resource_id
with pytest.raises(ValueError):
await helper.register_model(known_model)
@pytest.mark.asyncio
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
await helper.register_model(known_model) # duplicate entry
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
await helper.unregister_model(known_model.model_id)
assert helper.get_provider_model_id(known_model.model_id) is None
@pytest.mark.asyncio
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
with pytest.raises(ValueError):
await helper.unregister_model(unknown_model.model_id)
@pytest.mark.asyncio
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
@pytest.mark.asyncio
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
await helper.unregister_model(known_model.provider_resource_id)
assert helper.get_provider_model_id(known_model.provider_resource_id) is None