mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
feat: Add /v1/embeddings endpoint to batches API
This PR extends the Llama Stack Batches API to support the /v1/embeddings endpoint, enabling efficient batch processing of embedding requests alongside the existing /v1/chat/completions and /v1/completions support. Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
aab22dc759
commit
531b1451dc
3 changed files with 122 additions and 6 deletions
|
@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
|
||||||
|
|
||||||
# TODO: set expiration time for garbage collection
|
# TODO: set expiration time for garbage collection
|
||||||
|
|
||||||
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
|
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
|
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
|
||||||
)
|
)
|
||||||
|
|
||||||
if completion_window != "24h":
|
if completion_window != "24h":
|
||||||
|
@ -425,18 +425,23 @@ class ReferenceBatchesImpl(Batches):
|
||||||
valid = False
|
valid = False
|
||||||
|
|
||||||
if batch.endpoint == "/v1/chat/completions":
|
if batch.endpoint == "/v1/chat/completions":
|
||||||
required_params = [
|
required_params: list[tuple[str, Any, str]] = [
|
||||||
("model", str, "a string"),
|
("model", str, "a string"),
|
||||||
# messages is specific to /v1/chat/completions
|
# messages is specific to /v1/chat/completions
|
||||||
# we could skip validating messages here and let inference fail. however,
|
# we could skip validating messages here and let inference fail. however,
|
||||||
# that would be a very expensive way to find out messages is wrong.
|
# that would be a very expensive way to find out messages is wrong.
|
||||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||||
]
|
]
|
||||||
else: # /v1/completions
|
elif batch.endpoint == "/v1/completions":
|
||||||
required_params = [
|
required_params = [
|
||||||
("model", str, "a string"),
|
("model", str, "a string"),
|
||||||
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||||
]
|
]
|
||||||
|
else: # /v1/embeddings
|
||||||
|
required_params = [
|
||||||
|
("model", str, "a string"),
|
||||||
|
("input", (str, list), "a string or array of strings"),
|
||||||
|
]
|
||||||
|
|
||||||
for param, expected_type, type_string in required_params:
|
for param, expected_type, type_string in required_params:
|
||||||
if param not in body:
|
if param not in body:
|
||||||
|
@ -614,7 +619,7 @@ class ReferenceBatchesImpl(Batches):
|
||||||
"body": chat_response.model_dump_json(),
|
"body": chat_response.model_dump_json(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
else: # /v1/completions
|
elif request.url == "/v1/completions":
|
||||||
completion_response = await self.inference_api.openai_completion(**request.body)
|
completion_response = await self.inference_api.openai_completion(**request.body)
|
||||||
|
|
||||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||||
|
@ -630,6 +635,20 @@ class ReferenceBatchesImpl(Batches):
|
||||||
"body": completion_response.model_dump_json(),
|
"body": completion_response.model_dump_json(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
else: # /v1/embeddings
|
||||||
|
embeddings_response = await self.inference_api.openai_embeddings(**request.body)
|
||||||
|
assert hasattr(embeddings_response, "model_dump_json"), (
|
||||||
|
"Embeddings response must have model_dump_json method"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"id": request_id,
|
||||||
|
"custom_id": request.custom_id,
|
||||||
|
"response": {
|
||||||
|
"status_code": 200,
|
||||||
|
"request_id": request_id, # TODO: should this be different?
|
||||||
|
"body": embeddings_response.model_dump_json(),
|
||||||
|
},
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -323,3 +323,92 @@ class TestBatchesIntegration:
|
||||||
if final_batch.error_file_id is not None:
|
if final_batch.error_file_id is not None:
|
||||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||||
assert deleted_error_file.deleted
|
assert deleted_error_file.deleted
|
||||||
|
|
||||||
|
def test_batch_e2e_embeddings(self, openai_client, batch_helper, embedding_model_id):
|
||||||
|
"""Run an end-to-end batch with embeddings requests including both string and list inputs."""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "success-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/embeddings",
|
||||||
|
"body": {"model": embedding_model_id, "input": "Hello world", "encoding_format": "float"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"custom_id": "success-2",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/embeddings",
|
||||||
|
"body": {
|
||||||
|
"model": embedding_model_id,
|
||||||
|
"input": ["How are you?", "Good morning", "Have a great day"],
|
||||||
|
"encoding_format": "float",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/embeddings",
|
||||||
|
completion_window="24h",
|
||||||
|
metadata={"test": "e2e_embeddings_success"},
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(
|
||||||
|
batch.id,
|
||||||
|
max_wait_time=3 * 60,
|
||||||
|
expected_statuses={"completed"},
|
||||||
|
timeout_action="skip",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert final_batch.status == "completed"
|
||||||
|
assert final_batch.request_counts is not None
|
||||||
|
assert final_batch.request_counts.total == 2
|
||||||
|
assert final_batch.request_counts.completed == 2
|
||||||
|
assert final_batch.output_file_id is not None
|
||||||
|
|
||||||
|
output_content = openai_client.files.content(final_batch.output_file_id)
|
||||||
|
if isinstance(output_content, str):
|
||||||
|
output_text = output_content
|
||||||
|
else:
|
||||||
|
output_text = output_content.content.decode("utf-8")
|
||||||
|
|
||||||
|
output_lines = output_text.strip().split("\n")
|
||||||
|
assert len(output_lines) == 2
|
||||||
|
|
||||||
|
# Check first result (string input)
|
||||||
|
result1 = json.loads(output_lines[0])
|
||||||
|
assert result1["custom_id"] in ["success-1", "success-2"]
|
||||||
|
assert "response" in result1
|
||||||
|
assert result1["response"]["status_code"] == 200
|
||||||
|
|
||||||
|
# Verify the response body contains embeddings data
|
||||||
|
response_body1 = json.loads(result1["response"]["body"])
|
||||||
|
assert response_body1["object"] == "list"
|
||||||
|
assert "data" in response_body1
|
||||||
|
assert len(response_body1["data"]) == 1
|
||||||
|
assert "embedding" in response_body1["data"][0]
|
||||||
|
assert "index" in response_body1["data"][0]
|
||||||
|
assert response_body1["data"][0]["index"] == 0
|
||||||
|
|
||||||
|
# Check second result (list input)
|
||||||
|
result2 = json.loads(output_lines[1])
|
||||||
|
assert result2["custom_id"] in ["success-1", "success-2"]
|
||||||
|
assert "response" in result2
|
||||||
|
assert result2["response"]["status_code"] == 200
|
||||||
|
|
||||||
|
# Verify the response body contains embeddings data for list input
|
||||||
|
response_body2 = json.loads(result2["response"]["body"])
|
||||||
|
assert response_body2["object"] == "list"
|
||||||
|
assert "data" in response_body2
|
||||||
|
assert len(response_body2["data"]) == 3 # Three strings in the list
|
||||||
|
for i, embedding_data in enumerate(response_body2["data"]):
|
||||||
|
assert "embedding" in embedding_data
|
||||||
|
assert "index" in embedding_data
|
||||||
|
assert embedding_data["index"] == i
|
||||||
|
|
||||||
|
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
|
||||||
|
assert deleted_output_file.deleted
|
||||||
|
|
||||||
|
if final_batch.error_file_id is not None:
|
||||||
|
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||||
|
assert deleted_error_file.deleted
|
||||||
|
|
|
@ -213,7 +213,6 @@ class TestReferenceBatchesImpl:
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"endpoint",
|
"endpoint",
|
||||||
[
|
[
|
||||||
"/v1/embeddings",
|
|
||||||
"/v1/invalid/endpoint",
|
"/v1/invalid/endpoint",
|
||||||
"",
|
"",
|
||||||
],
|
],
|
||||||
|
@ -765,3 +764,12 @@ class TestReferenceBatchesImpl:
|
||||||
await asyncio.sleep(0.042) # let tasks start
|
await asyncio.sleep(0.042) # let tasks start
|
||||||
|
|
||||||
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
|
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
|
||||||
|
|
||||||
|
async def test_create_batch_embeddings_endpoint(self, provider):
|
||||||
|
"""Test that batch creation succeeds with embeddings endpoint."""
|
||||||
|
batch = await provider.create_batch(
|
||||||
|
input_file_id="file_123",
|
||||||
|
endpoint="/v1/embeddings",
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
assert batch.endpoint == "/v1/embeddings"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue