diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 638f7bb7b..fab7c802e 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -4944,6 +4944,27 @@
}
],
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
+ },
+ "text_truncation": {
+ "type": "string",
+ "enum": [
+ "none",
+ "start",
+ "end"
+ ],
+ "description": "(Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length."
+ },
+ "output_dimension": {
+ "type": "integer",
+ "description": "(Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models."
+ },
+ "task_type": {
+ "type": "string",
+ "enum": [
+ "query",
+ "document"
+ ],
+ "description": "(Optional) How is the embedding being used? This is only supported by asymmetric embedding models."
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 08effe7cf..fc57bf258 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -3235,6 +3235,28 @@ components:
List of contents to generate embeddings for. Each content can be a string
or an InterleavedContentItem (and hence can be multimodal). The behavior
depends on the model and provider. Some models may only support text.
+ text_truncation:
+ type: string
+ enum:
+ - none
+ - start
+ - end
+ description: >-
+ (Optional) Config for how to truncate text for embedding when text is
+ longer than the model's max sequence length.
+ output_dimension:
+ type: integer
+ description: >-
+ (Optional) Output dimensionality for the embeddings. Only supported by
+ Matryoshka models.
+ task_type:
+ type: string
+ enum:
+ - query
+ - document
+ description: >-
+ (Optional) How is the embedding being used? This is only supported by
+ asymmetric embedding models.
additionalProperties: false
required:
- model_id
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index 2dfe55977..d83506dd4 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -402,6 +402,30 @@ class ModelStore(Protocol):
def get_model(self, identifier: str) -> Model: ...
+class TextTruncation(Enum):
+ """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
+
+ :cvar none: No truncation (default). If the text is longer than the model's max sequence length, you will get an error.
+ :cvar start: Truncate from the start
+ :cvar end: Truncate from the end
+ """
+
+ none = "none"
+ start = "start"
+ end = "end"
+
+
+class EmbeddingTaskType(Enum):
+ """How is the embedding being used? This is only supported by asymmetric embedding models.
+
+ :cvar query: Used for a query for semantic search.
+ :cvar document: Used at indexing time when ingesting documents.
+ """
+
+ query = "query"
+ document = "document"
+
+
@runtime_checkable
@trace_protocol
class Inference(Protocol):
@@ -482,11 +506,17 @@ class Inference(Protocol):
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model.
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
+ :param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
+ :param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
+ :param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
"""
...
diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py
index d885ebc09..df4ed03d3 100644
--- a/llama_stack/distribution/routers/routers.py
+++ b/llama_stack/distribution/routers/routers.py
@@ -6,7 +6,11 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
-from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem
+from llama_stack.apis.common.content_types import (
+ URL,
+ InterleavedContent,
+ InterleavedContentItem,
+)
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
BenchmarkConfig,
@@ -17,11 +21,13 @@ from llama_stack.apis.eval import (
)
from llama_stack.apis.inference import (
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
+ TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -215,6 +221,9 @@ class InferenceRouter(Inference):
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
@@ -224,6 +233,9 @@ class InferenceRouter(Inference):
return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id,
contents=contents,
+ text_truncation=text_truncation,
+ output_dimension=output_dimension,
+ task_type=task_type,
)
diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py
index a6f7a78af..d03ea933a 100644
--- a/llama_stack/providers/inline/inference/vllm/vllm.py
+++ b/llama_stack/providers/inline/inference/vllm/vllm.py
@@ -22,12 +22,14 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
+ TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -231,5 +233,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
- async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse:
+ async def embeddings(
+ self,
+ model_id: str,
+ contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
+ ) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py
index 69fb5dea2..b82a4c752 100644
--- a/llama_stack/providers/remote/inference/bedrock/bedrock.py
+++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py
@@ -9,17 +9,22 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient
-from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
+from llama_stack.apis.common.content_types import (
+ InterleavedContent,
+ InterleavedContentItem,
+)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
+ TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -163,6 +168,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embeddings = []
diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py
index 71b9155d5..4deeea630 100644
--- a/llama_stack/providers/remote/inference/cerebras/cerebras.py
+++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py
@@ -8,17 +8,22 @@ from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras
-from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
+from llama_stack.apis.common.content_types import (
+ InterleavedContent,
+ InterleavedContentItem,
+)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
+ TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -173,5 +178,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py
index e3acd4314..75751c8b1 100644
--- a/llama_stack/providers/remote/inference/databricks/databricks.py
+++ b/llama_stack/providers/remote/inference/databricks/databricks.py
@@ -8,16 +8,21 @@ from typing import AsyncGenerator, List, Optional
from openai import OpenAI
-from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
+from llama_stack.apis.common.content_types import (
+ InterleavedContent,
+ InterleavedContentItem,
+)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
+ TextTruncation,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
@@ -132,5 +137,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py
index 95fe65c39..b9b23584b 100644
--- a/llama_stack/providers/remote/inference/fireworks/fireworks.py
+++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py
@@ -8,19 +8,24 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
-from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
+from llama_stack.apis.common.content_types import (
+ InterleavedContent,
+ InterleavedContentItem,
+)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
+ TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -233,6 +238,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py
index 4b21ae81d..45c15a467 100644
--- a/llama_stack/providers/remote/inference/groq/groq.py
+++ b/llama_stack/providers/remote/inference/groq/groq.py
@@ -17,17 +17,23 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
InterleavedContent,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
+ TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
-from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
+from llama_stack.models.llama.datatypes import (
+ SamplingParams,
+ ToolDefinition,
+ ToolPromptFormat,
+)
from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import (
@@ -142,6 +148,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py
index 6f38230b2..ecd53e91c 100644
--- a/llama_stack/providers/remote/inference/nvidia/nvidia.py
+++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py
@@ -23,14 +23,20 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
+ TextTruncation,
ToolChoice,
ToolConfig,
)
-from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
+from llama_stack.models.llama.datatypes import (
+ SamplingParams,
+ ToolDefinition,
+ ToolPromptFormat,
+)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@@ -122,6 +128,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
if any(content_has_media(content) for content in contents):
raise NotImplementedError("Media is not supported")
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index 0071aaa5d..62c8381a8 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -21,11 +21,13 @@ from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
+ TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -260,6 +262,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py
index a34c34f69..11da6bb9e 100644
--- a/llama_stack/providers/remote/inference/passthrough/passthrough.py
+++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py
@@ -11,11 +11,13 @@ from llama_stack_client import LlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
+ TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -138,6 +140,9 @@ class PassthroughInferenceAdapter(Inference):
self,
model_id: str,
contents: List[InterleavedContent],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
client = self._get_client()
model = await self.model_store.get_model(model_id)
@@ -145,4 +150,7 @@ class PassthroughInferenceAdapter(Inference):
return client.inference.embeddings(
model_id=model.provider_resource_id,
contents=contents,
+ text_truncation=text_truncation,
+ output_dimension=output_dimension,
+ task_type=task_type,
)
diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py
index a5acc47f8..bd620aa64 100644
--- a/llama_stack/providers/remote/inference/runpod/runpod.py
+++ b/llama_stack/providers/remote/inference/runpod/runpod.py
@@ -121,5 +121,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
self,
model: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py
index b60954abc..57a296258 100644
--- a/llama_stack/providers/remote/inference/sambanova/sambanova.py
+++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py
@@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
ChatCompletionResponse,
CompletionMessage,
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
@@ -27,6 +28,7 @@ from llama_stack.apis.inference import (
SamplingParams,
StopReason,
SystemMessage,
+ TextTruncation,
ToolCall,
ToolChoice,
ToolConfig,
@@ -140,6 +142,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py
index a52abd20d..d09ca241f 100644
--- a/llama_stack/providers/remote/inference/tgi/tgi.py
+++ b/llama_stack/providers/remote/inference/tgi/tgi.py
@@ -10,18 +10,23 @@ from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi
-from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
+from llama_stack.apis.common.content_types import (
+ InterleavedContent,
+ InterleavedContentItem,
+)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
+ TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -269,6 +274,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index a2c4f1542..1fca54bb3 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -8,18 +8,23 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together
-from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
+from llama_stack.apis.common.content_types import (
+ InterleavedContent,
+ InterleavedContentItem,
+)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
+ TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -220,6 +225,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index bff5da8a7..b9422d85d 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -28,12 +28,14 @@ from llama_stack.apis.inference import (
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
+ EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
ResponseFormatType,
SamplingParams,
+ TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -383,6 +385,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py
index 947a62f09..32aa5da3f 100644
--- a/llama_stack/providers/utils/inference/embedding_mixin.py
+++ b/llama_stack/providers/utils/inference/embedding_mixin.py
@@ -5,12 +5,14 @@
# the root directory of this source tree.
import logging
-from typing import List
+from typing import List, Optional
from llama_stack.apis.inference import (
EmbeddingsResponse,
+ EmbeddingTaskType,
InterleavedContentItem,
ModelStore,
+ TextTruncation,
)
EMBEDDING_MODELS = {}
@@ -26,6 +28,9 @@ class SentenceTransformerEmbeddingMixin:
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
+ text_truncation: Optional[TextTruncation] = TextTruncation.none,
+ output_dimension: Optional[int] = None,
+ task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)