Merge pull request #9220 from BerriAI/litellm_qa_responses_api

[Fixes] Responses API - allow /responses and subpaths as LLM API route + Add exception mapping for responses API
This commit is contained in:
Ishaan Jaff 2025-03-13 21:36:59 -07:00 committed by GitHub
commit ceb8668e4a
16 changed files with 340 additions and 76 deletions

View file

@ -1357,7 +1357,7 @@ jobs:
# Store test results # Store test results
- store_test_results: - store_test_results:
path: test-results path: test-results
e2e_openai_misc_endpoints: e2e_openai_endpoints:
machine: machine:
image: ubuntu-2204:2023.10.1 image: ubuntu-2204:2023.10.1
resource_class: xlarge resource_class: xlarge
@ -1474,7 +1474,7 @@ jobs:
command: | command: |
pwd pwd
ls ls
python -m pytest -s -vv tests/openai_misc_endpoints_tests --junitxml=test-results/junit.xml --durations=5 python -m pytest -s -vv tests/openai_endpoints_tests --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m no_output_timeout: 120m
# Store test results # Store test results
@ -2429,7 +2429,7 @@ workflows:
only: only:
- main - main
- /litellm_.*/ - /litellm_.*/
- e2e_openai_misc_endpoints: - e2e_openai_endpoints:
filters: filters:
branches: branches:
only: only:
@ -2571,7 +2571,7 @@ workflows:
requires: requires:
- local_testing - local_testing
- build_and_test - build_and_test
- e2e_openai_misc_endpoints - e2e_openai_endpoints
- load_testing - load_testing
- test_bad_database_url - test_bad_database_url
- llm_translation_testing - llm_translation_testing

View file

