diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index 7cb2a73f3..dcf3812e0 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -901,49 +901,6 @@
]
}
},
- "/v1/inference/embeddings": {
- "post": {
- "responses": {
- "200": {
- "description": "An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.",
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/EmbeddingsResponse"
- }
- }
- }
- },
- "400": {
- "$ref": "#/components/responses/BadRequest400"
- },
- "429": {
- "$ref": "#/components/responses/TooManyRequests429"
- },
- "500": {
- "$ref": "#/components/responses/InternalServerError500"
- },
- "default": {
- "$ref": "#/components/responses/DefaultError"
- }
- },
- "tags": [
- "Inference"
- ],
- "description": "Generate embeddings for content pieces using the specified model.",
- "parameters": [],
- "requestBody": {
- "content": {
- "application/json": {
- "schema": {
- "$ref": "#/components/schemas/EmbeddingsRequest"
- }
- }
- },
- "required": true
- }
- }
- },
"/v1/eval/benchmarks/{benchmark_id}/evaluations": {
"post": {
"responses": {
@@ -9698,80 +9655,6 @@
"title": "OpenAIDeleteResponseObject",
"description": "Response object confirming deletion of an OpenAI response."
},
- "EmbeddingsRequest": {
- "type": "object",
- "properties": {
- "model_id": {
- "type": "string",
- "description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
- },
- "contents": {
- "oneOf": [
- {
- "type": "array",
- "items": {
- "type": "string"
- }
- },
- {
- "type": "array",
- "items": {
- "$ref": "#/components/schemas/InterleavedContentItem"
- }
- }
- ],
- "description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
- },
- "text_truncation": {
- "type": "string",
- "enum": [
- "none",
- "start",
- "end"
- ],
- "description": "(Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length."
- },
- "output_dimension": {
- "type": "integer",
- "description": "(Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models."
- },
- "task_type": {
- "type": "string",
- "enum": [
- "query",
- "document"
- ],
- "description": "(Optional) How is the embedding being used? This is only supported by asymmetric embedding models."
- }
- },
- "additionalProperties": false,
- "required": [
- "model_id",
- "contents"
- ],
- "title": "EmbeddingsRequest"
- },
- "EmbeddingsResponse": {
- "type": "object",
- "properties": {
- "embeddings": {
- "type": "array",
- "items": {
- "type": "array",
- "items": {
- "type": "number"
- }
- },
- "description": "List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}"
- }
- },
- "additionalProperties": false,
- "required": [
- "embeddings"
- ],
- "title": "EmbeddingsResponse",
- "description": "Response containing generated embeddings."
- },
"AgentCandidate": {
"type": "object",
"properties": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 25089868c..473c9de45 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -616,39 +616,6 @@ paths:
required: true
schema:
type: string
- /v1/inference/embeddings:
- post:
- responses:
- '200':
- description: >-
- An array of embeddings, one for each content. Each embedding is a list
- of floats. The dimensionality of the embedding is model-specific; you
- can check model metadata using /models/{model_id}.
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/EmbeddingsResponse'
- '400':
- $ref: '#/components/responses/BadRequest400'
- '429':
- $ref: >-
- #/components/responses/TooManyRequests429
- '500':
- $ref: >-
- #/components/responses/InternalServerError500
- default:
- $ref: '#/components/responses/DefaultError'
- tags:
- - Inference
- description: >-
- Generate embeddings for content pieces using the specified model.
- parameters: []
- requestBody:
- content:
- application/json:
- schema:
- $ref: '#/components/schemas/EmbeddingsRequest'
- required: true
/v1/eval/benchmarks/{benchmark_id}/evaluations:
post:
responses:
@@ -7173,72 +7140,6 @@ components:
title: OpenAIDeleteResponseObject
description: >-
Response object confirming deletion of an OpenAI response.
- EmbeddingsRequest:
- type: object
- properties:
- model_id:
- type: string
- description: >-
- The identifier of the model to use. The model must be an embedding model
- registered with Llama Stack and available via the /models endpoint.
- contents:
- oneOf:
- - type: array
- items:
- type: string
- - type: array
- items:
- $ref: '#/components/schemas/InterleavedContentItem'
- description: >-
- List of contents to generate embeddings for. Each content can be a string
- or an InterleavedContentItem (and hence can be multimodal). The behavior
- depends on the model and provider. Some models may only support text.
- text_truncation:
- type: string
- enum:
- - none
- - start
- - end
- description: >-
- (Optional) Config for how to truncate text for embedding when text is
- longer than the model's max sequence length.
- output_dimension:
- type: integer
- description: >-
- (Optional) Output dimensionality for the embeddings. Only supported by
- Matryoshka models.
- task_type:
- type: string
- enum:
- - query
- - document
- description: >-
- (Optional) How is the embedding being used? This is only supported by
- asymmetric embedding models.
- additionalProperties: false
- required:
- - model_id
- - contents
- title: EmbeddingsRequest
- EmbeddingsResponse:
- type: object
- properties:
- embeddings:
- type: array
- items:
- type: array
- items:
- type: number
- description: >-
- List of embedding vectors, one per input content. Each embedding is a
- list of floats. The dimensionality of the embedding is model-specific;
- you can check model metadata using /models/{model_id}
- additionalProperties: false
- required:
- - embeddings
- title: EmbeddingsResponse
- description: >-
- Response containing generated embeddings.
AgentCandidate:
type: object
properties:
diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py
index bd4737ca7..513862cd4 100644
--- a/llama_stack/apis/inference/inference.py
+++ b/llama_stack/apis/inference/inference.py
@@ -17,7 +17,7 @@ from typing import (
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
-from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
+from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import Order
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry import MetricResponseMixin
@@ -1135,26 +1135,6 @@ class InferenceProvider(Protocol):
raise NotImplementedError("Batch chat completion is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
- @webmethod(route="/inference/embeddings", method="POST")
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- """Generate embeddings for content pieces using the specified model.
-
- :param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
- :param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
- :param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
- :param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
- :param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
- :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
- """
- ...
-
@webmethod(route="/inference/rerank", method="POST", experimental=True)
async def rerank(
self,
diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py
index 4b66601bb..2ee49a027 100644
--- a/llama_stack/core/routers/inference.py
+++ b/llama_stack/core/routers/inference.py
@@ -16,7 +16,6 @@ from pydantic import Field, TypeAdapter
from llama_stack.apis.common.content_types import (
InterleavedContent,
- InterleavedContentItem,
)
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
from llama_stack.apis.inference import (
@@ -28,8 +27,6 @@ from llama_stack.apis.inference import (
CompletionMessage,
CompletionResponse,
CompletionResponseStreamChunk,
- EmbeddingsResponse,
- EmbeddingTaskType,
Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig,
@@ -50,7 +47,6 @@ from llama_stack.apis.inference import (
ResponseFormat,
SamplingParams,
StopReason,
- TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -347,25 +343,6 @@ class InferenceRouter(Inference):
provider = await self.routing_table.get_provider_impl(model_id)
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- logger.debug(f"InferenceRouter.embeddings: {model_id}")
- await self._get_model(model_id, ModelType.embedding)
- provider = await self.routing_table.get_provider_impl(model_id)
- return await provider.embeddings(
- model_id=model_id,
- contents=contents,
- text_truncation=text_truncation,
- output_dimension=output_dimension,
- task_type=task_type,
- )
-
async def openai_completion(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py
index 63ea196f6..6947941f9 100644
--- a/llama_stack/providers/remote/inference/bedrock/bedrock.py
+++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py
@@ -11,21 +11,17 @@ from botocore.client import BaseClient
from llama_stack.apis.common.content_types import (
InterleavedContent,
- InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
- EmbeddingsResponse,
- EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
- TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -47,8 +43,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
- content_has_media,
- interleaved_content_as_str,
)
from .models import MODEL_ENTRIES
@@ -176,31 +170,6 @@ class BedrockInferenceAdapter(
),
}
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- model = await self.model_store.get_model(model_id)
- embeddings = []
- for content in contents:
- assert not content_has_media(content), "Bedrock does not support media for embeddings"
- input_text = interleaved_content_as_str(content)
- input_body = {"inputText": input_text}
- body = json.dumps(input_body)
- response = self.client.invoke_model(
- body=body,
- modelId=model.provider_resource_id,
- accept="application/json",
- contentType="application/json",
- )
- response_body = json.loads(response.get("body").read())
- embeddings.append(response_body.get("embedding"))
- return EmbeddingsResponse(embeddings=embeddings)
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py
index 5e07c49ee..e22bfc444 100644
--- a/llama_stack/providers/remote/inference/cerebras/cerebras.py
+++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py
@@ -10,21 +10,17 @@ from cerebras.cloud.sdk import AsyncCerebras
from llama_stack.apis.common.content_types import (
InterleavedContent,
- InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionRequest,
CompletionResponse,
- EmbeddingsResponse,
- EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
- TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -187,16 +183,6 @@ class CerebrasInferenceAdapter(
**get_sampling_options(request.sampling_params),
}
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- raise NotImplementedError()
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py
index 34ee59212..57acae293 100644
--- a/llama_stack/providers/remote/inference/databricks/databricks.py
+++ b/llama_stack/providers/remote/inference/databricks/databricks.py
@@ -10,20 +10,16 @@ from openai import OpenAI
from llama_stack.apis.common.content_types import (
InterleavedContent,
- InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
- EmbeddingsResponse,
- EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
OpenAIEmbeddingsResponse,
ResponseFormat,
SamplingParams,
- TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -147,16 +143,6 @@ class DatabricksInferenceAdapter(
**get_sampling_options(request.sampling_params),
}
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- raise NotImplementedError()
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py
index e907e8ec6..7d00fd337 100644
--- a/llama_stack/providers/remote/inference/fireworks/fireworks.py
+++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py
@@ -12,15 +12,12 @@ from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import (
InterleavedContent,
- InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
- EmbeddingsResponse,
- EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
@@ -33,7 +30,6 @@ from llama_stack.apis.inference import (
ResponseFormat,
ResponseFormatType,
SamplingParams,
- TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -57,8 +53,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
- content_has_media,
- interleaved_content_as_str,
request_has_media,
)
@@ -261,31 +255,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
return params
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- model = await self.model_store.get_model(model_id)
-
- kwargs = {}
- if model.metadata.get("embedding_dimension"):
- kwargs["dimensions"] = model.metadata.get("embedding_dimension")
- assert all(not content_has_media(content) for content in contents), (
- "Fireworks does not support media for embeddings"
- )
- response = self._get_client().embeddings.create(
- model=model.provider_resource_id,
- input=[interleaved_content_as_str(content) for content in contents],
- **kwargs,
- )
-
- embeddings = [data.embedding for data in response.data]
- return EmbeddingsResponse(embeddings=embeddings)
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py
index a5475bc92..c83c0afd2 100644
--- a/llama_stack/providers/remote/inference/nvidia/nvidia.py
+++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py
@@ -11,8 +11,6 @@ from openai import NOT_GIVEN, APIConnectionError
from llama_stack.apis.common.content_types import (
InterleavedContent,
- InterleavedContentItem,
- TextContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
@@ -21,8 +19,6 @@ from llama_stack.apis.inference import (
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
- EmbeddingsResponse,
- EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
@@ -31,7 +27,6 @@ from llama_stack.apis.inference import (
OpenAIEmbeddingUsage,
ResponseFormat,
SamplingParams,
- TextTruncation,
ToolChoice,
ToolConfig,
)
@@ -155,60 +150,6 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
# we pass n=1 to get only one completion
return convert_openai_completion_choice(response.choices[0])
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- if any(content_has_media(content) for content in contents):
- raise NotImplementedError("Media is not supported")
-
- #
- # Llama Stack: contents = list[str] | list[InterleavedContentItem]
- # ->
- # OpenAI: input = str | list[str]
- #
- # we can ignore str and always pass list[str] to OpenAI
- #
- flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
- input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
- provider_model_id = await self._get_provider_model_id(model_id)
-
- extra_body = {}
-
- if text_truncation is not None:
- text_truncation_options = {
- TextTruncation.none: "NONE",
- TextTruncation.end: "END",
- TextTruncation.start: "START",
- }
- extra_body["truncate"] = text_truncation_options[text_truncation]
-
- if output_dimension is not None:
- extra_body["dimensions"] = output_dimension
-
- if task_type is not None:
- task_type_options = {
- EmbeddingTaskType.document: "passage",
- EmbeddingTaskType.query: "query",
- }
- extra_body["input_type"] = task_type_options[task_type]
-
- response = await self.client.embeddings.create(
- model=provider_model_id,
- input=input,
- extra_body=extra_body,
- )
- #
- # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
- # ->
- # Llama Stack: EmbeddingsResponse(embeddings=list[list[float]])
- #
- return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py
index fcaf5ee92..187cbc758 100644
--- a/llama_stack/providers/remote/inference/ollama/ollama.py
+++ b/llama_stack/providers/remote/inference/ollama/ollama.py
@@ -17,7 +17,6 @@ from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import (
ImageContentItem,
InterleavedContent,
- InterleavedContentItem,
TextContentItem,
)
from llama_stack.apis.common.errors import UnsupportedModelError
@@ -28,8 +27,6 @@ from llama_stack.apis.inference import (
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
- EmbeddingsResponse,
- EmbeddingTaskType,
GrammarResponseFormat,
InferenceProvider,
JsonSchemaResponseFormat,
@@ -44,7 +41,6 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
- TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -76,9 +72,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
- content_has_media,
convert_image_content_to_url,
- interleaved_content_as_str,
localize_image_content,
request_has_media,
)
@@ -394,27 +388,6 @@ class OllamaInferenceAdapter(
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- model = await self._get_model(model_id)
-
- assert all(not content_has_media(content) for content in contents), (
- "Ollama does not support media for embeddings"
- )
- response = await self.client.embed(
- model=model.provider_resource_id,
- input=[interleaved_content_as_str(content) for content in contents],
- )
- embeddings = response["embeddings"]
-
- return EmbeddingsResponse(embeddings=embeddings)
-
async def register_model(self, model: Model) -> Model:
try:
model = await self.register_helper.register_model(model)
diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py
index 2f1cd40f2..7b32cc948 100644
--- a/llama_stack/providers/remote/inference/passthrough/passthrough.py
+++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py
@@ -14,8 +14,6 @@ from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionMessage,
- EmbeddingsResponse,
- EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
@@ -27,7 +25,6 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
- TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -190,25 +187,6 @@ class PassthroughInferenceAdapter(Inference):
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
yield chunk
- async def embeddings(
- self,
- model_id: str,
- contents: list[InterleavedContent],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- client = self._get_client()
- model = await self.model_store.get_model(model_id)
-
- return await client.inference.embeddings(
- model_id=model.provider_resource_id,
- contents=contents,
- text_truncation=text_truncation,
- output_dimension=output_dimension,
- task_type=task_type,
- )
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py
index ff2fe6401..82252b04d 100644
--- a/llama_stack/providers/remote/inference/runpod/runpod.py
+++ b/llama_stack/providers/remote/inference/runpod/runpod.py
@@ -136,16 +136,6 @@ class RunpodInferenceAdapter(
**get_sampling_options(request.sampling_params),
}
- async def embeddings(
- self,
- model: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- raise NotImplementedError()
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py
index 97c72d14c..430353440 100644
--- a/llama_stack/providers/remote/inference/tgi/tgi.py
+++ b/llama_stack/providers/remote/inference/tgi/tgi.py
@@ -11,14 +11,11 @@ from huggingface_hub import AsyncInferenceClient, HfApi
from llama_stack.apis.common.content_types import (
InterleavedContent,
- InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
- EmbeddingsResponse,
- EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
@@ -26,7 +23,6 @@ from llama_stack.apis.inference import (
ResponseFormat,
ResponseFormatType,
SamplingParams,
- TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -282,16 +278,6 @@ class _HfAdapter(
**self._build_options(request.sampling_params, request.response_format),
)
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- raise NotImplementedError()
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py
index 54c76607f..a64a25725 100644
--- a/llama_stack/providers/remote/inference/together/together.py
+++ b/llama_stack/providers/remote/inference/together/together.py
@@ -12,14 +12,11 @@ from together import AsyncTogether
from llama_stack.apis.common.content_types import (
InterleavedContent,
- InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
- EmbeddingsResponse,
- EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
@@ -32,7 +29,6 @@ from llama_stack.apis.inference import (
ResponseFormat,
ResponseFormatType,
SamplingParams,
- TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -53,8 +49,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
- content_has_media,
- interleaved_content_as_str,
request_has_media,
)
@@ -235,26 +229,6 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
logger.debug(f"params to together: {params}")
return params
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- model = await self.model_store.get_model(model_id)
- assert all(not content_has_media(content) for content in contents), (
- "Together does not support media for embeddings"
- )
- client = self._get_client()
- r = await client.embeddings.create(
- model=model.provider_resource_id,
- input=[interleaved_content_as_str(content) for content in contents],
- )
- embeddings = [item.embedding for item in r.data]
- return EmbeddingsResponse(embeddings=embeddings)
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index 9e9a80ca5..868f7ce0f 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -15,7 +15,6 @@ from openai.types.chat.chat_completion_chunk import (
from llama_stack.apis.common.content_types import (
InterleavedContent,
- InterleavedContentItem,
TextDelta,
ToolCallDelta,
ToolCallParseStatus,
@@ -30,8 +29,6 @@ from llama_stack.apis.inference import (
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
- EmbeddingsResponse,
- EmbeddingTaskType,
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
@@ -47,7 +44,6 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
- TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -78,8 +74,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
)
from llama_stack.providers.utils.inference.prompt_adapter import (
completion_request_to_prompt,
- content_has_media,
- interleaved_content_as_str,
request_has_media,
)
@@ -535,32 +529,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
**options,
}
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- self._lazy_initialize_client()
- assert self.client is not None
- model = await self._get_model(model_id)
-
- kwargs = {}
- assert model.model_type == ModelType.embedding
- assert model.metadata.get("embedding_dimension")
- kwargs["dimensions"] = model.metadata.get("embedding_dimension")
- assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
- response = await self.client.embeddings.create(
- model=model.provider_resource_id,
- input=[interleaved_content_as_str(content) for content in contents],
- **kwargs,
- )
-
- embeddings = [data.embedding for data in response.data]
- return EmbeddingsResponse(embeddings=embeddings)
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py
index 78161d1cb..522e504fb 100644
--- a/llama_stack/providers/remote/inference/watsonx/watsonx.py
+++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py
@@ -11,13 +11,11 @@ from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
-from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
+from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
- EmbeddingsResponse,
- EmbeddingTaskType,
GreedySamplingStrategy,
Inference,
LogProbConfig,
@@ -30,7 +28,6 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
- TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -249,16 +246,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
}
return params
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- raise NotImplementedError("embedding is not supported for watsonx")
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py
index 65ba2854b..d1b2be332 100644
--- a/llama_stack/providers/utils/inference/embedding_mixin.py
+++ b/llama_stack/providers/utils/inference/embedding_mixin.py
@@ -14,16 +14,11 @@ if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
from llama_stack.apis.inference import (
- EmbeddingsResponse,
- EmbeddingTaskType,
- InterleavedContentItem,
ModelStore,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
- TextTruncation,
)
-from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
EMBEDDING_MODELS = {}
@@ -34,21 +29,6 @@ log = get_logger(name=__name__, category="providers::utils")
class SentenceTransformerEmbeddingMixin:
model_store: ModelStore
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- model = await self.model_store.get_model(model_id)
- embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
- embeddings = embedding_model.encode(
- [interleaved_content_as_str(content) for content in contents], show_progress_bar=False
- )
- return EmbeddingsResponse(embeddings=embeddings)
-
async def openai_embeddings(
self,
model: str,
diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
index 9bd43e4c9..26a1bc3a2 100644
--- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py
+++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py
@@ -11,14 +11,11 @@ import litellm
from llama_stack.apis.common.content_types import (
InterleavedContent,
- InterleavedContentItem,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
- EmbeddingsResponse,
- EmbeddingTaskType,
InferenceProvider,
JsonSchemaResponseFormat,
LogProbConfig,
@@ -32,7 +29,6 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatParam,
ResponseFormat,
SamplingParams,
- TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
@@ -50,9 +46,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
prepare_openai_completion_params,
)
-from llama_stack.providers.utils.inference.prompt_adapter import (
- interleaved_content_as_str,
-)
logger = get_logger(name=__name__, category="providers::utils")
@@ -269,24 +262,6 @@ class LiteLLMOpenAIMixin(
)
return api_key
- async def embeddings(
- self,
- model_id: str,
- contents: list[str] | list[InterleavedContentItem],
- text_truncation: TextTruncation | None = TextTruncation.none,
- output_dimension: int | None = None,
- task_type: EmbeddingTaskType | None = None,
- ) -> EmbeddingsResponse:
- model = await self.model_store.get_model(model_id)
-
- response = litellm.embedding(
- model=self.get_litellm_model_name(model.provider_resource_id),
- input=[interleaved_content_as_str(content) for content in contents],
- )
-
- embeddings = [data["embedding"] for data in response["data"]]
- return EmbeddingsResponse(embeddings=embeddings)
-
async def openai_embeddings(
self,
model: str,
diff --git a/tests/integration/inference/test_embedding.py b/tests/integration/inference/test_embedding.py
deleted file mode 100644
index e592a6b14..000000000
--- a/tests/integration/inference/test_embedding.py
+++ /dev/null
@@ -1,303 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-
-#
-# Test plan:
-#
-# Types of input:
-# - array of a string
-# - array of a image (ImageContentItem, either URL or base64 string)
-# - array of a text (TextContentItem)
-# Types of output:
-# - list of list of floats
-# Params:
-# - text_truncation
-# - absent w/ long text -> error
-# - none w/ long text -> error
-# - absent w/ short text -> ok
-# - none w/ short text -> ok
-# - end w/ long text -> ok
-# - end w/ short text -> ok
-# - start w/ long text -> ok
-# - start w/ short text -> ok
-# - output_dimension
-# - response dimension matches
-# - task_type, only for asymmetric models
-# - query embedding != passage embedding
-# Negative:
-# - long string
-# - long text
-#
-# Todo:
-# - negative tests
-# - empty
-# - empty list
-# - empty string
-# - empty text
-# - empty image
-# - long
-# - large image
-# - appropriate combinations
-# - batch size
-# - many inputs
-# - invalid
-# - invalid URL
-# - invalid base64
-#
-# Notes:
-# - use llama_stack_client fixture
-# - use pytest.mark.parametrize when possible
-# - no accuracy tests: only check the type of output, not the content
-#
-
-import pytest
-from llama_stack_client import BadRequestError as LlamaStackBadRequestError
-from llama_stack_client.types import EmbeddingsResponse
-from llama_stack_client.types.shared.interleaved_content import (
- ImageContentItem,
- ImageContentItemImage,
- ImageContentItemImageURL,
- TextContentItem,
-)
-from openai import BadRequestError as OpenAIBadRequestError
-
-from llama_stack.core.library_client import LlamaStackAsLibraryClient
-
-DUMMY_STRING = "hello"
-DUMMY_STRING2 = "world"
-DUMMY_LONG_STRING = "NVDA " * 10240
-DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
-DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
-DUMMY_LONG_TEXT = TextContentItem(text=DUMMY_LONG_STRING, type="text")
-# TODO(mf): add a real image URL and base64 string
-DUMMY_IMAGE_URL = ImageContentItem(
- image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
-)
-DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
-SUPPORTED_PROVIDERS = {"remote::nvidia"}
-MODELS_SUPPORTING_MEDIA = {}
-MODELS_SUPPORTING_OUTPUT_DIMENSION = {"nvidia/llama-3.2-nv-embedqa-1b-v2"}
-MODELS_REQUIRING_TASK_TYPE = {
- "nvidia/llama-3.2-nv-embedqa-1b-v2",
- "nvidia/nv-embedqa-e5-v5",
- "nvidia/nv-embedqa-mistral-7b-v2",
- "snowflake/arctic-embed-l",
-}
-MODELS_SUPPORTING_TASK_TYPE = MODELS_REQUIRING_TASK_TYPE
-
-
-def default_task_type(model_id):
- """
- Some models require a task type parameter. This provides a default value for
- testing those models.
- """
- if model_id in MODELS_REQUIRING_TASK_TYPE:
- return {"task_type": "query"}
- return {}
-
-
-@pytest.mark.parametrize(
- "contents",
- [
- [DUMMY_STRING, DUMMY_STRING2],
- [DUMMY_TEXT, DUMMY_TEXT2],
- ],
- ids=[
- "list[string]",
- "list[text]",
- ],
-)
-def test_embedding_text(llama_stack_client, embedding_model_id, contents, inference_provider_type):
- if inference_provider_type not in SUPPORTED_PROVIDERS:
- pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
- response = llama_stack_client.inference.embeddings(
- model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
- )
- assert isinstance(response, EmbeddingsResponse)
- assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
- assert isinstance(response.embeddings[0], list)
- assert isinstance(response.embeddings[0][0], float)
-
-
-@pytest.mark.parametrize(
- "contents",
- [
- [DUMMY_IMAGE_URL, DUMMY_IMAGE_BASE64],
- [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT],
- ],
- ids=[
- "list[url,base64]",
- "list[url,string,base64,text]",
- ],
-)
-def test_embedding_image(llama_stack_client, embedding_model_id, contents, inference_provider_type):
- if inference_provider_type not in SUPPORTED_PROVIDERS:
- pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
- if embedding_model_id not in MODELS_SUPPORTING_MEDIA:
- pytest.xfail(f"{embedding_model_id} doesn't support media")
- response = llama_stack_client.inference.embeddings(
- model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
- )
- assert isinstance(response, EmbeddingsResponse)
- assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
- assert isinstance(response.embeddings[0], list)
- assert isinstance(response.embeddings[0][0], float)
-
-
-@pytest.mark.parametrize(
- "text_truncation",
- [
- "end",
- "start",
- ],
-)
-@pytest.mark.parametrize(
- "contents",
- [
- [DUMMY_LONG_TEXT],
- [DUMMY_STRING],
- ],
- ids=[
- "long",
- "short",
- ],
-)
-def test_embedding_truncation(
- llama_stack_client, embedding_model_id, text_truncation, contents, inference_provider_type
-):
- if inference_provider_type not in SUPPORTED_PROVIDERS:
- pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
- response = llama_stack_client.inference.embeddings(
- model_id=embedding_model_id,
- contents=contents,
- text_truncation=text_truncation,
- **default_task_type(embedding_model_id),
- )
- assert isinstance(response, EmbeddingsResponse)
- assert len(response.embeddings) == 1
- assert isinstance(response.embeddings[0], list)
- assert isinstance(response.embeddings[0][0], float)
-
-
-@pytest.mark.parametrize(
- "text_truncation",
- [
- None,
- "none",
- ],
-)
-@pytest.mark.parametrize(
- "contents",
- [
- [DUMMY_LONG_TEXT],
- [DUMMY_LONG_STRING],
- ],
- ids=[
- "long-text",
- "long-str",
- ],
-)
-def test_embedding_truncation_error(
- llama_stack_client, embedding_model_id, text_truncation, contents, inference_provider_type
-):
- if inference_provider_type not in SUPPORTED_PROVIDERS:
- pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
- # Using LlamaStackClient from llama_stack_client will raise llama_stack_client.BadRequestError
- # While using LlamaStackAsLibraryClient from llama_stack.distribution.library_client will raise the error that the backend raises
- error_type = (
- OpenAIBadRequestError
- if isinstance(llama_stack_client, LlamaStackAsLibraryClient)
- else LlamaStackBadRequestError
- )
- with pytest.raises(error_type):
- llama_stack_client.inference.embeddings(
- model_id=embedding_model_id,
- contents=[DUMMY_LONG_TEXT],
- text_truncation=text_truncation,
- **default_task_type(embedding_model_id),
- )
-
-
-def test_embedding_output_dimension(llama_stack_client, embedding_model_id, inference_provider_type):
- if inference_provider_type not in SUPPORTED_PROVIDERS:
- pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
- if embedding_model_id not in MODELS_SUPPORTING_OUTPUT_DIMENSION:
- pytest.xfail(f"{embedding_model_id} doesn't support output_dimension")
- base_response = llama_stack_client.inference.embeddings(
- model_id=embedding_model_id, contents=[DUMMY_STRING], **default_task_type(embedding_model_id)
- )
- test_response = llama_stack_client.inference.embeddings(
- model_id=embedding_model_id,
- contents=[DUMMY_STRING],
- **default_task_type(embedding_model_id),
- output_dimension=32,
- )
- assert len(base_response.embeddings[0]) != len(test_response.embeddings[0])
- assert len(test_response.embeddings[0]) == 32
-
-
-def test_embedding_task_type(llama_stack_client, embedding_model_id, inference_provider_type):
- if inference_provider_type not in SUPPORTED_PROVIDERS:
- pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
- if embedding_model_id not in MODELS_SUPPORTING_TASK_TYPE:
- pytest.xfail(f"{embedding_model_id} doesn't support task_type")
- query_embedding = llama_stack_client.inference.embeddings(
- model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query"
- )
- document_embedding = llama_stack_client.inference.embeddings(
- model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="document"
- )
- assert query_embedding.embeddings != document_embedding.embeddings
-
-
-@pytest.mark.parametrize(
- "text_truncation",
- [
- None,
- "none",
- "end",
- "start",
- ],
-)
-def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_truncation, inference_provider_type):
- if inference_provider_type not in SUPPORTED_PROVIDERS:
- pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
- response = llama_stack_client.inference.embeddings(
- model_id=embedding_model_id,
- contents=[DUMMY_STRING],
- text_truncation=text_truncation,
- **default_task_type(embedding_model_id),
- )
- assert isinstance(response, EmbeddingsResponse)
- assert len(response.embeddings) == 1
- assert isinstance(response.embeddings[0], list)
- assert isinstance(response.embeddings[0][0], float)
-
-
-@pytest.mark.parametrize(
- "text_truncation",
- [
- "NONE",
- "END",
- "START",
- "left",
- "right",
- ],
-)
-def test_embedding_text_truncation_error(
- llama_stack_client, embedding_model_id, text_truncation, inference_provider_type
-):
- if inference_provider_type not in SUPPORTED_PROVIDERS:
- pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
- error_type = ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
- with pytest.raises(error_type):
- llama_stack_client.inference.embeddings(
- model_id=embedding_model_id,
- contents=[DUMMY_STRING],
- text_truncation=text_truncation,
- **default_task_type(embedding_model_id),
- )
diff --git a/tests/unit/providers/vector_io/test_faiss.py b/tests/unit/providers/vector_io/test_faiss.py
index 90108d7a0..9ee5c82f4 100644
--- a/tests/unit/providers/vector_io/test_faiss.py
+++ b/tests/unit/providers/vector_io/test_faiss.py
@@ -5,13 +5,12 @@
# the root directory of this source tree.
import asyncio
-from unittest.mock import AsyncMock, MagicMock, patch
+from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.files import Files
-from llama_stack.apis.inference import EmbeddingsResponse, Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.datatypes import HealthStatus
@@ -70,13 +69,6 @@ def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
return mock_vector_db
-@pytest.fixture
-def mock_inference_api(sample_embeddings):
- mock_api = MagicMock(spec=Inference)
- mock_api.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings))
- return mock_api
-
-
@pytest.fixture
def mock_files_api():
mock_api = MagicMock(spec=Files)
@@ -96,22 +88,6 @@ async def faiss_index(embedding_dimension):
yield index
-@pytest.fixture
-async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter:
- # Create the adapter
- adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api)
-
- # Create a mock KVStore
- mock_kvstore = MagicMock()
- mock_kvstore.values_in_range = AsyncMock(return_value=[])
-
- # Patch the initialize method to avoid the kvstore_impl call
- with patch.object(FaissVectorIOAdapter, "initialize"):
- # Set the kvstore directly
- adapter.kvstore = mock_kvstore
- yield adapter
-
-
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
faiss_index, sample_chunks, sample_embeddings, embedding_dimension
):