refactor: standardize InferenceRouter model handling

* introduces ModelTypeError custom exception class
* introduces _get_model private method in InferenceRouter class
* standardizes inconsistent variable name usage for models in InferenceRouter class
* removes unneeded model type check in ollama provider

Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
Nathan Weinberg 2025-07-30 13:31:46 -04:00 committed by Nathan Weinberg
parent 803114180b
commit ff8942bc71
4 changed files with 28 additions and 38 deletions

View file

@ -62,3 +62,13 @@ class SessionNotFoundError(ValueError):
def __init__(self, session_name: str) -> None: def __init__(self, session_name: str) -> None:
message = f"Session '{session_name}' not found or access denied." message = f"Session '{session_name}' not found or access denied."
super().__init__(message) super().__init__(message)
class ModelTypeError(TypeError):
"""raised when a model is present but not the correct type"""
def __init__(self, model_name: str, model_type: str, expected_model_type: str) -> None:
message = (
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
)
super().__init__(message)

View file

@ -18,7 +18,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
) )
from llama_stack.apis.common.errors import ModelNotFoundError from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BatchChatCompletionResponse, BatchChatCompletionResponse,
BatchCompletionResponse, BatchCompletionResponse,
@ -177,6 +177,15 @@ class InferenceRouter(Inference):
encoded = self.formatter.encode_content(messages) encoded = self.formatter.encode_content(messages)
return len(encoded.tokens) if encoded and encoded.tokens else 0 return len(encoded.tokens) if encoded and encoded.tokens else 0
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
model = await self.routing_table.get_model(model_id)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type != expected_model_type:
raise ModelTypeError(model_id, model.model_type, expected_model_type)
return model
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
@ -195,11 +204,7 @@ class InferenceRouter(Inference):
) )
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
model = await self.routing_table.get_model(model_id) model = await self._get_model(model_id, ModelType.llm)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
if tool_config: if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice: if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match") raise ValueError("tool_choice and tool_config.tool_choice must match")
@ -301,11 +306,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}", f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
) )
model = await self.routing_table.get_model(model_id) model = await self._get_model(model_id, ModelType.llm)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
provider = await self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -355,11 +356,7 @@ class InferenceRouter(Inference):
task_type: EmbeddingTaskType | None = None, task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}") logger.debug(f"InferenceRouter.embeddings: {model_id}")
model = await self.routing_table.get_model(model_id) await self._get_model(model_id, ModelType.embedding)
if model is None:
raise ModelNotFoundError(model_id)
if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
provider = await self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
return await provider.embeddings( return await provider.embeddings(
model_id=model_id, model_id=model_id,
@ -395,12 +392,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}", f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
) )
model_obj = await self.routing_table.get_model(model) model_obj = await self._get_model(model, ModelType.llm)
if model_obj is None:
raise ModelNotFoundError(model)
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
params = dict( params = dict(
model=model_obj.identifier, model=model_obj.identifier,
prompt=prompt, prompt=prompt,
@ -476,11 +468,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}", f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
) )
model_obj = await self.routing_table.get_model(model) model_obj = await self._get_model(model, ModelType.llm)
if model_obj is None:
raise ModelNotFoundError(model)
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
# Use the OpenAI client for a bit of extra input validation without # Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface # exposing the OpenAI client itself as part of our API surface
@ -567,12 +555,7 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}", f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
) )
model_obj = await self.routing_table.get_model(model) model_obj = await self._get_model(model, ModelType.embedding)
if model_obj is None:
raise ModelNotFoundError(model)
if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model '{model}' is not an embedding model")
params = dict( params = dict(
model=model_obj.identifier, model=model_obj.identifier,
input=input, input=input,

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import TypeAdapter from pydantic import TypeAdapter
from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
@ -66,7 +66,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if model is None: if model is None:
raise ModelNotFoundError(embedding_model) raise ModelNotFoundError(embedding_model)
if model.model_type != ModelType.embedding: if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model") raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata: if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension") raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
vector_db_data = { vector_db_data = {

View file

@ -457,9 +457,6 @@ class OllamaInferenceAdapter(
user: str | None = None, user: str | None = None,
) -> OpenAIEmbeddingsResponse: ) -> OpenAIEmbeddingsResponse:
model_obj = await self._get_model(model) model_obj = await self._get_model(model)
if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model {model} is not an embedding model")
if model_obj.provider_resource_id is None: if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set") raise ValueError(f"Model {model} has no provider_resource_id set")