mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
chore(api): remove deprecated embeddings impls
This commit is contained in:
parent
478b4ff1e6
commit
30998fd1ff
20 changed files with 3 additions and 927 deletions
117
docs/_static/llama-stack-spec.html
vendored
117
docs/_static/llama-stack-spec.html
vendored
|
@ -901,49 +901,6 @@
|
|||
]
|
||||
}
|
||||
},
|
||||
"/v1/inference/embeddings": {
|
||||
"post": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/EmbeddingsResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
},
|
||||
"429": {
|
||||
"$ref": "#/components/responses/TooManyRequests429"
|
||||
},
|
||||
"500": {
|
||||
"$ref": "#/components/responses/InternalServerError500"
|
||||
},
|
||||
"default": {
|
||||
"$ref": "#/components/responses/DefaultError"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
"Inference"
|
||||
],
|
||||
"description": "Generate embeddings for content pieces using the specified model.",
|
||||
"parameters": [],
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/EmbeddingsRequest"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/eval/benchmarks/{benchmark_id}/evaluations": {
|
||||
"post": {
|
||||
"responses": {
|
||||
|
@ -9698,80 +9655,6 @@
|
|||
"title": "OpenAIDeleteResponseObject",
|
||||
"description": "Response object confirming deletion of an OpenAI response."
|
||||
},
|
||||
"EmbeddingsRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"model_id": {
|
||||
"type": "string",
|
||||
"description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint."
|
||||
},
|
||||
"contents": {
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/InterleavedContentItem"
|
||||
}
|
||||
}
|
||||
],
|
||||
"description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text."
|
||||
},
|
||||
"text_truncation": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"none",
|
||||
"start",
|
||||
"end"
|
||||
],
|
||||
"description": "(Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length."
|
||||
},
|
||||
"output_dimension": {
|
||||
"type": "integer",
|
||||
"description": "(Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models."
|
||||
},
|
||||
"task_type": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"query",
|
||||
"document"
|
||||
],
|
||||
"description": "(Optional) How is the embedding being used? This is only supported by asymmetric embedding models."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"model_id",
|
||||
"contents"
|
||||
],
|
||||
"title": "EmbeddingsRequest"
|
||||
},
|
||||
"EmbeddingsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"embeddings": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "number"
|
||||
}
|
||||
},
|
||||
"description": "List of embedding vectors, one per input content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"embeddings"
|
||||
],
|
||||
"title": "EmbeddingsResponse",
|
||||
"description": "Response containing generated embeddings."
|
||||
},
|
||||
"AgentCandidate": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
99
docs/_static/llama-stack-spec.yaml
vendored
99
docs/_static/llama-stack-spec.yaml
vendored
|
@ -616,39 +616,6 @@ paths:
|
|||
required: true
|
||||
schema:
|
||||
type: string
|
||||
/v1/inference/embeddings:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
An array of embeddings, one for each content. Each embedding is a list
|
||||
of floats. The dimensionality of the embedding is model-specific; you
|
||||
can check model metadata using /models/{model_id}.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/EmbeddingsResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Inference
|
||||
description: >-
|
||||
Generate embeddings for content pieces using the specified model.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/EmbeddingsRequest'
|
||||
required: true
|
||||
/v1/eval/benchmarks/{benchmark_id}/evaluations:
|
||||
post:
|
||||
responses:
|
||||
|
@ -7173,72 +7140,6 @@ components:
|
|||
title: OpenAIDeleteResponseObject
|
||||
description: >-
|
||||
Response object confirming deletion of an OpenAI response.
|
||||
EmbeddingsRequest:
|
||||
type: object
|
||||
properties:
|
||||
model_id:
|
||||
type: string
|
||||
description: >-
|
||||
The identifier of the model to use. The model must be an embedding model
|
||||
registered with Llama Stack and available via the /models endpoint.
|
||||
contents:
|
||||
oneOf:
|
||||
- type: array
|
||||
items:
|
||||
type: string
|
||||
- type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/InterleavedContentItem'
|
||||
description: >-
|
||||
List of contents to generate embeddings for. Each content can be a string
|
||||
or an InterleavedContentItem (and hence can be multimodal). The behavior
|
||||
depends on the model and provider. Some models may only support text.
|
||||
text_truncation:
|
||||
type: string
|
||||
enum:
|
||||
- none
|
||||
- start
|
||||
- end
|
||||
description: >-
|
||||
(Optional) Config for how to truncate text for embedding when text is
|
||||
longer than the model's max sequence length.
|
||||
output_dimension:
|
||||
type: integer
|
||||
description: >-
|
||||
(Optional) Output dimensionality for the embeddings. Only supported by
|
||||
Matryoshka models.
|
||||
task_type:
|
||||
type: string
|
||||
enum:
|
||||
- query
|
||||
- document
|
||||
description: >-
|
||||
(Optional) How is the embedding being used? This is only supported by
|
||||
asymmetric embedding models.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- model_id
|
||||
- contents
|
||||
title: EmbeddingsRequest
|
||||
EmbeddingsResponse:
|
||||
type: object
|
||||
properties:
|
||||
embeddings:
|
||||
type: array
|
||||
items:
|
||||
type: array
|
||||
items:
|
||||
type: number
|
||||
description: >-
|
||||
List of embedding vectors, one per input content. Each embedding is a
|
||||
list of floats. The dimensionality of the embedding is model-specific;
|
||||
you can check model metadata using /models/{model_id}
|
||||
additionalProperties: false
|
||||
required:
|
||||
- embeddings
|
||||
title: EmbeddingsResponse
|
||||
description: >-
|
||||
Response containing generated embeddings.
|
||||
AgentCandidate:
|
||||
type: object
|
||||
properties:
|
||||
|
|
|
@ -17,7 +17,7 @@ from typing import (
|
|||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.telemetry import MetricResponseMixin
|
||||
|
@ -1135,26 +1135,6 @@ class InferenceProvider(Protocol):
|
|||
raise NotImplementedError("Batch chat completion is not implemented")
|
||||
return # this is so mypy's safe-super rule will consider the method concrete
|
||||
|
||||
@webmethod(route="/inference/embeddings", method="POST")
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
"""Generate embeddings for content pieces using the specified model.
|
||||
|
||||
:param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint.
|
||||
:param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text.
|
||||
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
|
||||
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
|
||||
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
|
||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/inference/rerank", method="POST", experimental=True)
|
||||
async def rerank(
|
||||
self,
|
||||
|
|
|
@ -16,7 +16,6 @@ from pydantic import Field, TypeAdapter
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.inference import (
|
||||
|
@ -28,8 +27,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
LogProbConfig,
|
||||
|
@ -50,7 +47,6 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -347,25 +343,6 @@ class InferenceRouter(Inference):
|
|||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
logger.debug(f"InferenceRouter.embeddings: {model_id}")
|
||||
await self._get_model(model_id, ModelType.embedding)
|
||||
provider = await self.routing_table.get_provider_impl(model_id)
|
||||
return await provider.embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
output_dimension=output_dimension,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -11,21 +11,17 @@ from botocore.client import BaseClient
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -47,8 +43,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
@ -176,31 +170,6 @@ class BedrockInferenceAdapter(
|
|||
),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embeddings = []
|
||||
for content in contents:
|
||||
assert not content_has_media(content), "Bedrock does not support media for embeddings"
|
||||
input_text = interleaved_content_as_str(content)
|
||||
input_body = {"inputText": input_text}
|
||||
body = json.dumps(input_body)
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=model.provider_resource_id,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
embeddings.append(response_body.get("embedding"))
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -10,21 +10,17 @@ from cerebras.cloud.sdk import AsyncCerebras
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -187,16 +183,6 @@ class CerebrasInferenceAdapter(
|
|||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -10,20 +10,16 @@ from openai import OpenAI
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
OpenAIEmbeddingsResponse,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -147,16 +143,6 @@ class DatabricksInferenceAdapter(
|
|||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -12,15 +12,12 @@ from openai import AsyncOpenAI
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -33,7 +30,6 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -57,8 +53,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -261,31 +255,6 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
|||
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
if model.metadata.get("embedding_dimension"):
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Fireworks does not support media for embeddings"
|
||||
)
|
||||
response = self._get_client().embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -11,8 +11,6 @@ from openai import NOT_GIVEN, APIConnectionError
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
|
@ -21,8 +19,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -31,7 +27,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIEmbeddingUsage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
|
@ -155,60 +150,6 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
|||
# we pass n=1 to get only one completion
|
||||
return convert_openai_completion_choice(response.choices[0])
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
if any(content_has_media(content) for content in contents):
|
||||
raise NotImplementedError("Media is not supported")
|
||||
|
||||
#
|
||||
# Llama Stack: contents = list[str] | list[InterleavedContentItem]
|
||||
# ->
|
||||
# OpenAI: input = str | list[str]
|
||||
#
|
||||
# we can ignore str and always pass list[str] to OpenAI
|
||||
#
|
||||
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
|
||||
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||
provider_model_id = await self._get_provider_model_id(model_id)
|
||||
|
||||
extra_body = {}
|
||||
|
||||
if text_truncation is not None:
|
||||
text_truncation_options = {
|
||||
TextTruncation.none: "NONE",
|
||||
TextTruncation.end: "END",
|
||||
TextTruncation.start: "START",
|
||||
}
|
||||
extra_body["truncate"] = text_truncation_options[text_truncation]
|
||||
|
||||
if output_dimension is not None:
|
||||
extra_body["dimensions"] = output_dimension
|
||||
|
||||
if task_type is not None:
|
||||
task_type_options = {
|
||||
EmbeddingTaskType.document: "passage",
|
||||
EmbeddingTaskType.query: "query",
|
||||
}
|
||||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
response = await self.client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
#
|
||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
|
||||
# ->
|
||||
# Llama Stack: EmbeddingsResponse(embeddings=list[list[float]])
|
||||
#
|
||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -17,7 +17,6 @@ from openai import AsyncOpenAI
|
|||
from llama_stack.apis.common.content_types import (
|
||||
ImageContentItem,
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextContentItem,
|
||||
)
|
||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||
|
@ -28,8 +27,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
|
@ -44,7 +41,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -76,9 +72,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
convert_image_content_to_url,
|
||||
interleaved_content_as_str,
|
||||
localize_image_content,
|
||||
request_has_media,
|
||||
)
|
||||
|
@ -394,27 +388,6 @@ class OllamaInferenceAdapter(
|
|||
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Ollama does not support media for embeddings"
|
||||
)
|
||||
response = await self.client.embed(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = response["embeddings"]
|
||||
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
try:
|
||||
model = await self.register_helper.register_model(model)
|
||||
|
|
|
@ -14,8 +14,6 @@ from llama_stack.apis.inference import (
|
|||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -27,7 +25,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -190,25 +187,6 @@ class PassthroughInferenceAdapter(Inference):
|
|||
chunk = convert_to_pydantic(ChatCompletionResponseStreamChunk, chunk)
|
||||
yield chunk
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[InterleavedContent],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
client = self._get_client()
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
return await client.inference.embeddings(
|
||||
model_id=model.provider_resource_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
output_dimension=output_dimension,
|
||||
task_type=task_type,
|
||||
)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -136,16 +136,6 @@ class RunpodInferenceAdapter(
|
|||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -11,14 +11,11 @@ from huggingface_hub import AsyncInferenceClient, HfApi
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -26,7 +23,6 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -282,16 +278,6 @@ class _HfAdapter(
|
|||
**self._build_options(request.sampling_params, request.response_format),
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -12,14 +12,11 @@ from together import AsyncTogether
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
|
@ -32,7 +29,6 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
ResponseFormatType,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -53,8 +49,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_prompt,
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -235,26 +229,6 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
|||
logger.debug(f"params to together: {params}")
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
assert all(not content_has_media(content) for content in contents), (
|
||||
"Together does not support media for embeddings"
|
||||
)
|
||||
client = self._get_client()
|
||||
r = await client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
embeddings = [item.embedding for item in r.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -15,7 +15,6 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
|
@ -30,8 +29,6 @@ from llama_stack.apis.inference import (
|
|||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GrammarResponseFormat,
|
||||
Inference,
|
||||
JsonSchemaResponseFormat,
|
||||
|
@ -47,7 +44,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -78,8 +74,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
completion_request_to_prompt,
|
||||
content_has_media,
|
||||
interleaved_content_as_str,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
|
@ -535,32 +529,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
**options,
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None
|
||||
model = await self._get_model(model_id)
|
||||
|
||||
kwargs = {}
|
||||
assert model.model_type == ModelType.embedding
|
||||
assert model.metadata.get("embedding_dimension")
|
||||
kwargs["dimensions"] = model.metadata.get("embedding_dimension")
|
||||
assert all(not content_has_media(content) for content in contents), "VLLM does not support media for embeddings"
|
||||
response = await self.client.embeddings.create(
|
||||
model=model.provider_resource_id,
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [data.embedding for data in response.data]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -11,13 +11,11 @@ from ibm_watson_machine_learning.foundation_models import Model
|
|||
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
GreedySamplingStrategy,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
|
@ -30,7 +28,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -249,16 +246,6 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
}
|
||||
return params
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError("embedding is not supported for watsonx")
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -14,16 +14,11 @@ if TYPE_CHECKING:
|
|||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
InterleavedContentItem,
|
||||
ModelStore,
|
||||
OpenAIEmbeddingData,
|
||||
OpenAIEmbeddingsResponse,
|
||||
OpenAIEmbeddingUsage,
|
||||
TextTruncation,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
||||
EMBEDDING_MODELS = {}
|
||||
|
||||
|
@ -34,21 +29,6 @@ log = get_logger(name=__name__, category="providers::utils")
|
|||
class SentenceTransformerEmbeddingMixin:
|
||||
model_store: ModelStore
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
embeddings = embedding_model.encode(
|
||||
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
|
||||
)
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -11,14 +11,11 @@ import litellm
|
|||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
InterleavedContentItem,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingTaskType,
|
||||
InferenceProvider,
|
||||
JsonSchemaResponseFormat,
|
||||
LogProbConfig,
|
||||
|
@ -32,7 +29,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatParam,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
TextTruncation,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
|
@ -50,9 +46,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
get_sampling_options,
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="providers::utils")
|
||||
|
||||
|
@ -269,24 +262,6 @@ class LiteLLMOpenAIMixin(
|
|||
)
|
||||
return api_key
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: list[str] | list[InterleavedContentItem],
|
||||
text_truncation: TextTruncation | None = TextTruncation.none,
|
||||
output_dimension: int | None = None,
|
||||
task_type: EmbeddingTaskType | None = None,
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.model_store.get_model(model_id)
|
||||
|
||||
response = litellm.embedding(
|
||||
model=self.get_litellm_model_name(model.provider_resource_id),
|
||||
input=[interleaved_content_as_str(content) for content in contents],
|
||||
)
|
||||
|
||||
embeddings = [data["embedding"] for data in response["data"]]
|
||||
return EmbeddingsResponse(embeddings=embeddings)
|
||||
|
||||
async def openai_embeddings(
|
||||
self,
|
||||
model: str,
|
||||
|
|
|
@ -1,303 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
#
|
||||
# Test plan:
|
||||
#
|
||||
# Types of input:
|
||||
# - array of a string
|
||||
# - array of a image (ImageContentItem, either URL or base64 string)
|
||||
# - array of a text (TextContentItem)
|
||||
# Types of output:
|
||||
# - list of list of floats
|
||||
# Params:
|
||||
# - text_truncation
|
||||
# - absent w/ long text -> error
|
||||
# - none w/ long text -> error
|
||||
# - absent w/ short text -> ok
|
||||
# - none w/ short text -> ok
|
||||
# - end w/ long text -> ok
|
||||
# - end w/ short text -> ok
|
||||
# - start w/ long text -> ok
|
||||
# - start w/ short text -> ok
|
||||
# - output_dimension
|
||||
# - response dimension matches
|
||||
# - task_type, only for asymmetric models
|
||||
# - query embedding != passage embedding
|
||||
# Negative:
|
||||
# - long string
|
||||
# - long text
|
||||
#
|
||||
# Todo:
|
||||
# - negative tests
|
||||
# - empty
|
||||
# - empty list
|
||||
# - empty string
|
||||
# - empty text
|
||||
# - empty image
|
||||
# - long
|
||||
# - large image
|
||||
# - appropriate combinations
|
||||
# - batch size
|
||||
# - many inputs
|
||||
# - invalid
|
||||
# - invalid URL
|
||||
# - invalid base64
|
||||
#
|
||||
# Notes:
|
||||
# - use llama_stack_client fixture
|
||||
# - use pytest.mark.parametrize when possible
|
||||
# - no accuracy tests: only check the type of output, not the content
|
||||
#
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||
from llama_stack_client.types import EmbeddingsResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
ImageContentItem,
|
||||
ImageContentItemImage,
|
||||
ImageContentItemImageURL,
|
||||
TextContentItem,
|
||||
)
|
||||
from openai import BadRequestError as OpenAIBadRequestError
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
DUMMY_STRING = "hello"
|
||||
DUMMY_STRING2 = "world"
|
||||
DUMMY_LONG_STRING = "NVDA " * 10240
|
||||
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
|
||||
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
|
||||
DUMMY_LONG_TEXT = TextContentItem(text=DUMMY_LONG_STRING, type="text")
|
||||
# TODO(mf): add a real image URL and base64 string
|
||||
DUMMY_IMAGE_URL = ImageContentItem(
|
||||
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
|
||||
)
|
||||
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
||||
SUPPORTED_PROVIDERS = {"remote::nvidia"}
|
||||
MODELS_SUPPORTING_MEDIA = {}
|
||||
MODELS_SUPPORTING_OUTPUT_DIMENSION = {"nvidia/llama-3.2-nv-embedqa-1b-v2"}
|
||||
MODELS_REQUIRING_TASK_TYPE = {
|
||||
"nvidia/llama-3.2-nv-embedqa-1b-v2",
|
||||
"nvidia/nv-embedqa-e5-v5",
|
||||
"nvidia/nv-embedqa-mistral-7b-v2",
|
||||
"snowflake/arctic-embed-l",
|
||||
}
|
||||
MODELS_SUPPORTING_TASK_TYPE = MODELS_REQUIRING_TASK_TYPE
|
||||
|
||||
|
||||
def default_task_type(model_id):
|
||||
"""
|
||||
Some models require a task type parameter. This provides a default value for
|
||||
testing those models.
|
||||
"""
|
||||
if model_id in MODELS_REQUIRING_TASK_TYPE:
|
||||
return {"task_type": "query"}
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"contents",
|
||||
[
|
||||
[DUMMY_STRING, DUMMY_STRING2],
|
||||
[DUMMY_TEXT, DUMMY_TEXT2],
|
||||
],
|
||||
ids=[
|
||||
"list[string]",
|
||||
"list[text]",
|
||||
],
|
||||
)
|
||||
def test_embedding_text(llama_stack_client, embedding_model_id, contents, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
response = llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
|
||||
)
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
|
||||
assert isinstance(response.embeddings[0], list)
|
||||
assert isinstance(response.embeddings[0][0], float)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"contents",
|
||||
[
|
||||
[DUMMY_IMAGE_URL, DUMMY_IMAGE_BASE64],
|
||||
[DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT],
|
||||
],
|
||||
ids=[
|
||||
"list[url,base64]",
|
||||
"list[url,string,base64,text]",
|
||||
],
|
||||
)
|
||||
def test_embedding_image(llama_stack_client, embedding_model_id, contents, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
if embedding_model_id not in MODELS_SUPPORTING_MEDIA:
|
||||
pytest.xfail(f"{embedding_model_id} doesn't support media")
|
||||
response = llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
|
||||
)
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
|
||||
assert isinstance(response.embeddings[0], list)
|
||||
assert isinstance(response.embeddings[0][0], float)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_truncation",
|
||||
[
|
||||
"end",
|
||||
"start",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"contents",
|
||||
[
|
||||
[DUMMY_LONG_TEXT],
|
||||
[DUMMY_STRING],
|
||||
],
|
||||
ids=[
|
||||
"long",
|
||||
"short",
|
||||
],
|
||||
)
|
||||
def test_embedding_truncation(
|
||||
llama_stack_client, embedding_model_id, text_truncation, contents, inference_provider_type
|
||||
):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
response = llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id,
|
||||
contents=contents,
|
||||
text_truncation=text_truncation,
|
||||
**default_task_type(embedding_model_id),
|
||||
)
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) == 1
|
||||
assert isinstance(response.embeddings[0], list)
|
||||
assert isinstance(response.embeddings[0][0], float)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_truncation",
|
||||
[
|
||||
None,
|
||||
"none",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"contents",
|
||||
[
|
||||
[DUMMY_LONG_TEXT],
|
||||
[DUMMY_LONG_STRING],
|
||||
],
|
||||
ids=[
|
||||
"long-text",
|
||||
"long-str",
|
||||
],
|
||||
)
|
||||
def test_embedding_truncation_error(
|
||||
llama_stack_client, embedding_model_id, text_truncation, contents, inference_provider_type
|
||||
):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
# Using LlamaStackClient from llama_stack_client will raise llama_stack_client.BadRequestError
|
||||
# While using LlamaStackAsLibraryClient from llama_stack.distribution.library_client will raise the error that the backend raises
|
||||
error_type = (
|
||||
OpenAIBadRequestError
|
||||
if isinstance(llama_stack_client, LlamaStackAsLibraryClient)
|
||||
else LlamaStackBadRequestError
|
||||
)
|
||||
with pytest.raises(error_type):
|
||||
llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id,
|
||||
contents=[DUMMY_LONG_TEXT],
|
||||
text_truncation=text_truncation,
|
||||
**default_task_type(embedding_model_id),
|
||||
)
|
||||
|
||||
|
||||
def test_embedding_output_dimension(llama_stack_client, embedding_model_id, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
if embedding_model_id not in MODELS_SUPPORTING_OUTPUT_DIMENSION:
|
||||
pytest.xfail(f"{embedding_model_id} doesn't support output_dimension")
|
||||
base_response = llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id, contents=[DUMMY_STRING], **default_task_type(embedding_model_id)
|
||||
)
|
||||
test_response = llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id,
|
||||
contents=[DUMMY_STRING],
|
||||
**default_task_type(embedding_model_id),
|
||||
output_dimension=32,
|
||||
)
|
||||
assert len(base_response.embeddings[0]) != len(test_response.embeddings[0])
|
||||
assert len(test_response.embeddings[0]) == 32
|
||||
|
||||
|
||||
def test_embedding_task_type(llama_stack_client, embedding_model_id, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
if embedding_model_id not in MODELS_SUPPORTING_TASK_TYPE:
|
||||
pytest.xfail(f"{embedding_model_id} doesn't support task_type")
|
||||
query_embedding = llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query"
|
||||
)
|
||||
document_embedding = llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="document"
|
||||
)
|
||||
assert query_embedding.embeddings != document_embedding.embeddings
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_truncation",
|
||||
[
|
||||
None,
|
||||
"none",
|
||||
"end",
|
||||
"start",
|
||||
],
|
||||
)
|
||||
def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_truncation, inference_provider_type):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
response = llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id,
|
||||
contents=[DUMMY_STRING],
|
||||
text_truncation=text_truncation,
|
||||
**default_task_type(embedding_model_id),
|
||||
)
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) == 1
|
||||
assert isinstance(response.embeddings[0], list)
|
||||
assert isinstance(response.embeddings[0][0], float)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_truncation",
|
||||
[
|
||||
"NONE",
|
||||
"END",
|
||||
"START",
|
||||
"left",
|
||||
"right",
|
||||
],
|
||||
)
|
||||
def test_embedding_text_truncation_error(
|
||||
llama_stack_client, embedding_model_id, text_truncation, inference_provider_type
|
||||
):
|
||||
if inference_provider_type not in SUPPORTED_PROVIDERS:
|
||||
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
|
||||
error_type = ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
with pytest.raises(error_type):
|
||||
llama_stack_client.inference.embeddings(
|
||||
model_id=embedding_model_id,
|
||||
contents=[DUMMY_STRING],
|
||||
text_truncation=text_truncation,
|
||||
**default_task_type(embedding_model_id),
|
||||
)
|
|
@ -5,13 +5,12 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import EmbeddingsResponse, Inference
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
@ -70,13 +69,6 @@ def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
|
|||
return mock_vector_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api(sample_embeddings):
|
||||
mock_api = MagicMock(spec=Inference)
|
||||
mock_api.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings))
|
||||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_files_api():
|
||||
mock_api = MagicMock(spec=Files)
|
||||
|
@ -96,22 +88,6 @@ async def faiss_index(embedding_dimension):
|
|||
yield index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter:
|
||||
# Create the adapter
|
||||
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api)
|
||||
|
||||
# Create a mock KVStore
|
||||
mock_kvstore = MagicMock()
|
||||
mock_kvstore.values_in_range = AsyncMock(return_value=[])
|
||||
|
||||
# Patch the initialize method to avoid the kvstore_impl call
|
||||
with patch.object(FaissVectorIOAdapter, "initialize"):
|
||||
# Set the kvstore directly
|
||||
adapter.kvstore = mock_kvstore
|
||||
yield adapter
|
||||
|
||||
|
||||
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
|
||||
faiss_index, sample_chunks, sample_embeddings, embedding_dimension
|
||||
):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue