diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 3979fe074..d4bf692f8 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -45,7 +45,6 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, - OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, OpenAIMessageParam, @@ -64,12 +63,15 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + b64_encode_openai_embeddings_response, get_sampling_options, prepare_openai_completion_params, + prepare_openai_embeddings_params, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, + process_embedding_b64_encoded_input, ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_prompt, @@ -392,28 +394,22 @@ class OllamaInferenceAdapter( if model_obj.model_type != ModelType.embedding: raise ValueError(f"Model {model} is not an embedding model") - params: dict[str, Any] = { - "model": model_obj.provider_resource_id, - "input": input, - } + if model_obj.provider_resource_id is None: + raise ValueError(f"Model {model} has no provider_resource_id set") + params = prepare_openai_embeddings_params( + model=model_obj.provider_resource_id, + input=input, + encoding_format=encoding_format, + dimensions=dimensions, + user=user, + ) # Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters - if encoding_format is not None: - params["encoding_format"] = encoding_format - if dimensions is not None: - params["dimensions"] = dimensions - if user is not None: - params["user"] = user + # but we implement the encoding here + params = process_embedding_b64_encoded_input(params) response = await self.openai_client.embeddings.create(**params) - data = [] - for i, embedding_data in enumerate(response.data): - data.append( - OpenAIEmbeddingData( - embedding=embedding_data.embedding, - index=i, - ) - ) + data = b64_encode_openai_embeddings_response(response.data, encoding_format) usage = OpenAIEmbeddingUsage( prompt_tokens=response.usage.prompt_tokens, diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index dab10bc55..13381f3c9 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import base64 -import struct from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -37,7 +35,6 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAIChatCompletionChunk, OpenAICompletion, - OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, OpenAIMessageParam, @@ -48,6 +45,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( + b64_encode_openai_embeddings_response, convert_message_to_openai_dict_new, convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, @@ -293,16 +291,7 @@ class LiteLLMOpenAIMixin( ) # Convert response to OpenAI format - data = [] - for i, embedding_data in enumerate(response["data"]): - # we encode to base64 if the encoding format is base64 in the request - if encoding_format == "base64": - byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"]) - embedding = base64.b64encode(byte_data).decode("utf-8") - else: - embedding = embedding_data["embedding"] - - data.append(OpenAIEmbeddingData(embedding=embedding, index=i)) + data = b64_encode_openai_embeddings_response(response.data, encoding_format) usage = OpenAIEmbeddingUsage( prompt_tokens=response["usage"]["prompt_tokens"], diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 049f06fdb..54a8d4de6 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -3,8 +3,10 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 import json import logging +import struct import time import uuid import warnings @@ -108,6 +110,7 @@ from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, OpenAICompletion, OpenAICompletionChoice, + OpenAIEmbeddingData, OpenAIMessageParam, OpenAIResponseFormatParam, ToolConfig, @@ -1483,3 +1486,73 @@ class OpenAIChatCompletionToLlamaStackMixin: model=model, object="chat.completion", ) + + +def prepare_openai_embeddings_params( + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, +): + if model is None: + raise ValueError("Model must be provided for embeddings") + + input_list = [input] if isinstance(input, str) else input + + params: dict[str, Any] = { + "model": model, + "input": input_list, + } + + if encoding_format is not None: + params["encoding_format"] = encoding_format + if dimensions is not None: + params["dimensions"] = dimensions + if user is not None: + params["user"] = user + + return params + + +def process_embedding_b64_encoded_input(params: dict[str, Any]) -> dict[str, Any]: + """ + Process the embeddings parameters to encode the input in base64 format if specified. + Currently implemented for ollama as base64 is not yet supported by their compatible API. + """ + if params.get("encoding_format") == "base64": + processed_params = params.copy() + input = params.get("input") + if isinstance(input, str): + processed_params["input"] = base64.b64encode(input.encode()).decode() + elif isinstance(input, list): + processed_params["input"] = [base64.b64encode(i.encode()).decode() for i in input] + else: + return params + + return processed_params + + +def b64_encode_openai_embeddings_response( + response_data: dict, encoding_format: str | None = "float" +) -> list[OpenAIEmbeddingData]: + """ + Process the OpenAI embeddings response to encode the embeddings in base64 format if specified. + """ + data = [] + for i, embedding_data in enumerate(response_data): + if encoding_format == "base64": + byte_array = bytearray() + for embedding_value in embedding_data.embedding: + byte_array.extend(struct.pack("f", float(embedding_value))) + + response_embedding = base64.b64encode(byte_array).decode("utf-8") + else: + response_embedding = embedding_data.embedding + data.append( + OpenAIEmbeddingData( + embedding=response_embedding, + index=i, + ) + ) + return data diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index 40a0984da..6d5068a34 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -34,7 +34,11 @@ def skip_if_model_doesnt_support_variable_dimensions(model_id): pytest.skip("{model_id} does not support variable output embedding dimensions") -@pytest.fixture(params=["openai_client", "llama_stack_client"]) +@pytest.fixture( + params=[ + "openai_client", + ] +) def compat_client(request, client_with_models): if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient): pytest.skip("OpenAI client tests not supported with library client") @@ -55,12 +59,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.") -def skip_if_client_doesnt_support_base64_encoding(client, model_id): - provider = provider_from_model(client, model_id) - if provider.provider_type in ("remote::ollama",): - pytest.skip(f"Client {client} doesn't support base64 encoding for embeddings.") - - @pytest.fixture def openai_client(client_with_models): base_url = f"{client_with_models.base_url}/v1/openai/v1" @@ -253,7 +251,6 @@ def test_openai_embeddings_with_encoding_format_base64(compat_client, client_wit def test_openai_embeddings_base64_batch_processing(compat_client, client_with_models, embedding_model_id): """Test OpenAI embeddings endpoint with base64 encoding for batch processing.""" skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) - skip_if_client_doesnt_support_base64_encoding(client_with_models, embedding_model_id) input_texts = ["First text for base64", "Second text for base64", "Third text for base64"]