mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: Add /v1/embeddings endpoint to batches API
This PR extends the Llama Stack Batches API to support the /v1/embeddings endpoint, enabling efficient batch processing of embedding requests alongside the existing /v1/chat/completions and /v1/completions support. Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
aab22dc759
commit
531b1451dc
3 changed files with 122 additions and 6 deletions
|
@ -323,3 +323,92 @@ class TestBatchesIntegration:
|
|||
if final_batch.error_file_id is not None:
|
||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||
assert deleted_error_file.deleted
|
||||
|
||||
def test_batch_e2e_embeddings(self, openai_client, batch_helper, embedding_model_id):
|
||||
"""Run an end-to-end batch with embeddings requests including both string and list inputs."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "success-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {"model": embedding_model_id, "input": "Hello world", "encoding_format": "float"},
|
||||
},
|
||||
{
|
||||
"custom_id": "success-2",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {
|
||||
"model": embedding_model_id,
|
||||
"input": ["How are you?", "Good morning", "Have a great day"],
|
||||
"encoding_format": "float",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/embeddings",
|
||||
completion_window="24h",
|
||||
metadata={"test": "e2e_embeddings_success"},
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(
|
||||
batch.id,
|
||||
max_wait_time=3 * 60,
|
||||
expected_statuses={"completed"},
|
||||
timeout_action="skip",
|
||||
)
|
||||
|
||||
assert final_batch.status == "completed"
|
||||
assert final_batch.request_counts is not None
|
||||
assert final_batch.request_counts.total == 2
|
||||
assert final_batch.request_counts.completed == 2
|
||||
assert final_batch.output_file_id is not None
|
||||
|
||||
output_content = openai_client.files.content(final_batch.output_file_id)
|
||||
if isinstance(output_content, str):
|
||||
output_text = output_content
|
||||
else:
|
||||
output_text = output_content.content.decode("utf-8")
|
||||
|
||||
output_lines = output_text.strip().split("\n")
|
||||
assert len(output_lines) == 2
|
||||
|
||||
# Check first result (string input)
|
||||
result1 = json.loads(output_lines[0])
|
||||
assert result1["custom_id"] in ["success-1", "success-2"]
|
||||
assert "response" in result1
|
||||
assert result1["response"]["status_code"] == 200
|
||||
|
||||
# Verify the response body contains embeddings data
|
||||
response_body1 = json.loads(result1["response"]["body"])
|
||||
assert response_body1["object"] == "list"
|
||||
assert "data" in response_body1
|
||||
assert len(response_body1["data"]) == 1
|
||||
assert "embedding" in response_body1["data"][0]
|
||||
assert "index" in response_body1["data"][0]
|
||||
assert response_body1["data"][0]["index"] == 0
|
||||
|
||||
# Check second result (list input)
|
||||
result2 = json.loads(output_lines[1])
|
||||
assert result2["custom_id"] in ["success-1", "success-2"]
|
||||
assert "response" in result2
|
||||
assert result2["response"]["status_code"] == 200
|
||||
|
||||
# Verify the response body contains embeddings data for list input
|
||||
response_body2 = json.loads(result2["response"]["body"])
|
||||
assert response_body2["object"] == "list"
|
||||
assert "data" in response_body2
|
||||
assert len(response_body2["data"]) == 3 # Three strings in the list
|
||||
for i, embedding_data in enumerate(response_body2["data"]):
|
||||
assert "embedding" in embedding_data
|
||||
assert "index" in embedding_data
|
||||
assert embedding_data["index"] == i
|
||||
|
||||
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
|
||||
assert deleted_output_file.deleted
|
||||
|
||||
if final_batch.error_file_id is not None:
|
||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||
assert deleted_error_file.deleted
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue