feat(api): Add options for supporting various embedding models (#1192)

We need to support:
- asymmetric embedding models (#934)
- truncation policies (#933)
- varying dimensional output (#932) 

## Test Plan

```bash
$ cd llama_stack/providers/tests/inference
$ pytest -s -v -k fireworks test_embeddings.py \
   --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784
$  pytest -s -v -k together test_embeddings.py \
   --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784
$ pytest -s -v -k ollama test_embeddings.py \
   --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784
```
This commit is contained in:
Ashwin Bharambe 2025-02-20 22:27:12 -08:00 committed by GitHub
parent 6f9d622340
commit 81ce39a607
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 202 additions and 11 deletions

View file

@ -4944,6 +4944,27 @@
} }
], ],
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text." "description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
},
"text_truncation": {
"type": "string",
"enum": [
"none",
"start",
"end"
],
"description": "(Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length."
},
"output_dimension": {
"type": "integer",
"description": "(Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models."
},
"task_type": {
"type": "string",
"enum": [
"query",
"document"
],
"description": "(Optional) How is the embedding being used? This is only supported by asymmetric embedding models."
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -3235,6 +3235,28 @@ components:
List of contents to generate embeddings for. Each content can be a string List of contents to generate embeddings for. Each content can be a string
or an InterleavedContentItem (and hence can be multimodal). The behavior or an InterleavedContentItem (and hence can be multimodal). The behavior
depends on the model and provider. Some models may only support text. depends on the model and provider. Some models may only support text.
text_truncation:
type: string
enum:
- none
- start
- end
description: >-
(Optional) Config for how to truncate text for embedding when text is
longer than the model's max sequence length.
output_dimension:
type: integer
description: >-
(Optional) Output dimensionality for the embeddings. Only supported by
Matryoshka models.
task_type:
type: string
enum:
- query
- document
description: >-
(Optional) How is the embedding being used? This is only supported by
asymmetric embedding models.
additionalProperties: false additionalProperties: false
required: required:
- model_id - model_id

View file

@ -402,6 +402,30 @@ class ModelStore(Protocol):
def get_model(self, identifier: str) -> Model: ... def get_model(self, identifier: str) -> Model: ...
class TextTruncation(Enum):
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
:cvar none: No truncation (default). If the text is longer than the model's max sequence length, you will get an error.
:cvar start: Truncate from the start
:cvar end: Truncate from the end
"""
none = "none"
start = "start"
end = "end"
class EmbeddingTaskType(Enum):
"""How is the embedding being used? This is only supported by asymmetric embedding models.
:cvar query: Used for a query for semantic search.
:cvar document: Used at indexing time when ingesting documents.
"""
query = "query"
document = "document"
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class Inference(Protocol): class Inference(Protocol):
@ -482,11 +506,17 @@ class Inference(Protocol):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model. """Generate embeddings for content pieces using the specified model.
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint. :param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text. :param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
""" """
... ...

View file

@ -6,7 +6,11 @@
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import ( from llama_stack.apis.eval import (
BenchmarkConfig, BenchmarkConfig,
@ -17,11 +21,13 @@ from llama_stack.apis.eval import (
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -215,6 +221,9 @@ class InferenceRouter(Inference):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:
@ -224,6 +233,9 @@ class InferenceRouter(Inference):
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
) )

View file

@ -22,12 +22,14 @@ from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
InterleavedContentItem, InterleavedContentItem,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -231,5 +233,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse: async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -9,17 +9,22 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -163,6 +168,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
embeddings = [] embeddings = []

View file

@ -8,17 +8,22 @@ from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -173,5 +178,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,16 +8,21 @@ from typing import AsyncGenerator, List, Optional
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
@ -132,5 +137,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,19 +8,24 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -233,6 +238,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -17,17 +17,23 @@ from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat from llama_stack.models.llama.datatypes import (
SamplingParams,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.models.llama.sku_list import CoreModelId from llama_stack.models.llama.sku_list import CoreModelId
from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
@ -142,6 +148,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -23,14 +23,20 @@ from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
) )
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat from llama_stack.models.llama.datatypes import (
SamplingParams,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
@ -122,6 +128,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
if any(content_has_media(content) for content in contents): if any(content_has_media(content) for content in contents):
raise NotImplementedError("Media is not supported") raise NotImplementedError("Media is not supported")

View file

@ -21,11 +21,13 @@ from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -260,6 +262,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -11,11 +11,13 @@ from llama_stack_client import LlamaStackClient
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -138,6 +140,9 @@ class PassthroughInferenceAdapter(Inference):
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[InterleavedContent],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
client = self._get_client() client = self._get_client()
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
@ -145,4 +150,7 @@ class PassthroughInferenceAdapter(Inference):
return client.inference.embeddings( return client.inference.embeddings(
model_id=model.provider_resource_id, model_id=model.provider_resource_id,
contents=contents, contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
) )

View file

@ -121,5 +121,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model: str, model: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
CompletionMessage, CompletionMessage,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
@ -27,6 +28,7 @@ from llama_stack.apis.inference import (
SamplingParams, SamplingParams,
StopReason, StopReason,
SystemMessage, SystemMessage,
TextTruncation,
ToolCall, ToolCall,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
@ -140,6 +142,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -10,18 +10,23 @@ from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi from huggingface_hub import AsyncInferenceClient, HfApi
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -269,6 +274,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,18 +8,23 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together from together import Together
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -220,6 +225,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (

View file

@ -28,12 +28,14 @@ from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -383,6 +385,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -5,12 +5,14 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from typing import List from typing import List, Optional
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType,
InterleavedContentItem, InterleavedContentItem,
ModelStore, ModelStore,
TextTruncation,
) )
EMBEDDING_MODELS = {} EMBEDDING_MODELS = {}
@ -26,6 +28,9 @@ class SentenceTransformerEmbeddingMixin:
self, self,
model_id: str, model_id: str,
contents: List[str] | List[InterleavedContentItem], contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)