diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index e049518a4..39f45d7d1 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches): # TODO: set expiration time for garbage collection - if endpoint not in ["/v1/chat/completions", "/v1/completions"]: + if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]: raise ValueError( - f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint", + f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint", ) if completion_window != "24h": @@ -425,18 +425,23 @@ class ReferenceBatchesImpl(Batches): valid = False if batch.endpoint == "/v1/chat/completions": - required_params = [ + required_params: list[tuple[str, Any, str]] = [ ("model", str, "a string"), # messages is specific to /v1/chat/completions # we could skip validating messages here and let inference fail. however, # that would be a very expensive way to find out messages is wrong. ("messages", list, "an array"), # TODO: allow messages to be a string? ] - else: # /v1/completions + elif batch.endpoint == "/v1/completions": required_params = [ ("model", str, "a string"), ("prompt", str, "a string"), # TODO: allow prompt to be a list of strings?? ] + else: # /v1/embeddings + required_params = [ + ("model", str, "a string"), + ("input", (str, list), "a string or array of strings"), + ] for param, expected_type, type_string in required_params: if param not in body: @@ -614,7 +619,7 @@ class ReferenceBatchesImpl(Batches): "body": chat_response.model_dump_json(), }, } - else: # /v1/completions + elif request.url == "/v1/completions": completion_response = await self.inference_api.openai_completion(**request.body) # this is for mypy, we don't allow streaming so we'll get the right type @@ -630,6 +635,20 @@ class ReferenceBatchesImpl(Batches): "body": completion_response.model_dump_json(), }, } + else: # /v1/embeddings + embeddings_response = await self.inference_api.openai_embeddings(**request.body) + assert hasattr(embeddings_response, "model_dump_json"), ( + "Embeddings response must have model_dump_json method" + ) + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, # TODO: should this be different? + "body": embeddings_response.model_dump_json(), + }, + } except Exception as e: logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") return { diff --git a/tests/integration/batches/test_batches.py b/tests/integration/batches/test_batches.py index d55a68bd3..2ff838bdd 100644 --- a/tests/integration/batches/test_batches.py +++ b/tests/integration/batches/test_batches.py @@ -323,3 +323,92 @@ class TestBatchesIntegration: if final_batch.error_file_id is not None: deleted_error_file = openai_client.files.delete(final_batch.error_file_id) assert deleted_error_file.deleted + + def test_batch_e2e_embeddings(self, openai_client, batch_helper, embedding_model_id): + """Run an end-to-end batch with embeddings requests including both string and list inputs.""" + batch_requests = [ + { + "custom_id": "success-1", + "method": "POST", + "url": "/v1/embeddings", + "body": {"model": embedding_model_id, "input": "Hello world", "encoding_format": "float"}, + }, + { + "custom_id": "success-2", + "method": "POST", + "url": "/v1/embeddings", + "body": { + "model": embedding_model_id, + "input": ["How are you?", "Good morning", "Have a great day"], + "encoding_format": "float", + }, + }, + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/embeddings", + completion_window="24h", + metadata={"test": "e2e_embeddings_success"}, + ) + + final_batch = batch_helper.wait_for( + batch.id, + max_wait_time=3 * 60, + expected_statuses={"completed"}, + timeout_action="skip", + ) + + assert final_batch.status == "completed" + assert final_batch.request_counts is not None + assert final_batch.request_counts.total == 2 + assert final_batch.request_counts.completed == 2 + assert final_batch.output_file_id is not None + + output_content = openai_client.files.content(final_batch.output_file_id) + if isinstance(output_content, str): + output_text = output_content + else: + output_text = output_content.content.decode("utf-8") + + output_lines = output_text.strip().split("\n") + assert len(output_lines) == 2 + + # Check first result (string input) + result1 = json.loads(output_lines[0]) + assert result1["custom_id"] in ["success-1", "success-2"] + assert "response" in result1 + assert result1["response"]["status_code"] == 200 + + # Verify the response body contains embeddings data + response_body1 = json.loads(result1["response"]["body"]) + assert response_body1["object"] == "list" + assert "data" in response_body1 + assert len(response_body1["data"]) == 1 + assert "embedding" in response_body1["data"][0] + assert "index" in response_body1["data"][0] + assert response_body1["data"][0]["index"] == 0 + + # Check second result (list input) + result2 = json.loads(output_lines[1]) + assert result2["custom_id"] in ["success-1", "success-2"] + assert "response" in result2 + assert result2["response"]["status_code"] == 200 + + # Verify the response body contains embeddings data for list input + response_body2 = json.loads(result2["response"]["body"]) + assert response_body2["object"] == "list" + assert "data" in response_body2 + assert len(response_body2["data"]) == 3 # Three strings in the list + for i, embedding_data in enumerate(response_body2["data"]): + assert "embedding" in embedding_data + assert "index" in embedding_data + assert embedding_data["index"] == i + + deleted_output_file = openai_client.files.delete(final_batch.output_file_id) + assert deleted_output_file.deleted + + if final_batch.error_file_id is not None: + deleted_error_file = openai_client.files.delete(final_batch.error_file_id) + assert deleted_error_file.deleted diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py index dfef5e040..89cb1af9d 100644 --- a/tests/unit/providers/batches/test_reference.py +++ b/tests/unit/providers/batches/test_reference.py @@ -213,7 +213,6 @@ class TestReferenceBatchesImpl: @pytest.mark.parametrize( "endpoint", [ - "/v1/embeddings", "/v1/invalid/endpoint", "", ], @@ -765,3 +764,12 @@ class TestReferenceBatchesImpl: await asyncio.sleep(0.042) # let tasks start assert active_batches == 2, f"Expected 2 active batches, got {active_batches}" + + async def test_create_batch_embeddings_endpoint(self, provider): + """Test that batch creation succeeds with embeddings endpoint.""" + batch = await provider.create_batch( + input_file_id="file_123", + endpoint="/v1/embeddings", + completion_window="24h", + ) + assert batch.endpoint == "/v1/embeddings"