mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat(batches, completions): add /v1/completions support to /v1/batches (#3309)
# What does this PR do? add support for /v1/completions to the /v1/batches api ## Test Plan ci
This commit is contained in:
parent
e2fe39aee1
commit
df1526991f
4 changed files with 205 additions and 26 deletions
|
@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
|
||||||
|
|
||||||
# TODO: set expiration time for garbage collection
|
# TODO: set expiration time for garbage collection
|
||||||
|
|
||||||
if endpoint not in ["/v1/chat/completions"]:
|
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
|
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
|
||||||
)
|
)
|
||||||
|
|
||||||
if completion_window != "24h":
|
if completion_window != "24h":
|
||||||
|
@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches):
|
||||||
)
|
)
|
||||||
valid = False
|
valid = False
|
||||||
|
|
||||||
for param, expected_type, type_string in [
|
if batch.endpoint == "/v1/chat/completions":
|
||||||
|
required_params = [
|
||||||
("model", str, "a string"),
|
("model", str, "a string"),
|
||||||
# messages is specific to /v1/chat/completions
|
# messages is specific to /v1/chat/completions
|
||||||
# we could skip validating messages here and let inference fail. however,
|
# we could skip validating messages here and let inference fail. however,
|
||||||
# that would be a very expensive way to find out messages is wrong.
|
# that would be a very expensive way to find out messages is wrong.
|
||||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||||
]:
|
]
|
||||||
|
else: # /v1/completions
|
||||||
|
required_params = [
|
||||||
|
("model", str, "a string"),
|
||||||
|
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||||
|
]
|
||||||
|
|
||||||
|
for param, expected_type, type_string in required_params:
|
||||||
if param not in body:
|
if param not in body:
|
||||||
errors.append(
|
errors.append(
|
||||||
BatchError(
|
BatchError(
|
||||||
|
@ -591,6 +599,7 @@ class ReferenceBatchesImpl(Batches):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO(SECURITY): review body for security issues
|
# TODO(SECURITY): review body for security issues
|
||||||
|
if request.url == "/v1/chat/completions":
|
||||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||||
|
|
||||||
|
@ -605,6 +614,22 @@ class ReferenceBatchesImpl(Batches):
|
||||||
"body": chat_response.model_dump_json(),
|
"body": chat_response.model_dump_json(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
else: # /v1/completions
|
||||||
|
completion_response = await self.inference_api.openai_completion(**request.body)
|
||||||
|
|
||||||
|
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||||
|
assert hasattr(completion_response, "model_dump_json"), (
|
||||||
|
"Completion response must have model_dump_json method"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"id": request_id,
|
||||||
|
"custom_id": request.custom_id,
|
||||||
|
"response": {
|
||||||
|
"status_code": 200,
|
||||||
|
"request_id": request_id,
|
||||||
|
"body": completion_response.model_dump_json(),
|
||||||
|
},
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -268,3 +268,58 @@ class TestBatchesIntegration:
|
||||||
|
|
||||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||||
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
|
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
|
||||||
|
|
||||||
|
def test_batch_e2e_completions(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Run an end-to-end batch with a single successful text completion request."""
|
||||||
|
request_body = {"model": text_model_id, "prompt": "Say completions", "max_tokens": 20}
|
||||||
|
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "success-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/completions",
|
||||||
|
"body": request_body,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
metadata={"test": "e2e_completions_success"},
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(
|
||||||
|
batch.id,
|
||||||
|
max_wait_time=3 * 60,
|
||||||
|
expected_statuses={"completed"},
|
||||||
|
timeout_action="skip",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert final_batch.status == "completed"
|
||||||
|
assert final_batch.request_counts is not None
|
||||||
|
assert final_batch.request_counts.total == 1
|
||||||
|
assert final_batch.request_counts.completed == 1
|
||||||
|
assert final_batch.output_file_id is not None
|
||||||
|
|
||||||
|
output_content = openai_client.files.content(final_batch.output_file_id)
|
||||||
|
if isinstance(output_content, str):
|
||||||
|
output_text = output_content
|
||||||
|
else:
|
||||||
|
output_text = output_content.content.decode("utf-8")
|
||||||
|
|
||||||
|
output_lines = output_text.strip().split("\n")
|
||||||
|
assert len(output_lines) == 1
|
||||||
|
|
||||||
|
result = json.loads(output_lines[0])
|
||||||
|
assert result["custom_id"] == "success-1"
|
||||||
|
assert "response" in result
|
||||||
|
assert result["response"]["status_code"] == 200
|
||||||
|
|
||||||
|
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
|
||||||
|
assert deleted_output_file.deleted
|
||||||
|
|
||||||
|
if final_batch.error_file_id is not None:
|
||||||
|
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||||
|
assert deleted_error_file.deleted
|
||||||
|
|
42
tests/integration/recordings/responses/41e27b9b5d09.json
Normal file
42
tests/integration/recordings/responses/41e27b9b5d09.json
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"method": "POST",
|
||||||
|
"url": "http://0.0.0.0:11434/v1/v1/completions",
|
||||||
|
"headers": {},
|
||||||
|
"body": {
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"prompt": "Say completions",
|
||||||
|
"max_tokens": 20
|
||||||
|
},
|
||||||
|
"endpoint": "/v1/completions",
|
||||||
|
"model": "llama3.2:3b-instruct-fp16"
|
||||||
|
},
|
||||||
|
"response": {
|
||||||
|
"body": {
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-271",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "You want me to respond with a completion, but you didn't specify what I should complete. Could"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1756846620,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"prompt_tokens": 28,
|
||||||
|
"total_tokens": 48,
|
||||||
|
"completion_tokens_details": null,
|
||||||
|
"prompt_tokens_details": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"is_streaming": false
|
||||||
|
}
|
||||||
|
}
|
|
@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated:
|
||||||
* test_validate_input_url_mismatch (negative)
|
* test_validate_input_url_mismatch (negative)
|
||||||
* test_validate_input_multiple_errors_per_request (negative)
|
* test_validate_input_multiple_errors_per_request (negative)
|
||||||
* test_validate_input_invalid_request_format (negative)
|
* test_validate_input_invalid_request_format (negative)
|
||||||
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
|
* test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
|
||||||
|
* test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
|
||||||
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
||||||
|
|
||||||
The tests use temporary SQLite databases for isolation and mock external
|
The tests use temporary SQLite databases for isolation and mock external
|
||||||
|
@ -213,7 +214,6 @@ class TestReferenceBatchesImpl:
|
||||||
"endpoint",
|
"endpoint",
|
||||||
[
|
[
|
||||||
"/v1/embeddings",
|
"/v1/embeddings",
|
||||||
"/v1/completions",
|
|
||||||
"/v1/invalid/endpoint",
|
"/v1/invalid/endpoint",
|
||||||
"",
|
"",
|
||||||
],
|
],
|
||||||
|
@ -499,8 +499,10 @@ class TestReferenceBatchesImpl:
|
||||||
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
|
async def test_validate_input_missing_parameters_chat_completions(
|
||||||
"""Test _validate_input when file contains request with missing required parameters."""
|
self, provider, param_name, param_path, error_code, error_message
|
||||||
|
):
|
||||||
|
"""Test _validate_input when file contains request with missing required parameters for chat completions."""
|
||||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
@ -541,6 +543,61 @@ class TestReferenceBatchesImpl:
|
||||||
assert errors[0].message == error_message
|
assert errors[0].message == error_message
|
||||||
assert errors[0].param == param_path
|
assert errors[0].param == param_path
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"param_name,param_path,error_code,error_message",
|
||||||
|
[
|
||||||
|
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
|
||||||
|
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
|
||||||
|
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
|
||||||
|
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
|
||||||
|
("model", "body.model", "invalid_request", "Model parameter is required"),
|
||||||
|
("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_validate_input_missing_parameters_completions(
|
||||||
|
self, provider, param_name, param_path, error_code, error_message
|
||||||
|
):
|
||||||
|
"""Test _validate_input when file contains request with missing required parameters for text completions."""
|
||||||
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
base_request = {
|
||||||
|
"custom_id": "req-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/completions",
|
||||||
|
"body": {"model": "test-model", "prompt": "Hello"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove the specific parameter being tested
|
||||||
|
if "." in param_path:
|
||||||
|
top_level, nested_param = param_path.split(".", 1)
|
||||||
|
del base_request[top_level][nested_param]
|
||||||
|
else:
|
||||||
|
del base_request[param_name]
|
||||||
|
|
||||||
|
mock_response.body = json.dumps(base_request).encode()
|
||||||
|
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id="batch_test",
|
||||||
|
object="batch",
|
||||||
|
endpoint="/v1/completions",
|
||||||
|
input_file_id=f"missing_{param_name}_file",
|
||||||
|
completion_window="24h",
|
||||||
|
status="validating",
|
||||||
|
created_at=1234567890,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, requests = await provider._validate_input(batch)
|
||||||
|
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert len(requests) == 0
|
||||||
|
|
||||||
|
assert errors[0].code == error_code
|
||||||
|
assert errors[0].line == 1
|
||||||
|
assert errors[0].message == error_message
|
||||||
|
assert errors[0].param == param_path
|
||||||
|
|
||||||
async def test_validate_input_url_mismatch(self, provider):
|
async def test_validate_input_url_mismatch(self, provider):
|
||||||
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
||||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue