From b21050935e4904005446987c0339c9f443f870a3 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Sat, 31 May 2025 22:11:47 -0700 Subject: [PATCH] feat: New OpenAI compat embeddings API (#2314) # What does this PR do? Adds a new endpoint that is compatible with OpenAI for embeddings api. `/openai/v1/embeddings` Added providers for OpenAI, LiteLLM and SentenceTransformer. ## Test Plan ``` LLAMA_STACK_CONFIG=http://localhost:8321 pytest -sv tests/integration/inference/test_openai_embeddings.py --embedding-model all-MiniLM-L6-v2,text-embedding-3-small,gemini/text-embedding-004 ``` --- docs/_static/llama-stack-spec.html | 176 +++++++++++ docs/_static/llama-stack-spec.yaml | 144 +++++++++ llama_stack/apis/inference/inference.py | 62 ++++ llama_stack/distribution/routers/inference.py | 29 ++ .../providers/inline/inference/vllm/vllm.py | 11 + .../remote/inference/bedrock/bedrock.py | 11 + .../remote/inference/cerebras/cerebras.py | 11 + .../remote/inference/databricks/databricks.py | 11 + .../remote/inference/fireworks/fireworks.py | 11 + .../remote/inference/nvidia/nvidia.py | 11 + .../remote/inference/ollama/ollama.py | 11 + .../remote/inference/openai/openai.py | 52 ++++ .../inference/passthrough/passthrough.py | 11 + .../remote/inference/runpod/runpod.py | 11 + .../providers/remote/inference/tgi/tgi.py | 11 + .../remote/inference/together/together.py | 11 + .../providers/remote/inference/vllm/vllm.py | 11 + .../remote/inference/watsonx/watsonx.py | 11 + .../utils/inference/embedding_mixin.py | 49 ++++ .../utils/inference/litellm_openai_mixin.py | 51 ++++ .../inference/test_openai_embeddings.py | 275 ++++++++++++++++++ 21 files changed, 981 insertions(+) create mode 100644 tests/integration/inference/test_openai_embeddings.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index dbe251921..d88462909 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -3607,6 +3607,49 @@ } } }, + "/v1/openai/v1/embeddings": { + "post": { + "responses": { + "200": { + "description": "An OpenAIEmbeddingsResponse containing the embeddings.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenAIEmbeddingsResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Inference" + ], + "description": "Generate OpenAI-compatible embeddings for the given input using the specified model.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OpenaiEmbeddingsRequest" + } + } + }, + "required": true + } + } + }, "/v1/openai/v1/models": { "get": { "responses": { @@ -11777,6 +11820,139 @@ "title": "OpenAICompletionChoice", "description": "A choice from an OpenAI-compatible completion response." }, + "OpenaiEmbeddingsRequest": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint." + }, + "input": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ], + "description": "Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings." + }, + "encoding_format": { + "type": "string", + "description": "(Optional) The format to return the embeddings in. Can be either \"float\" or \"base64\". Defaults to \"float\"." + }, + "dimensions": { + "type": "integer", + "description": "(Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models." + }, + "user": { + "type": "string", + "description": "(Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse." + } + }, + "additionalProperties": false, + "required": [ + "model", + "input" + ], + "title": "OpenaiEmbeddingsRequest" + }, + "OpenAIEmbeddingData": { + "type": "object", + "properties": { + "object": { + "type": "string", + "const": "embedding", + "default": "embedding", + "description": "The object type, which will be \"embedding\"" + }, + "embedding": { + "oneOf": [ + { + "type": "array", + "items": { + "type": "number" + } + }, + { + "type": "string" + } + ], + "description": "The embedding vector as a list of floats (when encoding_format=\"float\") or as a base64-encoded string (when encoding_format=\"base64\")" + }, + "index": { + "type": "integer", + "description": "The index of the embedding in the input list" + } + }, + "additionalProperties": false, + "required": [ + "object", + "embedding", + "index" + ], + "title": "OpenAIEmbeddingData", + "description": "A single embedding data object from an OpenAI-compatible embeddings response." + }, + "OpenAIEmbeddingUsage": { + "type": "object", + "properties": { + "prompt_tokens": { + "type": "integer", + "description": "The number of tokens in the input" + }, + "total_tokens": { + "type": "integer", + "description": "The total number of tokens used" + } + }, + "additionalProperties": false, + "required": [ + "prompt_tokens", + "total_tokens" + ], + "title": "OpenAIEmbeddingUsage", + "description": "Usage information for an OpenAI-compatible embeddings response." + }, + "OpenAIEmbeddingsResponse": { + "type": "object", + "properties": { + "object": { + "type": "string", + "const": "list", + "default": "list", + "description": "The object type, which will be \"list\"" + }, + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/OpenAIEmbeddingData" + }, + "description": "List of embedding data objects" + }, + "model": { + "type": "string", + "description": "The model that was used to generate the embeddings" + }, + "usage": { + "$ref": "#/components/schemas/OpenAIEmbeddingUsage", + "description": "Usage information" + } + }, + "additionalProperties": false, + "required": [ + "object", + "data", + "model", + "usage" + ], + "title": "OpenAIEmbeddingsResponse", + "description": "Response from an OpenAI-compatible embeddings request." + }, "OpenAIModel": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 2dfb037da..7638c3cbd 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2520,6 +2520,38 @@ paths: schema: $ref: '#/components/schemas/OpenaiCompletionRequest' required: true + /v1/openai/v1/embeddings: + post: + responses: + '200': + description: >- + An OpenAIEmbeddingsResponse containing the embeddings. + content: + application/json: + schema: + $ref: '#/components/schemas/OpenAIEmbeddingsResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Inference + description: >- + Generate OpenAI-compatible embeddings for the given input using the specified + model. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/OpenaiEmbeddingsRequest' + required: true /v1/openai/v1/models: get: responses: @@ -8197,6 +8229,118 @@ components: title: OpenAICompletionChoice description: >- A choice from an OpenAI-compatible completion response. + OpenaiEmbeddingsRequest: + type: object + properties: + model: + type: string + description: >- + The identifier of the model to use. The model must be an embedding model + registered with Llama Stack and available via the /models endpoint. + input: + oneOf: + - type: string + - type: array + items: + type: string + description: >- + Input text to embed, encoded as a string or array of strings. To embed + multiple inputs in a single request, pass an array of strings. + encoding_format: + type: string + description: >- + (Optional) The format to return the embeddings in. Can be either "float" + or "base64". Defaults to "float". + dimensions: + type: integer + description: >- + (Optional) The number of dimensions the resulting output embeddings should + have. Only supported in text-embedding-3 and later models. + user: + type: string + description: >- + (Optional) A unique identifier representing your end-user, which can help + OpenAI to monitor and detect abuse. + additionalProperties: false + required: + - model + - input + title: OpenaiEmbeddingsRequest + OpenAIEmbeddingData: + type: object + properties: + object: + type: string + const: embedding + default: embedding + description: >- + The object type, which will be "embedding" + embedding: + oneOf: + - type: array + items: + type: number + - type: string + description: >- + The embedding vector as a list of floats (when encoding_format="float") + or as a base64-encoded string (when encoding_format="base64") + index: + type: integer + description: >- + The index of the embedding in the input list + additionalProperties: false + required: + - object + - embedding + - index + title: OpenAIEmbeddingData + description: >- + A single embedding data object from an OpenAI-compatible embeddings response. + OpenAIEmbeddingUsage: + type: object + properties: + prompt_tokens: + type: integer + description: The number of tokens in the input + total_tokens: + type: integer + description: The total number of tokens used + additionalProperties: false + required: + - prompt_tokens + - total_tokens + title: OpenAIEmbeddingUsage + description: >- + Usage information for an OpenAI-compatible embeddings response. + OpenAIEmbeddingsResponse: + type: object + properties: + object: + type: string + const: list + default: list + description: The object type, which will be "list" + data: + type: array + items: + $ref: '#/components/schemas/OpenAIEmbeddingData' + description: List of embedding data objects + model: + type: string + description: >- + The model that was used to generate the embeddings + usage: + $ref: '#/components/schemas/OpenAIEmbeddingUsage' + description: Usage information + additionalProperties: false + required: + - object + - data + - model + - usage + title: OpenAIEmbeddingsResponse + description: >- + Response from an OpenAI-compatible embeddings request. OpenAIModel: type: object properties: diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index e79dc6d94..74697dd18 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -783,6 +783,48 @@ class OpenAICompletion(BaseModel): object: Literal["text_completion"] = "text_completion" +@json_schema_type +class OpenAIEmbeddingData(BaseModel): + """A single embedding data object from an OpenAI-compatible embeddings response. + + :param object: The object type, which will be "embedding" + :param embedding: The embedding vector as a list of floats (when encoding_format="float") or as a base64-encoded string (when encoding_format="base64") + :param index: The index of the embedding in the input list + """ + + object: Literal["embedding"] = "embedding" + embedding: list[float] | str + index: int + + +@json_schema_type +class OpenAIEmbeddingUsage(BaseModel): + """Usage information for an OpenAI-compatible embeddings response. + + :param prompt_tokens: The number of tokens in the input + :param total_tokens: The total number of tokens used + """ + + prompt_tokens: int + total_tokens: int + + +@json_schema_type +class OpenAIEmbeddingsResponse(BaseModel): + """Response from an OpenAI-compatible embeddings request. + + :param object: The object type, which will be "list" + :param data: List of embedding data objects + :param model: The model that was used to generate the embeddings + :param usage: Usage information + """ + + object: Literal["list"] = "list" + data: list[OpenAIEmbeddingData] + model: str + usage: OpenAIEmbeddingUsage + + class ModelStore(Protocol): async def get_model(self, identifier: str) -> Model: ... @@ -1076,6 +1118,26 @@ class InferenceProvider(Protocol): """ ... + @webmethod(route="/openai/v1/embeddings", method="POST") + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + """Generate OpenAI-compatible embeddings for the given input using the specified model. + + :param model: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint. + :param input: Input text to embed, encoded as a string or array of strings. To embed multiple inputs in a single request, pass an array of strings. + :param encoding_format: (Optional) The format to return the embeddings in. Can be either "float" or "base64". Defaults to "float". + :param dimensions: (Optional) The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. + :param user: (Optional) A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + :returns: An OpenAIEmbeddingsResponse containing the embeddings. + """ + ... + class Inference(InferenceProvider): """Llama Stack Inference API for generating completions, chat completions, and embeddings. diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/distribution/routers/inference.py index f77b19302..763bd9105 100644 --- a/llama_stack/distribution/routers/inference.py +++ b/llama_stack/distribution/routers/inference.py @@ -45,6 +45,7 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, + OpenAIEmbeddingsResponse, OpenAIMessageParam, OpenAIResponseFormatParam, ) @@ -546,6 +547,34 @@ class InferenceRouter(Inference): await self.store.store_chat_completion(response, messages) return response + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + logger.debug( + f"InferenceRouter.openai_embeddings: {model=}, input_type={type(input)}, {encoding_format=}, {dimensions=}", + ) + model_obj = await self.routing_table.get_model(model) + if model_obj is None: + raise ValueError(f"Model '{model}' not found") + if model_obj.model_type != ModelType.embedding: + raise ValueError(f"Model '{model}' is not an embedding model") + + params = dict( + model=model_obj.identifier, + input=input, + encoding_format=encoding_format, + dimensions=dimensions, + user=user, + ) + + provider = self.routing_table.get_provider_impl(model_obj.identifier) + return await provider.openai_embeddings(**params) + async def list_chat_completions( self, after: str | None = None, diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 438cb14a0..bf54462b5 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -40,6 +40,7 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -410,6 +411,16 @@ class VLLMInferenceImpl( ) -> EmbeddingsResponse: raise NotImplementedError() + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def chat_completion( self, model_id: str, diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 0404a578f..952d86f1a 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -22,6 +22,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -197,3 +198,13 @@ class BedrockInferenceAdapter( response_body = json.loads(response.get("body").read()) embeddings.append(response_body.get("embedding")) return EmbeddingsResponse(embeddings=embeddings) + + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 685375346..952118e24 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -21,6 +21,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -194,3 +195,13 @@ class CerebrasInferenceAdapter( task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError() + + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 5c36eac3e..1dc18b97f 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -20,6 +20,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -152,3 +153,13 @@ class DatabricksInferenceAdapter( task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: raise NotImplementedError() + + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index b6d3984c6..fe21685dd 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -37,6 +37,7 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, + OpenAIEmbeddingsResponse, OpenAIMessageParam, OpenAIResponseFormatParam, ) @@ -286,6 +287,16 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 333486fe4..4c68322e0 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -29,6 +29,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -238,6 +239,16 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def chat_completion( self, model_id: str, diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 3b4287673..8863e0edc 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -32,6 +32,7 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -370,6 +371,16 @@ class OllamaInferenceAdapter( return model + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index c3c25edd3..6f3a686a8 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -14,6 +14,9 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, + OpenAIEmbeddingData, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, OpenAIMessageParam, OpenAIResponseFormatParam, ) @@ -38,6 +41,7 @@ logger = logging.getLogger(__name__) # | batch_chat_completion | LiteLLMOpenAIMixin | # | openai_completion | AsyncOpenAI | # | openai_chat_completion | AsyncOpenAI | +# | openai_embeddings | AsyncOpenAI | # class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): def __init__(self, config: OpenAIConfig) -> None: @@ -171,3 +175,51 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): user=user, ) return await self._openai_client.chat.completions.create(**params) + + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + model_id = (await self.model_store.get_model(model)).provider_resource_id + if model_id.startswith("openai/"): + model_id = model_id[len("openai/") :] + + # Prepare parameters for OpenAI embeddings API + params = { + "model": model_id, + "input": input, + } + + if encoding_format is not None: + params["encoding_format"] = encoding_format + if dimensions is not None: + params["dimensions"] = dimensions + if user is not None: + params["user"] = user + + # Call OpenAI embeddings API + response = await self._openai_client.embeddings.create(**params) + + data = [] + for i, embedding_data in enumerate(response.data): + data.append( + OpenAIEmbeddingData( + embedding=embedding_data.embedding, + index=i, + ) + ) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response.usage.prompt_tokens, + total_tokens=response.usage.total_tokens, + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=response.model, + usage=usage, + ) diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index 78ee52641..6cf4680e2 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -19,6 +19,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -210,6 +211,16 @@ class PassthroughInferenceAdapter(Inference): task_type=task_type, ) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 2706aa15e..f8c98893e 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.inference.inference import OpenAIEmbeddingsResponse # from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper @@ -134,3 +135,13 @@ class RunpodInferenceAdapter( task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() + + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 8f6666462..292d74ef8 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -23,6 +23,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, ResponseFormatType, SamplingParams, @@ -291,6 +292,16 @@ class _HfAdapter( ) -> EmbeddingsResponse: raise NotImplementedError() + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + class TGIAdapter(_HfAdapter): async def initialize(self, config: TGIImplConfig) -> None: diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 562e6e0ff..7305a638d 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -23,6 +23,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, ResponseFormatType, SamplingParams, @@ -267,6 +268,16 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi embeddings = [item.embedding for item in r.data] return EmbeddingsResponse(embeddings=embeddings) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index fe2d8bec1..9f38d9abf 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -38,6 +38,7 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -507,6 +508,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): embeddings = [data.embedding for data in response.data] return EmbeddingsResponse(embeddings=embeddings) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py index c1299e11f..59f5f5562 100644 --- a/llama_stack/providers/remote/inference/watsonx/watsonx.py +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -21,6 +21,7 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingsResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -260,6 +261,16 @@ class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): ) -> EmbeddingsResponse: raise NotImplementedError("embedding is not supported for watsonx") + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError() + async def openai_completion( self, model: str, diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 7c8144c62..97cf87360 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -4,7 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 import logging +import struct from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -15,6 +17,9 @@ from llama_stack.apis.inference import ( EmbeddingTaskType, InterleavedContentItem, ModelStore, + OpenAIEmbeddingData, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, TextTruncation, ) from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str @@ -43,6 +48,50 @@ class SentenceTransformerEmbeddingMixin: ) return EmbeddingsResponse(embeddings=embeddings) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + # Convert input to list format if it's a single string + input_list = [input] if isinstance(input, str) else input + if not input_list: + raise ValueError("Empty list not supported") + + # Get the model and generate embeddings + model_obj = await self.model_store.get_model(model) + embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id) + embeddings = embedding_model.encode(input_list, show_progress_bar=False) + + # Convert embeddings to the requested format + data = [] + for i, embedding in enumerate(embeddings): + if encoding_format == "base64": + # Convert float array to base64 string + float_bytes = struct.pack(f"{len(embedding)}f", *embedding) + embedding_value = base64.b64encode(float_bytes).decode("ascii") + else: + # Default to float format + embedding_value = embedding.tolist() + + data.append( + OpenAIEmbeddingData( + embedding=embedding_value, + index=i, + ) + ) + + # Not returning actual token usage + usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1) + return OpenAIEmbeddingsResponse( + data=data, + model=model_obj.provider_resource_id, + usage=usage, + ) + def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": global EMBEDDING_MODELS diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 4d17db21e..dab10bc55 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 +import struct from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -35,6 +37,9 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, + OpenAIEmbeddingData, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, OpenAIMessageParam, OpenAIResponseFormatParam, ) @@ -264,6 +269,52 @@ class LiteLLMOpenAIMixin( embeddings = [data["embedding"] for data in response["data"]] return EmbeddingsResponse(embeddings=embeddings) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + model_obj = await self.model_store.get_model(model) + + # Convert input to list if it's a string + input_list = [input] if isinstance(input, str) else input + + # Call litellm embedding function + # litellm.drop_params = True + response = litellm.embedding( + model=self.get_litellm_model_name(model_obj.provider_resource_id), + input=input_list, + api_key=self.get_api_key(), + api_base=self.api_base, + dimensions=dimensions, + ) + + # Convert response to OpenAI format + data = [] + for i, embedding_data in enumerate(response["data"]): + # we encode to base64 if the encoding format is base64 in the request + if encoding_format == "base64": + byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"]) + embedding = base64.b64encode(byte_data).decode("utf-8") + else: + embedding = embedding_data["embedding"] + + data.append(OpenAIEmbeddingData(embedding=embedding, index=i)) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response["usage"]["prompt_tokens"], + total_tokens=response["usage"]["total_tokens"], + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=model_obj.provider_resource_id, + usage=usage, + ) + async def openai_completion( self, model: str, diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py new file mode 100644 index 000000000..759556257 --- /dev/null +++ b/tests/integration/inference/test_openai_embeddings.py @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import base64 +import struct + +import pytest +from openai import OpenAI + +from llama_stack.distribution.library_client import LlamaStackAsLibraryClient + + +def decode_base64_to_floats(base64_string: str) -> list[float]: + """Helper function to decode base64 string to list of float32 values.""" + embedding_bytes = base64.b64decode(base64_string) + float_count = len(embedding_bytes) // 4 # 4 bytes per float32 + embedding_floats = struct.unpack(f"{float_count}f", embedding_bytes) + return list(embedding_floats) + + +def provider_from_model(client_with_models, model_id): + models = {m.identifier: m for m in client_with_models.models.list()} + models.update({m.provider_resource_id: m for m in client_with_models.models.list()}) + provider_id = models[model_id].provider_id + providers = {p.provider_id: p for p in client_with_models.providers.list()} + return providers[provider_id] + + +def skip_if_model_doesnt_support_variable_dimensions(model_id): + if "text-embedding-3" not in model_id: + pytest.skip("{model_id} does not support variable output embedding dimensions") + + +def skip_if_model_doesnt_support_openai_embeddings(client_with_models, model_id): + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI embeddings are not supported when testing with library client yet.") + + provider = provider_from_model(client_with_models, model_id) + if provider.provider_type in ( + "inline::meta-reference", + "remote::bedrock", + "remote::cerebras", + "remote::databricks", + "remote::runpod", + "remote::sambanova", + "remote::tgi", + "remote::ollama", + ): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.") + + +@pytest.fixture +def openai_client(client_with_models): + base_url = f"{client_with_models.base_url}/v1/openai/v1" + return OpenAI(base_url=base_url, api_key="fake") + + +def test_openai_embeddings_single_string(openai_client, client_with_models, embedding_model_id): + """Test OpenAI embeddings endpoint with a single string input.""" + skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) + + input_text = "Hello, world!" + + response = openai_client.embeddings.create( + model=embedding_model_id, + input=input_text, + encoding_format="float", + ) + + assert response.object == "list" + assert response.model == embedding_model_id + assert len(response.data) == 1 + assert response.data[0].object == "embedding" + assert response.data[0].index == 0 + assert isinstance(response.data[0].embedding, list) + assert len(response.data[0].embedding) > 0 + assert all(isinstance(x, float) for x in response.data[0].embedding) + + +def test_openai_embeddings_multiple_strings(openai_client, client_with_models, embedding_model_id): + """Test OpenAI embeddings endpoint with multiple string inputs.""" + skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) + + input_texts = ["Hello, world!", "How are you today?", "This is a test."] + + response = openai_client.embeddings.create( + model=embedding_model_id, + input=input_texts, + ) + + assert response.object == "list" + assert response.model == embedding_model_id + assert len(response.data) == len(input_texts) + + for i, embedding_data in enumerate(response.data): + assert embedding_data.object == "embedding" + assert embedding_data.index == i + assert isinstance(embedding_data.embedding, list) + assert len(embedding_data.embedding) > 0 + assert all(isinstance(x, float) for x in embedding_data.embedding) + + +def test_openai_embeddings_with_encoding_format_float(openai_client, client_with_models, embedding_model_id): + """Test OpenAI embeddings endpoint with float encoding format.""" + skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) + + input_text = "Test encoding format" + + response = openai_client.embeddings.create( + model=embedding_model_id, + input=input_text, + encoding_format="float", + ) + + assert response.object == "list" + assert len(response.data) == 1 + assert isinstance(response.data[0].embedding, list) + assert all(isinstance(x, float) for x in response.data[0].embedding) + + +def test_openai_embeddings_with_dimensions(openai_client, client_with_models, embedding_model_id): + """Test OpenAI embeddings endpoint with custom dimensions parameter.""" + skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) + skip_if_model_doesnt_support_variable_dimensions(embedding_model_id) + + input_text = "Test dimensions parameter" + dimensions = 16 + + response = openai_client.embeddings.create( + model=embedding_model_id, + input=input_text, + dimensions=dimensions, + ) + + assert response.object == "list" + assert len(response.data) == 1 + # Note: Not all models support custom dimensions, so we don't assert the exact dimension + assert isinstance(response.data[0].embedding, list) + assert len(response.data[0].embedding) > 0 + + +def test_openai_embeddings_with_user_parameter(openai_client, client_with_models, embedding_model_id): + """Test OpenAI embeddings endpoint with user parameter.""" + skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) + + input_text = "Test user parameter" + user_id = "test-user-123" + + response = openai_client.embeddings.create( + model=embedding_model_id, + input=input_text, + user=user_id, + ) + + assert response.object == "list" + assert len(response.data) == 1 + assert isinstance(response.data[0].embedding, list) + assert len(response.data[0].embedding) > 0 + + +def test_openai_embeddings_empty_list_error(openai_client, client_with_models, embedding_model_id): + """Test that empty list input raises an appropriate error.""" + skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) + + with pytest.raises(Exception): # noqa: B017 + openai_client.embeddings.create( + model=embedding_model_id, + input=[], + ) + + +def test_openai_embeddings_invalid_model_error(openai_client, client_with_models, embedding_model_id): + """Test that invalid model ID raises an appropriate error.""" + skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) + + with pytest.raises(Exception): # noqa: B017 + openai_client.embeddings.create( + model="invalid-model-id", + input="Test text", + ) + + +def test_openai_embeddings_different_inputs_different_outputs(openai_client, client_with_models, embedding_model_id): + """Test that different inputs produce different embeddings.""" + skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) + + input_text1 = "This is the first text" + input_text2 = "This is completely different content" + + response1 = openai_client.embeddings.create( + model=embedding_model_id, + input=input_text1, + ) + + response2 = openai_client.embeddings.create( + model=embedding_model_id, + input=input_text2, + ) + + embedding1 = response1.data[0].embedding + embedding2 = response2.data[0].embedding + + assert len(embedding1) == len(embedding2) + # Embeddings should be different for different inputs + assert embedding1 != embedding2 + + +def test_openai_embeddings_with_encoding_format_base64(openai_client, client_with_models, embedding_model_id): + """Test OpenAI embeddings endpoint with base64 encoding format.""" + skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) + skip_if_model_doesnt_support_variable_dimensions(embedding_model_id) + + input_text = "Test base64 encoding format" + dimensions = 12 + + response = openai_client.embeddings.create( + model=embedding_model_id, + input=input_text, + encoding_format="base64", + dimensions=dimensions, + ) + + # Validate response structure + assert response.object == "list" + assert len(response.data) == 1 + + # With base64 encoding, embedding should be a string, not a list + embedding_data = response.data[0] + assert embedding_data.object == "embedding" + assert embedding_data.index == 0 + assert isinstance(embedding_data.embedding, str) + + # Verify it's valid base64 and decode to floats + embedding_floats = decode_base64_to_floats(embedding_data.embedding) + + # Verify we got valid floats + assert len(embedding_floats) == dimensions, f"Got embedding length {len(embedding_floats)}, expected {dimensions}" + assert all(isinstance(x, float) for x in embedding_floats) + + +def test_openai_embeddings_base64_batch_processing(openai_client, client_with_models, embedding_model_id): + """Test OpenAI embeddings endpoint with base64 encoding for batch processing.""" + skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) + + input_texts = ["First text for base64", "Second text for base64", "Third text for base64"] + + response = openai_client.embeddings.create( + model=embedding_model_id, + input=input_texts, + encoding_format="base64", + ) + + # Validate response structure + assert response.object == "list" + assert response.model == embedding_model_id + assert len(response.data) == len(input_texts) + + # Validate each embedding in the batch + embedding_dimensions = [] + for i, embedding_data in enumerate(response.data): + assert embedding_data.object == "embedding" + assert embedding_data.index == i + + # With base64 encoding, embedding should be a string, not a list + assert isinstance(embedding_data.embedding, str) + embedding_floats = decode_base64_to_floats(embedding_data.embedding) + assert len(embedding_floats) > 0 + assert all(isinstance(x, float) for x in embedding_floats) + embedding_dimensions.append(len(embedding_floats)) + + # All embeddings should have the same dimensionality + assert all(dim == embedding_dimensions[0] for dim in embedding_dimensions)