fix: allow use of models registered at runtime (#1980)

# What does this PR do?

fix a bug where models registered at runtime could not be used.

```
$ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct

$ curl http://localhost:8321/v1/openai/v1/chat/completions \                                                        
-H "Content-Type: application/json" \
-d '{
  "model": "test-model",
  "messages": [{"role": "user", "content": "What is the weather like in Boston today?"}]
}'

=(client)=> {"detail":"Internal server error: An unexpected error occurred."}
=(server)=> TypeError: Missing required arguments; Expected either ('messages' and 'model') or ('messages', 'model' and 'stream') arguments to be given
```

*root cause:* test-model is not added to ModelRegistryHelper's
alias_to_provider_id_map.

as part of the fix, this adds tests for ModelRegistryHelper and defines
its expected behavior.

user visible behavior changes -

| action | existing behavior | new behavior |
| -- | -- | -- |
| double register | success (but no change) | error |
| register unknown | success (fail when used) | error |

existing behavior for register unknown model and double register -
```
$ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct-unknown
Successfully registered model test-model

$ llama-stack-client models list | grep test-model
│ llm │ test-model                               │ meta/llama-3.1-70b-instruct-unknown │     │ nv… │

$ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct       
Successfully registered model test-model

$ llama-stack-client models list | grep test-model
│ llm │ test-model                               │ meta/llama-3.1-70b-instruct-unknown │     │ nv… │
```

new behavior for register unknown -
```
$ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct-unknown
╭──────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Failed to register model                                                                         │
│                                                                                                  │
│ Error Type: BadRequestError                                                                      │
│ Details: Error code: 400 - {'detail': "Invalid value: Model id                                   │
│ 'meta/llama-3.1-70b-instruct-unknown' is not supported. Supported ids are:                       │
│ meta/llama-3.1-70b-instruct, snowflake/arctic-embed-l, meta/llama-3.2-1b-instruct,               │
│ nvidia/nv-embedqa-mistral-7b-v2, meta/llama-3.2-90b-vision-instruct, meta/llama-3.2-3b-instruct, │
│ meta/llama-3.2-11b-vision-instruct, meta/llama-3.1-405b-instruct, meta/llama3-8b-instruct,       │
│ meta/llama3-70b-instruct, nvidia/llama-3.2-nv-embedqa-1b-v2, meta/llama-3.1-8b-instruct,         │
│ nvidia/nv-embedqa-e5-v5"}                                                                        │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
```

new behavior for double register -
```
$ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.1-70b-instruct
Successfully registered model test-model

$ llama-stack-client models register test-model --provider-id nvidia --provider-model-id meta/llama-3.2-1b-instruct 
╭──────────────────────────────────────────────────────────────────────────────────────────────────╮
│ Failed to register model                                                                         │
│                                                                                                  │
│ Error Type: BadRequestError                                                                      │
│ Details: Error code: 400 - {'detail': "Invalid value: Model id 'test-model' is already           │
│ registered. Please use a different id or unregister it first."}                                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
```


## Test Plan

```
uv run pytest -v tests/unit/providers/utils/test_model_registry.py
```
This commit is contained in:
Matthew Farrellee 2025-05-01 15:00:58 -04:00 committed by GitHub
parent 64829947d0
commit 88a796ca5a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 231 additions and 27 deletions

View file

@ -22,6 +22,27 @@ from llama_stack.schema_utils import json_schema_type
class ModelsProtocolPrivate(Protocol):
"""
Protocol for model management.
This allows users to register their preferred model identifiers.
Model registration requires -
- a provider, used to route the registration request
- a model identifier, user's intended name for the model during inference
- a provider model identifier, a model identifier supported by the provider
Providers will only accept registration for provider model ids they support.
Example,
register: provider x my-model-id x provider-model-id
-> Error if provider does not support provider-model-id
-> Error if my-model-id is already registered
-> Success if provider supports provider-model-id
inference: my-model-id x ...
-> Provider uses provider-model-id for inference
"""
async def register_model(self, model: Model) -> Model: ...
async def unregister_model(self, model_id: str) -> None: ...

View file

@ -8,7 +8,7 @@
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx
from ollama import AsyncClient
from ollama import AsyncClient # type: ignore[attr-defined]
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import (
@ -333,7 +333,10 @@ class OllamaInferenceAdapter(
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model:
model = await self.register_helper.register_model(model)
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
if model.model_type == ModelType.embedding:
logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...")
await self.client.pull(model.provider_resource_id)
@ -342,9 +345,12 @@ class OllamaInferenceAdapter(
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
if provider_resource_id is None:
provider_resource_id = model.provider_resource_id
if provider_resource_id not in available_models:
available_models_latest = [m["model"].split(":latest")[0] for m in response["models"]]
if model.provider_resource_id in available_models_latest:
if provider_resource_id in available_models_latest:
logger.warning(
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
)
@ -352,6 +358,7 @@ class OllamaInferenceAdapter(
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}"
)
model.provider_resource_id = provider_resource_id
return model

View file

@ -372,7 +372,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
# Changing this may lead to unpredictable behavior.
client = self._create_client() if self.client is None else self.client
model = await self.register_helper.register_model(model)
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
res = await client.models.list()
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:

View file

@ -75,40 +75,50 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
def get_provider_model_id(self, identifier: str) -> Optional[str]:
return self.alias_to_provider_id_map.get(identifier, None)
# TODO: why keep a separate llama model mapping?
def get_llama_model(self, provider_model_id: str) -> Optional[str]:
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
async def register_model(self, model: Model) -> Model:
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)):
raise ValueError(
f"Model '{model.provider_resource_id}' is not supported. Supported models are: {', '.join(self.alias_to_provider_id_map.keys())}"
)
provider_resource_id = self.get_provider_model_id(model.model_id)
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences
raise ValueError(
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
)
else:
llama_model = model.metadata.get("llama_model")
if llama_model is None:
return model
if llama_model:
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError(
f"Invalid llama_model '{llama_model}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
)
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError(
f"Invalid llama_model '{llama_model}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
)
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
self.alias_to_provider_id_map[model.model_id] = supported_model_id
return model
async def unregister_model(self, model_id: str) -> None:
pass
# TODO: should we block unregistering base supported provider model IDs?
if model_id not in self.alias_to_provider_id_map:
raise ValueError(f"Model id '{model_id}' is not registered.")
del self.alias_to_provider_id_map[model_id]