Merge pull request #9220 from BerriAI/litellm_qa_responses_api

[Fixes] Responses API - allow /responses and subpaths as LLM API route + Add exception mapping for responses API
This commit is contained in:
Ishaan Jaff 2025-03-13 21:36:59 -07:00 committed by GitHub
commit ceb8668e4a
16 changed files with 340 additions and 76 deletions

View file

@ -1357,7 +1357,7 @@ jobs:
# Store test results
- store_test_results:
path: test-results
e2e_openai_misc_endpoints:
e2e_openai_endpoints:
machine:
image: ubuntu-2204:2023.10.1
resource_class: xlarge
@ -1474,7 +1474,7 @@ jobs:
command: |
pwd
ls
python -m pytest -s -vv tests/openai_misc_endpoints_tests --junitxml=test-results/junit.xml --durations=5
python -m pytest -s -vv tests/openai_endpoints_tests --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m
# Store test results
@ -2429,7 +2429,7 @@ workflows:
only:
- main
- /litellm_.*/
- e2e_openai_misc_endpoints:
- e2e_openai_endpoints:
filters:
branches:
only:
@ -2571,7 +2571,7 @@ workflows:
requires:
- local_testing
- build_and_test
- e2e_openai_misc_endpoints
- e2e_openai_endpoints
- load_testing
- test_bad_database_url
- llm_translation_testing

View file

