Merge pull request #9419 from BerriAI/litellm_streaming_o1_pro
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 19s
Helm unit test / unit-test (push) Successful in 21s

[Feat] OpenAI o1-pro Responses API streaming support
This commit is contained in:
Ishaan Jaff 2025-03-20 21:54:43 -07:00 committed by GitHub
commit c44fe8bd90
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 491 additions and 20 deletions

View file

@ -7,7 +7,6 @@ import httpx
from litellm.types.llms.openai import (
ResponseInputParam,
ResponsesAPIOptionalRequestParams,
ResponsesAPIRequestParams,
ResponsesAPIResponse,
ResponsesAPIStreamingResponse,
)
@ -97,7 +96,7 @@ class BaseResponsesAPIConfig(ABC):
response_api_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> ResponsesAPIRequestParams:
) -> Dict:
pass
@abstractmethod
@ -131,3 +130,12 @@ class BaseResponsesAPIConfig(ABC):
message=error_message,
headers=headers,
)
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""Returns True if litellm should fake a stream for the given model and stream value"""
return False

View file

@ -20,6 +20,7 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.responses.streaming_iterator import (
BaseResponsesAPIStreamingIterator,
MockResponsesAPIStreamingIterator,
ResponsesAPIStreamingIterator,
SyncResponsesAPIStreamingIterator,
)
@ -978,6 +979,7 @@ class BaseLLMHTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
_is_async: bool = False,
fake_stream: bool = False,
) -> Union[
ResponsesAPIResponse,
BaseResponsesAPIStreamingIterator,
@ -1003,6 +1005,7 @@ class BaseLLMHTTPHandler:
extra_body=extra_body,
timeout=timeout,
client=client if isinstance(client, AsyncHTTPHandler) else None,
fake_stream=fake_stream,
)
if client is None or not isinstance(client, HTTPHandler):
@ -1051,14 +1054,27 @@ class BaseLLMHTTPHandler:
try:
if stream:
# For streaming, use stream=True in the request
if fake_stream is True:
stream, data = self._prepare_fake_stream_request(
stream=stream,
data=data,
fake_stream=fake_stream,
)
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout
or response_api_optional_request_params.get("timeout"),
stream=True,
stream=stream,
)
if fake_stream is True:
return MockResponsesAPIStreamingIterator(
response=response,
model=model,
logging_obj=logging_obj,
responses_api_provider_config=responses_api_provider_config,
)
return SyncResponsesAPIStreamingIterator(
response=response,
@ -1100,6 +1116,7 @@ class BaseLLMHTTPHandler:
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
fake_stream: bool = False,
) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
"""
Async version of the responses API handler.
@ -1145,22 +1162,36 @@ class BaseLLMHTTPHandler:
"headers": headers,
},
)
# Check if streaming is requested
stream = response_api_optional_request_params.get("stream", False)
try:
if stream:
# For streaming, we need to use stream=True in the request
if fake_stream is True:
stream, data = self._prepare_fake_stream_request(
stream=stream,
data=data,
fake_stream=fake_stream,
)
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout
or response_api_optional_request_params.get("timeout"),
stream=True,
stream=stream,
)
if fake_stream is True:
return MockResponsesAPIStreamingIterator(
response=response,
model=model,
logging_obj=logging_obj,
responses_api_provider_config=responses_api_provider_config,
)
# Return the streaming iterator
return ResponsesAPIStreamingIterator(
response=response,
@ -1177,6 +1208,7 @@ class BaseLLMHTTPHandler:
timeout=timeout
or response_api_optional_request_params.get("timeout"),
)
except Exception as e:
raise self._handle_error(
e=e,
@ -1189,6 +1221,21 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj,
)
def _prepare_fake_stream_request(
self,
stream: bool,
data: dict,
fake_stream: bool,
) -> Tuple[bool, dict]:
"""
Handles preparing a request when `fake_stream` is True.
"""
if fake_stream is True:
stream = False
data.pop("stream", None)
return stream, data
return stream, data
def _handle_error(
self,
e: Exception,

View file

@ -65,10 +65,12 @@ class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
response_api_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> ResponsesAPIRequestParams:
) -> Dict:
"""No transform applied since inputs are in OpenAI spec already"""
return ResponsesAPIRequestParams(
model=model, input=input, **response_api_optional_request_params
return dict(
ResponsesAPIRequestParams(
model=model, input=input, **response_api_optional_request_params
)
)
def transform_response_api_response(
@ -188,3 +190,27 @@ class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
raise ValueError(f"Unknown event type: {event_type}")
return model_class
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
if stream is not True:
return False
if model is not None:
try:
if (
litellm.utils.supports_native_streaming(
model=model,
custom_llm_provider=custom_llm_provider,
)
is False
):
return True
except Exception as e:
verbose_logger.debug(
f"Error getting model info in OpenAIResponsesAPIConfig: {e}"
)
return False

View file

@ -232,6 +232,9 @@ def responses(
timeout=timeout or request_timeout,
_is_async=_is_async,
client=kwargs.get("client"),
fake_stream=responses_api_provider_config.should_fake_stream(
model=model, stream=stream, custom_llm_provider=custom_llm_provider
),
)
return response

View file

@ -11,6 +11,7 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.types.llms.openai import (
ResponseCompletedEvent,
ResponsesAPIStreamEvents,
ResponsesAPIStreamingResponse,
)
@ -207,3 +208,63 @@ class SyncResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
start_time=self.start_time,
end_time=datetime.now(),
)
class MockResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
"""
mock iterator - some models like o1-pro do not support streaming, we need to fake a stream
"""
def __init__(
self,
response: httpx.Response,
model: str,
responses_api_provider_config: BaseResponsesAPIConfig,
logging_obj: LiteLLMLoggingObj,
):
self.raw_http_response = response
super().__init__(
response=response,
model=model,
responses_api_provider_config=responses_api_provider_config,
logging_obj=logging_obj,
)
self.is_done = False
def __aiter__(self):
return self
async def __anext__(self) -> ResponsesAPIStreamingResponse:
if self.is_done:
raise StopAsyncIteration
self.is_done = True
transformed_response = (
self.responses_api_provider_config.transform_response_api_response(
model=self.model,
raw_response=self.raw_http_response,
logging_obj=self.logging_obj,
)
)
return ResponseCompletedEvent(
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
response=transformed_response,
)
def __iter__(self):
return self
def __next__(self) -> ResponsesAPIStreamingResponse:
if self.is_done:
raise StopIteration
self.is_done = True
transformed_response = (
self.responses_api_provider_config.transform_response_api_response(
model=self.model,
raw_response=self.raw_http_response,
logging_obj=self.logging_obj,
)
)
return ResponseCompletedEvent(
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
response=transformed_response,
)

View file

@ -1975,6 +1975,39 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
)
def supports_native_streaming(model: str, custom_llm_provider: Optional[str]) -> bool:
"""
Check if the given model supports native streaming and return a boolean value.
Parameters:
model (str): The model name to be checked.
custom_llm_provider (str): The provider to be checked.
Returns:
bool: True if the model supports native streaming, False otherwise.
Raises:
Exception: If the given model is not found in model_prices_and_context_window.json.
"""
try:
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider
)
model_info = _get_model_info_helper(
model=model, custom_llm_provider=custom_llm_provider
)
supports_native_streaming = model_info.get("supports_native_streaming", True)
if supports_native_streaming is None:
supports_native_streaming = True
return supports_native_streaming
except Exception as e:
verbose_logger.debug(
f"Model not found or error in checking supports_native_streaming support. You passed model={model}, custom_llm_provider={custom_llm_provider}. Error: {str(e)}"
)
return False
def supports_response_schema(
model: str, custom_llm_provider: Optional[str] = None
) -> bool: