diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index a303c4572c..6f7671c369 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -1,6 +1,6 @@ import io import json -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union import httpx # type: ignore @@ -18,7 +18,11 @@ from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, get_async_httpx_client, ) -from litellm.responses.streaming_iterator import ResponsesAPIStreamingIterator +from litellm.responses.streaming_iterator import ( + BaseResponsesAPIStreamingIterator, + ResponsesAPIStreamingIterator, + SyncResponsesAPIStreamingIterator, +) from litellm.types.llms.openai import ( ResponseInputParam, ResponsesAPIOptionalRequestParams, @@ -961,32 +965,164 @@ class BaseLLMHTTPHandler: return returned_response return model_response - async def async_response_api_handler( + def response_api_handler( self, model: str, - custom_llm_provider: str, input: Union[str, ResponseInputParam], responses_api_provider_config: BaseResponsesAPIConfig, response_api_optional_request_params: Dict, - logging_obj: LiteLLMLoggingObj, + custom_llm_provider: str, litellm_params: GenericLiteLLMParams, - client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + logging_obj: LiteLLMLoggingObj, extra_headers: Optional[Dict[str, Any]] = None, extra_body: Optional[Dict[str, Any]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None, - ) -> Union[ResponsesAPIResponse, ResponsesAPIStreamingIterator]: - if client is None or not isinstance(client, AsyncHTTPHandler): - async_httpx_client = get_async_httpx_client( - llm_provider=litellm.LlmProviders(custom_llm_provider) + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + _is_async: bool = False, + ) -> Union[ + ResponsesAPIResponse, + BaseResponsesAPIStreamingIterator, + Coroutine[ + Any, Any, Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator] + ], + ]: + """ + Handles responses API requests. + When _is_async=True, returns a coroutine instead of making the call directly. + """ + if _is_async: + # Return the async coroutine if called with _is_async=True + return self.async_response_api_handler( + model=model, + input=input, + responses_api_provider_config=responses_api_provider_config, + response_api_optional_request_params=response_api_optional_request_params, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + logging_obj=logging_obj, + extra_headers=extra_headers, + extra_body=extra_body, + timeout=timeout, + client=client if isinstance(client, AsyncHTTPHandler) else None, + ) + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client( + params={"ssl_verify": litellm_params.get("ssl_verify", None)} ) else: - async_httpx_client = client + sync_httpx_client = client + headers = responses_api_provider_config.validate_environment( api_key=litellm_params.api_key, headers=response_api_optional_request_params.get("extra_headers", {}) or {}, model=model, ) + if extra_headers: + headers.update(extra_headers) + + api_base = responses_api_provider_config.get_complete_url( + api_base=litellm_params.api_base, + model=model, + ) + + data = responses_api_provider_config.transform_responses_api_request( + model=model, + input=input, + response_api_optional_request_params=response_api_optional_request_params, + litellm_params=litellm_params, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + # Check if streaming is requested + stream = response_api_optional_request_params.get("stream", False) + + try: + if stream: + # For streaming, use stream=True in the request + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout + or response_api_optional_request_params.get("timeout"), + stream=True, + ) + + return SyncResponsesAPIStreamingIterator( + response=response, + model=model, + logging_obj=logging_obj, + responses_api_provider_config=responses_api_provider_config, + ) + else: + # For non-streaming requests + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout + or response_api_optional_request_params.get("timeout"), + ) + except Exception as e: + raise self._handle_error( + e=e, + provider_config=responses_api_provider_config, + ) + + return responses_api_provider_config.transform_response_api_response( + model=model, + raw_response=response, + logging_obj=logging_obj, + ) + + async def async_response_api_handler( + self, + model: str, + input: Union[str, ResponseInputParam], + responses_api_provider_config: BaseResponsesAPIConfig, + response_api_optional_request_params: Dict, + custom_llm_provider: str, + litellm_params: GenericLiteLLMParams, + logging_obj: LiteLLMLoggingObj, + extra_headers: Optional[Dict[str, Any]] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]: + """ + Async version of the responses API handler. + Uses async HTTP client to make requests. + """ + if client is None or not isinstance(client, AsyncHTTPHandler): + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider), + params={"ssl_verify": litellm_params.get("ssl_verify", None)}, + ) + else: + async_httpx_client = client + + headers = responses_api_provider_config.validate_environment( + api_key=litellm_params.api_key, + headers=response_api_optional_request_params.get("extra_headers", {}) or {}, + model=model, + ) + + if extra_headers: + headers.update(extra_headers) + api_base = responses_api_provider_config.get_complete_url( api_base=litellm_params.api_base, model=model, @@ -1021,7 +1157,8 @@ class BaseLLMHTTPHandler: url=api_base, headers=headers, data=json.dumps(data), - timeout=response_api_optional_request_params.get("timeout"), + timeout=timeout + or response_api_optional_request_params.get("timeout"), stream=True, ) @@ -1038,7 +1175,8 @@ class BaseLLMHTTPHandler: url=api_base, headers=headers, data=json.dumps(data), - timeout=response_api_optional_request_params.get("timeout"), + timeout=timeout + or response_api_optional_request_params.get("timeout"), ) except Exception as e: raise self._handle_error( diff --git a/litellm/responses/main.py b/litellm/responses/main.py index 337e7fc3b0..62d3ddf215 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -1,3 +1,6 @@ +import asyncio +import contextvars +from functools import partial from typing import Any, Dict, Iterable, List, Literal, Optional, Union, get_type_hints import httpx @@ -23,7 +26,10 @@ from litellm.types.llms.openai import ( from litellm.types.router import GenericLiteLLMParams from litellm.utils import ProviderConfigManager, client -from .streaming_iterator import ResponsesAPIStreamingIterator +from .streaming_iterator import ( + BaseResponsesAPIStreamingIterator, + ResponsesAPIStreamingIterator, +) ####### ENVIRONMENT VARIABLES ################### # Initialize any necessary instances or variables here @@ -75,9 +81,89 @@ async def aresponses( extra_body: Optional[Dict[str, Any]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None, **kwargs, -) -> Union[ResponsesAPIResponse, ResponsesAPIStreamingIterator]: +) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]: + """ + Async: Handles responses API requests by reusing the synchronous function + """ + try: + loop = asyncio.get_event_loop() + kwargs["aresponses"] = True + + func = partial( + responses, + input, + model, + include, + instructions, + max_output_tokens, + metadata, + parallel_tool_calls, + previous_response_id, + reasoning, + store, + stream, + temperature, + text, + tool_choice, + tools, + top_p, + truncation, + user, + extra_headers, + extra_query, + extra_body, + timeout, + **kwargs, + ) + + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + init_response = await loop.run_in_executor(None, func_with_context) + + if asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response + return response + except Exception as e: + raise e + + +@client +def responses( + input: Union[str, ResponseInputParam], + model: str, + include: Optional[List[ResponseIncludable]] = None, + instructions: Optional[str] = None, + max_output_tokens: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + parallel_tool_calls: Optional[bool] = None, + previous_response_id: Optional[str] = None, + reasoning: Optional[Reasoning] = None, + store: Optional[bool] = None, + stream: Optional[bool] = None, + temperature: Optional[float] = None, + text: Optional[ResponseTextConfigParam] = None, + tool_choice: Optional[ToolChoice] = None, + tools: Optional[Iterable[ToolParam]] = None, + top_p: Optional[float] = None, + truncation: Optional[Literal["auto", "disabled"]] = None, + user: Optional[str] = None, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Optional[Dict[str, Any]] = None, + extra_query: Optional[Dict[str, Any]] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + **kwargs, +): + """ + Synchronous version of the Responses API. + Uses the synchronous HTTP handler to make requests. + """ litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None) + _is_async = kwargs.pop("aresponses", False) is True # get llm provider logic litellm_params = GenericLiteLLMParams(**kwargs) @@ -132,7 +218,11 @@ async def aresponses( custom_llm_provider=custom_llm_provider, ) - response = await base_llm_http_handler.async_response_api_handler( + # Get an instance of BaseLLMHTTPHandler + base_llm_http_handler_instance = BaseLLMHTTPHandler() + + # Call the handler with _is_async flag instead of directly calling the async handler + response = base_llm_http_handler_instance.response_api_handler( model=model, input=input, responses_api_provider_config=responses_api_provider_config, @@ -143,34 +233,8 @@ async def aresponses( extra_headers=extra_headers, extra_body=extra_body, timeout=timeout, + _is_async=_is_async, + client=kwargs.get("client"), ) + return response - - -def responses( - input: Union[str, ResponseInputParam], - model: str, - include: Optional[List[ResponseIncludable]] = None, - instructions: Optional[str] = None, - max_output_tokens: Optional[int] = None, - metadata: Optional[Dict[str, Any]] = None, - parallel_tool_calls: Optional[bool] = None, - previous_response_id: Optional[str] = None, - reasoning: Optional[Reasoning] = None, - store: Optional[bool] = None, - stream: Optional[bool] = None, - temperature: Optional[float] = None, - text: Optional[ResponseTextConfigParam] = None, - tool_choice: Optional[ToolChoice] = None, - tools: Optional[Iterable[ToolParam]] = None, - top_p: Optional[float] = None, - truncation: Optional[Literal["auto", "disabled"]] = None, - user: Optional[str] = None, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Optional[Dict[str, Any]] = None, - extra_query: Optional[Dict[str, Any]] = None, - extra_body: Optional[Dict[str, Any]] = None, - timeout: Optional[Union[float, httpx.Timeout]] = None, -): - pass diff --git a/tests/llm_responses_api_testing/test_openai_responses_api.py b/tests/llm_responses_api_testing/test_openai_responses_api.py index 9745269bef..15192c1b55 100644 --- a/tests/llm_responses_api_testing/test_openai_responses_api.py +++ b/tests/llm_responses_api_testing/test_openai_responses_api.py @@ -12,28 +12,40 @@ from litellm.types.utils import StandardLoggingPayload from litellm.types.llms.openai import ResponseCompletedEvent, ResponsesAPIResponse +@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio -async def test_basic_openai_responses_api(): +async def test_basic_openai_responses_api(sync_mode): litellm._turn_on_debug() - response = await litellm.aresponses( - model="gpt-4o", input="Tell me a three sentence bedtime story about a unicorn." - ) + + if sync_mode: + response = litellm.responses(model="gpt-4o", input="Basic ping") + else: + response = await litellm.aresponses(model="gpt-4o", input="Basic ping") + print("litellm response=", json.dumps(response, indent=4, default=str)) - # validate_responses_api_response() - +@pytest.mark.parametrize("sync_mode", [True]) @pytest.mark.asyncio -async def test_basic_openai_responses_api_streaming(): +async def test_basic_openai_responses_api_streaming(sync_mode): litellm._turn_on_debug() - response = await litellm.aresponses( - model="gpt-4o", - input="Tell me a three sentence bedtime story about a unicorn.", - stream=True, - ) - async for event in response: - print("litellm response=", json.dumps(event, indent=4, default=str)) + if sync_mode: + response = litellm.responses( + model="gpt-4o", + input="Basic ping", + stream=True, + ) + for event in response: + print("litellm response=", json.dumps(event, indent=4, default=str)) + else: + response = await litellm.aresponses( + model="gpt-4o", + input="Basic ping", + stream=True, + ) + async for event in response: + print("litellm response=", json.dumps(event, indent=4, default=str)) class TestCustomLogger(CustomLogger):