From fa5dfee07b251b1fcb85e7d42377aee29e268cd9 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Wed, 23 Apr 2025 11:48:32 -0400 Subject: [PATCH] fix: Return HTTP 400 for OpenAI API validation errors (#2002) # What does this PR do? When clients called the Open AI API with invalid input that wasn't caught by our own Pydantic API validation but instead only caught by the backend inference provider, that backend inference provider was returning a HTTP 400 error. However, we were wrapping that into a HTTP 500 error, obfuscating the actual issue from calling clients and triggering OpenAI client retry logic. This change adjusts our existing `translate_exception` method in `server.py` to wrap `openai.BadRequestError` as HTTP 400 errors, passing through the string representation of the error message to the calling user so they can see the actual input validation error and correct it. I tried changing this in a few other places, but ultimately `translate_exception` was the only real place to handle this for both streaming and non-streaming requests across all inference providers that use the OpenAI server APIs. This also tightens up our validation a bit for the OpenAI chat completions API, to catch empty `messages` parameters, invalid `tool_choice` parameters, invalid `tools` items, or passing `tool_choice` when `tools` isn't given. Lastly, this extends our OpenAI API chat completions verifications to also check for consistent input validation across providers. Providers behind Llama Stack should automatically pass all the new tests due to the input validation added here, but some of the providers fail this test when not run behind Llama Stack due to differences in how they handle input validation and errors. (Closes #1951) ## Test Plan To test this, start an OpenAI API verification stack: ``` llama stack run --image-type venv tests/verifications/openai-api-verification-run.yaml ``` Then, run the new verification tests with your provider(s) of choice: ``` python -m pytest -s -v \ tests/verifications/openai_api/test_chat_completion.py \ --provider openai-llama-stack python -m pytest -s -v \ tests/verifications/openai_api/test_chat_completion.py \ --provider together-llama-stack ``` Signed-off-by: Ben Browning --- llama_stack/distribution/routers/routers.py | 17 ++++++- llama_stack/distribution/server/server.py | 3 ++ .../fixtures/test_cases/chat_completion.yaml | 46 +++++++++++++++++++ .../openai_api/test_chat_completion.py | 45 ++++++++++++++++++ 4 files changed, 110 insertions(+), 1 deletion(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 17aecdaf8..d88df00bd 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,6 +8,11 @@ import asyncio import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam +from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam +from pydantic import Field, TypeAdapter +from typing_extensions import Annotated + from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -526,7 +531,7 @@ class InferenceRouter(Inference): async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], + messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)], frequency_penalty: Optional[float] = None, function_call: Optional[Union[str, Dict[str, Any]]] = None, functions: Optional[List[Dict[str, Any]]] = None, @@ -558,6 +563,16 @@ class InferenceRouter(Inference): if model_obj.model_type == ModelType.embedding: raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions") + # Use the OpenAI client for a bit of extra input validation without + # exposing the OpenAI client itself as part of our API surface + if tool_choice: + TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice) + if tools is None: + raise ValueError("'tool_choice' is only allowed when 'tools' is also provided") + if tools: + for tool in tools: + TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) + params = dict( model=model_obj.identifier, messages=messages, diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 50cf44ec9..2942920d4 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -22,6 +22,7 @@ from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse +from openai import BadRequestError from pydantic import BaseModel, ValidationError from typing_extensions import Annotated @@ -110,6 +111,8 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio ) elif isinstance(exc, ValueError): return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}") + elif isinstance(exc, BadRequestError): + return HTTPException(status_code=400, detail=str(exc)) elif isinstance(exc, PermissionError): return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") elif isinstance(exc, TimeoutError): diff --git a/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml b/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml index 1ace76e34..0c9f1fe9e 100644 --- a/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml +++ b/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml @@ -15,6 +15,52 @@ test_chat_basic: S? role: user output: Saturn +test_chat_input_validation: + test_name: test_chat_input_validation + test_params: + case: + - case_id: "messages_missing" + input: + messages: [] + output: + error: + status_code: 400 + - case_id: "messages_role_invalid" + input: + messages: + - content: Which planet do humans live on? + role: fake_role + output: + error: + status_code: 400 + - case_id: "tool_choice_invalid" + input: + messages: + - content: Which planet do humans live on? + role: user + tool_choice: invalid + output: + error: + status_code: 400 + - case_id: "tool_choice_no_tools" + input: + messages: + - content: Which planet do humans live on? + role: user + tool_choice: required + output: + error: + status_code: 400 + - case_id: "tools_type_invalid" + input: + messages: + - content: Which planet do humans live on? + role: user + tools: + - type: invalid + output: + error: + status_code: 400 test_chat_image: test_name: test_chat_image test_params: diff --git a/tests/verifications/openai_api/test_chat_completion.py b/tests/verifications/openai_api/test_chat_completion.py index 3a311667a..277eaafa3 100644 --- a/tests/verifications/openai_api/test_chat_completion.py +++ b/tests/verifications/openai_api/test_chat_completion.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any import pytest +from openai import APIError from pydantic import BaseModel from tests.verifications.openai_api.fixtures.fixtures import ( @@ -136,6 +137,50 @@ def test_chat_streaming_basic(request, openai_client, model, provider, verificat assert case["output"].lower() in content.lower() +@pytest.mark.parametrize( + "case", + chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"], + ids=case_id_generator, +) +def test_chat_non_streaming_error_handling(request, openai_client, model, provider, verification_config, case): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + with pytest.raises(APIError) as e: + openai_client.chat.completions.create( + model=model, + messages=case["input"]["messages"], + stream=False, + tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None, + tools=case["input"]["tools"] if "tools" in case["input"] else None, + ) + assert case["output"]["error"]["status_code"] == e.value.status_code + + +@pytest.mark.parametrize( + "case", + chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"], + ids=case_id_generator, +) +def test_chat_streaming_error_handling(request, openai_client, model, provider, verification_config, case): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + with pytest.raises(APIError) as e: + response = openai_client.chat.completions.create( + model=model, + messages=case["input"]["messages"], + stream=True, + tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None, + tools=case["input"]["tools"] if "tools" in case["input"] else None, + ) + for _chunk in response: + pass + assert str(case["output"]["error"]["status_code"]) in e.value.message + + @pytest.mark.parametrize( "case", chat_completion_test_cases["test_chat_image"]["test_params"]["case"],