forked from phoenix-oss/llama-stack-mirror
feat: New OpenAI compat embeddings API (#2314)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 4s
Integration Tests / test-matrix (http, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, scoring) (push) Failing after 10s
Test Llama Stack Build / generate-matrix (push) Successful in 6s
Integration Tests / test-matrix (library, providers) (push) Failing after 7s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 7s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 9s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 7s
Test Llama Stack Build / build (push) Failing after 5s
Unit Tests / unit-tests (3.10) (push) Failing after 7s
Update ReadTheDocs / update-readthedocs (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test External Providers / test-external-providers (venv) (push) Failing after 26s
Pre-commit / pre-commit (push) Successful in 1m11s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 4s
Integration Tests / test-matrix (http, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, datasets) (push) Failing after 10s
Integration Tests / test-matrix (http, post_training) (push) Failing after 9s
Integration Tests / test-matrix (library, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, tool_runtime) (push) Failing after 8s
Integration Tests / test-matrix (http, providers) (push) Failing after 9s
Integration Tests / test-matrix (library, datasets) (push) Failing after 8s
Integration Tests / test-matrix (library, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, scoring) (push) Failing after 10s
Test Llama Stack Build / generate-matrix (push) Successful in 6s
Integration Tests / test-matrix (library, providers) (push) Failing after 7s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 6s
Integration Tests / test-matrix (library, inspect) (push) Failing after 9s
Test Llama Stack Build / build-single-provider (push) Failing after 7s
Integration Tests / test-matrix (library, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, post_training) (push) Failing after 9s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 7s
Integration Tests / test-matrix (library, tool_runtime) (push) Failing after 10s
Unit Tests / unit-tests (3.11) (push) Failing after 7s
Test Llama Stack Build / build (push) Failing after 5s
Unit Tests / unit-tests (3.10) (push) Failing after 7s
Update ReadTheDocs / update-readthedocs (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 7s
Test External Providers / test-external-providers (venv) (push) Failing after 26s
Pre-commit / pre-commit (push) Successful in 1m11s
# What does this PR do? Adds a new endpoint that is compatible with OpenAI for embeddings api. `/openai/v1/embeddings` Added providers for OpenAI, LiteLLM and SentenceTransformer. ## Test Plan ``` LLAMA_STACK_CONFIG=http://localhost:8321 pytest -sv tests/integration/inference/test_openai_embeddings.py --embedding-model all-MiniLM-L6-v2,text-embedding-3-small,gemini/text-embedding-004 ```
This commit is contained in:
parent
277f8690ef
commit
b21050935e
21 changed files with 981 additions and 0 deletions
275
tests/integration/inference/test_openai_embeddings.py
Normal file
275
tests/integration/inference/test_openai_embeddings.py
Normal file
|
@ -0,0 +1,275 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import struct
|
||||
|
||||
import pytest
|
||||
from openai import OpenAI
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
|
||||
def decode_base64_to_floats(base64_string: str) -> list[float]:
|
||||
"""Helper function to decode base64 string to list of float32 values."""
|
||||
embedding_bytes = base64.b64decode(base64_string)
|
||||
float_count = len(embedding_bytes) // 4 # 4 bytes per float32
|
||||
embedding_floats = struct.unpack(f"{float_count}f", embedding_bytes)
|
||||
return list(embedding_floats)
|
||||
|
||||
|
||||
def provider_from_model(client_with_models, model_id):
|
||||
models = {m.identifier: m for m in client_with_models.models.list()}
|
||||
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
|
||||
provider_id = models[model_id].provider_id
|
||||
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||
return providers[provider_id]
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_variable_dimensions(model_id):
|
||||
if "text-embedding-3" not in model_id:
|
||||
pytest.skip("{model_id} does not support variable output embedding dimensions")
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_openai_embeddings(client_with_models, model_id):
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI embeddings are not supported when testing with library client yet.")
|
||||
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type in (
|
||||
"inline::meta-reference",
|
||||
"remote::bedrock",
|
||||
"remote::cerebras",
|
||||
"remote::databricks",
|
||||
"remote::runpod",
|
||||
"remote::sambanova",
|
||||
"remote::tgi",
|
||||
"remote::ollama",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI embeddings.")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client(client_with_models):
|
||||
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||
return OpenAI(base_url=base_url, api_key="fake")
|
||||
|
||||
|
||||
def test_openai_embeddings_single_string(openai_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with a single string input."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_text = "Hello, world!"
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
assert response.object == "list"
|
||||
assert response.model == embedding_model_id
|
||||
assert len(response.data) == 1
|
||||
assert response.data[0].object == "embedding"
|
||||
assert response.data[0].index == 0
|
||||
assert isinstance(response.data[0].embedding, list)
|
||||
assert len(response.data[0].embedding) > 0
|
||||
assert all(isinstance(x, float) for x in response.data[0].embedding)
|
||||
|
||||
|
||||
def test_openai_embeddings_multiple_strings(openai_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with multiple string inputs."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_texts = ["Hello, world!", "How are you today?", "This is a test."]
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_texts,
|
||||
)
|
||||
|
||||
assert response.object == "list"
|
||||
assert response.model == embedding_model_id
|
||||
assert len(response.data) == len(input_texts)
|
||||
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
assert embedding_data.object == "embedding"
|
||||
assert embedding_data.index == i
|
||||
assert isinstance(embedding_data.embedding, list)
|
||||
assert len(embedding_data.embedding) > 0
|
||||
assert all(isinstance(x, float) for x in embedding_data.embedding)
|
||||
|
||||
|
||||
def test_openai_embeddings_with_encoding_format_float(openai_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with float encoding format."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_text = "Test encoding format"
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
assert response.object == "list"
|
||||
assert len(response.data) == 1
|
||||
assert isinstance(response.data[0].embedding, list)
|
||||
assert all(isinstance(x, float) for x in response.data[0].embedding)
|
||||
|
||||
|
||||
def test_openai_embeddings_with_dimensions(openai_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with custom dimensions parameter."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
skip_if_model_doesnt_support_variable_dimensions(embedding_model_id)
|
||||
|
||||
input_text = "Test dimensions parameter"
|
||||
dimensions = 16
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
dimensions=dimensions,
|
||||
)
|
||||
|
||||
assert response.object == "list"
|
||||
assert len(response.data) == 1
|
||||
# Note: Not all models support custom dimensions, so we don't assert the exact dimension
|
||||
assert isinstance(response.data[0].embedding, list)
|
||||
assert len(response.data[0].embedding) > 0
|
||||
|
||||
|
||||
def test_openai_embeddings_with_user_parameter(openai_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with user parameter."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_text = "Test user parameter"
|
||||
user_id = "test-user-123"
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
user=user_id,
|
||||
)
|
||||
|
||||
assert response.object == "list"
|
||||
assert len(response.data) == 1
|
||||
assert isinstance(response.data[0].embedding, list)
|
||||
assert len(response.data[0].embedding) > 0
|
||||
|
||||
|
||||
def test_openai_embeddings_empty_list_error(openai_client, client_with_models, embedding_model_id):
|
||||
"""Test that empty list input raises an appropriate error."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
openai_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=[],
|
||||
)
|
||||
|
||||
|
||||
def test_openai_embeddings_invalid_model_error(openai_client, client_with_models, embedding_model_id):
|
||||
"""Test that invalid model ID raises an appropriate error."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
openai_client.embeddings.create(
|
||||
model="invalid-model-id",
|
||||
input="Test text",
|
||||
)
|
||||
|
||||
|
||||
def test_openai_embeddings_different_inputs_different_outputs(openai_client, client_with_models, embedding_model_id):
|
||||
"""Test that different inputs produce different embeddings."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_text1 = "This is the first text"
|
||||
input_text2 = "This is completely different content"
|
||||
|
||||
response1 = openai_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text1,
|
||||
)
|
||||
|
||||
response2 = openai_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text2,
|
||||
)
|
||||
|
||||
embedding1 = response1.data[0].embedding
|
||||
embedding2 = response2.data[0].embedding
|
||||
|
||||
assert len(embedding1) == len(embedding2)
|
||||
# Embeddings should be different for different inputs
|
||||
assert embedding1 != embedding2
|
||||
|
||||
|
||||
def test_openai_embeddings_with_encoding_format_base64(openai_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with base64 encoding format."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
skip_if_model_doesnt_support_variable_dimensions(embedding_model_id)
|
||||
|
||||
input_text = "Test base64 encoding format"
|
||||
dimensions = 12
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_text,
|
||||
encoding_format="base64",
|
||||
dimensions=dimensions,
|
||||
)
|
||||
|
||||
# Validate response structure
|
||||
assert response.object == "list"
|
||||
assert len(response.data) == 1
|
||||
|
||||
# With base64 encoding, embedding should be a string, not a list
|
||||
embedding_data = response.data[0]
|
||||
assert embedding_data.object == "embedding"
|
||||
assert embedding_data.index == 0
|
||||
assert isinstance(embedding_data.embedding, str)
|
||||
|
||||
# Verify it's valid base64 and decode to floats
|
||||
embedding_floats = decode_base64_to_floats(embedding_data.embedding)
|
||||
|
||||
# Verify we got valid floats
|
||||
assert len(embedding_floats) == dimensions, f"Got embedding length {len(embedding_floats)}, expected {dimensions}"
|
||||
assert all(isinstance(x, float) for x in embedding_floats)
|
||||
|
||||
|
||||
def test_openai_embeddings_base64_batch_processing(openai_client, client_with_models, embedding_model_id):
|
||||
"""Test OpenAI embeddings endpoint with base64 encoding for batch processing."""
|
||||
skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id)
|
||||
|
||||
input_texts = ["First text for base64", "Second text for base64", "Third text for base64"]
|
||||
|
||||
response = openai_client.embeddings.create(
|
||||
model=embedding_model_id,
|
||||
input=input_texts,
|
||||
encoding_format="base64",
|
||||
)
|
||||
|
||||
# Validate response structure
|
||||
assert response.object == "list"
|
||||
assert response.model == embedding_model_id
|
||||
assert len(response.data) == len(input_texts)
|
||||
|
||||
# Validate each embedding in the batch
|
||||
embedding_dimensions = []
|
||||
for i, embedding_data in enumerate(response.data):
|
||||
assert embedding_data.object == "embedding"
|
||||
assert embedding_data.index == i
|
||||
|
||||
# With base64 encoding, embedding should be a string, not a list
|
||||
assert isinstance(embedding_data.embedding, str)
|
||||
embedding_floats = decode_base64_to_floats(embedding_data.embedding)
|
||||
assert len(embedding_floats) > 0
|
||||
assert all(isinstance(x, float) for x in embedding_floats)
|
||||
embedding_dimensions.append(len(embedding_floats))
|
||||
|
||||
# All embeddings should have the same dimensionality
|
||||
assert all(dim == embedding_dimensions[0] for dim in embedding_dimensions)
|
Loading…
Add table
Add a link
Reference in a new issue