From 81ce39a607b8f1a286a91e3ee2c6df5079433eb0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 20 Feb 2025 22:27:12 -0800 Subject: [PATCH] feat(api): Add options for supporting various embedding models (#1192) We need to support: - asymmetric embedding models (#934) - truncation policies (#933) - varying dimensional output (#932) ## Test Plan ```bash $ cd llama_stack/providers/tests/inference $ pytest -s -v -k fireworks test_embeddings.py \ --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k together test_embeddings.py \ --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k ollama test_embeddings.py \ --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784 ``` --- docs/_static/llama-stack-spec.html | 21 +++++++++++++ docs/_static/llama-stack-spec.yaml | 22 ++++++++++++++ llama_stack/apis/inference/inference.py | 30 +++++++++++++++++++ llama_stack/distribution/routers/routers.py | 14 ++++++++- .../providers/inline/inference/vllm/vllm.py | 11 ++++++- .../remote/inference/bedrock/bedrock.py | 10 ++++++- .../remote/inference/cerebras/cerebras.py | 10 ++++++- .../remote/inference/databricks/databricks.py | 10 ++++++- .../remote/inference/fireworks/fireworks.py | 10 ++++++- .../providers/remote/inference/groq/groq.py | 11 ++++++- .../remote/inference/nvidia/nvidia.py | 11 ++++++- .../remote/inference/ollama/ollama.py | 5 ++++ .../inference/passthrough/passthrough.py | 8 +++++ .../remote/inference/runpod/runpod.py | 3 ++ .../remote/inference/sambanova/sambanova.py | 5 ++++ .../providers/remote/inference/tgi/tgi.py | 10 ++++++- .../remote/inference/together/together.py | 10 ++++++- .../providers/remote/inference/vllm/vllm.py | 5 ++++ .../utils/inference/embedding_mixin.py | 7 ++++- 19 files changed, 202 insertions(+), 11 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 638f7bb7b..fab7c802e 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4944,6 +4944,27 @@ } ], "description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text." + }, + "text_truncation": { + "type": "string", + "enum": [ + "none", + "start", + "end" + ], + "description": "(Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length." + }, + "output_dimension": { + "type": "integer", + "description": "(Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models." + }, + "task_type": { + "type": "string", + "enum": [ + "query", + "document" + ], + "description": "(Optional) How is the embedding being used? This is only supported by asymmetric embedding models." } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 08effe7cf..fc57bf258 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3235,6 +3235,28 @@ components: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text. + text_truncation: + type: string + enum: + - none + - start + - end + description: >- + (Optional) Config for how to truncate text for embedding when text is + longer than the model's max sequence length. + output_dimension: + type: integer + description: >- + (Optional) Output dimensionality for the embeddings. Only supported by + Matryoshka models. + task_type: + type: string + enum: + - query + - document + description: >- + (Optional) How is the embedding being used? This is only supported by + asymmetric embedding models. additionalProperties: false required: - model_id diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 2dfe55977..d83506dd4 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -402,6 +402,30 @@ class ModelStore(Protocol): def get_model(self, identifier: str) -> Model: ... +class TextTruncation(Enum): + """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. + + :cvar none: No truncation (default). If the text is longer than the model's max sequence length, you will get an error. + :cvar start: Truncate from the start + :cvar end: Truncate from the end + """ + + none = "none" + start = "start" + end = "end" + + +class EmbeddingTaskType(Enum): + """How is the embedding being used? This is only supported by asymmetric embedding models. + + :cvar query: Used for a query for semantic search. + :cvar document: Used at indexing time when ingesting documents. + """ + + query = "query" + document = "document" + + @runtime_checkable @trace_protocol class Inference(Protocol): @@ -482,11 +506,17 @@ class Inference(Protocol): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: """Generate embeddings for content pieces using the specified model. :param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint. :param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text. + :param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models. + :param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length. + :param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models. :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} """ ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index d885ebc09..df4ed03d3 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,7 +6,11 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + URL, + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.eval import ( BenchmarkConfig, @@ -17,11 +21,13 @@ from llama_stack.apis.eval import ( ) from llama_stack.apis.inference import ( EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -215,6 +221,9 @@ class InferenceRouter(Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.routing_table.get_model(model_id) if model is None: @@ -224,6 +233,9 @@ class InferenceRouter(Inference): return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, + text_truncation=text_truncation, + output_dimension=output_dimension, + task_type=task_type, ) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index a6f7a78af..d03ea933a 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -22,12 +22,14 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, InterleavedContentItem, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -231,5 +233,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse: + async def embeddings( + self, + model_id: str, + contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, + ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 69fb5dea2..b82a4c752 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -9,17 +9,22 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -163,6 +168,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embeddings = [] diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 71b9155d5..4deeea630 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -8,17 +8,22 @@ from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, CompletionRequest, CompletionResponse, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -173,5 +178,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index e3acd4314..75751c8b1 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -8,16 +8,21 @@ from typing import AsyncGenerator, List, Optional from openai import OpenAI -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolDefinition, ToolPromptFormat, @@ -132,5 +137,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 95fe65c39..b9b23584b 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -8,19 +8,24 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, CompletionResponse, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, ResponseFormatType, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -233,6 +238,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 4b21ae81d..45c15a467 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -17,17 +17,23 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, InterleavedContent, InterleavedContentItem, LogProbConfig, Message, ResponseFormat, + TextTruncation, ToolChoice, ToolConfig, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.datatypes import ( + SamplingParams, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.models.llama.sku_list import CoreModelId from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.model_registry import ( @@ -142,6 +148,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 6f38230b2..ecd53e91c 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -23,14 +23,20 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, + TextTruncation, ToolChoice, ToolConfig, ) -from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.datatypes import ( + SamplingParams, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -122,6 +128,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: if any(content_has_media(content) for content in contents): raise NotImplementedError("Media is not supported") diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 0071aaa5d..62c8381a8 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -21,11 +21,13 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -260,6 +262,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index a34c34f69..11da6bb9e 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -11,11 +11,13 @@ from llama_stack_client import LlamaStackClient from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -138,6 +140,9 @@ class PassthroughInferenceAdapter(Inference): self, model_id: str, contents: List[InterleavedContent], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: client = self._get_client() model = await self.model_store.get_model(model_id) @@ -145,4 +150,7 @@ class PassthroughInferenceAdapter(Inference): return client.inference.embeddings( model_id=model.provider_resource_id, contents=contents, + text_truncation=text_truncation, + output_dimension=output_dimension, + task_type=task_type, ) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index a5acc47f8..bd620aa64 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -121,5 +121,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): self, model: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index b60954abc..57a296258 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -20,6 +20,7 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, CompletionMessage, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, @@ -27,6 +28,7 @@ from llama_stack.apis.inference import ( SamplingParams, StopReason, SystemMessage, + TextTruncation, ToolCall, ToolChoice, ToolConfig, @@ -140,6 +142,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index a52abd20d..d09ca241f 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -10,18 +10,23 @@ from typing import AsyncGenerator, List, Optional from huggingface_hub import AsyncInferenceClient, HfApi -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, ResponseFormatType, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -269,6 +274,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index a2c4f1542..1fca54bb3 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -8,18 +8,23 @@ from typing import AsyncGenerator, List, Optional, Union from together import Together -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, ResponseFormatType, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -220,6 +225,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) assert all(not content_has_media(content) for content in contents), ( diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index bff5da8a7..b9422d85d 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -28,12 +28,14 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, ResponseFormatType, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -383,6 +385,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 947a62f09..32aa5da3f 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -5,12 +5,14 @@ # the root directory of this source tree. import logging -from typing import List +from typing import List, Optional from llama_stack.apis.inference import ( EmbeddingsResponse, + EmbeddingTaskType, InterleavedContentItem, ModelStore, + TextTruncation, ) EMBEDDING_MODELS = {} @@ -26,6 +28,9 @@ class SentenceTransformerEmbeddingMixin: self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)