Merge pull request #9419 from BerriAI/litellm_streaming_o1_pro
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 19s
Helm unit test / unit-test (push) Successful in 21s

[Feat] OpenAI o1-pro Responses API streaming support
This commit is contained in:
Ishaan Jaff 2025-03-20 21:54:43 -07:00 committed by GitHub
commit c44fe8bd90
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 491 additions and 20 deletions

View file

@ -49,7 +49,7 @@ jobs:
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.66.1
pip install openai==1.67.0
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
@ -168,7 +168,7 @@ jobs:
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.66.1
pip install openai==1.67.0
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
@ -268,7 +268,7 @@ jobs:
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.66.1
pip install openai==1.67.0
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
@ -513,7 +513,7 @@ jobs:
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
pip install openai==1.66.1
pip install openai==1.67.0
pip install prisma==0.11.0
pip install "detect_secrets==1.5.0"
pip install "httpx==0.24.1"
@ -1278,7 +1278,7 @@ jobs:
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
pip install "openai==1.66.1"
pip install "openai==1.67.0"
- run:
name: Install Grype
command: |
@ -1414,7 +1414,7 @@ jobs:
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
pip install "openai==1.66.1"
pip install "openai==1.67.0"
# Run pytest and generate JUnit XML report
- run:
name: Build Docker image
@ -1536,7 +1536,7 @@ jobs:
pip install "aiodynamo==23.10.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
pip install "openai==1.66.1"
pip install "openai==1.67.0"
- run:
name: Build Docker image
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
@ -1965,7 +1965,7 @@ jobs:
pip install "pytest-asyncio==0.21.1"
pip install "google-cloud-aiplatform==1.43.0"
pip install aiohttp
pip install "openai==1.66.1"
pip install "openai==1.67.0"
pip install "assemblyai==0.37.0"
python -m pip install --upgrade pip
pip install "pydantic==2.7.1"
@ -2241,7 +2241,7 @@ jobs:
pip install "pytest-retry==1.6.3"
pip install "pytest-asyncio==0.21.1"
pip install aiohttp
pip install "openai==1.66.1"
pip install "openai==1.67.0"
python -m pip install --upgrade pip
pip install "pydantic==2.7.1"
pip install "pytest==7.3.1"

View file

@ -1,5 +1,5 @@
# used by CI/CD testing
openai==1.66.1
openai==1.67.0
python-dotenv
tiktoken
importlib_metadata

View file

