diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 17aecdaf8..d88df00bd 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,6 +8,11 @@ import asyncio import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam +from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam +from pydantic import Field, TypeAdapter +from typing_extensions import Annotated + from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -526,7 +531,7 @@ class InferenceRouter(Inference): async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], + messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)], frequency_penalty: Optional[float] = None, function_call: Optional[Union[str, Dict[str, Any]]] = None, functions: Optional[List[Dict[str, Any]]] = None, @@ -558,6 +563,16 @@ class InferenceRouter(Inference): if model_obj.model_type == ModelType.embedding: raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions") + # Use the OpenAI client for a bit of extra input validation without + # exposing the OpenAI client itself as part of our API surface + if tool_choice: + TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice) + if tools is None: + raise ValueError("'tool_choice' is only allowed when 'tools' is also provided") + if tools: + for tool in tools: + TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) + params = dict( model=model_obj.identifier, messages=messages, diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 50cf44ec9..2942920d4 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -22,6 +22,7 @@ from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse +from openai import BadRequestError from pydantic import BaseModel, ValidationError from typing_extensions import Annotated @@ -110,6 +111,8 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio ) elif isinstance(exc, ValueError): return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}") + elif isinstance(exc, BadRequestError): + return HTTPException(status_code=400, detail=str(exc)) elif isinstance(exc, PermissionError): return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") elif isinstance(exc, TimeoutError): diff --git a/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml b/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml index 1ace76e34..0c9f1fe9e 100644 --- a/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml +++ b/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml @@ -15,6 +15,52 @@ test_chat_basic: S? role: user output: Saturn +test_chat_input_validation: + test_name: test_chat_input_validation + test_params: + case: + - case_id: "messages_missing" + input: + messages: [] + output: + error: + status_code: 400 + - case_id: "messages_role_invalid" + input: + messages: + - content: Which planet do humans live on? + role: fake_role + output: + error: + status_code: 400 + - case_id: "tool_choice_invalid" + input: + messages: + - content: Which planet do humans live on? + role: user + tool_choice: invalid + output: + error: + status_code: 400 + - case_id: "tool_choice_no_tools" + input: + messages: + - content: Which planet do humans live on? + role: user + tool_choice: required + output: + error: + status_code: 400 + - case_id: "tools_type_invalid" + input: + messages: + - content: Which planet do humans live on? + role: user + tools: + - type: invalid + output: + error: + status_code: 400 test_chat_image: test_name: test_chat_image test_params: diff --git a/tests/verifications/openai_api/test_chat_completion.py b/tests/verifications/openai_api/test_chat_completion.py index 3a311667a..277eaafa3 100644 --- a/tests/verifications/openai_api/test_chat_completion.py +++ b/tests/verifications/openai_api/test_chat_completion.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any import pytest +from openai import APIError from pydantic import BaseModel from tests.verifications.openai_api.fixtures.fixtures import ( @@ -136,6 +137,50 @@ def test_chat_streaming_basic(request, openai_client, model, provider, verificat assert case["output"].lower() in content.lower() +@pytest.mark.parametrize( + "case", + chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"], + ids=case_id_generator, +) +def test_chat_non_streaming_error_handling(request, openai_client, model, provider, verification_config, case): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + with pytest.raises(APIError) as e: + openai_client.chat.completions.create( + model=model, + messages=case["input"]["messages"], + stream=False, + tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None, + tools=case["input"]["tools"] if "tools" in case["input"] else None, + ) + assert case["output"]["error"]["status_code"] == e.value.status_code + + +@pytest.mark.parametrize( + "case", + chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"], + ids=case_id_generator, +) +def test_chat_streaming_error_handling(request, openai_client, model, provider, verification_config, case): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + with pytest.raises(APIError) as e: + response = openai_client.chat.completions.create( + model=model, + messages=case["input"]["messages"], + stream=True, + tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None, + tools=case["input"]["tools"] if "tools" in case["input"] else None, + ) + for _chunk in response: + pass + assert str(case["output"]["error"]["status_code"]) in e.value.message + + @pytest.mark.parametrize( "case", chat_completion_test_cases["test_chat_image"]["test_params"]["case"],