mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 14:08:00 +00:00
refactor: standardize InferenceRouter model handling (#2965)
Some checks failed
Integration Tests (Replay) / discover-tests (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 15s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 19s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 21s
Python Package Build Test / build (3.13) (push) Failing after 16s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 23s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 29s
Test External API and Providers / test-external (venv) (push) Failing after 20s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 25s
Unit Tests / unit-tests (3.12) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 17s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 21s
Unit Tests / unit-tests (3.13) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 29s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 22s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 22s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 17s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 24s
Pre-commit / pre-commit (push) Successful in 1m19s
Some checks failed
Integration Tests (Replay) / discover-tests (push) Successful in 3s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 15s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 19s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 21s
Python Package Build Test / build (3.13) (push) Failing after 16s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 23s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 29s
Test External API and Providers / test-external (venv) (push) Failing after 20s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 25s
Unit Tests / unit-tests (3.12) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 17s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 21s
Unit Tests / unit-tests (3.13) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 29s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 22s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 22s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 17s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 24s
Pre-commit / pre-commit (push) Successful in 1m19s
This commit is contained in:
parent
803114180b
commit
19123ca957
4 changed files with 28 additions and 38 deletions
|
@ -18,7 +18,7 @@ from llama_stack.apis.common.content_types import (
|
|||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.inference import (
|
||||
BatchChatCompletionResponse,
|
||||
BatchCompletionResponse,
|
||||
|
@ -177,6 +177,15 @@ class InferenceRouter(Inference):
|
|||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def _get_model(self, model_id: str, expected_model_type: str) -> Model:
|
||||
"""takes a model id and gets model after ensuring that it is accessible and of the correct type"""
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type != expected_model_type:
|
||||
raise ModelTypeError(model_id, model.model_type, expected_model_type)
|
||||
return model
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
@ -195,11 +204,7 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
model = await self._get_model(model_id, ModelType.llm)
|
||||
if tool_config:
|
||||
if tool_choice and tool_choice != tool_config.tool_choice:
|
||||
raise ValueError("tool_choice and tool_config.tool_choice must match")
|
||||
|
@ -301,11 +306,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.completion: {model_id=}, {stream=}, {content=}, {sampling_params=}, {response_format=}",
|
||||
)
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions")
|
||||
model = await self._get_model(model_id, ModelType.llm)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
|
@ -355,11 +356,7 @@ class InferenceRouter(Inference):
|
|||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ModelNotFoundError(model_id)
|
||||
if model.model_type == ModelType.llm:
|
||||
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings")
|
||||
await self._get_model(model_id, ModelType.embedding)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.embeddings(
|
||||
model_id=model_id,
|
||||
|
@ -395,12 +392,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
if model_obj is None:
|
||||
raise ModelNotFoundError(model)
|
||||
if model_obj.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
|
||||
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
prompt=prompt,
|
||||
|
@ -476,11 +468,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
if model_obj is None:
|
||||
raise ModelNotFoundError(model)
|
||||
if model_obj.model_type == ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
||||
model_obj = await self._get_model(model, ModelType.llm)
|
||||
|
||||
# Use the OpenAI client for a bit of extra input validation without
|
||||
# exposing the OpenAI client itself as part of our API surface
|
||||
|
@ -567,12 +555,7 @@ class InferenceRouter(Inference):
|
|||
logger.debug(
|
||||
f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}",
|
||||
)
|
||||
model_obj = await self.routing_table.get_model(model)
|
||||
if model_obj is None:
|
||||
raise ModelNotFoundError(model)
|
||||
if model_obj.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model '{model}' is not an embedding model")
|
||||
|
||||
model_obj = await self._get_model(model, ModelType.embedding)
|
||||
params = dict(
|
||||
model=model_obj.identifier,
|
||||
input=input,
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError, VectorStoreNotFoundError
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||
|
@ -66,7 +66,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
if model is None:
|
||||
raise ModelNotFoundError(embedding_model)
|
||||
if model.model_type != ModelType.embedding:
|
||||
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||
vector_db_data = {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue