chore: Add OpenAI compatibility for Ollama embeddings (#2440)

# What does this PR do?
This PR adds OpenAI compatibility for Ollama embeddings. Closes
https://github.com/meta-llama/llama-stack/issues/2428

Summary of changes:
- `llama_stack/providers/remote/inference/ollama/ollama.py`
- Implements the OpenAI embeddings endpoint for Ollama, replacing the
NotImplementedError with a full function that validates the model,
prepares parameters, calls the client, encodes embedding data
(optionally in base64), and returns a correctly structured response.
- Updates import statements to include the new embedding response
utilities.

- `llama_stack/providers/utils/inference/litellm_openai_mixin.py`
- Refactors the embedding data encoding logic to use a new shared
utility (`b64_encode_openai_embeddings_response`) instead of inline
base64 encoding and packing logic.
   - Cleans up imports accordingly.

- `llama_stack/providers/utils/inference/openai_compat.py`
- Adds `b64_encode_openai_embeddings_response` to handle encoding OpenAI
embedding outputs (including base64 support) in a reusable way.
- Adds `prepare_openai_embeddings_params` utility for standardizing
embedding parameter preparation.
   - Updates imports to include the new embedding data class.

- `tests/integration/inference/test_openai_embeddings.py`
- Removes `"remote::ollama"` from the list of providers that skip OpenAI
embeddings tests, since support is now implemented.

## Note

There was one minor issue, which required me to override the
`OpenAIEmbeddingsResponse.model` name with
`self._get_model(model).identifier` name, which is very unsatisfying.

## Test Plan
Unit Tests and integration tests

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Arceo 2025-06-13 12:28:51 -06:00 committed by GitHub
parent e2e15ebb6c
commit 554ada57b0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 90 additions and 16 deletions

View file

@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
JsonSchemaResponseFormat, JsonSchemaResponseFormat,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAIEmbeddingsResponse,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation, TextTruncation,
@ -46,6 +45,8 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
) )
@ -62,8 +63,10 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
b64_encode_openai_embeddings_response,
get_sampling_options, get_sampling_options,
prepare_openai_completion_params, prepare_openai_completion_params,
prepare_openai_embeddings_params,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response, process_completion_response,
@ -386,7 +389,35 @@ class OllamaInferenceAdapter(
dimensions: int | None = None, dimensions: int | None = None,
user: str | None = None, user: str | None = None,
) -> OpenAIEmbeddingsResponse: ) -> OpenAIEmbeddingsResponse:
raise NotImplementedError() model_obj = await self._get_model(model)
if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model {model} is not an embedding model")
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set")
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
params = prepare_openai_embeddings_params(
model=model_obj.provider_resource_id,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
response = await self.openai_client.embeddings.create(**params)
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
# TODO: Investigate why model_obj.identifier is used instead of response.model
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.identifier,
usage=usage,
)
async def openai_completion( async def openai_completion(
self, self,

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import struct
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
@ -37,7 +35,6 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
OpenAIMessageParam, OpenAIMessageParam,
@ -48,6 +45,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
b64_encode_openai_embeddings_response,
convert_message_to_openai_dict_new, convert_message_to_openai_dict_new,
convert_openai_chat_completion_choice, convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream, convert_openai_chat_completion_stream,
@ -293,16 +291,7 @@ class LiteLLMOpenAIMixin(
) )
# Convert response to OpenAI format # Convert response to OpenAI format
data = [] data = b64_encode_openai_embeddings_response(response.data, encoding_format)
for i, embedding_data in enumerate(response["data"]):
# we encode to base64 if the encoding format is base64 in the request
if encoding_format == "base64":
byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"])
embedding = base64.b64encode(byte_data).decode("utf-8")
else:
embedding = embedding_data["embedding"]
data.append(OpenAIEmbeddingData(embedding=embedding, index=i))
usage = OpenAIEmbeddingUsage( usage = OpenAIEmbeddingUsage(
prompt_tokens=response["usage"]["prompt_tokens"], prompt_tokens=response["usage"]["prompt_tokens"],

View file

@ -3,8 +3,10 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import json import json
import logging import logging
import struct
import time import time
import uuid import uuid
import warnings import warnings
@ -108,6 +110,7 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAICompletion, OpenAICompletion,
OpenAICompletionChoice, OpenAICompletionChoice,
OpenAIEmbeddingData,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
ToolConfig, ToolConfig,
@ -1483,3 +1486,55 @@ class OpenAIChatCompletionToLlamaStackMixin:
model=model, model=model,
object="chat.completion", object="chat.completion",
) )
def prepare_openai_embeddings_params(
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
):
if model is None:
raise ValueError("Model must be provided for embeddings")
input_list = [input] if isinstance(input, str) else input
params: dict[str, Any] = {
"model": model,
"input": input_list,
}
if encoding_format is not None:
params["encoding_format"] = encoding_format
if dimensions is not None:
params["dimensions"] = dimensions
if user is not None:
params["user"] = user
return params
def b64_encode_openai_embeddings_response(
response_data: dict, encoding_format: str | None = "float"
) -> list[OpenAIEmbeddingData]:
"""
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
"""
data = []
for i, embedding_data in enumerate(response_data):
if encoding_format == "base64":
byte_array = bytearray()
for embedding_value in embedding_data.embedding:
byte_array.extend(struct.pack("f", float(embedding_value)))
response_embedding = base64.b64encode(byte_array).decode("utf-8")
else:
response_embedding = embedding_data.embedding
data.append(
OpenAIEmbeddingData(
embedding=response_embedding,
index=i,
)
)
return data

View file

@ -51,7 +51,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
"remote::runpod", "remote::runpod",
"remote::sambanova", "remote::sambanova",
"remote::tgi", "remote::tgi",
"remote::ollama",
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")