feat: Add /v1/embeddings endpoint to batches API

This PR extends the Llama Stack Batches API to support the /v1/embeddings endpoint, enabling efficient batch processing of embedding requests alongside the existing /v1/chat/completions and /v1/completions support.

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-09-08 16:55:17 -07:00
parent aab22dc759
commit 531b1451dc
3 changed files with 122 additions and 6 deletions

View file

@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection # TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions", "/v1/completions"]: if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
raise ValueError( raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint", f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
) )
if completion_window != "24h": if completion_window != "24h":
@ -425,18 +425,23 @@ class ReferenceBatchesImpl(Batches):
valid = False valid = False
if batch.endpoint == "/v1/chat/completions": if batch.endpoint == "/v1/chat/completions":
required_params = [ required_params: list[tuple[str, Any, str]] = [
("model", str, "a string"), ("model", str, "a string"),
# messages is specific to /v1/chat/completions # messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however, # we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong. # that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string? ("messages", list, "an array"), # TODO: allow messages to be a string?
] ]
else: # /v1/completions elif batch.endpoint == "/v1/completions":
required_params = [ required_params = [
("model", str, "a string"), ("model", str, "a string"),
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings?? ("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
] ]
else: # /v1/embeddings
required_params = [
("model", str, "a string"),
("input", (str, list), "a string or array of strings"),
]
for param, expected_type, type_string in required_params: for param, expected_type, type_string in required_params:
if param not in body: if param not in body:
@ -614,7 +619,7 @@ class ReferenceBatchesImpl(Batches):
"body": chat_response.model_dump_json(), "body": chat_response.model_dump_json(),
}, },
} }
else: # /v1/completions elif request.url == "/v1/completions":
completion_response = await self.inference_api.openai_completion(**request.body) completion_response = await self.inference_api.openai_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type # this is for mypy, we don't allow streaming so we'll get the right type
@ -630,6 +635,20 @@ class ReferenceBatchesImpl(Batches):
"body": completion_response.model_dump_json(), "body": completion_response.model_dump_json(),
}, },
} }
else: # /v1/embeddings
embeddings_response = await self.inference_api.openai_embeddings(**request.body)
assert hasattr(embeddings_response, "model_dump_json"), (
"Embeddings response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": embeddings_response.model_dump_json(),
},
}
except Exception as e: except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return { return {

View file

@ -323,3 +323,92 @@ class TestBatchesIntegration:
if final_batch.error_file_id is not None: if final_batch.error_file_id is not None:
deleted_error_file = openai_client.files.delete(final_batch.error_file_id) deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
assert deleted_error_file.deleted assert deleted_error_file.deleted
def test_batch_e2e_embeddings(self, openai_client, batch_helper, embedding_model_id):
"""Run an end-to-end batch with embeddings requests including both string and list inputs."""
batch_requests = [
{
"custom_id": "success-1",
"method": "POST",
"url": "/v1/embeddings",
"body": {"model": embedding_model_id, "input": "Hello world", "encoding_format": "float"},
},
{
"custom_id": "success-2",
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": embedding_model_id,
"input": ["How are you?", "Good morning", "Have a great day"],
"encoding_format": "float",
},
},
]
with batch_helper.create_file(batch_requests) as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/embeddings",
completion_window="24h",
metadata={"test": "e2e_embeddings_success"},
)
final_batch = batch_helper.wait_for(
batch.id,
max_wait_time=3 * 60,
expected_statuses={"completed"},
timeout_action="skip",
)
assert final_batch.status == "completed"
assert final_batch.request_counts is not None
assert final_batch.request_counts.total == 2
assert final_batch.request_counts.completed == 2
assert final_batch.output_file_id is not None
output_content = openai_client.files.content(final_batch.output_file_id)
if isinstance(output_content, str):
output_text = output_content
else:
output_text = output_content.content.decode("utf-8")
output_lines = output_text.strip().split("\n")
assert len(output_lines) == 2
# Check first result (string input)
result1 = json.loads(output_lines[0])
assert result1["custom_id"] in ["success-1", "success-2"]
assert "response" in result1
assert result1["response"]["status_code"] == 200
# Verify the response body contains embeddings data
response_body1 = json.loads(result1["response"]["body"])
assert response_body1["object"] == "list"
assert "data" in response_body1
assert len(response_body1["data"]) == 1
assert "embedding" in response_body1["data"][0]
assert "index" in response_body1["data"][0]
assert response_body1["data"][0]["index"] == 0
# Check second result (list input)
result2 = json.loads(output_lines[1])
assert result2["custom_id"] in ["success-1", "success-2"]
assert "response" in result2
assert result2["response"]["status_code"] == 200
# Verify the response body contains embeddings data for list input
response_body2 = json.loads(result2["response"]["body"])
assert response_body2["object"] == "list"
assert "data" in response_body2
assert len(response_body2["data"]) == 3 # Three strings in the list
for i, embedding_data in enumerate(response_body2["data"]):
assert "embedding" in embedding_data
assert "index" in embedding_data
assert embedding_data["index"] == i
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
assert deleted_output_file.deleted
if final_batch.error_file_id is not None:
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
assert deleted_error_file.deleted

View file

@ -213,7 +213,6 @@ class TestReferenceBatchesImpl:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"endpoint", "endpoint",
[ [
"/v1/embeddings",
"/v1/invalid/endpoint", "/v1/invalid/endpoint",
"", "",
], ],
@ -765,3 +764,12 @@ class TestReferenceBatchesImpl:
await asyncio.sleep(0.042) # let tasks start await asyncio.sleep(0.042) # let tasks start
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}" assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
async def test_create_batch_embeddings_endpoint(self, provider):
"""Test that batch creation succeeds with embeddings endpoint."""
batch = await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/embeddings",
completion_window="24h",
)
assert batch.endpoint == "/v1/embeddings"