mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
refactor: standardize InferenceRouter model handling
* introduces ModelTypeError custom exception class * introduces _get_model private method in InferenceRouter class * standardizes inconsistent variable name usage for models in InferenceRouter class * removes unneeded model type check in ollama provider Signed-off-by: Nathan Weinberg <nweinber@redhat.com>
This commit is contained in:
parent
803114180b
commit
ff8942bc71
4 changed files with 28 additions and 38 deletions
|
@ -62,3 +62,13 @@ class SessionNotFoundError(ValueError):
|
||||||
def __init__(self, session_name: str) -> None:
|
def __init__(self, session_name: str) -> None:
|
||||||
message = f"Session '{session_name}' not found or access denied."
|
message = f"Session '{session_name}' not found or access denied."
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTypeError(TypeError):
|
||||||
|
"""raised when a model is present but not the correct type"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str, model_type: str, expected_model_type: str) -> None:
|
||||||
|
message = (
|
||||||
|
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
|
||||||
|
)
|
||||||
|
super().__init__(message)
|
||||||
|
|
|
@ -18,7 +18,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
BatchChatCompletionResponse,
|
BatchChatCompletionResponse,
|
||||||
BatchCompletionResponse,
|
BatchCompletionResponse,
|
||||||
|
@ -177,6 +177,15 @@ class InferenceRouter(Inference):
|
||||||
encoded = self.formatter.encode_content(messages)
|
encoded = self.formatter.encode_content(messages)
|
||||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||||
|
|
||||||
|
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
|
||||||
|
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
|
||||||
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ModelNotFoundError(model_id)
|
||||||
|
if model.model_type != expected_model_type:
|
||||||
|
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||||
|
return model
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -195,11 +204,7 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self._get_model(model_id, ModelType.llm)
|
||||||
if model is None:
|
|
||||||
raise ModelNotFoundError(model_id)
|
|
||||||
if model.model_type == ModelType.embedding:
|
|
||||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
|
||||||
if tool_config:
|
if tool_config:
|
||||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||||
|
@ -301,11 +306,7 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||||
)
|
)
|
||||||
model = await self.routing_table.get_model(model_id)
|
model = await self._get_model(model_id, ModelType.llm)
|
||||||
if model is None:
|
|
||||||
raise ModelNotFoundError(model_id)
|
|
||||||
if model.model_type == ModelType.embedding:
|
|
||||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
|
||||||
provider = await self.routing_table.get_provider_impl(model_id)
|
provider = await self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -355,11 +356,7 @@ class InferenceRouter(Inference):
|
||||||
task_type: EmbeddingTaskType | None = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||||
model = await self.routing_table.get_model(model_id)
|
await self._get_model(model_id, ModelType.embedding)
|
||||||
if model is None:
|
|
||||||
raise ModelNotFoundError(model_id)
|
|
||||||
if model.model_type == ModelType.llm:
|
|
||||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
|
||||||
provider = await self.routing_table.get_provider_impl(model_id)
|
provider = await self.routing_table.get_provider_impl(model_id)
|
||||||
return await provider.embeddings(
|
return await provider.embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -395,12 +392,7 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||||
)
|
)
|
||||||
model_obj = await self.routing_table.get_model(model)
|
model_obj = await self._get_model(model, ModelType.llm)
|
||||||
if model_obj is None:
|
|
||||||
raise ModelNotFoundError(model)
|
|
||||||
if model_obj.model_type == ModelType.embedding:
|
|
||||||
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
|
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model=model_obj.identifier,
|
model=model_obj.identifier,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
@ -476,11 +468,7 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||||
)
|
)
|
||||||
model_obj = await self.routing_table.get_model(model)
|
model_obj = await self._get_model(model, ModelType.llm)
|
||||||
if model_obj is None:
|
|
||||||
raise ModelNotFoundError(model)
|
|
||||||
if model_obj.model_type == ModelType.embedding:
|
|
||||||
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
|
||||||
|
|
||||||
# Use the OpenAI client for a bit of extra input validation without
|
# Use the OpenAI client for a bit of extra input validation without
|
||||||
# exposing the OpenAI client itself as part of our API surface
|
# exposing the OpenAI client itself as part of our API surface
|
||||||
|
@ -567,12 +555,7 @@ class InferenceRouter(Inference):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
||||||
)
|
)
|
||||||
model_obj = await self.routing_table.get_model(model)
|
model_obj = await self._get_model(model, ModelType.embedding)
|
||||||
if model_obj is None:
|
|
||||||
raise ModelNotFoundError(model)
|
|
||||||
if model_obj.model_type != ModelType.embedding:
|
|
||||||
raise ValueError(f"Model '{model}' is not an embedding model")
|
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model=model_obj.identifier,
|
model=model_obj.identifier,
|
||||||
input=input,
|
input=input,
|
||||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
|
@ -66,7 +66,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
if model is None:
|
if model is None:
|
||||||
raise ModelNotFoundError(embedding_model)
|
raise ModelNotFoundError(embedding_model)
|
||||||
if model.model_type != ModelType.embedding:
|
if model.model_type != ModelType.embedding:
|
||||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||||
if "embedding_dimension" not in model.metadata:
|
if "embedding_dimension" not in model.metadata:
|
||||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||||
vector_db_data = {
|
vector_db_data = {
|
||||||
|
|
|
@ -457,9 +457,6 @@ class OllamaInferenceAdapter(
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
) -> OpenAIEmbeddingsResponse:
|
) -> OpenAIEmbeddingsResponse:
|
||||||
model_obj = await self._get_model(model)
|
model_obj = await self._get_model(model)
|
||||||
if model_obj.model_type != ModelType.embedding:
|
|
||||||
raise ValueError(f"Model {model} is not an embedding model")
|
|
||||||
|
|
||||||
if model_obj.provider_resource_id is None:
|
if model_obj.provider_resource_id is None:
|
||||||
raise ValueError(f"Model {model} has no provider_resource_id set")
|
raise ValueError(f"Model {model} has no provider_resource_id set")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue