chore(api): remove deprecated embeddings impls

This commit is contained in:
Matthew Farrellee 2025-09-02 02:02:02 -04:00
parent 478b4ff1e6
commit 30998fd1ff
20 changed files with 3 additions and 927 deletions

View file

@ -901,49 +901,6 @@
] ]
} }
}, },
"/v1/inference/embeddings": {
"post": {
"responses": {
"200": {
"description": "An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/EmbeddingsResponse"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Inference"
],
"description": "Generate embeddings for content pieces using the specified model.",
"parameters": [],
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/EmbeddingsRequest"
}
}
},
"required": true
}
}
},
"/v1/eval/benchmarks/{benchmark_id}/evaluations": { "/v1/eval/benchmarks/{benchmark_id}/evaluations": {
"post": { "post": {
"responses": { "responses": {
@ -9698,80 +9655,6 @@
"title": "OpenAIDeleteResponseObject", "title": "OpenAIDeleteResponseObject",
"description": "Response object confirming deletion of an OpenAI response." "description": "Response object confirming deletion of an OpenAI response."
}, },
"EmbeddingsRequest": {
"type": "object",
"properties": {
"model_id": {
"type": "string",
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
},
"contents": {
"oneOf": [
{
"type": "array",
"items": {
"type": "string"
}
},
{
"type": "array",
"items": {
"$ref": "#/components/schemas/InterleavedContentItem"
}
}
],
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
},
"text_truncation": {
"type": "string",
"enum": [
"none",
"start",
"end"
],
"description": "(Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length."
},
"output_dimension": {
"type": "integer",
"description": "(Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models."
},
"task_type": {
"type": "string",
"enum": [
"query",
"document"
],
"description": "(Optional) How is the embedding being used? This is only supported by asymmetric embedding models."
}
},
"additionalProperties": false,
"required": [
"model_id",
"contents"
],
"title": "EmbeddingsRequest"
},
"EmbeddingsResponse": {
"type": "object",
"properties": {
"embeddings": {
"type": "array",
"items": {
"type": "array",
"items": {
"type": "number"
}
},
"description": "List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}"
}
},
"additionalProperties": false,
"required": [
"embeddings"
],
"title": "EmbeddingsResponse",
"description": "Response containing generated embeddings."
},
"AgentCandidate": { "AgentCandidate": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -616,39 +616,6 @@ paths:
required: true required: true
schema: schema:
type: string type: string
/v1/inference/embeddings:
post:
responses:
'200':
description: >-
An array of embeddings, one for each content. Each embedding is a list
of floats. The dimensionality of the embedding is model-specific; you
can check model metadata using /models/{model_id}.
content:
application/json:
schema:
$ref: '#/components/schemas/EmbeddingsResponse'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Inference
description: >-
Generate embeddings for content pieces using the specified model.
parameters: []
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/EmbeddingsRequest'
required: true
/v1/eval/benchmarks/{benchmark_id}/evaluations: /v1/eval/benchmarks/{benchmark_id}/evaluations:
post: post:
responses: responses:
@ -7173,72 +7140,6 @@ components:
title: OpenAIDeleteResponseObject title: OpenAIDeleteResponseObject
description: >- description: >-
Response object confirming deletion of an OpenAI response. Response object confirming deletion of an OpenAI response.
EmbeddingsRequest:
type: object
properties:
model_id:
type: string
description: >-
The identifier of the model to use. The model must be an embedding model
registered with Llama Stack and available via the /models endpoint.
contents:
oneOf:
- type: array
items:
type: string
- type: array
items:
$ref: '#/components/schemas/InterleavedContentItem'
description: >-
List of contents to generate embeddings for. Each content can be a string
or an InterleavedContentItem (and hence can be multimodal). The behavior
depends on the model and provider. Some models may only support text.
text_truncation:
type: string
enum:
- none
- start
- end
description: >-
(Optional) Config for how to truncate text for embedding when text is
longer than the model's max sequence length.
output_dimension:
type: integer
description: >-
(Optional) Output dimensionality for the embeddings. Only supported by
Matryoshka models.
task_type:
type: string
enum:
- query
- document
description: >-
(Optional) How is the embedding being used? This is only supported by
asymmetric embedding models.
additionalProperties: false
required:
- model_id
- contents
title: EmbeddingsRequest
EmbeddingsResponse:
type: object
properties:
embeddings:
type: array
items:
type: array
items:
type: number
description: >-
List of embedding vectors, one per input content. Each embedding is a
list of floats. The dimensionality of the embedding is model-specific;
you can check model metadata using /models/{model_id}
additionalProperties: false
required:
- embeddings
title: EmbeddingsResponse
description: >-
Response containing generated embeddings.
AgentCandidate: AgentCandidate:
type: object type: object
properties: properties:

View file

@ -17,7 +17,7 @@ from typing import (
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.telemetry import MetricResponseMixin from llama_stack.apis.telemetry import MetricResponseMixin
@ -1135,26 +1135,6 @@ class InferenceProvider(Protocol):
raise NotImplementedError("Batch chat completion is not implemented") raise NotImplementedError("Batch chat completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/inference/embeddings", method="POST")
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
"""Generate embeddings for content pieces using the specified model.
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
"""
...
@webmethod(route="/inference/rerank", method="POST", experimental=True) @webmethod(route="/inference/rerank", method="POST", experimental=True)
async def rerank( async def rerank(
self, self,

View file

@ -16,7 +16,6 @@ from pydantic import Field, TypeAdapter
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem,
) )
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -28,8 +27,6 @@ from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
ListOpenAIChatCompletionResponse, ListOpenAIChatCompletionResponse,
LogProbConfig, LogProbConfig,
@ -50,7 +47,6 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
StopReason, StopReason,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -347,25 +343,6 @@ class InferenceRouter(Inference):
provider = await self.routing_table.get_provider_impl(model_id) provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs) return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
logger.debug(f"InferenceRouter.embeddings: {model_id}")
await self._get_model(model_id, ModelType.embedding)
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.embeddings(
model_id=model_id,
contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
)
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,

View file

@ -11,21 +11,17 @@ from botocore.client import BaseClient
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -47,8 +43,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
) )
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
@ -176,31 +170,6 @@ class BedrockInferenceAdapter(
), ),
} }
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embeddings = []
for content in contents:
assert not content_has_media(content), "Bedrock does not support media for embeddings"
input_text = interleaved_content_as_str(content)
input_body = {"inputText": input_text}
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=model.provider_resource_id,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embeddings.append(response_body.get("embedding"))
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -10,21 +10,17 @@ from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -187,16 +183,6 @@ class CerebrasInferenceAdapter(
**get_sampling_options(request.sampling_params), **get_sampling_options(request.sampling_params),
} }
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -10,20 +10,16 @@ from openai import OpenAI
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -147,16 +143,6 @@ class DatabricksInferenceAdapter(
**get_sampling_options(request.sampling_params), **get_sampling_options(request.sampling_params),
} }
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -12,15 +12,12 @@ from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
@ -33,7 +30,6 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -57,8 +53,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
request_has_media, request_has_media,
) )
@ -261,31 +255,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
return params return params
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
kwargs = {}
if model.metadata.get("embedding_dimension"):
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
assert all(not content_has_media(content) for content in contents), (
"Fireworks does not support media for embeddings"
)
response = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -11,8 +11,6 @@ from openai import NOT_GIVEN, APIConnectionError
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem,
TextContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -21,8 +19,6 @@ from llama_stack.apis.inference import (
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
@ -31,7 +27,6 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
) )
@ -155,60 +150,6 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
# we pass n=1 to get only one completion # we pass n=1 to get only one completion
return convert_openai_completion_choice(response.choices[0]) return convert_openai_completion_choice(response.choices[0])
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
if any(content_has_media(content) for content in contents):
raise NotImplementedError("Media is not supported")
#
# Llama Stack: contents = list[str] | list[InterleavedContentItem]
# ->
# OpenAI: input = str | list[str]
#
# we can ignore str and always pass list[str] to OpenAI
#
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
provider_model_id = await self._get_provider_model_id(model_id)
extra_body = {}
if text_truncation is not None:
text_truncation_options = {
TextTruncation.none: "NONE",
TextTruncation.end: "END",
TextTruncation.start: "START",
}
extra_body["truncate"] = text_truncation_options[text_truncation]
if output_dimension is not None:
extra_body["dimensions"] = output_dimension
if task_type is not None:
task_type_options = {
EmbeddingTaskType.document: "passage",
EmbeddingTaskType.query: "query",
}
extra_body["input_type"] = task_type_options[task_type]
response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
)
#
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
# ->
# Llama Stack: EmbeddingsResponse(embeddings=list[list[float]])
#
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -17,7 +17,6 @@ from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
InterleavedContent, InterleavedContent,
InterleavedContentItem,
TextContentItem, TextContentItem,
) )
from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.common.errors import UnsupportedModelError
@ -28,8 +27,6 @@ from llama_stack.apis.inference import (
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat, GrammarResponseFormat,
InferenceProvider, InferenceProvider,
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
@ -44,7 +41,6 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -76,9 +72,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media,
convert_image_content_to_url, convert_image_content_to_url,
interleaved_content_as_str,
localize_image_content, localize_image_content,
request_has_media, request_has_media,
) )
@ -394,27 +388,6 @@ class OllamaInferenceAdapter(
async for chunk in process_chat_completion_stream_response(stream, request): async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk yield chunk
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self._get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings"
)
response = await self.client.embed(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)
embeddings = response["embeddings"]
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
try: try:
model = await self.register_helper.register_model(model) model = await self.register_helper.register_model(model)