@ -7,7 +7,6 @@ import httpx
from litellm.types.llms.openai import (
ResponseInputParam,
ResponsesAPIOptionalRequestParams,
ResponsesAPIRequestParams,
ResponsesAPIResponse,
ResponsesAPIStreamingResponse,
)
@ -97,7 +96,7 @@ class BaseResponsesAPIConfig(ABC):
response_api_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> ResponsesAPIRequestParams:
) -> Dict:
pass
@abstractmethod
@ -131,3 +130,12 @@ class BaseResponsesAPIConfig(ABC):
message=error_message,
headers=headers,
)
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""Returns True if litellm should fake a stream for the given model and stream value"""
return False

View file

@ -20,6 +20,7 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.responses.streaming_iterator import (
BaseResponsesAPIStreamingIterator,
MockResponsesAPIStreamingIterator,
ResponsesAPIStreamingIterator,
SyncResponsesAPIStreamingIterator,
)
@ -978,6 +979,7 @@ class BaseLLMHTTPHandler:
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
_is_async: bool = False,
fake_stream: bool = False,
) -> Union[
ResponsesAPIResponse,
BaseResponsesAPIStreamingIterator,
@ -1003,6 +1005,7 @@ class BaseLLMHTTPHandler:
extra_body=extra_body,
timeout=timeout,
client=client if isinstance(client, AsyncHTTPHandler) else None,
fake_stream=fake_stream,
)
if client is None or not isinstance(client, HTTPHandler):
@ -1051,14 +1054,27 @@ class BaseLLMHTTPHandler:
try:
if stream:
# For streaming, use stream=True in the request
if fake_stream is True:
stream, data = self._prepare_fake_stream_request(
stream=stream,
data=data,
fake_stream=fake_stream,
)
response = sync_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout
or response_api_optional_request_params.get("timeout"),
stream=True,
stream=stream,
)
if fake_stream is True:
return MockResponsesAPIStreamingIterator(
response=response,
model=model,
logging_obj=logging_obj,
responses_api_provider_config=responses_api_provider_config,
)
return SyncResponsesAPIStreamingIterator(
response=response,
@ -1100,6 +1116,7 @@ class BaseLLMHTTPHandler:
extra_body: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
fake_stream: bool = False,
) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
"""
Async version of the responses API handler.
@ -1145,22 +1162,36 @@ class BaseLLMHTTPHandler:
"headers": headers,
},
)
# Check if streaming is requested
stream = response_api_optional_request_params.get("stream", False)
try:
if stream:
# For streaming, we need to use stream=True in the request
if fake_stream is True:
stream, data = self._prepare_fake_stream_request(
stream=stream,
data=data,
fake_stream=fake_stream,
)
response = await async_httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout
or response_api_optional_request_params.get("timeout"),
stream=True,
stream=stream,
)
if fake_stream is True:
return MockResponsesAPIStreamingIterator(
response=response,
model=model,
logging_obj=logging_obj,
responses_api_provider_config=responses_api_provider_config,
)
# Return the streaming iterator
return ResponsesAPIStreamingIterator(
response=response,
@ -1177,6 +1208,7 @@ class BaseLLMHTTPHandler:
timeout=timeout
or response_api_optional_request_params.get("timeout"),
)
except Exception as e:
raise self._handle_error(
e=e,
@ -1189,6 +1221,21 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj,
)
def _prepare_fake_stream_request(
self,
stream: bool,
data: dict,
fake_stream: bool,
) -> Tuple[bool, dict]:
"""
Handles preparing a request when `fake_stream` is True.
"""
if fake_stream is True:
stream = False
data.pop("stream", None)
return stream, data
return stream, data
def _handle_error(
self,
e: Exception,

View file

@ -65,10 +65,12 @@ class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
response_api_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> ResponsesAPIRequestParams:
) -> Dict:
"""No transform applied since inputs are in OpenAI spec already"""
return ResponsesAPIRequestParams(
model=model, input=input, **response_api_optional_request_params
return dict(
ResponsesAPIRequestParams(
model=model, input=input, **response_api_optional_request_params
)
)
def transform_response_api_response(
@ -188,3 +190,27 @@ class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
raise ValueError(f"Unknown event type: {event_type}")
return model_class
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
if stream is not True:
return False
if model is not None:
try:
if (
litellm.utils.supports_native_streaming(
model=model,
custom_llm_provider=custom_llm_provider,
)
is False
):
return True
except Exception as e:
verbose_logger.debug(
f"Error getting model info in OpenAIResponsesAPIConfig: {e}"
)
return False

View file

@ -232,6 +232,9 @@ def responses(
timeout=timeout or request_timeout,
_is_async=_is_async,
client=kwargs.get("client"),
fake_stream=responses_api_provider_config.should_fake_stream(
model=model, stream=stream, custom_llm_provider=custom_llm_provider
),
)
return response

View file

@ -11,6 +11,7 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.types.llms.openai import (
ResponseCompletedEvent,
ResponsesAPIStreamEvents,
ResponsesAPIStreamingResponse,
)
@ -207,3 +208,63 @@ class SyncResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
start_time=self.start_time,
end_time=datetime.now(),
)
class MockResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
"""
mock iterator - some models like o1-pro do not support streaming, we need to fake a stream
"""
def __init__(
self,
response: httpx.Response,
model: str,
responses_api_provider_config: BaseResponsesAPIConfig,
logging_obj: LiteLLMLoggingObj,
):
self.raw_http_response = response
super().__init__(
response=response,
model=model,
responses_api_provider_config=responses_api_provider_config,
logging_obj=logging_obj,
)
self.is_done = False
def __aiter__(self):
return self
async def __anext__(self) -> ResponsesAPIStreamingResponse:
if self.is_done:
raise StopAsyncIteration
self.is_done = True
transformed_response = (
self.responses_api_provider_config.transform_response_api_response(
model=self.model,
raw_response=self.raw_http_response,
logging_obj=self.logging_obj,
)
)
return ResponseCompletedEvent(
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
response=transformed_response,
)
def __iter__(self):
return self
def __next__(self) -> ResponsesAPIStreamingResponse:
if self.is_done:
raise StopIteration
self.is_done = True
transformed_response = (
self.responses_api_provider_config.transform_response_api_response(
model=self.model,
raw_response=self.raw_http_response,
logging_obj=self.logging_obj,
)
)
return ResponseCompletedEvent(
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
response=transformed_response,
)

View file

@ -1975,6 +1975,39 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
)
def supports_native_streaming(model: str, custom_llm_provider: Optional[str]) -> bool:
"""
Check if the given model supports native streaming and return a boolean value.
Parameters:
model (str): The model name to be checked.
custom_llm_provider (str): The provider to be checked.
Returns:
bool: True if the model supports native streaming, False otherwise.
Raises:
Exception: If the given model is not found in model_prices_and_context_window.json.
"""
try:
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, custom_llm_provider=custom_llm_provider
)
model_info = _get_model_info_helper(
model=model, custom_llm_provider=custom_llm_provider
)
supports_native_streaming = model_info.get("supports_native_streaming", True)
if supports_native_streaming is None:
supports_native_streaming = True
return supports_native_streaming
except Exception as e:
verbose_logger.debug(
f"Model not found or error in checking supports_native_streaming support. You passed model={model}, custom_llm_provider={custom_llm_provider}. Error: {str(e)}"
)
return False
def supports_response_schema(
model: str, custom_llm_provider: Optional[str] = None
) -> bool:

