llama-stack-mirror/llama_stack/providers/utils/inference/model_registry.py
Rohan Awhad 7cb5d3c60f
chore: standardize unsupported model error #2517 (#2518)
# What does this PR do?

- llama_stack/exceptions.py: Add UnsupportedModelError class
- remote inference ollama.py and utils/inference/model_registry.py:
Changed ValueError in favor of UnsupportedModelError
- utils/inference/litellm_openai_mixin.py: remove `register_model`
function implementation from `LiteLLMOpenAIMixin` class. Now uses the
parent class `ModelRegistryHelper`'s function implementation

Closes #2517


## Test Plan


1. Create a new `test_run_openai.yaml` and paste the following config in
it:

```yaml
version: '2'
image_name: test-image
apis:
- inference
providers:
  inference:
  - provider_id: openai
    provider_type: remote::openai
    config:
      max_tokens: 8192
models:
- metadata: {}
  model_id: "non-existent-model"
  provider_id: openai
  model_type: llm
server:
  port: 8321
```

And run the server with:
```bash
uv run llama stack run test_run_openai.yaml
```

You should now get a `llama_stack.exceptions.UnsupportedModelError` with
the supported list of models in the error message.

---

Tested for the following remote inference providers, and they all raise
the `UnsupportedModelError`:
- Anthropic
- Cerebras
- Fireworks
- Gemini
- Groq
- Ollama
- OpenAI
- SambaNova
- Together
- Watsonx

---------

Co-authored-by: Rohan Awhad <rawhad@redhat.com>
2025-06-27 14:26:58 -04:00

125 lines
5.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.models import ModelType
from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
# TODO: this class is more confusing than useful right now. We need to make it
# more closer to the Model class.
class ProviderModelEntry(BaseModel):
provider_model_id: str
aliases: list[str] = Field(default_factory=list)
llama_model: str | None = None
model_type: ModelType = ModelType.llm
metadata: dict[str, Any] = Field(default_factory=dict)
def get_huggingface_repo(model_descriptor: str) -> str | None:
for model in all_registered_models():
if model.descriptor() == model_descriptor:
return model.huggingface_repo
return None
def build_hf_repo_model_entry(
provider_model_id: str,
model_descriptor: str,
additional_aliases: list[str] | None = None,
) -> ProviderModelEntry:
aliases = [
get_huggingface_repo(model_descriptor),
]
if additional_aliases:
aliases.extend(additional_aliases)
return ProviderModelEntry(
provider_model_id=provider_model_id,
aliases=aliases,
llama_model=model_descriptor,
)
def build_model_entry(provider_model_id: str, model_descriptor: str) -> ProviderModelEntry:
return ProviderModelEntry(
provider_model_id=provider_model_id,
aliases=[],
llama_model=model_descriptor,
model_type=ModelType.llm,
)
class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_entries: list[ProviderModelEntry]):
self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {}
for entry in model_entries:
for alias in entry.aliases:
self.alias_to_provider_id_map[alias] = entry.provider_model_id
# also add a mapping from provider model id to itself for easy lookup
self.alias_to_provider_id_map[entry.provider_model_id] = entry.provider_model_id
if entry.llama_model:
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
def get_provider_model_id(self, identifier: str) -> str | None:
return self.alias_to_provider_id_map.get(identifier, None)
# TODO: why keep a separate llama model mapping?
def get_llama_model(self, provider_model_id: str) -> str | None:
return self.provider_id_to_llama_model_map.get(provider_model_id, None)
async def register_model(self, model: Model) -> Model:
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)):
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys())
provider_resource_id = self.get_provider_model_id(model.model_id)
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
if provider_resource_id:
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences
raise ValueError(
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
)
else:
llama_model = model.metadata.get("llama_model")
if llama_model:
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
raise ValueError(
f"Invalid llama_model '{llama_model}' specified in metadata. "
f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}"
)
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
self.alias_to_provider_id_map[model.model_id] = supported_model_id
return model
async def unregister_model(self, model_id: str) -> None:
# TODO: should we block unregistering base supported provider model IDs?
if model_id not in self.alias_to_provider_id_map:
raise ValueError(f"Model id '{model_id}' is not registered.")
del self.alias_to_provider_id_map[model_id]