View file

@ -14,8 +14,6 @@ from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionMessage, CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
@ -27,7 +25,6 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -190,25 +187,6 @@ class PassthroughInferenceAdapter(Inference):
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk) chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
yield chunk yield chunk
async def embeddings(
self,
model_id: str,
contents: list[InterleavedContent],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
client = self._get_client()
model = await self.model_store.get_model(model_id)
return await client.inference.embeddings(
model_id=model.provider_resource_id,
contents=contents,
text_truncation=text_truncation,
output_dimension=output_dimension,
task_type=task_type,
)
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -136,16 +136,6 @@ class RunpodInferenceAdapter(
**get_sampling_options(request.sampling_params), **get_sampling_options(request.sampling_params),
} }
async def embeddings(
self,
model: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -11,14 +11,11 @@ from huggingface_hub import AsyncInferenceClient, HfApi
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
@ -26,7 +23,6 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -282,16 +278,6 @@ class _HfAdapter(
**self._build_options(request.sampling_params, request.response_format), **self._build_options(request.sampling_params, request.response_format),
) )
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -12,14 +12,11 @@ from together import AsyncTogether
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
@ -32,7 +29,6 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -53,8 +49,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
request_has_media, request_has_media,
) )
@ -235,26 +229,6 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
logger.debug(f"params to together: {params}") logger.debug(f"params to together: {params}")
return params return params
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
assert all(not content_has_media(content) for content in contents), (
"Together does not support media for embeddings"
)
client = self._get_client()
r = await client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -15,7 +15,6 @@ from openai.types.chat.chat_completion_chunk import (
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem,
TextDelta, TextDelta,
ToolCallDelta, ToolCallDelta,
ToolCallParseStatus, ToolCallParseStatus,
@ -30,8 +29,6 @@ from llama_stack.apis.inference import (
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
GrammarResponseFormat, GrammarResponseFormat,
Inference, Inference,
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
@ -47,7 +44,6 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -78,8 +74,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt, completion_request_to_prompt,
content_has_media,
interleaved_content_as_str,
request_has_media, request_has_media,
) )
@ -535,32 +529,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
**options, **options,
} }
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
self._lazy_initialize_client()
assert self.client is not None
model = await self._get_model(model_id)
kwargs = {}
assert model.model_type == ModelType.embedding
assert model.metadata.get("embedding_dimension")
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
response = await self.client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -11,13 +11,11 @@ from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
CompletionRequest, CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
GreedySamplingStrategy, GreedySamplingStrategy,
Inference, Inference,
LogProbConfig, LogProbConfig,
@ -30,7 +28,6 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -249,16 +246,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
} }
return params return params
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
raise NotImplementedError("embedding is not supported for watsonx")
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -14,16 +14,11 @@ if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
EmbeddingsResponse,
EmbeddingTaskType,
InterleavedContentItem,
ModelStore, ModelStore,
OpenAIEmbeddingData, OpenAIEmbeddingData,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
TextTruncation,
) )
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
EMBEDDING_MODELS = {} EMBEDDING_MODELS = {}
@ -34,21 +29,6 @@ log = get_logger(name=__name__, category="providers::utils")
class SentenceTransformerEmbeddingMixin: class SentenceTransformerEmbeddingMixin:
model_store: ModelStore model_store: ModelStore
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = embedding_model.encode(
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
)
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -11,14 +11,11 @@ import litellm
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
InferenceProvider, InferenceProvider,
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
@ -32,7 +29,6 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolDefinition, ToolDefinition,
@ -50,9 +46,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options, get_sampling_options,
prepare_openai_completion_params, prepare_openai_completion_params,
) )
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
logger = get_logger(name=__name__, category="providers::utils") logger = get_logger(name=__name__, category="providers::utils")
@ -269,24 +262,6 @@ class LiteLLMOpenAIMixin(
) )
return api_key return api_key
async def embeddings(
self,
model_id: str,
contents: list[str] | list[InterleavedContentItem],
text_truncation: TextTruncation | None = TextTruncation.none,
output_dimension: int | None = None,
task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
response = litellm.embedding(
model=self.get_litellm_model_name(model.provider_resource_id),
input=[interleaved_content_as_str(content) for content in contents],
)
embeddings = [data["embedding"] for data in response["data"]]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings( async def openai_embeddings(
self, self,
model: str, model: str,

View file

@ -1,303 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
#
# Test plan:
#
# Types of input:
# - array of a string
# - array of a image (ImageContentItem, either URL or base64 string)
# - array of a text (TextContentItem)
# Types of output:
# - list of list of floats
# Params:
# - text_truncation
# - absent w/ long text -> error
# - none w/ long text -> error
# - absent w/ short text -> ok
# - none w/ short text -> ok
# - end w/ long text -> ok
# - end w/ short text -> ok
# - start w/ long text -> ok
# - start w/ short text -> ok
# - output_dimension
# - response dimension matches
# - task_type, only for asymmetric models
# - query embedding != passage embedding
# Negative:
# - long string
# - long text
#
# Todo:
# - negative tests
# - empty
# - empty list
# - empty string
# - empty text
# - empty image
# - long
# - large image
# - appropriate combinations
# - batch size
# - many inputs
# - invalid
# - invalid URL
# - invalid base64
#
# Notes:
# - use llama_stack_client fixture
# - use pytest.mark.parametrize when possible
# - no accuracy tests: only check the type of output, not the content
#
import pytest
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types import EmbeddingsResponse
from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem,
ImageContentItemImage,
ImageContentItemImageURL,
TextContentItem,
)
from openai import BadRequestError as OpenAIBadRequestError
from llama_stack.core.library_client import LlamaStackAsLibraryClient
DUMMY_STRING = "hello"
DUMMY_STRING2 = "world"
DUMMY_LONG_STRING = "NVDA " * 10240
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
DUMMY_LONG_TEXT = TextContentItem(text=DUMMY_LONG_STRING, type="text")
# TODO(mf): add a real image URL and base64 string
DUMMY_IMAGE_URL = ImageContentItem(
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
)
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
SUPPORTED_PROVIDERS = {"remote::nvidia"}
MODELS_SUPPORTING_MEDIA = {}
MODELS_SUPPORTING_OUTPUT_DIMENSION = {"nvidia/llama-3.2-nv-embedqa-1b-v2"}
MODELS_REQUIRING_TASK_TYPE = {
"nvidia/llama-3.2-nv-embedqa-1b-v2",
"nvidia/nv-embedqa-e5-v5",
"nvidia/nv-embedqa-mistral-7b-v2",
"snowflake/arctic-embed-l",
}
MODELS_SUPPORTING_TASK_TYPE = MODELS_REQUIRING_TASK_TYPE
def default_task_type(model_id):
"""
Some models require a task type parameter. This provides a default value for
testing those models.
"""
if model_id in MODELS_REQUIRING_TASK_TYPE:
return {"task_type": "query"}
return {}
@pytest.mark.parametrize(
"contents",
[
[DUMMY_STRING, DUMMY_STRING2],
[DUMMY_TEXT, DUMMY_TEXT2],
],
ids=[
"list[string]",
"list[text]",
],
)
def test_embedding_text(llama_stack_client, embedding_model_id, contents, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0][0], float)
@pytest.mark.parametrize(
"contents",
[
[DUMMY_IMAGE_URL, DUMMY_IMAGE_BASE64],
[DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT],
],
ids=[
"list[url,base64]",
"list[url,string,base64,text]",
],
)
def test_embedding_image(llama_stack_client, embedding_model_id, contents, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
if embedding_model_id not in MODELS_SUPPORTING_MEDIA:
pytest.xfail(f"{embedding_model_id} doesn't support media")
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0][0], float)
@pytest.mark.parametrize(
"text_truncation",
[
"end",
"start",
],
)
@pytest.mark.parametrize(
"contents",
[
[DUMMY_LONG_TEXT],
[DUMMY_STRING],
],
ids=[
"long",
"short",
],
)
def test_embedding_truncation(
llama_stack_client, embedding_model_id, text_truncation, contents, inference_provider_type
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id,
contents=contents,
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1
assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0][0], float)
@pytest.mark.parametrize(
"text_truncation",
[
None,
"none",
],
)
@pytest.mark.parametrize(
"contents",
[
[DUMMY_LONG_TEXT],
[DUMMY_LONG_STRING],
],
ids=[
"long-text",
"long-str",
],
)
def test_embedding_truncation_error(
llama_stack_client, embedding_model_id, text_truncation, contents, inference_provider_type
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
# Using LlamaStackClient from llama_stack_client will raise llama_stack_client.BadRequestError
# While using LlamaStackAsLibraryClient from llama_stack.distribution.library_client will raise the error that the backend raises
error_type = (
OpenAIBadRequestError
if isinstance(llama_stack_client, LlamaStackAsLibraryClient)
else LlamaStackBadRequestError
)
with pytest.raises(error_type):
llama_stack_client.inference.embeddings(
model_id=embedding_model_id,
contents=[DUMMY_LONG_TEXT],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
)
def test_embedding_output_dimension(llama_stack_client, embedding_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
if embedding_model_id not in MODELS_SUPPORTING_OUTPUT_DIMENSION:
pytest.xfail(f"{embedding_model_id} doesn't support output_dimension")
base_response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], **default_task_type(embedding_model_id)
)
test_response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id,
contents=[DUMMY_STRING],
**default_task_type(embedding_model_id),
output_dimension=32,
)
assert len(base_response.embeddings[0]) != len(test_response.embeddings[0])
assert len(test_response.embeddings[0]) == 32
def test_embedding_task_type(llama_stack_client, embedding_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
if embedding_model_id not in MODELS_SUPPORTING_TASK_TYPE:
pytest.xfail(f"{embedding_model_id} doesn't support task_type")
query_embedding = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query"
)
document_embedding = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="document"
)
assert query_embedding.embeddings != document_embedding.embeddings
@pytest.mark.parametrize(
"text_truncation",
[
None,
"none",
"end",
"start",
],
)
def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_truncation, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id,
contents=[DUMMY_STRING],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1
assert isinstance(response.embeddings[0], list)
assert isinstance(response.embeddings[0][0], float)
@pytest.mark.parametrize(
"text_truncation",
[
"NONE",
"END",
"START",
"left",
"right",
],
)
def test_embedding_text_truncation_error(
llama_stack_client, embedding_model_id, text_truncation, inference_provider_type
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
error_type = ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
with pytest.raises(error_type):
llama_stack_client.inference.embeddings(
model_id=embedding_model_id,
contents=[DUMMY_STRING],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
)

View file

@ -5,13 +5,12 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import EmbeddingsResponse, Inference
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.datatypes import HealthStatus from llama_stack.providers.datatypes import HealthStatus
@ -70,13 +69,6 @@ def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
return mock_vector_db return mock_vector_db
@pytest.fixture
def mock_inference_api(sample_embeddings):
mock_api = MagicMock(spec=Inference)
mock_api.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings))
return mock_api
@pytest.fixture @pytest.fixture
def mock_files_api(): def mock_files_api():
mock_api = MagicMock(spec=Files) mock_api = MagicMock(spec=Files)
@ -96,22 +88,6 @@ async def faiss_index(embedding_dimension):
yield index yield index
@pytest.fixture
async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter:
# Create the adapter
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api)
# Create a mock KVStore
mock_kvstore = MagicMock()
mock_kvstore.values_in_range = AsyncMock(return_value=[])
# Patch the initialize method to avoid the kvstore_impl call
with patch.object(FaissVectorIOAdapter, "initialize"):
# Set the kvstore directly
adapter.kvstore = mock_kvstore
yield adapter
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical( async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
faiss_index, sample_chunks, sample_embeddings, embedding_dimension faiss_index, sample_chunks, sample_embeddings, embedding_dimension
): ):