@ -127,7 +127,7 @@ def exception_type( # type: ignore # noqa: PLR0915
completion_kwargs={}, completion_kwargs={},
extra_kwargs={}, extra_kwargs={},
): ):
"""Maps an LLM Provider Exception to OpenAI Exception Format"""
if any( if any(
isinstance(original_exception, exc_type) isinstance(original_exception, exc_type)
for exc_type in litellm.LITELLM_EXCEPTION_TYPES for exc_type in litellm.LITELLM_EXCEPTION_TYPES

View file

@ -248,6 +248,13 @@ class LiteLLMRoutes(enum.Enum):
"/v1/realtime", "/v1/realtime",
"/realtime?{model}", "/realtime?{model}",
"/v1/realtime?{model}", "/v1/realtime?{model}",
# responses API
"/responses",
"/v1/responses",
"/responses/{response_id}",
"/v1/responses/{response_id}",
"/responses/{response_id}/input_items",
"/v1/responses/{response_id}/input_items",
] ]
mapped_pass_through_routes = [ mapped_pass_through_routes = [

View file

@ -78,3 +78,93 @@ async def responses_api(
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
version=version, version=version,
) )
@router.get(
"/v1/responses/{response_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
@router.get(
"/responses/{response_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
async def get_response(
response_id: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get a response by ID.
Follows the OpenAI Responses API spec: https://platform.openai.com/docs/api-reference/responses/get
```bash
curl -X GET http://localhost:4000/v1/responses/resp_abc123 \
-H "Authorization: Bearer sk-1234"
```
"""
# TODO: Implement response retrieval logic
pass
@router.delete(
"/v1/responses/{response_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
@router.delete(
"/responses/{response_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
async def delete_response(
response_id: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete a response by ID.
Follows the OpenAI Responses API spec: https://platform.openai.com/docs/api-reference/responses/delete
```bash
curl -X DELETE http://localhost:4000/v1/responses/resp_abc123 \
-H "Authorization: Bearer sk-1234"
```
"""
# TODO: Implement response deletion logic
pass
@router.get(
"/v1/responses/{response_id}/input_items",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
@router.get(
"/responses/{response_id}/input_items",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
async def get_response_input_items(
response_id: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get input items for a response.
Follows the OpenAI Responses API spec: https://platform.openai.com/docs/api-reference/responses/input-items
```bash
curl -X GET http://localhost:4000/v1/responses/resp_abc123/input_items \
-H "Authorization: Bearer sk-1234"
```
"""
# TODO: Implement input items retrieval logic
pass

View file

@ -58,15 +58,24 @@ async def aresponses(
extra_query: Optional[Dict[str, Any]] = None, extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs, **kwargs,
) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]: ) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
""" """
Async: Handles responses API requests by reusing the synchronous function Async: Handles responses API requests by reusing the synchronous function
""" """
local_vars = locals()
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
kwargs["aresponses"] = True kwargs["aresponses"] = True
# get custom llm provider so we can use this for mapping exceptions
if custom_llm_provider is None:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, api_base=local_vars.get("base_url", None)
)
func = partial( func = partial(
responses, responses,
input=input, input=input,
@ -91,6 +100,7 @@ async def aresponses(
extra_query=extra_query, extra_query=extra_query,
extra_body=extra_body, extra_body=extra_body,
timeout=timeout, timeout=timeout,
custom_llm_provider=custom_llm_provider,
**kwargs, **kwargs,
) )
@ -104,7 +114,13 @@ async def aresponses(
response = init_response response = init_response
return response return response
except Exception as e: except Exception as e:
raise e raise litellm.exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client @client
@ -133,85 +149,97 @@ def responses(
extra_query: Optional[Dict[str, Any]] = None, extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs, **kwargs,
): ):
""" """
Synchronous version of the Responses API. Synchronous version of the Responses API.
Uses the synchronous HTTP handler to make requests. Uses the synchronous HTTP handler to make requests.
""" """
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
_is_async = kwargs.pop("aresponses", False) is True
# get llm provider logic
litellm_params = GenericLiteLLMParams(**kwargs)
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
litellm.get_llm_provider(
model=model,
custom_llm_provider=kwargs.get("custom_llm_provider", None),
api_base=litellm_params.api_base,
api_key=litellm_params.api_key,
)
)
# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)
)
if responses_api_provider_config is None:
raise litellm.BadRequestError(
model=model,
llm_provider=custom_llm_provider,
message=f"Responses API not available for custom_llm_provider={custom_llm_provider}, model: {model}",
)
# Get all parameters using locals() and combine with kwargs
local_vars = locals() local_vars = locals()
local_vars.update(kwargs) try:
# Get ResponsesAPIOptionalRequestParams with only valid parameters litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
response_api_optional_params: ResponsesAPIOptionalRequestParams = ( litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
ResponsesAPIRequestUtils.get_requested_response_api_optional_param(local_vars) _is_async = kwargs.pop("aresponses", False) is True
)
# Get optional parameters for the responses API # get llm provider logic
responses_api_request_params: Dict = ( litellm_params = GenericLiteLLMParams(**kwargs)
ResponsesAPIRequestUtils.get_optional_params_responses_api( model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
model=model, litellm.get_llm_provider(
responses_api_provider_config=responses_api_provider_config, model=model,
response_api_optional_params=response_api_optional_params, custom_llm_provider=custom_llm_provider,
api_base=litellm_params.api_base,
api_key=litellm_params.api_key,
)
) )
)
# Pre Call logging # get provider config
litellm_logging_obj.update_environment_variables( responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
model=model, ProviderConfigManager.get_provider_responses_api_config(
user=user, model=model,
optional_params=dict(responses_api_request_params), provider=litellm.LlmProviders(custom_llm_provider),
litellm_params={ )
"litellm_call_id": litellm_call_id, )
**responses_api_request_params,
},
custom_llm_provider=custom_llm_provider,
)
# Call the handler with _is_async flag instead of directly calling the async handler if responses_api_provider_config is None:
response = base_llm_http_handler.response_api_handler( raise litellm.BadRequestError(
model=model, model=model,
input=input, llm_provider=custom_llm_provider,
responses_api_provider_config=responses_api_provider_config, message=f"Responses API not available for custom_llm_provider={custom_llm_provider}, model: {model}",
response_api_optional_request_params=responses_api_request_params, )
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
logging_obj=litellm_logging_obj,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout or request_timeout,
_is_async=_is_async,
client=kwargs.get("client"),
)
return response local_vars.update(kwargs)
# Get ResponsesAPIOptionalRequestParams with only valid parameters
response_api_optional_params: ResponsesAPIOptionalRequestParams = (
ResponsesAPIRequestUtils.get_requested_response_api_optional_param(
local_vars
)
)
# Get optional parameters for the responses API
responses_api_request_params: Dict = (
ResponsesAPIRequestUtils.get_optional_params_responses_api(
model=model,
responses_api_provider_config=responses_api_provider_config,
response_api_optional_params=response_api_optional_params,
)
)
# Pre Call logging
litellm_logging_obj.update_environment_variables(
model=model,
user=user,
optional_params=dict(responses_api_request_params),
litellm_params={
"litellm_call_id": litellm_call_id,
**responses_api_request_params,
},
custom_llm_provider=custom_llm_provider,
)
# Call the handler with _is_async flag instead of directly calling the async handler
response = base_llm_http_handler.response_api_handler(
model=model,
input=input,
responses_api_provider_config=responses_api_provider_config,
response_api_optional_request_params=responses_api_request_params,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
logging_obj=litellm_logging_obj,
extra_headers=extra_headers,
extra_body=extra_body,
timeout=timeout or request_timeout,
_is_async=_is_async,
client=kwargs.get("client"),
)
return response
except Exception as e:
raise litellm.exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)

View file

@ -795,3 +795,34 @@ async def test_openai_responses_litellm_router_with_metadata():
loaded_request_body["metadata"] == test_metadata loaded_request_body["metadata"] == test_metadata
), "metadata in request body should match what was passed" ), "metadata in request body should match what was passed"
mock_post.assert_called_once() mock_post.assert_called_once()
def test_bad_request_bad_param_error():
"""Raise a BadRequestError when an invalid parameter value is provided"""
try:
litellm.responses(model="gpt-4o", input="This should fail", temperature=2000)
pytest.fail("Expected BadRequestError but no exception was raised")
except litellm.BadRequestError as e:
print(f"Exception raised: {e}")
print(f"Exception type: {type(e)}")
print(f"Exception args: {e.args}")
print(f"Exception details: {e.__dict__}")
except Exception as e:
pytest.fail(f"Unexpected exception raised: {e}")
@pytest.mark.asyncio()
async def test_async_bad_request_bad_param_error():
"""Raise a BadRequestError when an invalid parameter value is provided"""
try:
await litellm.aresponses(
model="gpt-4o", input="This should fail", temperature=2000
)
pytest.fail("Expected BadRequestError but no exception was raised")
except litellm.BadRequestError as e:
print(f"Exception raised: {e}")
print(f"Exception type: {type(e)}")
print(f"Exception args: {e.args}")
print(f"Exception details: {e.__dict__}")
except Exception as e:
pytest.fail(f"Unexpected exception raised: {e}")

View file

@ -0,0 +1,108 @@
import httpx
from openai import OpenAI, BadRequestError
import pytest
def generate_key():
"""Generate a key for testing"""
url = "http://0.0.0.0:4000/key/generate"
headers = {
"Authorization": "Bearer sk-1234",
"Content-Type": "application/json",
}
data = {}
response = httpx.post(url, headers=headers, json=data)
if response.status_code != 200:
raise Exception(f"Key generation failed with status: {response.status_code}")
return response.json()["key"]
def get_test_client():
"""Create OpenAI client with generated key"""
key = generate_key()
return OpenAI(api_key=key, base_url="http://0.0.0.0:4000")
def validate_response(response):
"""
Validate basic response structure from OpenAI responses API
"""
assert response is not None
assert hasattr(response, "choices")
assert len(response.choices) > 0
assert hasattr(response.choices[0], "message")
assert hasattr(response.choices[0].message, "content")
assert isinstance(response.choices[0].message.content, str)
assert hasattr(response, "id")
assert isinstance(response.id, str)
assert hasattr(response, "model")
assert isinstance(response.model, str)
assert hasattr(response, "created")
assert isinstance(response.created, int)
assert hasattr(response, "usage")
assert hasattr(response.usage, "prompt_tokens")
assert hasattr(response.usage, "completion_tokens")
assert hasattr(response.usage, "total_tokens")
def validate_stream_chunk(chunk):
"""
Validate streaming chunk structure from OpenAI responses API
"""
assert chunk is not None
assert hasattr(chunk, "choices")
assert len(chunk.choices) > 0
assert hasattr(chunk.choices[0], "delta")
# Some chunks might not have content in the delta
if (
hasattr(chunk.choices[0].delta, "content")
and chunk.choices[0].delta.content is not None
):
assert isinstance(chunk.choices[0].delta.content, str)
assert hasattr(chunk, "id")
assert isinstance(chunk.id, str)
assert hasattr(chunk, "model")
assert isinstance(chunk.model, str)
assert hasattr(chunk, "created")
assert isinstance(chunk.created, int)
def test_basic_response():
client = get_test_client()
response = client.responses.create(
model="gpt-4o", input="just respond with the word 'ping'"
)
print("basic response=", response)
def test_streaming_response():
client = get_test_client()
stream = client.responses.create(
model="gpt-4o", input="just respond with the word 'ping'", stream=True
)
collected_chunks = []
for chunk in stream:
print("stream chunk=", chunk)
collected_chunks.append(chunk)
assert len(collected_chunks) > 0
def test_bad_request_error():
client = get_test_client()
with pytest.raises(BadRequestError):
# Trigger error with invalid model name
client.responses.create(model="non-existent-model", input="This should fail")
def test_bad_request_bad_param_error():
client = get_test_client()
with pytest.raises(BadRequestError):
# Trigger error with invalid model name
client.responses.create(
model="gpt-4o", input="This should fail", temperature=2000
)