fix(api): update embeddings signature so inputs and outputs list align (#1161)

See Issue #922 

The change is slightly backwards incompatible but no callsite (in our
client codebases or stack-apps) every passes a depth-2
`List[List[InterleavedContentItem]]` (which is now disallowed.)

## Test Plan

```bash
$ cd llama_stack/providers/tests/inference
$ pytest -s -v -k fireworks test_embeddings.py \
   --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784
$  pytest -s -v -k together test_embeddings.py \
   --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784
$ pytest -s -v -k ollama test_embeddings.py \
   --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784
```

Also ran `tests/client-sdk/inference/test_embeddings.py`
This commit is contained in:
Ashwin Bharambe 2025-02-20 21:43:13 -08:00 committed by GitHub
parent cfa752fc92
commit 6f9d622340
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 85 additions and 41 deletions

View file

@ -4929,11 +4929,21 @@
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint." "description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
}, },
"contents": { "contents": {
"type": "array", "oneOf": [
"items": { {
"$ref": "#/components/schemas/InterleavedContent" "type": "array",
}, "items": {
"description": "List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text." "type": "string"
}
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContentItem"
}
}
],
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -3224,13 +3224,17 @@ components:
The identifier of the model to use. The model must be an embedding model The identifier of the model to use. The model must be an embedding model
registered with Llama Stack and available via the /models endpoint. registered with Llama Stack and available via the /models endpoint.
contents: contents:
type: array oneOf:
items: - type: array
$ref: '#/components/schemas/InterleavedContent' items:
type: string
- type: array
items:
$ref: '#/components/schemas/InterleavedContentItem'
description: >- description: >-
List of contents to generate embeddings for. Note that content can be List of contents to generate embeddings for. Each content can be a string
multimodal. The behavior depends on the model and provider. Some models or an InterleavedContentItem (and hence can be multimodal). The behavior
may only support text. depends on the model and provider. Some models may only support text.
additionalProperties: false additionalProperties: false
required: required:
- model_id - model_id

View file

@ -20,7 +20,7 @@ from typing import (
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
@ -481,12 +481,12 @@ class Inference(Protocol):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model. """Generate embeddings for content pieces using the specified model.
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint. :param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text. :param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
""" """
... ...

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import ( from llama_stack.apis.eval import (
BenchmarkConfig, BenchmarkConfig,
@ -214,7 +214,7 @@ class InferenceRouter(Inference):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id) model = await self.routing_table.get_model(model_id)
if model is None: if model is None:

View file

@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse, EmbeddingsResponse,
Inference, Inference,
InterleavedContentItem,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -230,5 +231,5 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse: async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -9,7 +9,7 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -162,7 +162,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
embeddings = [] embeddings = []

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
CompletionRequest, CompletionRequest,
@ -172,6 +172,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -130,7 +130,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -232,7 +232,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -19,6 +19,7 @@ from llama_stack.apis.inference import (
EmbeddingsResponse, EmbeddingsResponse,
Inference, Inference,
InterleavedContent, InterleavedContent,
InterleavedContentItem,
LogProbConfig, LogProbConfig,
Message, Message,
ResponseFormat, ResponseFormat,
@ -140,7 +141,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -13,6 +13,7 @@ from ollama import AsyncClient
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
InterleavedContentItem,
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -258,7 +259,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -69,9 +69,10 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model, model=model,
@ -119,6 +120,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model: str, model: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -5,16 +5,36 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
InterleavedContentItem,
TextContentItem, TextContentItem,
) )
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionMessage,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
SystemMessage,
ToolCall,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy, GreedySamplingStrategy,
TopKSamplingStrategy, TopKSamplingStrategy,
@ -119,7 +139,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -10,7 +10,7 @@ from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi from huggingface_hub import AsyncInferenceClient, HfApi
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -268,7 +268,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together from together import Together
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -219,7 +219,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (

View file

@ -10,7 +10,13 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import StopReason, ToolCall from llama_models.datatypes import StopReason, ToolCall
from openai import OpenAI from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
@ -376,7 +382,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)

View file

@ -9,7 +9,7 @@ from typing import List
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
EmbeddingsResponse, EmbeddingsResponse,
InterleavedContent, InterleavedContentItem,
ModelStore, ModelStore,
) )
@ -25,7 +25,7 @@ class SentenceTransformerEmbeddingMixin:
async def embeddings( async def embeddings(
self, self,
model_id: str, model_id: str,
contents: List[InterleavedContent], contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)