fix(api): update embeddings signature so inputs and outputs list align (#1161)

See Issue #922 

The change is slightly backwards incompatible but no callsite (in our
client codebases or stack-apps) every passes a depth-2
`List[List[InterleavedContentItem]]` (which is now disallowed.)

## Test Plan

```bash
$ cd llama_stack/providers/tests/inference
$ pytest -s -v -k fireworks test_embeddings.py \
   --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784
$  pytest -s -v -k together test_embeddings.py \
   --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784
$ pytest -s -v -k ollama test_embeddings.py \
   --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784
```

Also ran `tests/client-sdk/inference/test_embeddings.py`
This commit is contained in:
Ashwin Bharambe 2025-02-20 21:43:13 -08:00 committed by GitHub
parent cfa752fc92
commit 6f9d622340
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 85 additions and 41 deletions

View file

@ -4929,11 +4929,21 @@
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
},
"contents": {
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContent"
},
"description": "List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text."
"oneOf": [
{
"type": "array",
"items": {
"type": "string"
}
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContentItem"
}
}
],
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
}
},
"additionalProperties": false,

View file

@ -3224,13 +3224,17 @@ components:
The identifier of the model to use. The model must be an embedding model
registered with Llama Stack and available via the /models endpoint.
contents:
type: array
items:
$ref: '#/components/schemas/InterleavedContent'
oneOf:
- type: array
items:
type: string
- type: array
items:
$ref: '#/components/schemas/InterleavedContentItem'
description: >-
List of contents to generate embeddings for. Note that content can be
multimodal. The behavior depends on the model and provider. Some models
may only support text.
List of contents to generate embeddings for. Each content can be a string
or an InterleavedContentItem (and hence can be multimodal). The behavior
depends on the model and provider. Some models may only support text.
additionalProperties: false
required:
- model_id

View file

@ -20,7 +20,7 @@ from typing import (
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
@ -481,12 +481,12 @@ class Inference(Protocol):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model.
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text.
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
"""
...

View file

@ -6,7 +6,7 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
BenchmarkConfig,
@ -214,7 +214,7 @@ class InferenceRouter(Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:

View file

@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
@ -230,5 +231,5 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -9,7 +9,7 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from botocore.client import BaseClient
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -162,7 +162,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embeddings = []

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
@ -172,6 +172,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -130,7 +130,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
contents: List[InterleavedContent],
model_id: str,
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -232,7 +232,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)

View file

@ -19,6 +19,7 @@ from llama_stack.apis.inference import (
EmbeddingsResponse,
Inference,
InterleavedContent,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
@ -140,7 +141,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -13,6 +13,7 @@ from ollama import AsyncClient
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
@ -258,7 +259,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)

View file

@ -69,9 +69,10 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
request = ChatCompletionRequest(
model=model,
@ -119,6 +120,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -5,16 +5,36 @@
# the root directory of this source tree.
import json
from typing import AsyncGenerator
from typing import AsyncGenerator, List, Optional
from openai import OpenAI
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionMessage,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
StopReason,
SystemMessage,
ToolCall,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
ToolResponseMessage,
UserMessage,
)
from llama_stack.models.llama.datatypes import (
GreedySamplingStrategy,
TopKSamplingStrategy,
@ -119,7 +139,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -10,7 +10,7 @@ from typing import AsyncGenerator, List, Optional
from huggingface_hub import AsyncInferenceClient, HfApi
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -268,7 +268,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
from together import Together
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -219,7 +219,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), (

View file

@ -10,7 +10,13 @@ from typing import AsyncGenerator, List, Optional, Union
from llama_models.datatypes import StopReason, ToolCall
from openai import OpenAI
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -376,7 +382,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)

View file

@ -9,7 +9,7 @@ from typing import List
from llama_stack.apis.inference import (
EmbeddingsResponse,
InterleavedContent,
InterleavedContentItem,
ModelStore,
)
@ -25,7 +25,7 @@ class SentenceTransformerEmbeddingMixin:
async def embeddings(
self,
model_id: str,
contents: List[InterleavedContent],
contents: List[str] | List[InterleavedContentItem],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)