mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
Merge 531b1451dc
into d266c59c2a
This commit is contained in:
commit
dfe5c0cd75
3 changed files with 122 additions and 6 deletions
|
@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
|
@ -425,18 +425,23 @@ class ReferenceBatchesImpl(Batches):
|
|||
valid = False
|
||||
|
||||
if batch.endpoint == "/v1/chat/completions":
|
||||
required_params = [
|
||||
required_params: list[tuple[str, Any, str]] = [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]
|
||||
else: # /v1/completions
|
||||
elif batch.endpoint == "/v1/completions":
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||
]
|
||||
else: # /v1/embeddings
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("input", (str, list), "a string or array of strings"),
|
||||
]
|
||||
|
||||
for param, expected_type, type_string in required_params:
|
||||
if param not in body:
|
||||
|
@ -614,7 +619,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/completions
|
||||
elif request.url == "/v1/completions":
|
||||
completion_response = await self.inference_api.openai_completion(**request.body)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
|
@ -630,6 +635,20 @@ class ReferenceBatchesImpl(Batches):
|
|||
"body": completion_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/embeddings
|
||||
embeddings_response = await self.inference_api.openai_embeddings(**request.body)
|
||||
assert hasattr(embeddings_response, "model_dump_json"), (
|
||||
"Embeddings response must have model_dump_json method"
|
||||
)
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": embeddings_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
|
|
|
@ -323,3 +323,92 @@ class TestBatchesIntegration:
|
|||
if final_batch.error_file_id is not None:
|
||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||
assert deleted_error_file.deleted
|
||||
|
||||
def test_batch_e2e_embeddings(self, openai_client, batch_helper, embedding_model_id):
|
||||
"""Run an end-to-end batch with embeddings requests including both string and list inputs."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "success-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {"model": embedding_model_id, "input": "Hello world", "encoding_format": "float"},
|
||||
},
|
||||
{
|
||||
"custom_id": "success-2",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {
|
||||
"model": embedding_model_id,
|
||||
"input": ["How are you?", "Good morning", "Have a great day"],
|
||||
"encoding_format": "float",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/embeddings",
|
||||
completion_window="24h",
|
||||
metadata={"test": "e2e_embeddings_success"},
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(
|
||||
batch.id,
|
||||
max_wait_time=3 * 60,
|
||||
expected_statuses={"completed"},
|
||||
timeout_action="skip",
|
||||
)
|
||||
|
||||
assert final_batch.status == "completed"
|
||||
assert final_batch.request_counts is not None
|
||||
assert final_batch.request_counts.total == 2
|
||||
assert final_batch.request_counts.completed == 2
|
||||
assert final_batch.output_file_id is not None
|
||||
|
||||
output_content = openai_client.files.content(final_batch.output_file_id)
|
||||
if isinstance(output_content, str):
|
||||
output_text = output_content
|
||||
else:
|
||||
output_text = output_content.content.decode("utf-8")
|
||||
|
||||
output_lines = output_text.strip().split("\n")
|
||||
assert len(output_lines) == 2
|
||||
|
||||
# Check first result (string input)
|
||||
result1 = json.loads(output_lines[0])
|
||||
assert result1["custom_id"] in ["success-1", "success-2"]
|
||||
assert "response" in result1
|
||||
assert result1["response"]["status_code"] == 200
|
||||
|
||||
# Verify the response body contains embeddings data
|
||||
response_body1 = json.loads(result1["response"]["body"])
|
||||
assert response_body1["object"] == "list"
|
||||
assert "data" in response_body1
|
||||
assert len(response_body1["data"]) == 1
|
||||
assert "embedding" in response_body1["data"][0]
|
||||
assert "index" in response_body1["data"][0]
|
||||
assert response_body1["data"][0]["index"] == 0
|
||||
|
||||
# Check second result (list input)
|
||||
result2 = json.loads(output_lines[1])
|
||||
assert result2["custom_id"] in ["success-1", "success-2"]
|
||||
assert "response" in result2
|
||||
assert result2["response"]["status_code"] == 200
|
||||
|
||||
# Verify the response body contains embeddings data for list input
|
||||
response_body2 = json.loads(result2["response"]["body"])
|
||||
assert response_body2["object"] == "list"
|
||||
assert "data" in response_body2
|
||||
assert len(response_body2["data"]) == 3 # Three strings in the list
|
||||
for i, embedding_data in enumerate(response_body2["data"]):
|
||||
assert "embedding" in embedding_data
|
||||
assert "index" in embedding_data
|
||||
assert embedding_data["index"] == i
|
||||
|
||||
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
|
||||
assert deleted_output_file.deleted
|
||||
|
||||
if final_batch.error_file_id is not None:
|
||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||
assert deleted_error_file.deleted
|
||||
|
|
|
@ -213,7 +213,6 @@ class TestReferenceBatchesImpl:
|
|||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
|
@ -765,3 +764,12 @@ class TestReferenceBatchesImpl:
|
|||
await asyncio.sleep(0.042) # let tasks start
|
||||
|
||||
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
|
||||
|
||||
async def test_create_batch_embeddings_endpoint(self, provider):
|
||||
"""Test that batch creation succeeds with embeddings endpoint."""
|
||||
batch = await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/embeddings",
|
||||
completion_window="24h",
|
||||
)
|
||||
assert batch.endpoint == "/v1/embeddings"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue