diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index d885ebc09..df4ed03d3 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,7 +6,11 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + URL, + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.eval import ( BenchmarkConfig, @@ -17,11 +21,13 @@ from llama_stack.apis.eval import ( ) from llama_stack.apis.inference import ( EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -215,6 +221,9 @@ class InferenceRouter(Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.routing_table.get_model(model_id) if model is None: @@ -224,6 +233,9 @@ class InferenceRouter(Inference): return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, + text_truncation=text_truncation, + output_dimension=output_dimension, + task_type=task_type, )