Update the router

This commit is contained in:
Ashwin Bharambe 2025-02-20 22:24:59 -08:00
parent 2c1e8b5956
commit e9fd8371a8

View file

@ -6,7 +6,11 @@
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import ( from llama_stack.apis.eval import (
BenchmarkConfig, BenchmarkConfig,
@ -17,11 +21,13 @@ from llama_stack.apis.eval import (
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -215,6 +221,9 @@ class InferenceRouter(Inference):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
@ -224,6 +233,9 @@ class InferenceRouter(Inference):
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
) )