mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: Add OpenAI compatibility for Ollama embeddings (#2440)
# What does this PR do? This PR adds OpenAI compatibility for Ollama embeddings. Closes https://github.com/meta-llama/llama-stack/issues/2428 Summary of changes: - `llama_stack/providers/remote/inference/ollama/ollama.py` - Implements the OpenAI embeddings endpoint for Ollama, replacing the NotImplementedError with a full function that validates the model, prepares parameters, calls the client, encodes embedding data (optionally in base64), and returns a correctly structured response. - Updates import statements to include the new embedding response utilities. - `llama_stack/providers/utils/inference/litellm_openai_mixin.py` - Refactors the embedding data encoding logic to use a new shared utility (`b64_encode_openai_embeddings_response`) instead of inline base64 encoding and packing logic. - Cleans up imports accordingly. - `llama_stack/providers/utils/inference/openai_compat.py` - Adds `b64_encode_openai_embeddings_response` to handle encoding OpenAI embedding outputs (including base64 support) in a reusable way. - Adds `prepare_openai_embeddings_params` utility for standardizing embedding parameter preparation. - Updates imports to include the new embedding data class. - `tests/integration/inference/test_openai_embeddings.py` - Removes `"remote::ollama"` from the list of providers that skip OpenAI embeddings tests, since support is now implemented. ## Note There was one minor issue, which required me to override the `OpenAIEmbeddingsResponse.model` name with `self._get_model(model).identifier` name, which is very unsatisfying. ## Test Plan Unit Tests and integration tests --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
e2e15ebb6c
commit
554ada57b0
4 changed files with 90 additions and 16 deletions
|
@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -46,6 +45,8 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
|
@ -62,8 +63,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
OpenAICompatCompletionResponse,
|
OpenAICompatCompletionResponse,
|
||||||
|
b64_encode_openai_embeddings_response,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
prepare_openai_completion_params,
|
prepare_openai_completion_params,
|
||||||
|
prepare_openai_embeddings_params,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
|
@ -386,7 +389,35 @@ class OllamaInferenceAdapter(
|
||||||
dimensions: int | None = None,
|
dimensions: int | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
) -> OpenAIEmbeddingsResponse:
|
) -> OpenAIEmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model_obj = await self._get_model(model)
|
||||||
|
if model_obj.model_type != ModelType.embedding:
|
||||||
|
raise ValueError(f"Model {model} is not an embedding model")
|
||||||
|
|
||||||
|
if model_obj.provider_resource_id is None:
|
||||||
|
raise ValueError(f"Model {model} has no provider_resource_id set")
|
||||||
|
|
||||||
|
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
|
||||||
|
params = prepare_openai_embeddings_params(
|
||||||
|
model=model_obj.provider_resource_id,
|
||||||
|
input=input,
|
||||||
|
encoding_format=encoding_format,
|
||||||
|
dimensions=dimensions,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await self.openai_client.embeddings.create(**params)
|
||||||
|
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
|
||||||
|
|
||||||
|
usage = OpenAIEmbeddingUsage(
|
||||||
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
|
total_tokens=response.usage.total_tokens,
|
||||||
|
)
|
||||||
|
# TODO: Investigate why model_obj.identifier is used instead of response.model
|
||||||
|
return OpenAIEmbeddingsResponse(
|
||||||
|
data=data,
|
||||||
|
model=model_obj.identifier,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -4,8 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import base64
|
|
||||||
import struct
|
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -37,7 +35,6 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAIEmbeddingData,
|
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
@ -48,6 +45,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
b64_encode_openai_embeddings_response,
|
||||||
convert_message_to_openai_dict_new,
|
convert_message_to_openai_dict_new,
|
||||||
convert_openai_chat_completion_choice,
|
convert_openai_chat_completion_choice,
|
||||||
convert_openai_chat_completion_stream,
|
convert_openai_chat_completion_stream,
|
||||||
|
@ -293,16 +291,7 @@ class LiteLLMOpenAIMixin(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert response to OpenAI format
|
# Convert response to OpenAI format
|
||||||
data = []
|
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
|
||||||
for i, embedding_data in enumerate(response["data"]):
|
|
||||||
# we encode to base64 if the encoding format is base64 in the request
|
|
||||||
if encoding_format == "base64":
|
|
||||||
byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"])
|
|
||||||
embedding = base64.b64encode(byte_data).decode("utf-8")
|
|
||||||
else:
|
|
||||||
embedding = embedding_data["embedding"]
|
|
||||||
|
|
||||||
data.append(OpenAIEmbeddingData(embedding=embedding, index=i))
|
|
||||||
|
|
||||||
usage = OpenAIEmbeddingUsage(
|
usage = OpenAIEmbeddingUsage(
|
||||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||||
|
|
|
@ -3,8 +3,10 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import struct
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -108,6 +110,7 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAICompletionChoice,
|
OpenAICompletionChoice,
|
||||||
|
OpenAIEmbeddingData,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
@ -1483,3 +1486,55 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
model=model,
|
model=model,
|
||||||
object="chat.completion",
|
object="chat.completion",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_openai_embeddings_params(
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
):
|
||||||
|
if model is None:
|
||||||
|
raise ValueError("Model must be provided for embeddings")
|
||||||
|
|
||||||
|
input_list = [input] if isinstance(input, str) else input
|
||||||
|
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"input": input_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding_format is not None:
|
||||||
|
params["encoding_format"] = encoding_format
|
||||||
|
if dimensions is not None:
|
||||||
|
params["dimensions"] = dimensions
|
||||||
|
if user is not None:
|
||||||
|
params["user"] = user
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def b64_encode_openai_embeddings_response(
|
||||||
|
response_data: dict, encoding_format: str | None = "float"
|
||||||
|
) -> list[OpenAIEmbeddingData]:
|
||||||
|
"""
|
||||||
|
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
for i, embedding_data in enumerate(response_data):
|
||||||
|
if encoding_format == "base64":
|
||||||
|
byte_array = bytearray()
|
||||||
|
for embedding_value in embedding_data.embedding:
|
||||||
|
byte_array.extend(struct.pack("f", float(embedding_value)))
|
||||||
|
|
||||||
|
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
||||||
|
else:
|
||||||
|
response_embedding = embedding_data.embedding
|
||||||
|
data.append(
|
||||||
|
OpenAIEmbeddingData(
|
||||||
|
embedding=response_embedding,
|
||||||
|
index=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
|
@ -51,7 +51,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
|
||||||
"remote::runpod",
|
"remote::runpod",
|
||||||
"remote::sambanova",
|
"remote::sambanova",
|
||||||
"remote::tgi",
|
"remote::tgi",
|
||||||
"remote::ollama",
|
|
||||||
):
|
):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue