mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat(batches, completions): add /v1/completions support to /v1/batches (#3309)
# What does this PR do? add support for /v1/completions to the /v1/batches api ## Test Plan ci
This commit is contained in:
parent
e2fe39aee1
commit
df1526991f
4 changed files with 205 additions and 26 deletions
|
@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated:
|
|||
* test_validate_input_url_mismatch (negative)
|
||||
* test_validate_input_multiple_errors_per_request (negative)
|
||||
* test_validate_input_invalid_request_format (negative)
|
||||
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
|
||||
* test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
|
||||
* test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
|
||||
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
||||
|
||||
The tests use temporary SQLite databases for isolation and mock external
|
||||
|
@ -213,7 +214,6 @@ class TestReferenceBatchesImpl:
|
|||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/completions",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
|
@ -499,8 +499,10 @@ class TestReferenceBatchesImpl:
|
|||
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
|
||||
"""Test _validate_input when file contains request with missing required parameters."""
|
||||
async def test_validate_input_missing_parameters_chat_completions(
|
||||
self, provider, param_name, param_path, error_code, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with missing required parameters for chat completions."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
|
@ -541,6 +543,61 @@ class TestReferenceBatchesImpl:
|
|||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,error_code,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
|
||||
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
|
||||
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
|
||||
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
|
||||
("model", "body.model", "invalid_request", "Model parameter is required"),
|
||||
("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters_completions(
|
||||
self, provider, param_name, param_path, error_code, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with missing required parameters for text completions."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {"model": "test-model", "prompt": "Hello"},
|
||||
}
|
||||
|
||||
# Remove the specific parameter being tested
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
del base_request[top_level][nested_param]
|
||||
else:
|
||||
del base_request[param_name]
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/completions",
|
||||
input_file_id=f"missing_{param_name}_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == error_code
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_validate_input_url_mismatch(self, provider):
|
||||
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue