chore: standardize model not found error (#2964)

# What does this PR do?
1. Creates a new `ModelNotFoundError` class
2. Implements the new class where appropriate 

Relates to #2379

Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
Nathan Weinberg 2025-07-30 15:19:53 -04:00 committed by GitHub
parent 266e2afb9c
commit c5622c79de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 23 additions and 10 deletions

View file

@ -11,3 +11,11 @@ class UnsupportedModelError(ValueError):
def __init__(self, model_name: str, supported_models_list: list[str]): def __init__(self, model_name: str, supported_models_list: list[str]):
message = f"'{model_name}' model is not supported. Supported models are: {', '.join(supported_models_list)}" message = f"'{model_name}' model is not supported. Supported models are: {', '.join(supported_models_list)}"
super().__init__(message) super().__init__(message)
class ModelNotFoundError(ValueError):
"""raised when Llama Stack cannot find a referenced model"""
def __init__(self, model_name: str) -> None:
message = f"Model '{model_name}' not found. Use client.models.list() to list available models."
super().__init__(message)

View file

@ -17,6 +17,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
) )
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BatchChatCompletionResponse, BatchChatCompletionResponse,
BatchCompletionResponse, BatchCompletionResponse,
@ -188,7 +189,7 @@ class InferenceRouter(Inference):
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ModelNotFoundError(model_id)
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
if tool_config: if tool_config:
@ -317,7 +318,7 @@ class InferenceRouter(Inference):
) )
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ModelNotFoundError(model_id)
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
provider = await self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
@ -390,7 +391,7 @@ class InferenceRouter(Inference):
logger.debug(f"InferenceRouter.embeddings: {model_id}") logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ModelNotFoundError(model_id)
if model.model_type == ModelType.llm: if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
provider = await self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
@ -430,7 +431,7 @@ class InferenceRouter(Inference):
) )
model_obj = await self.routing_table.get_model(model) model_obj = await self.routing_table.get_model(model)
if model_obj is None: if model_obj is None:
raise ValueError(f"Model '{model}' not found") raise ModelNotFoundError(model)
if model_obj.model_type == ModelType.embedding: if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support completions") raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
@ -491,7 +492,7 @@ class InferenceRouter(Inference):
) )
model_obj = await self.routing_table.get_model(model) model_obj = await self.routing_table.get_model(model)
if model_obj is None: if model_obj is None:
raise ValueError(f"Model '{model}' not found") raise ModelNotFoundError(model)
if model_obj.model_type == ModelType.embedding: if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions") raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
@ -562,7 +563,7 @@ class InferenceRouter(Inference):
) )
model_obj = await self.routing_table.get_model(model) model_obj = await self.routing_table.get_model(model)
if model_obj is None: if model_obj is None:
raise ValueError(f"Model '{model}' not found") raise ModelNotFoundError(model)
if model_obj.model_type != ModelType.embedding: if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model '{model}' is not an embedding model") raise ValueError(f"Model '{model}' is not an embedding model")

View file

@ -6,6 +6,7 @@
from typing import Any from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn
@ -257,7 +258,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
models = await routing_table.get_all_with_type("model") models = await routing_table.get_all_with_type("model")
matching_models = [m for m in models if m.provider_resource_id == model_id] matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0: if len(matching_models) == 0:
raise ValueError(f"Model '{model_id}' not found") raise ModelNotFoundError(model_id)
if len(matching_models) > 1: if len(matching_models) > 1:
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}") raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")

View file

@ -7,6 +7,7 @@
import time import time
from typing import Any from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
ModelWithOwner, ModelWithOwner,
@ -111,7 +112,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id) existing_model = await self.get_model(model_id)
if existing_model is None: if existing_model is None:
raise ValueError(f"Model {model_id} not found") raise ModelNotFoundError(model_id)
await self.unregister_object(existing_model) await self.unregister_object(existing_model)
async def update_registered_models( async def update_registered_models(

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import TypeAdapter from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
@ -63,7 +64,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
raise ValueError("No provider available. Please configure a vector_io provider.") raise ValueError("No provider available. Please configure a vector_io provider.")
model = await lookup_model(self, embedding_model) model = await lookup_model(self, embedding_model)
if model is None: if model is None:
raise ValueError(f"Model {embedding_model} not found") raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding: if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model") raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata: if "embedding_dimension" not in model.metadata:

View file

@ -15,6 +15,7 @@ from pathlib import Path
import fire import fire
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.models.llama.llama3.generation import Llama3 from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_list import resolve_model
@ -34,7 +35,7 @@ def run_main(
llama_model = resolve_model(model_id) llama_model = resolve_model(model_id)
if not llama_model: if not llama_model:
raise ValueError(f"Model {model_id} not found") raise ModelNotFoundError(model_id)
cls = Llama4 if llama4 else Llama3 cls = Llama4 if llama4 else Llama3
generator = cls.build( generator = cls.build(