@ -127,7 +127,7 @@ def exception_type( # type: ignore # noqa: PLR0915
completion_kwargs={},
extra_kwargs={},
):
"""Maps an LLM Provider Exception to OpenAI Exception Format"""
if any(
isinstance(original_exception, exc_type)
for exc_type in litellm.LITELLM_EXCEPTION_TYPES

View file

@ -248,6 +248,13 @@ class LiteLLMRoutes(enum.Enum):
"/v1/realtime",
"/realtime?{model}",
"/v1/realtime?{model}",
# responses API
"/responses",
"/v1/responses",
"/responses/{response_id}",
"/v1/responses/{response_id}",
"/responses/{response_id}/input_items",
"/v1/responses/{response_id}/input_items",
]
mapped_pass_through_routes = [

View file

@ -78,3 +78,93 @@ async def responses_api(
proxy_logging_obj=proxy_logging_obj,
version=version,
)
@router.get(
"/v1/responses/{response_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
@router.get(
"/responses/{response_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
async def get_response(
response_id: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get a response by ID.
Follows the OpenAI Responses API spec: https://platform.openai.com/docs/api-reference/responses/get
```bash
curl -X GET http://localhost:4000/v1/responses/resp_abc123 \
-H "Authorization: Bearer sk-1234"
```
"""
# TODO: Implement response retrieval logic
pass
@router.delete(
"/v1/responses/{response_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
@router.delete(
"/responses/{response_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
async def delete_response(
response_id: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete a response by ID.
Follows the OpenAI Responses API spec: https://platform.openai.com/docs/api-reference/responses/delete
```bash
curl -X DELETE http://localhost:4000/v1/responses/resp_abc123 \
-H "Authorization: Bearer sk-1234"
```
"""
# TODO: Implement response deletion logic
pass
@router.get(
"/v1/responses/{response_id}/input_items",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
@router.get(
"/responses/{response_id}/input_items",
dependencies=[Depends(user_api_key_auth)],
tags=["responses"],
)
async def get_response_input_items(
response_id: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get input items for a response.
Follows the OpenAI Responses API spec: https://platform.openai.com/docs/api-reference/responses/input-items
```bash
curl -X GET http://localhost:4000/v1/responses/resp_abc123/input_items \
-H "Authorization: Bearer sk-1234"
```
"""
# TODO: Implement input items retrieval logic
pass

View file

@ -58,15 +58,24 @@ async def aresponses(
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
"""
Async: Handles responses API requests by reusing the synchronous function
"""
local_vars = locals()
try:
loop = asyncio.get_event_loop()
kwargs["aresponses"] = True
# get custom llm provider so we can use this for mapping exceptions
if custom_llm_provider is None:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, api_base=local_vars.get("base_url", None)
)
func = partial(
responses,
input=input,
@ -91,6 +100,7 @@ async def aresponses(
extra_query=extra_query,
extra_body=extra_body,
timeout=timeout,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
@ -104,7 +114,13 @@ async def aresponses(
response = init_response
return response
except Exception as e:
raise e
raise litellm.exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)
@client
@ -133,12 +149,16 @@ def responses(
extra_query: Optional[Dict[str, Any]] = None,
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
# LiteLLM specific params,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
"""
Synchronous version of the Responses API.
Uses the synchronous HTTP handler to make requests.
"""
local_vars = locals()
try:
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
_is_async = kwargs.pop("aresponses", False) is True
@ -148,7 +168,7 @@ def responses(
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
litellm.get_llm_provider(
model=model,
custom_llm_provider=kwargs.get("custom_llm_provider", None),
custom_llm_provider=custom_llm_provider,
api_base=litellm_params.api_base,
api_key=litellm_params.api_key,
)
@ -169,12 +189,12 @@ def responses(
message=f"Responses API not available for custom_llm_provider={custom_llm_provider}, model: {model}",
)
# Get all parameters using locals() and combine with kwargs
local_vars = locals()
local_vars.update(kwargs)
# Get ResponsesAPIOptionalRequestParams with only valid parameters
response_api_optional_params: ResponsesAPIOptionalRequestParams = (
ResponsesAPIRequestUtils.get_requested_response_api_optional_param(local_vars)
ResponsesAPIRequestUtils.get_requested_response_api_optional_param(
local_vars
)
)
# Get optional parameters for the responses API
@ -215,3 +235,11 @@ def responses(
)
return response
except Exception as e:
raise litellm.exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
)

View file

@ -795,3 +795,34 @@ async def test_openai_responses_litellm_router_with_metadata():
loaded_request_body["metadata"] == test_metadata
), "metadata in request body should match what was passed"
mock_post.assert_called_once()
def test_bad_request_bad_param_error():
"""Raise a BadRequestError when an invalid parameter value is provided"""
try:
litellm.responses(model="gpt-4o", input="This should fail", temperature=2000)
pytest.fail("Expected BadRequestError but no exception was raised")
except litellm.BadRequestError as e:
print(f"Exception raised: {e}")
print(f"Exception type: {type(e)}")
print(f"Exception args: {e.args}")
print(f"Exception details: {e.__dict__}")
except Exception as e:
pytest.fail(f"Unexpected exception raised: {e}")
@pytest.mark.asyncio()
async def test_async_bad_request_bad_param_error():
"""Raise a BadRequestError when an invalid parameter value is provided"""
try:
await litellm.aresponses(
model="gpt-4o", input="This should fail", temperature=2000
)
pytest.fail("Expected BadRequestError but no exception was raised")
except litellm.BadRequestError as e:
print(f"Exception raised: {e}")
print(f"Exception type: {type(e)}")
print(f"Exception args: {e.args}")
print(f"Exception details: {e.__dict__}")
except Exception as e:
pytest.fail(f"Unexpected exception raised: {e}")

View file

@ -0,0 +1,108 @@
import httpx
from openai import OpenAI, BadRequestError
import pytest
def generate_key():
"""Generate a key for testing"""
url = "http://0.0.0.0:4000/key/generate"
headers = {
"Authorization": "Bearer sk-1234",
"Content-Type": "application/json",
}
data = {}
response = httpx.post(url, headers=headers, json=data)
if response.status_code != 200:
raise Exception(f"Key generation failed with status: {response.status_code}")
return response.json()["key"]
def get_test_client():
"""Create OpenAI client with generated key"""
key = generate_key()
return OpenAI(api_key=key, base_url="http://0.0.0.0:4000")
def validate_response(response):
"""
Validate basic response structure from OpenAI responses API
"""
assert response is not None
assert hasattr(response, "choices")
assert len(response.choices) > 0
assert hasattr(response.choices[0], "message")
assert hasattr(response.choices[0].message, "content")
assert isinstance(response.choices[0].message.content, str)
assert hasattr(response, "id")
assert isinstance(response.id, str)
assert hasattr(response, "model")
assert isinstance(response.model, str)
assert hasattr(response, "created")
assert isinstance(response.created, int)
assert hasattr(response, "usage")
assert hasattr(response.usage, "prompt_tokens")
assert hasattr(response.usage, "completion_tokens")
assert hasattr(response.usage, "total_tokens")
def validate_stream_chunk(chunk):
"""
Validate streaming chunk structure from OpenAI responses API
"""
assert chunk is not None
assert hasattr(chunk, "choices")
assert len(chunk.choices) > 0
assert hasattr(chunk.choices[0], "delta")
# Some chunks might not have content in the delta
if (
hasattr(chunk.choices[0].delta, "content")
and chunk.choices[0].delta.content is not None
):
assert isinstance(chunk.choices[0].delta.content, str)
assert hasattr(chunk, "id")
assert isinstance(chunk.id, str)
assert hasattr(chunk, "model")
assert isinstance(chunk.model, str)
assert hasattr(chunk, "created")
assert isinstance(chunk.created, int)
def test_basic_response():
client = get_test_client()
response = client.responses.create(
model="gpt-4o", input="just respond with the word 'ping'"
)
print("basic response=", response)
def test_streaming_response():
client = get_test_client()
stream = client.responses.create(
model="gpt-4o", input="just respond with the word 'ping'", stream=True
)
collected_chunks = []
for chunk in stream:
print("stream chunk=", chunk)
collected_chunks.append(chunk)
assert len(collected_chunks) > 0
def test_bad_request_error():
client = get_test_client()
with pytest.raises(BadRequestError):
# Trigger error with invalid model name
client.responses.create(model="non-existent-model", input="This should fail")
def test_bad_request_bad_param_error():
client = get_test_client()
with pytest.raises(BadRequestError):
# Trigger error with invalid model name
client.responses.create(
model="gpt-4o", input="This should fail", temperature=2000
)