mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
refactoring some code into openai_compat
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
89d4a05303
commit
e0f1788e9e
4 changed files with 95 additions and 40 deletions
|
@ -45,7 +45,6 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAIEmbeddingData,
|
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
@ -64,12 +63,15 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
OpenAICompatCompletionChoice,
|
OpenAICompatCompletionChoice,
|
||||||
OpenAICompatCompletionResponse,
|
OpenAICompatCompletionResponse,
|
||||||
|
b64_encode_openai_embeddings_response,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
prepare_openai_completion_params,
|
prepare_openai_completion_params,
|
||||||
|
prepare_openai_embeddings_params,
|
||||||
process_chat_completion_response,
|
process_chat_completion_response,
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
process_completion_stream_response,
|
process_completion_stream_response,
|
||||||
|
process_embedding_b64_encoded_input,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
|
@ -392,28 +394,22 @@ class OllamaInferenceAdapter(
|
||||||
if model_obj.model_type != ModelType.embedding:
|
if model_obj.model_type != ModelType.embedding:
|
||||||
raise ValueError(f"Model {model} is not an embedding model")
|
raise ValueError(f"Model {model} is not an embedding model")
|
||||||
|
|
||||||
params: dict[str, Any] = {
|
if model_obj.provider_resource_id is None:
|
||||||
"model": model_obj.provider_resource_id,
|
raise ValueError(f"Model {model} has no provider_resource_id set")
|
||||||
"input": input,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
params = prepare_openai_embeddings_params(
|
||||||
|
model=model_obj.provider_resource_id,
|
||||||
|
input=input,
|
||||||
|
encoding_format=encoding_format,
|
||||||
|
dimensions=dimensions,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
|
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
|
||||||
if encoding_format is not None:
|
# but we implement the encoding here
|
||||||
params["encoding_format"] = encoding_format
|
params = process_embedding_b64_encoded_input(params)
|
||||||
if dimensions is not None:
|
|
||||||
params["dimensions"] = dimensions
|
|
||||||
if user is not None:
|
|
||||||
params["user"] = user
|
|
||||||
|
|
||||||
response = await self.openai_client.embeddings.create(**params)
|
response = await self.openai_client.embeddings.create(**params)
|
||||||
data = []
|
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
|
||||||
for i, embedding_data in enumerate(response.data):
|
|
||||||
data.append(
|
|
||||||
OpenAIEmbeddingData(
|
|
||||||
embedding=embedding_data.embedding,
|
|
||||||
index=i,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
usage = OpenAIEmbeddingUsage(
|
usage = OpenAIEmbeddingUsage(
|
||||||
prompt_tokens=response.usage.prompt_tokens,
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
|
|
|
@ -4,8 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import base64
|
|
||||||
import struct
|
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncGenerator, AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
@ -37,7 +35,6 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAIEmbeddingData,
|
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
|
@ -48,6 +45,7 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
b64_encode_openai_embeddings_response,
|
||||||
convert_message_to_openai_dict_new,
|
convert_message_to_openai_dict_new,
|
||||||
convert_openai_chat_completion_choice,
|
convert_openai_chat_completion_choice,
|
||||||
convert_openai_chat_completion_stream,
|
convert_openai_chat_completion_stream,
|
||||||
|
@ -293,16 +291,7 @@ class LiteLLMOpenAIMixin(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert response to OpenAI format
|
# Convert response to OpenAI format
|
||||||
data = []
|
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
|
||||||
for i, embedding_data in enumerate(response["data"]):
|
|
||||||
# we encode to base64 if the encoding format is base64 in the request
|
|
||||||
if encoding_format == "base64":
|
|
||||||
byte_data = b"".join(struct.pack("f", f) for f in embedding_data["embedding"])
|
|
||||||
embedding = base64.b64encode(byte_data).decode("utf-8")
|
|
||||||
else:
|
|
||||||
embedding = embedding_data["embedding"]
|
|
||||||
|
|
||||||
data.append(OpenAIEmbeddingData(embedding=embedding, index=i))
|
|
||||||
|
|
||||||
usage = OpenAIEmbeddingUsage(
|
usage = OpenAIEmbeddingUsage(
|
||||||
prompt_tokens=response["usage"]["prompt_tokens"],
|
prompt_tokens=response["usage"]["prompt_tokens"],
|
||||||
|
|
|
@ -3,8 +3,10 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import struct
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -108,6 +110,7 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAICompletionChoice,
|
OpenAICompletionChoice,
|
||||||
|
OpenAIEmbeddingData,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
@ -1483,3 +1486,73 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
model=model,
|
model=model,
|
||||||
object="chat.completion",
|
object="chat.completion",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_openai_embeddings_params(
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
):
|
||||||
|
if model is None:
|
||||||
|
raise ValueError("Model must be provided for embeddings")
|
||||||
|
|
||||||
|
input_list = [input] if isinstance(input, str) else input
|
||||||
|
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"input": input_list,
|
||||||
|
}
|
||||||
|
|
||||||
|
if encoding_format is not None:
|
||||||
|
params["encoding_format"] = encoding_format
|
||||||
|
if dimensions is not None:
|
||||||
|
params["dimensions"] = dimensions
|
||||||
|
if user is not None:
|
||||||
|
params["user"] = user
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def process_embedding_b64_encoded_input(params: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Process the embeddings parameters to encode the input in base64 format if specified.
|
||||||
|
Currently implemented for ollama as base64 is not yet supported by their compatible API.
|
||||||
|
"""
|
||||||
|
if params.get("encoding_format") == "base64":
|
||||||
|
processed_params = params.copy()
|
||||||
|
input = params.get("input")
|
||||||
|
if isinstance(input, str):
|
||||||
|
processed_params["input"] = base64.b64encode(input.encode()).decode()
|
||||||
|
elif isinstance(input, list):
|
||||||
|
processed_params["input"] = [base64.b64encode(i.encode()).decode() for i in input]
|
||||||
|
else:
|
||||||
|
return params
|
||||||
|
|
||||||
|
return processed_params
|
||||||
|
|
||||||
|
|
||||||
|
def b64_encode_openai_embeddings_response(
|
||||||
|
response_data: dict, encoding_format: str | None = "float"
|
||||||
|
) -> list[OpenAIEmbeddingData]:
|
||||||
|
"""
|
||||||
|
Process the OpenAI embeddings response to encode the embeddings in base64 format if specified.
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
for i, embedding_data in enumerate(response_data):
|
||||||
|
if encoding_format == "base64":
|
||||||
|
byte_array = bytearray()
|
||||||
|
for embedding_value in embedding_data.embedding:
|
||||||
|
byte_array.extend(struct.pack("f", float(embedding_value)))
|
||||||
|
|
||||||
|
response_embedding = base64.b64encode(byte_array).decode("utf-8")
|
||||||
|
else:
|
||||||
|
response_embedding = embedding_data.embedding
|
||||||
|
data.append(
|
||||||
|
OpenAIEmbeddingData(
|
||||||
|
embedding=response_embedding,
|
||||||
|
index=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
|
@ -34,7 +34,11 @@ def skip_if_model_doesnt_support_variable_dimensions(model_id):
|
||||||
pytest.skip("{model_id} does not support variable output embedding dimensions")
|
pytest.skip("{model_id} does not support variable output embedding dimensions")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=["openai_client", "llama_stack_client"])
|
@pytest.fixture(
|
||||||
|
params=[
|
||||||
|
"openai_client",
|
||||||
|
]
|
||||||
|
)
|
||||||
def compat_client(request, client_with_models):
|
def compat_client(request, client_with_models):
|
||||||
if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient):
|
if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||||
pytest.skip("OpenAI client tests not supported with library client")
|
pytest.skip("OpenAI client tests not supported with library client")
|
||||||
|
@ -55,12 +59,6 @@ def skip_if_model_doesnt_support_openai_embeddings(client, model_id):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
||||||
|
|
||||||
|
|
||||||
def skip_if_client_doesnt_support_base64_encoding(client, model_id):
|
|
||||||
provider = provider_from_model(client, model_id)
|
|
||||||
if provider.provider_type in ("remote::ollama",):
|
|
||||||
pytest.skip(f"Client {client} doesn't support base64 encoding for embeddings.")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def openai_client(client_with_models):
|
def openai_client(client_with_models):
|
||||||
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||||
|
@ -253,7 +251,6 @@ def test_openai_embeddings_with_encoding_format_base64(compat_client, client_wit
|
||||||
def test_openai_embeddings_base64_batch_processing(compat_client, client_with_models, embedding_model_id):
|
def test_openai_embeddings_base64_batch_processing(compat_client, client_with_models, embedding_model_id):
|
||||||
"""Test OpenAI embeddings endpoint with base64 encoding for batch processing."""
|
"""Test OpenAI embeddings endpoint with base64 encoding for batch processing."""
|
||||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||||
skip_if_client_doesnt_support_base64_encoding(client_with_models, embedding_model_id)
|
|
||||||
|
|
||||||
input_texts = ["First text for base64", "Second text for base64", "Third text for base64"]
|
input_texts = ["First text for base64", "Second text for base64", "Third text for base64"]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue