refactoring some code into openai_compat

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-06-13 12:06:58 -04:00
parent 89d4a05303
commit e0f1788e9e
4 changed files with 95 additions and 40 deletions

View file

@ -45,7 +45,6 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
OpenAIMessageParam, OpenAIMessageParam,
@ -64,12 +63,15 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
b64_encode_openai_embeddings_response,
get_sampling_options, get_sampling_options,
prepare_openai_completion_params, prepare_openai_completion_params,
prepare_openai_embeddings_params,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response, process_completion_response,
process_completion_stream_response, process_completion_stream_response,
process_embedding_b64_encoded_input,
) )
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
@ -392,28 +394,22 @@ class OllamaInferenceAdapter(
if model_obj.model_type != ModelType.embedding: if model_obj.model_type != ModelType.embedding:
raise ValueError(f"Model {model} is not an embedding model") raise ValueError(f"Model {model} is not an embedding model")
params: dict[str, Any] = { if model_obj.provider_resource_id is None:
"model": model_obj.provider_resource_id, raise ValueError(f"Model {model} has no provider_resource_id set")
"input": input,
}
params = prepare_openai_embeddings_params(
model=model_obj.provider_resource_id,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters # Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
if encoding_format is not None: # but we implement the encoding here
params["encoding_format"] = encoding_format params = process_embedding_b64_encoded_input(params)
if dimensions is not None:
params["dimensions"] = dimensions
if user is not None:
params["user"] = user
response = await self.openai_client.embeddings.create(**params) response = await self.openai_client.embeddings.create(**params)
data = [] data = b64_encode_openai_embeddings_response(response.data, encoding_format)
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage( usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens, prompt_tokens=response.usage.prompt_tokens,

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import struct
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
@ -37,7 +35,6 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAICompletion, OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse, OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage, OpenAIEmbeddingUsage,
OpenAIMessageParam, OpenAIMessageParam,
@ -48,6 +45,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
b64_encode_openai_embeddings_response,
convert_message_to_openai_dict_new, convert_message_to_openai_dict_new,
convert_openai_chat_completion_choice, convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream, convert_openai_chat_completion_stream,
@ -293,16 +291,7 @@ class LiteLLMOpenAIMixin(
) )
# Convert response to OpenAI format # Convert response to OpenAI format
data = [] data = b64_encode_openai_embeddings_response(response.data, encoding_format)
for i, embedding_data in enumerate(response["data"]):
# we encode to base64 if the encoding format is base64 in the request
if encoding_format == "base64":
byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"])
embedding = base64.b64encode(byte_data).decode("utf-8")
else:
embedding = embedding_data["embedding"]
data.append(OpenAIEmbeddingData(embedding=embedding, index=i))
usage = OpenAIEmbeddingUsage( usage = OpenAIEmbeddingUsage(
prompt_tokens=response["usage"]["prompt_tokens"], prompt_tokens=response["usage"]["prompt_tokens"],

View file

@ -3,8 +3,10 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import base64
import json import json
import logging import logging
import struct
import time import time
import uuid import uuid
import warnings import warnings
@ -108,6 +110,7 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAICompletion, OpenAICompletion,
OpenAICompletionChoice, OpenAICompletionChoice,
OpenAIEmbeddingData,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
ToolConfig, ToolConfig,
@ -1483,3 +1486,73 @@ class OpenAIChatCompletionToLlamaStackMixin:
model=model, model=model,
object="chat.completion", object="chat.completion",
) )
def prepare_openai_embeddings_params(
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
):
if model is None:
raise ValueError("Model must be provided for embeddings")
input_list = [input] if isinstance(input, str) else input
params: dict[str, Any] = {
"model": model,
"input": input_list,
}
if encoding_format is not None:
params["encoding_format"] = encoding_format
if dimensions is not None:
params["dimensions"] = dimensions
if user is not None:
params["user"] = user
return params
def process_embedding_b64_encoded_input(params: dict[str, Any]) -> dict[str, Any]:
"""
Process the embeddings parameters to encode the input in base64 format if specified.
Currently implemented for ollama as base64 is not yet supported by their compatible API.
"""
if params.get("encoding_format") == "base64":
processed_params = params.copy()
input = params.get("input")
if isinstance(input, str):
processed_params["input"] = base64.b64encode(input.encode()).decode()
elif isinstance(input, list):
processed_params["input"] = [base64.b64encode(i.encode()).decode() for i in input]
else:
return params
return processed_params
def b64_encode_openai_embeddings_response(
response_data: dict, encoding_format: str | None = "float"
) -> list[OpenAIEmbeddingData]:
"""
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
"""
data = []
for i, embedding_data in enumerate(response_data):
if encoding_format == "base64":
byte_array = bytearray()
for embedding_value in embedding_data.embedding:
byte_array.extend(struct.pack("f", float(embedding_value)))
response_embedding = base64.b64encode(byte_array).decode("utf-8")
else:
response_embedding = embedding_data.embedding
data.append(
OpenAIEmbeddingData(
embedding=response_embedding,
index=i,
)
)
return data

View file

@ -34,7 +34,11 @@ def skip_if_model_doesnt_support_variable_dimensions(model_id):
pytest.skip("{model_id} does not support variable output embedding dimensions") pytest.skip("{model_id} does not support variable output embedding dimensions")
@pytest.fixture(params=["openai_client", "llama_stack_client"]) @pytest.fixture(
params=[
"openai_client",
]
)
def compat_client(request, client_with_models): def compat_client(request, client_with_models):
if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient): if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI client tests not supported with library client") pytest.skip("OpenAI client tests not supported with library client")
@ -55,12 +59,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
def skip_if_client_doesnt_support_base64_encoding(client, model_id):
provider = provider_from_model(client, model_id)
if provider.provider_type in ("remote::ollama",):
pytest.skip(f"Client {client} doesn't support base64 encoding for embeddings.")
@pytest.fixture @pytest.fixture
def openai_client(client_with_models): def openai_client(client_with_models):
base_url = f"{client_with_models.base_url}/v1/openai/v1" base_url = f"{client_with_models.base_url}/v1/openai/v1"
@ -253,7 +251,6 @@ def test_openai_embeddings_with_encoding_format_base64(compat_client, client_wit
def test_openai_embeddings_base64_batch_processing(compat_client, client_with_models, embedding_model_id): def test_openai_embeddings_base64_batch_processing(compat_client, client_with_models, embedding_model_id):
"""Test OpenAI embeddings endpoint with base64 encoding for batch processing.""" """Test OpenAI embeddings endpoint with base64 encoding for batch processing."""
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
skip_if_client_doesnt_support_base64_encoding(client_with_models, embedding_model_id)
input_texts = ["First text for base64", "Second text for base64", "Third text for base64"] input_texts = ["First text for base64", "Second text for base64", "Third text for base64"]