View file

@ -1,7 +1,7 @@
# LITELLM PROXY DEPENDENCIES #
anyio==4.4.0 # openai + http req.
httpx==0.27.0 # Pin Httpx dependency
openai==1.66.1 # openai req.
openai==1.67.0 # openai req.
fastapi==0.115.5 # server dep
backoff==2.2.1 # server dep
pyyaml==6.0.2 # server dep

View file

@ -0,0 +1,77 @@
import io
import os
import pathlib
import ssl
import sys
from unittest.mock import MagicMock
import pytest
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
def test_prepare_fake_stream_request():
# Initialize the BaseLLMHTTPHandler
handler = BaseLLMHTTPHandler()
# Test case 1: fake_stream is True
stream = True
data = {
"stream": True,
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
}
fake_stream = True
result_stream, result_data = handler._prepare_fake_stream_request(
stream=stream, data=data, fake_stream=fake_stream
)
# Verify that stream is set to False
assert result_stream is False
# Verify that "stream" key is removed from data
assert "stream" not in result_data
# Verify other data remains unchanged
assert result_data["model"] == "gpt-4"
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]
# Test case 2: fake_stream is False
stream = True
data = {
"stream": True,
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}],
}
fake_stream = False
result_stream, result_data = handler._prepare_fake_stream_request(
stream=stream, data=data, fake_stream=fake_stream
)
# Verify that stream remains True
assert result_stream is True
# Verify that data remains unchanged
assert "stream" in result_data
assert result_data["stream"] is True
assert result_data["model"] == "gpt-4"
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]
# Test case 3: data doesn't have stream key but fake_stream is True
stream = True
data = {"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}
fake_stream = True
result_stream, result_data = handler._prepare_fake_stream_request(
stream=stream, data=data, fake_stream=fake_stream
)
# Verify that stream is set to False
assert result_stream is False
# Verify that data remains unchanged (since there was no stream key to remove)
assert "stream" not in result_data
assert result_data["model"] == "gpt-4"
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]

View file

@ -94,7 +94,7 @@ def validate_responses_api_response(response, final_chunk: bool = False):
@pytest.mark.asyncio
async def test_basic_openai_responses_api(sync_mode):
litellm._turn_on_debug()
litellm.set_verbose = True
if sync_mode:
response = litellm.responses(
model="gpt-4o", input="Basic ping", max_output_tokens=20
@ -826,3 +826,219 @@ async def test_async_bad_request_bad_param_error():
print(f"Exception details: {e.__dict__}")
except Exception as e:
pytest.fail(f"Unexpected exception raised: {e}")
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_openai_o1_pro_response_api(sync_mode):
"""
Test that LiteLLM correctly handles an incomplete response from OpenAI's o1-pro model
due to reaching max_output_tokens limit.
"""
# Mock response from o1-pro
mock_response = {
"id": "resp_67dc3dd77b388190822443a85252da5a0e13d8bdc0e28d88",
"object": "response",
"created_at": 1742486999,
"status": "incomplete",
"error": None,
"incomplete_details": {"reason": "max_output_tokens"},
"instructions": None,
"max_output_tokens": 20,
"model": "o1-pro-2025-03-19",
"output": [
{
"type": "reasoning",
"id": "rs_67dc3de50f64819097450ed50a33d5f90e13d8bdc0e28d88",
"summary": [],
}
],
"parallel_tool_calls": True,
"previous_response_id": None,
"reasoning": {"effort": "medium", "generate_summary": None},
"store": True,
"temperature": 1.0,
"text": {"format": {"type": "text"}},
"tool_choice": "auto",
"tools": [],
"top_p": 1.0,
"truncation": "disabled",
"usage": {
"input_tokens": 73,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens": 20,
"output_tokens_details": {"reasoning_tokens": 0},
"total_tokens": 93,
},
"user": None,
"metadata": {},
}
class MockResponse:
def __init__(self, json_data, status_code):
self._json_data = json_data
self.status_code = status_code
self.text = json.dumps(json_data)
def json(self): # Changed from async to sync
return self._json_data
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
new_callable=AsyncMock,
) as mock_post:
# Configure the mock to return our response
mock_post.return_value = MockResponse(mock_response, 200)
litellm._turn_on_debug()
litellm.set_verbose = True
# Call o1-pro with max_output_tokens=20
response = await litellm.aresponses(
model="openai/o1-pro",
input="Write a detailed essay about artificial intelligence and its impact on society",
max_output_tokens=20,
)
# Verify the request was made correctly
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
assert request_body["model"] == "o1-pro"
assert request_body["max_output_tokens"] == 20
# Validate the response
print("Response:", json.dumps(response, indent=4, default=str))
# Check that the response has the expected structure
assert response["id"] == mock_response["id"]
assert response["status"] == "incomplete"
assert response["incomplete_details"].reason == "max_output_tokens"
assert response["max_output_tokens"] == 20
# Validate usage information
assert response["usage"]["input_tokens"] == 73
assert response["usage"]["output_tokens"] == 20
assert response["usage"]["total_tokens"] == 93
# Validate that the response is properly identified as incomplete
validate_responses_api_response(response, final_chunk=True)
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_openai_o1_pro_response_api_streaming(sync_mode):
"""
Test that LiteLLM correctly handles an incomplete response from OpenAI's o1-pro model
due to reaching max_output_tokens limit in both sync and async streaming modes.
"""
# Mock response from o1-pro
mock_response = {
"id": "resp_67dc3dd77b388190822443a85252da5a0e13d8bdc0e28d88",
"object": "response",
"created_at": 1742486999,
"status": "incomplete",
"error": None,
"incomplete_details": {"reason": "max_output_tokens"},
"instructions": None,
"max_output_tokens": 20,
"model": "o1-pro-2025-03-19",
"output": [
{
"type": "reasoning",
"id": "rs_67dc3de50f64819097450ed50a33d5f90e13d8bdc0e28d88",
"summary": [],
}
],
"parallel_tool_calls": True,
"previous_response_id": None,
"reasoning": {"effort": "medium", "generate_summary": None},
"store": True,
"temperature": 1.0,
"text": {"format": {"type": "text"}},
"tool_choice": "auto",
"tools": [],
"top_p": 1.0,
"truncation": "disabled",
"usage": {
"input_tokens": 73,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens": 20,
"output_tokens_details": {"reasoning_tokens": 0},
"total_tokens": 93,
},
"user": None,
"metadata": {},
}
class MockResponse:
def __init__(self, json_data, status_code):
self._json_data = json_data
self.status_code = status_code
self.text = json.dumps(json_data)
def json(self):
return self._json_data
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
new_callable=AsyncMock,
) as mock_post:
# Configure the mock to return our response
mock_post.return_value = MockResponse(mock_response, 200)
litellm._turn_on_debug()
litellm.set_verbose = True
# Verify the request was made correctly
if sync_mode:
# For sync mode, we need to patch the sync HTTP handler
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
return_value=MockResponse(mock_response, 200),
) as mock_sync_post:
response = litellm.responses(
model="openai/o1-pro",
input="Write a detailed essay about artificial intelligence and its impact on society",
max_output_tokens=20,
stream=True,
)
# Process the sync stream
event_count = 0
for event in response:
print(
f"Sync litellm response #{event_count}:",
json.dumps(event, indent=4, default=str),
)
event_count += 1
# Verify the sync request was made correctly
mock_sync_post.assert_called_once()
request_body = json.loads(mock_sync_post.call_args.kwargs["data"])
assert request_body["model"] == "o1-pro"
assert request_body["max_output_tokens"] == 20
assert "stream" not in request_body
else:
# For async mode
response = await litellm.aresponses(
model="openai/o1-pro",
input="Write a detailed essay about artificial intelligence and its impact on society",
max_output_tokens=20,
stream=True,
)
# Process the async stream
event_count = 0
async for event in response:
print(
f"Async litellm response #{event_count}:",
json.dumps(event, indent=4, default=str),
)
event_count += 1
# Verify the async request was made correctly
mock_post.assert_called_once()
request_body = json.loads(mock_post.call_args.kwargs["data"])
assert request_body["model"] == "o1-pro"
assert request_body["max_output_tokens"] == 20
assert "stream" not in request_body