forked from phoenix-oss/llama-stack-mirror
fix(api): update embeddings signature so inputs and outputs list align (#1161)
See Issue #922 The change is slightly backwards incompatible but no callsite (in our client codebases or stack-apps) every passes a depth-2 `List[List[InterleavedContentItem]]` (which is now disallowed.) ## Test Plan ```bash $ cd llama_stack/providers/tests/inference $ pytest -s -v -k fireworks test_embeddings.py \ --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k together test_embeddings.py \ --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k ollama test_embeddings.py \ --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784 ``` Also ran `tests/client-sdk/inference/test_embeddings.py`
This commit is contained in:
parent
cfa752fc92
commit
6f9d622340
17 changed files with 85 additions and 41 deletions
20
docs/_static/llama-stack-spec.html
vendored
20
docs/_static/llama-stack-spec.html
vendored
|
@ -4929,11 +4929,21 @@
|
|||
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
|
||||
},
|
||||
"contents": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/InterleavedContent"
|
||||
},
|
||||
"description": "List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text."
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||
}
|
||||
}
|
||||
],
|
||||
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
|
16
docs/_static/llama-stack-spec.yaml
vendored
16
docs/_static/llama-stack-spec.yaml
vendored
|
@ -3224,13 +3224,17 @@ components:
|
|||
The identifier of the model to use. The model must be an embedding model
|
||||
registered with Llama Stack and available via the /models endpoint.
|
||||
contents:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
oneOf:
|
||||
- type: array
|
||||
items:
|
||||
type: string
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/InterleavedContentItem'
|
||||
description: >-
|
||||
List of contents to generate embeddings for. Note that content can be
|
||||
multimodal. The behavior depends on the model and provider. Some models
|
||||
may only support text.
|
||||
List of contents to generate embeddings for. Each content can be a string
|
||||
or an InterleavedContentItem (and hence can be multimodal). The behavior
|
||||
depends on the model and provider. Some models may only support text.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- model_id
|
||||
|
|
|
@ -20,7 +20,7 @@ from typing import (
|
|||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
|
@ -481,12 +481,12 @@ class Inference(Protocol):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
"""Generate embeddings for content pieces using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||
:param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text.
|
||||
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
|
||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent
|
||||
from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
|
||||
from llama_stack.apis.eval import (
|
||||
BenchmarkConfig,
|
||||
|
@ -214,7 +214,7 @@ class InferenceRouter(Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
|
|
|
@ -23,6 +23,7 @@ from llama_stack.apis.inference import (
|
|||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
Inference,
|
||||
InterleavedContentItem,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
|
@ -230,5 +231,5 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse:
|
||||
async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
|||
|
||||
from botocore.client import BaseClient
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -162,7 +162,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embeddings = []
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from cerebras.cloud.sdk import AsyncCerebras
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
|
@ -172,6 +172,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional
|
|||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -130,7 +130,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedContent],
|
||||
model_id: str,
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from fireworks.client import Fireworks
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -232,7 +232,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from llama_stack.apis.inference import (
|
|||
EmbeddingsResponse,
|
||||
Inference,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
|
@ -140,7 +141,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from ollama import AsyncClient
|
|||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -258,7 +259,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
|
|
@ -69,9 +69,10 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
request = ChatCompletionRequest(
|
||||
model=model,
|
||||
|
@ -119,6 +120,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -5,16 +5,36 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
TopKSamplingStrategy,
|
||||
|
@ -119,7 +139,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import AsyncGenerator, List, Optional
|
|||
|
||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -268,7 +268,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
|
||||
from together import Together
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -219,7 +219,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
|
|
|
@ -10,7 +10,13 @@ from typing import AsyncGenerator, List, Optional, Union
|
|||
from llama_models.datatypes import StopReason, ToolCall
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
@ -376,7 +382,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from typing import List
|
|||
|
||||
from llama_stack.apis.inference import (
|
||||
EmbeddingsResponse,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
ModelStore,
|
||||
)
|
||||
|
||||
|
@ -25,7 +25,7 @@ class SentenceTransformerEmbeddingMixin:
|
|||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedContent],
|
||||
contents: List[str] | List[InterleavedContentItem],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue