chore: standardize model not found error (#2964)

# What does this PR do?
1. Creates a new `ModelNotFoundError` class
2. Implements the new class where appropriate 

Relates to #2379

Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
Nathan Weinberg 2025-07-30 15:19:53 -04:00 committed by GitHub
parent 266e2afb9c
commit c5622c79de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 23 additions and 10 deletions

View file

@ -11,3 +11,11 @@ class UnsupportedModelError(ValueError):
def __init__(self, model_name: str, supported_models_list: list[str]):
message = f"'{model_name}' model is not supported. Supported models are: {', '.join(supported_models_list)}"
super().__init__(message)
class ModelNotFoundError(ValueError):
"""raised when Llama Stack cannot find a referenced model"""
def __init__(self, model_name: str) -> None:
message = f"Model '{model_name}' not found. Use client.models.list() to list available models."
super().__init__(message)

View file

@ -17,6 +17,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.inference import (
BatchChatCompletionResponse,
BatchCompletionResponse,
@ -188,7 +189,7 @@ class InferenceRouter(Inference):
sampling_params = SamplingParams()
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
if tool_config:
@ -317,7 +318,7 @@ class InferenceRouter(Inference):
)
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
provider = await self.routing_table.get_provider_impl(model_id)
@ -390,7 +391,7 @@ class InferenceRouter(Inference):
logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
provider = await self.routing_table.get_provider_impl(model_id)
@ -430,7 +431,7 @@ class InferenceRouter(Inference):
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ValueError(f"Model '{model}' not found")
raise ModelNotFoundError(model)
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
@ -491,7 +492,7 @@ class InferenceRouter(Inference):
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ValueError(f"Model '{model}' not found")
raise ModelNotFoundError(model)
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
@ -562,7 +563,7 @@ class InferenceRouter(Inference):
)
model_obj = await self.routing_table.get_model(model)
if model_obj is None:
raise ValueError(f"Model '{model}' not found")
raise ModelNotFoundError(model)
if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model '{model}' is not an embedding model")

View file

@ -6,6 +6,7 @@
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import Model
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
@ -257,7 +258,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
models = await routing_table.get_all_with_type("model")
matching_models = [m for m in models if m.provider_resource_id == model_id]
if len(matching_models) == 0:
raise ValueError(f"Model '{model_id}' not found")
raise ModelNotFoundError(model_id)
if len(matching_models) > 1:
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")

View file

@ -7,6 +7,7 @@
import time
from typing import Any
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import (
ModelWithOwner,
@ -111,7 +112,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
raise ModelNotFoundError(model_id)
await self.unregister_object(existing_model)
async def update_registered_models(

View file

@ -8,6 +8,7 @@ from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
@ -63,7 +64,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await lookup_model(self, embedding_model)
if model is None:
raise ValueError(f"Model {embedding_model} not found")
raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata:

View file

@ -15,6 +15,7 @@ from pathlib import Path
import fire
from llama_stack.apis.common.errors import ModelNotFoundError
from llama_stack.models.llama.llama3.generation import Llama3
from llama_stack.models.llama.llama4.generation import Llama4
from llama_stack.models.llama.sku_list import resolve_model
@ -34,7 +35,7 @@ def run_main(
llama_model = resolve_model(model_id)
if not llama_model:
raise ValueError(f"Model {model_id} not found")
raise ModelNotFoundError(model_id)
cls = Llama4 if llama4 else Llama3
generator = cls.build(