mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
feat(batches, completions): add /v1/completions support to /v1/batches
This commit is contained in:
parent
faf891b40c
commit
6ae7509a47
4 changed files with 205 additions and 26 deletions
|
@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions"]:
|
||||
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
|
@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches):
|
|||
)
|
||||
valid = False
|
||||
|
||||
for param, expected_type, type_string in [
|
||||
if batch.endpoint == "/v1/chat/completions":
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]:
|
||||
]
|
||||
else: # /v1/completions
|
||||
required_params = [
|
||||
("model", str, "a string"),
|
||||
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||
]
|
||||
|
||||
for param, expected_type, type_string in required_params:
|
||||
if param not in body:
|
||||
errors.append(
|
||||
BatchError(
|
||||
|
@ -591,6 +599,7 @@ class ReferenceBatchesImpl(Batches):
|
|||
|
||||
try:
|
||||
# TODO(SECURITY): review body for security issues
|
||||
if request.url == "/v1/chat/completions":
|
||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||
|
||||
|
@ -605,6 +614,22 @@ class ReferenceBatchesImpl(Batches):
|
|||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
else: # /v1/completions
|
||||
completion_response = await self.inference_api.openai_completion(**request.body)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(completion_response, "model_dump_json"), (
|
||||
"Completion response must have model_dump_json method"
|
||||
)
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id,
|
||||
"body": completion_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
|
|
|
@ -268,3 +268,58 @@ class TestBatchesIntegration:
|
|||
|
||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
|
||||
|
||||
def test_batch_e2e_completions(self, openai_client, batch_helper, text_model_id):
|
||||
"""Run an end-to-end batch with a single successful text completion request."""
|
||||
request_body = {"model": text_model_id, "prompt": "Say completions", "max_tokens": 20}
|
||||
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "success-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": request_body,
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/completions",
|
||||
completion_window="24h",
|
||||
metadata={"test": "e2e_completions_success"},
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(
|
||||
batch.id,
|
||||
max_wait_time=3 * 60,
|
||||
expected_statuses={"completed"},
|
||||
timeout_action="skip",
|
||||
)
|
||||
|
||||
assert final_batch.status == "completed"
|
||||
assert final_batch.request_counts is not None
|
||||
assert final_batch.request_counts.total == 1
|
||||
assert final_batch.request_counts.completed == 1
|
||||
assert final_batch.output_file_id is not None
|
||||
|
||||
output_content = openai_client.files.content(final_batch.output_file_id)
|
||||
if isinstance(output_content, str):
|
||||
output_text = output_content
|
||||
else:
|
||||
output_text = output_content.content.decode("utf-8")
|
||||
|
||||
output_lines = output_text.strip().split("\n")
|
||||
assert len(output_lines) == 1
|
||||
|
||||
result = json.loads(output_lines[0])
|
||||
assert result["custom_id"] == "success-1"
|
||||
assert "response" in result
|
||||
assert result["response"]["status_code"] == 200
|
||||
|
||||
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
|
||||
assert deleted_output_file.deleted
|
||||
|
||||
if final_batch.error_file_id is not None:
|
||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||
assert deleted_error_file.deleted
|
||||
|
|
42
tests/integration/recordings/responses/41e27b9b5d09.json
Normal file
42
tests/integration/recordings/responses/41e27b9b5d09.json
Normal file
|
@ -0,0 +1,42 @@
|
|||
{
|
||||
"request": {
|
||||
"method": "POST",
|
||||
"url": "http://0.0.0.0:11434/v1/v1/completions",
|
||||
"headers": {},
|
||||
"body": {
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"prompt": "Say completions",
|
||||
"max_tokens": 20
|
||||
},
|
||||
"endpoint": "/v1/completions",
|
||||
"model": "llama3.2:3b-instruct-fp16"
|
||||
},
|
||||
"response": {
|
||||
"body": {
|
||||
"__type__": "openai.types.completion.Completion",
|
||||
"__data__": {
|
||||
"id": "cmpl-271",
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"text": "You want me to respond with a completion, but you didn't specify what I should complete. Could"
|
||||
}
|
||||
],
|
||||
"created": 1756846620,
|
||||
"model": "llama3.2:3b-instruct-fp16",
|
||||
"object": "text_completion",
|
||||
"system_fingerprint": "fp_ollama",
|
||||
"usage": {
|
||||
"completion_tokens": 20,
|
||||
"prompt_tokens": 28,
|
||||
"total_tokens": 48,
|
||||
"completion_tokens_details": null,
|
||||
"prompt_tokens_details": null
|
||||
}
|
||||
}
|
||||
},
|
||||
"is_streaming": false
|
||||
}
|
||||
}
|
|
@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated:
|
|||
* test_validate_input_url_mismatch (negative)
|
||||
* test_validate_input_multiple_errors_per_request (negative)
|
||||
* test_validate_input_invalid_request_format (negative)
|
||||
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
|
||||
* test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
|
||||
* test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
|
||||
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
||||
|
||||
The tests use temporary SQLite databases for isolation and mock external
|
||||
|
@ -213,7 +214,6 @@ class TestReferenceBatchesImpl:
|
|||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/completions",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
|
@ -499,8 +499,10 @@ class TestReferenceBatchesImpl:
|
|||
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
|
||||
"""Test _validate_input when file contains request with missing required parameters."""
|
||||
async def test_validate_input_missing_parameters_chat_completions(
|
||||
self, provider, param_name, param_path, error_code, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with missing required parameters for chat completions."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
|
@ -541,6 +543,61 @@ class TestReferenceBatchesImpl:
|
|||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,error_code,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
|
||||
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
|
||||
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
|
||||
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
|
||||
("model", "body.model", "invalid_request", "Model parameter is required"),
|
||||
("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters_completions(
|
||||
self, provider, param_name, param_path, error_code, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with missing required parameters for text completions."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {"model": "test-model", "prompt": "Hello"},
|
||||
}
|
||||
|
||||
# Remove the specific parameter being tested
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
del base_request[top_level][nested_param]
|
||||
else:
|
||||
del base_request[param_name]
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/completions",
|
||||
input_file_id=f"missing_{param_name}_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == error_code
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_validate_input_url_mismatch(self, provider):
|
||||
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue