From e9fd8371a8cfb7fd1879a745fa590abe55e49f71 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 20 Feb 2025 22:24:59 -0800 Subject: [PATCH] Update the router --- llama_stack/distribution/routers/routers.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index d885ebc09..df4ed03d3 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,7 +6,11 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + URL, + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.eval import ( BenchmarkConfig, @@ -17,11 +21,13 @@ from llama_stack.apis.eval import ( ) from llama_stack.apis.inference import ( EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -215,6 +221,9 @@ class InferenceRouter(Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.routing_table.get_model(model_id) if model is None: @@ -224,6 +233,9 @@ class InferenceRouter(Inference): return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, + text_truncation=text_truncation, + output_dimension=output_dimension, + task_type=task_type, )