mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
chore: standardize model not found error (#2964)
# What does this PR do? 1. Creates a new `ModelNotFoundError` class 2. Implements the new class where appropriate Relates to #2379 Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
parent
266e2afb9c
commit
c5622c79de
6 changed files with 23 additions and 10 deletions
|
@ -11,3 +11,11 @@ class UnsupportedModelError(ValueError):
|
||||||
def __init__(self, model_name: str, supported_models_list: list[str]):
|
def __init__(self, model_name: str, supported_models_list: list[str]):
|
||||||
message = f"'{model_name}' model is not supported. Supported models are: {', '.join(supported_models_list)}"
|
message = f"'{model_name}' model is not supported. Supported models are: {', '.join(supported_models_list)}"
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotFoundError(ValueError):
|
||||||
|
"""raised when Llama Stack cannot find a referenced model"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str) -> None:
|
||||||
|
message = f"Model '{model_name}' not found. Use client.models.list() to list available models."
|
||||||
|
super().__init__(message)
|
||||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
BatchChatCompletionResponse,
|
BatchChatCompletionResponse,
|
||||||
BatchCompletionResponse,
|
BatchCompletionResponse,
|
||||||
|
@ -188,7 +189,7 @@ class InferenceRouter(Inference):
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ModelNotFoundError(model_id)
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
if tool_config:
|
if tool_config:
|
||||||
|
@ -317,7 +318,7 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ModelNotFoundError(model_id)
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||||
provider = await self.routing_table.get_provider_impl(model_id)
|
provider = await self.routing_table.get_provider_impl(model_id)
|
||||||
|
@ -390,7 +391,7 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self.routing_table.get_model(model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ModelNotFoundError(model_id)
|
||||||
if model.model_type == ModelType.llm:
|
if model.model_type == ModelType.llm:
|
||||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||||
provider = await self.routing_table.get_provider_impl(model_id)
|
provider = await self.routing_table.get_provider_impl(model_id)
|
||||||
|
@ -430,7 +431,7 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
model_obj = await self.routing_table.get_model(model)
|
model_obj = await self.routing_table.get_model(model)
|
||||||
if model_obj is None:
|
if model_obj is None:
|
||||||
raise ValueError(f"Model '{model}' not found")
|
raise ModelNotFoundError(model)
|
||||||
if model_obj.model_type == ModelType.embedding:
|
if model_obj.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
|
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
|
||||||
|
|
||||||
|
@ -491,7 +492,7 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
model_obj = await self.routing_table.get_model(model)
|
model_obj = await self.routing_table.get_model(model)
|
||||||
if model_obj is None:
|
if model_obj is None:
|
||||||
raise ValueError(f"Model '{model}' not found")
|
raise ModelNotFoundError(model)
|
||||||
if model_obj.model_type == ModelType.embedding:
|
if model_obj.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
||||||
|
|
||||||
|
@ -562,7 +563,7 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
model_obj = await self.routing_table.get_model(model)
|
model_obj = await self.routing_table.get_model(model)
|
||||||
if model_obj is None:
|
if model_obj is None:
|
||||||
raise ValueError(f"Model '{model}' not found")
|
raise ModelNotFoundError(model)
|
||||||
if model_obj.model_type != ModelType.embedding:
|
if model_obj.model_type != ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model}' is not an embedding model")
|
raise ValueError(f"Model '{model}' is not an embedding model")
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
|
@ -257,7 +258,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) ->
|
||||||
models = await routing_table.get_all_with_type("model")
|
models = await routing_table.get_all_with_type("model")
|
||||||
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
matching_models = [m for m in models if m.provider_resource_id == model_id]
|
||||||
if len(matching_models) == 0:
|
if len(matching_models) == 0:
|
||||||
raise ValueError(f"Model '{model_id}' not found")
|
raise ModelNotFoundError(model_id)
|
||||||
|
|
||||||
if len(matching_models) > 1:
|
if len(matching_models) > 1:
|
||||||
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
|
raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}")
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
ModelWithOwner,
|
ModelWithOwner,
|
||||||
|
@ -111,7 +112,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
existing_model = await self.get_model(model_id)
|
existing_model = await self.get_model(model_id)
|
||||||
if existing_model is None:
|
if existing_model is None:
|
||||||
raise ValueError(f"Model {model_id} not found")
|
raise ModelNotFoundError(model_id)
|
||||||
await self.unregister_object(existing_model)
|
await self.unregister_object(existing_model)
|
||||||
|
|
||||||
async def update_registered_models(
|
async def update_registered_models(
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
|
@ -63,7 +64,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||||
model = await lookup_model(self, embedding_model)
|
model = await lookup_model(self, embedding_model)
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ValueError(f"Model {embedding_model} not found")
|
raise ModelNotFoundError(embedding_model)
|
||||||
if model.model_type != ModelType.embedding:
|
if model.model_type != ModelType.embedding:
|
||||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||||
if "embedding_dimension" not in model.metadata:
|
if "embedding_dimension" not in model.metadata:
|
||||||
|
|
|
@ -15,6 +15,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||||
from llama_stack.models.llama.llama3.generation import Llama3
|
from llama_stack.models.llama.llama3.generation import Llama3
|
||||||
from llama_stack.models.llama.llama4.generation import Llama4
|
from llama_stack.models.llama.llama4.generation import Llama4
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
@ -34,7 +35,7 @@ def run_main(
|
||||||
|
|
||||||
llama_model = resolve_model(model_id)
|
llama_model = resolve_model(model_id)
|
||||||
if not llama_model:
|
if not llama_model:
|
||||||
raise ValueError(f"Model {model_id} not found")
|
raise ModelNotFoundError(model_id)
|
||||||
|
|
||||||
cls = Llama4 if llama4 else Llama3
|
cls = Llama4 if llama4 else Llama3
|
||||||
generator = cls.build(
|
generator = cls.build(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue