From df1526991f6d5cfca05d3c4f1077b67f4832d93e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 5 Sep 2025 14:59:57 -0400 Subject: [PATCH] feat(batches, completions): add /v1/completions support to /v1/batches (#3309) # What does this PR do? add support for /v1/completions to the /v1/batches api ## Test Plan ci --- .../inline/batches/reference/batches.py | 69 +++++++++++++------ tests/integration/batches/test_batches.py | 55 +++++++++++++++ .../recordings/responses/41e27b9b5d09.json | 42 +++++++++++ .../unit/providers/batches/test_reference.py | 65 +++++++++++++++-- 4 files changed, 205 insertions(+), 26 deletions(-) create mode 100644 tests/integration/recordings/responses/41e27b9b5d09.json diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index 26f0ad15a..e049518a4 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches): # TODO: set expiration time for garbage collection - if endpoint not in ["/v1/chat/completions"]: + if endpoint not in ["/v1/chat/completions", "/v1/completions"]: raise ValueError( - f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", + f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint", ) if completion_window != "24h": @@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches): ) valid = False - for param, expected_type, type_string in [ - ("model", str, "a string"), - # messages is specific to /v1/chat/completions - # we could skip validating messages here and let inference fail. however, - # that would be a very expensive way to find out messages is wrong. - ("messages", list, "an array"), # TODO: allow messages to be a string? - ]: + if batch.endpoint == "/v1/chat/completions": + required_params = [ + ("model", str, "a string"), + # messages is specific to /v1/chat/completions + # we could skip validating messages here and let inference fail. however, + # that would be a very expensive way to find out messages is wrong. + ("messages", list, "an array"), # TODO: allow messages to be a string? + ] + else: # /v1/completions + required_params = [ + ("model", str, "a string"), + ("prompt", str, "a string"), # TODO: allow prompt to be a list of strings?? + ] + + for param, expected_type, type_string in required_params: if param not in body: errors.append( BatchError( @@ -591,20 +599,37 @@ class ReferenceBatchesImpl(Batches): try: # TODO(SECURITY): review body for security issues - request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] - chat_response = await self.inference_api.openai_chat_completion(**request.body) + if request.url == "/v1/chat/completions": + request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] + chat_response = await self.inference_api.openai_chat_completion(**request.body) - # this is for mypy, we don't allow streaming so we'll get the right type - assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" - return { - "id": request_id, - "custom_id": request.custom_id, - "response": { - "status_code": 200, - "request_id": request_id, # TODO: should this be different? - "body": chat_response.model_dump_json(), - }, - } + # this is for mypy, we don't allow streaming so we'll get the right type + assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, # TODO: should this be different? + "body": chat_response.model_dump_json(), + }, + } + else: # /v1/completions + completion_response = await self.inference_api.openai_completion(**request.body) + + # this is for mypy, we don't allow streaming so we'll get the right type + assert hasattr(completion_response, "model_dump_json"), ( + "Completion response must have model_dump_json method" + ) + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, + "body": completion_response.model_dump_json(), + }, + } except Exception as e: logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") return { diff --git a/tests/integration/batches/test_batches.py b/tests/integration/batches/test_batches.py index 59811b7a4..d55a68bd3 100644 --- a/tests/integration/batches/test_batches.py +++ b/tests/integration/batches/test_batches.py @@ -268,3 +268,58 @@ class TestBatchesIntegration: deleted_error_file = openai_client.files.delete(final_batch.error_file_id) assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully" + + def test_batch_e2e_completions(self, openai_client, batch_helper, text_model_id): + """Run an end-to-end batch with a single successful text completion request.""" + request_body = {"model": text_model_id, "prompt": "Say completions", "max_tokens": 20} + + batch_requests = [ + { + "custom_id": "success-1", + "method": "POST", + "url": "/v1/completions", + "body": request_body, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/completions", + completion_window="24h", + metadata={"test": "e2e_completions_success"}, + ) + + final_batch = batch_helper.wait_for( + batch.id, + max_wait_time=3 * 60, + expected_statuses={"completed"}, + timeout_action="skip", + ) + + assert final_batch.status == "completed" + assert final_batch.request_counts is not None + assert final_batch.request_counts.total == 1 + assert final_batch.request_counts.completed == 1 + assert final_batch.output_file_id is not None + + output_content = openai_client.files.content(final_batch.output_file_id) + if isinstance(output_content, str): + output_text = output_content + else: + output_text = output_content.content.decode("utf-8") + + output_lines = output_text.strip().split("\n") + assert len(output_lines) == 1 + + result = json.loads(output_lines[0]) + assert result["custom_id"] == "success-1" + assert "response" in result + assert result["response"]["status_code"] == 200 + + deleted_output_file = openai_client.files.delete(final_batch.output_file_id) + assert deleted_output_file.deleted + + if final_batch.error_file_id is not None: + deleted_error_file = openai_client.files.delete(final_batch.error_file_id) + assert deleted_error_file.deleted diff --git a/tests/integration/recordings/responses/41e27b9b5d09.json b/tests/integration/recordings/responses/41e27b9b5d09.json new file mode 100644 index 000000000..45d140843 --- /dev/null +++ b/tests/integration/recordings/responses/41e27b9b5d09.json @@ -0,0 +1,42 @@ +{ + "request": { + "method": "POST", + "url": "http://0.0.0.0:11434/v1/v1/completions", + "headers": {}, + "body": { + "model": "llama3.2:3b-instruct-fp16", + "prompt": "Say completions", + "max_tokens": 20 + }, + "endpoint": "/v1/completions", + "model": "llama3.2:3b-instruct-fp16" + }, + "response": { + "body": { + "__type__": "openai.types.completion.Completion", + "__data__": { + "id": "cmpl-271", + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "text": "You want me to respond with a completion, but you didn't specify what I should complete. Could" + } + ], + "created": 1756846620, + "model": "llama3.2:3b-instruct-fp16", + "object": "text_completion", + "system_fingerprint": "fp_ollama", + "usage": { + "completion_tokens": 20, + "prompt_tokens": 28, + "total_tokens": 48, + "completion_tokens_details": null, + "prompt_tokens_details": null + } + } + }, + "is_streaming": false + } +} diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py index 0ca866f7b..dfef5e040 100644 --- a/tests/unit/providers/batches/test_reference.py +++ b/tests/unit/providers/batches/test_reference.py @@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated: * test_validate_input_url_mismatch (negative) * test_validate_input_multiple_errors_per_request (negative) * test_validate_input_invalid_request_format (negative) - * test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation) + * test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions) + * test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions) * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation) The tests use temporary SQLite databases for isolation and mock external @@ -213,7 +214,6 @@ class TestReferenceBatchesImpl: "endpoint", [ "/v1/embeddings", - "/v1/completions", "/v1/invalid/endpoint", "", ], @@ -499,8 +499,10 @@ class TestReferenceBatchesImpl: ("messages", "body.messages", "invalid_request", "Messages parameter is required"), ], ) - async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message): - """Test _validate_input when file contains request with missing required parameters.""" + async def test_validate_input_missing_parameters_chat_completions( + self, provider, param_name, param_path, error_code, error_message + ): + """Test _validate_input when file contains request with missing required parameters for chat completions.""" provider.files_api.openai_retrieve_file = AsyncMock() mock_response = MagicMock() @@ -541,6 +543,61 @@ class TestReferenceBatchesImpl: assert errors[0].message == error_message assert errors[0].param == param_path + @pytest.mark.parametrize( + "param_name,param_path,error_code,error_message", + [ + ("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"), + ("method", "method", "missing_required_parameter", "Missing required parameter: method"), + ("url", "url", "missing_required_parameter", "Missing required parameter: url"), + ("body", "body", "missing_required_parameter", "Missing required parameter: body"), + ("model", "body.model", "invalid_request", "Model parameter is required"), + ("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"), + ], + ) + async def test_validate_input_missing_parameters_completions( + self, provider, param_name, param_path, error_code, error_message + ): + """Test _validate_input when file contains request with missing required parameters for text completions.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + + base_request = { + "custom_id": "req-1", + "method": "POST", + "url": "/v1/completions", + "body": {"model": "test-model", "prompt": "Hello"}, + } + + # Remove the specific parameter being tested + if "." in param_path: + top_level, nested_param = param_path.split(".", 1) + del base_request[top_level][nested_param] + else: + del base_request[param_name] + + mock_response.body = json.dumps(base_request).encode() + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/completions", + input_file_id=f"missing_{param_name}_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == error_code + assert errors[0].line == 1 + assert errors[0].message == error_message + assert errors[0].param == param_path + async def test_validate_input_url_mismatch(self, provider): """Test _validate_input when file contains request with URL that doesn't match batch endpoint.""" provider.files_api.openai_retrieve_file = AsyncMock()