mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #9419 from BerriAI/litellm_streaming_o1_pro
[Feat] OpenAI o1-pro Responses API streaming support
This commit is contained in:
commit
c44fe8bd90
11 changed files with 491 additions and 20 deletions
|
@ -49,7 +49,7 @@ jobs:
|
|||
pip install opentelemetry-api==1.25.0
|
||||
pip install opentelemetry-sdk==1.25.0
|
||||
pip install opentelemetry-exporter-otlp==1.25.0
|
||||
pip install openai==1.66.1
|
||||
pip install openai==1.67.0
|
||||
pip install prisma==0.11.0
|
||||
pip install "detect_secrets==1.5.0"
|
||||
pip install "httpx==0.24.1"
|
||||
|
@ -168,7 +168,7 @@ jobs:
|
|||
pip install opentelemetry-api==1.25.0
|
||||
pip install opentelemetry-sdk==1.25.0
|
||||
pip install opentelemetry-exporter-otlp==1.25.0
|
||||
pip install openai==1.66.1
|
||||
pip install openai==1.67.0
|
||||
pip install prisma==0.11.0
|
||||
pip install "detect_secrets==1.5.0"
|
||||
pip install "httpx==0.24.1"
|
||||
|
@ -268,7 +268,7 @@ jobs:
|
|||
pip install opentelemetry-api==1.25.0
|
||||
pip install opentelemetry-sdk==1.25.0
|
||||
pip install opentelemetry-exporter-otlp==1.25.0
|
||||
pip install openai==1.66.1
|
||||
pip install openai==1.67.0
|
||||
pip install prisma==0.11.0
|
||||
pip install "detect_secrets==1.5.0"
|
||||
pip install "httpx==0.24.1"
|
||||
|
@ -513,7 +513,7 @@ jobs:
|
|||
pip install opentelemetry-api==1.25.0
|
||||
pip install opentelemetry-sdk==1.25.0
|
||||
pip install opentelemetry-exporter-otlp==1.25.0
|
||||
pip install openai==1.66.1
|
||||
pip install openai==1.67.0
|
||||
pip install prisma==0.11.0
|
||||
pip install "detect_secrets==1.5.0"
|
||||
pip install "httpx==0.24.1"
|
||||
|
@ -1278,7 +1278,7 @@ jobs:
|
|||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install "openai==1.66.1"
|
||||
pip install "openai==1.67.0"
|
||||
- run:
|
||||
name: Install Grype
|
||||
command: |
|
||||
|
@ -1414,7 +1414,7 @@ jobs:
|
|||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install "openai==1.66.1"
|
||||
pip install "openai==1.67.0"
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Build Docker image
|
||||
|
@ -1536,7 +1536,7 @@ jobs:
|
|||
pip install "aiodynamo==23.10.1"
|
||||
pip install "asyncio==3.4.3"
|
||||
pip install "PyGithub==1.59.1"
|
||||
pip install "openai==1.66.1"
|
||||
pip install "openai==1.67.0"
|
||||
- run:
|
||||
name: Build Docker image
|
||||
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
|
||||
|
@ -1965,7 +1965,7 @@ jobs:
|
|||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install "google-cloud-aiplatform==1.43.0"
|
||||
pip install aiohttp
|
||||
pip install "openai==1.66.1"
|
||||
pip install "openai==1.67.0"
|
||||
pip install "assemblyai==0.37.0"
|
||||
python -m pip install --upgrade pip
|
||||
pip install "pydantic==2.7.1"
|
||||
|
@ -2241,7 +2241,7 @@ jobs:
|
|||
pip install "pytest-retry==1.6.3"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install aiohttp
|
||||
pip install "openai==1.66.1"
|
||||
pip install "openai==1.67.0"
|
||||
python -m pip install --upgrade pip
|
||||
pip install "pydantic==2.7.1"
|
||||
pip install "pytest==7.3.1"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# used by CI/CD testing
|
||||
openai==1.66.1
|
||||
openai==1.67.0
|
||||
python-dotenv
|
||||
tiktoken
|
||||
importlib_metadata
|
||||
|
|
|
@ -7,7 +7,6 @@ import httpx
|
|||
from litellm.types.llms.openai import (
|
||||
ResponseInputParam,
|
||||
ResponsesAPIOptionalRequestParams,
|
||||
ResponsesAPIRequestParams,
|
||||
ResponsesAPIResponse,
|
||||
ResponsesAPIStreamingResponse,
|
||||
)
|
||||
|
@ -97,7 +96,7 @@ class BaseResponsesAPIConfig(ABC):
|
|||
response_api_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> ResponsesAPIRequestParams:
|
||||
) -> Dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -131,3 +130,12 @@ class BaseResponsesAPIConfig(ABC):
|
|||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Returns True if litellm should fake a stream for the given model and stream value"""
|
||||
return False
|
||||
|
|
|
@ -20,6 +20,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
)
|
||||
from litellm.responses.streaming_iterator import (
|
||||
BaseResponsesAPIStreamingIterator,
|
||||
MockResponsesAPIStreamingIterator,
|
||||
ResponsesAPIStreamingIterator,
|
||||
SyncResponsesAPIStreamingIterator,
|
||||
)
|
||||
|
@ -978,6 +979,7 @@ class BaseLLMHTTPHandler:
|
|||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
_is_async: bool = False,
|
||||
fake_stream: bool = False,
|
||||
) -> Union[
|
||||
ResponsesAPIResponse,
|
||||
BaseResponsesAPIStreamingIterator,
|
||||
|
@ -1003,6 +1005,7 @@ class BaseLLMHTTPHandler:
|
|||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
client=client if isinstance(client, AsyncHTTPHandler) else None,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
|
@ -1051,14 +1054,27 @@ class BaseLLMHTTPHandler:
|
|||
try:
|
||||
if stream:
|
||||
# For streaming, use stream=True in the request
|
||||
if fake_stream is True:
|
||||
stream, data = self._prepare_fake_stream_request(
|
||||
stream=stream,
|
||||
data=data,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
response = sync_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout
|
||||
or response_api_optional_request_params.get("timeout"),
|
||||
stream=True,
|
||||
stream=stream,
|
||||
)
|
||||
if fake_stream is True:
|
||||
return MockResponsesAPIStreamingIterator(
|
||||
response=response,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
responses_api_provider_config=responses_api_provider_config,
|
||||
)
|
||||
|
||||
return SyncResponsesAPIStreamingIterator(
|
||||
response=response,
|
||||
|
@ -1100,6 +1116,7 @@ class BaseLLMHTTPHandler:
|
|||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
fake_stream: bool = False,
|
||||
) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
|
||||
"""
|
||||
Async version of the responses API handler.
|
||||
|
@ -1145,22 +1162,36 @@ class BaseLLMHTTPHandler:
|
|||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
# Check if streaming is requested
|
||||
stream = response_api_optional_request_params.get("stream", False)
|
||||
|
||||
try:
|
||||
if stream:
|
||||
# For streaming, we need to use stream=True in the request
|
||||
if fake_stream is True:
|
||||
stream, data = self._prepare_fake_stream_request(
|
||||
stream=stream,
|
||||
data=data,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
|
||||
response = await async_httpx_client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout
|
||||
or response_api_optional_request_params.get("timeout"),
|
||||
stream=True,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if fake_stream is True:
|
||||
return MockResponsesAPIStreamingIterator(
|
||||
response=response,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
responses_api_provider_config=responses_api_provider_config,
|
||||
)
|
||||
|
||||
# Return the streaming iterator
|
||||
return ResponsesAPIStreamingIterator(
|
||||
response=response,
|
||||
|
@ -1177,6 +1208,7 @@ class BaseLLMHTTPHandler:
|
|||
timeout=timeout
|
||||
or response_api_optional_request_params.get("timeout"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise self._handle_error(
|
||||
e=e,
|
||||
|
@ -1189,6 +1221,21 @@ class BaseLLMHTTPHandler:
|
|||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
def _prepare_fake_stream_request(
|
||||
self,
|
||||
stream: bool,
|
||||
data: dict,
|
||||
fake_stream: bool,
|
||||
) -> Tuple[bool, dict]:
|
||||
"""
|
||||
Handles preparing a request when `fake_stream` is True.
|
||||
"""
|
||||
if fake_stream is True:
|
||||
stream = False
|
||||
data.pop("stream", None)
|
||||
return stream, data
|
||||
return stream, data
|
||||
|
||||
def _handle_error(
|
||||
self,
|
||||
e: Exception,
|
||||
|
|
|
@ -65,10 +65,12 @@ class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
|
|||
response_api_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> ResponsesAPIRequestParams:
|
||||
) -> Dict:
|
||||
"""No transform applied since inputs are in OpenAI spec already"""
|
||||
return ResponsesAPIRequestParams(
|
||||
model=model, input=input, **response_api_optional_request_params
|
||||
return dict(
|
||||
ResponsesAPIRequestParams(
|
||||
model=model, input=input, **response_api_optional_request_params
|
||||
)
|
||||
)
|
||||
|
||||
def transform_response_api_response(
|
||||
|
@ -188,3 +190,27 @@ class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
|
|||
raise ValueError(f"Unknown event type: {event_type}")
|
||||
|
||||
return model_class
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
if stream is not True:
|
||||
return False
|
||||
if model is not None:
|
||||
try:
|
||||
if (
|
||||
litellm.utils.supports_native_streaming(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
is False
|
||||
):
|
||||
return True
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Error getting model info in OpenAIResponsesAPIConfig: {e}"
|
||||
)
|
||||
return False
|
||||
|
|
|
@ -232,6 +232,9 @@ def responses(
|
|||
timeout=timeout or request_timeout,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client"),
|
||||
fake_stream=responses_api_provider_config.should_fake_stream(
|
||||
model=model, stream=stream, custom_llm_provider=custom_llm_provider
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
@ -11,6 +11,7 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
|||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
||||
from litellm.types.llms.openai import (
|
||||
ResponseCompletedEvent,
|
||||
ResponsesAPIStreamEvents,
|
||||
ResponsesAPIStreamingResponse,
|
||||
)
|
||||
|
@ -207,3 +208,63 @@ class SyncResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
|
|||
start_time=self.start_time,
|
||||
end_time=datetime.now(),
|
||||
)
|
||||
|
||||
|
||||
class MockResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
|
||||
"""
|
||||
mock iterator - some models like o1-pro do not support streaming, we need to fake a stream
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
model: str,
|
||||
responses_api_provider_config: BaseResponsesAPIConfig,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
):
|
||||
self.raw_http_response = response
|
||||
super().__init__(
|
||||
response=response,
|
||||
model=model,
|
||||
responses_api_provider_config=responses_api_provider_config,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
self.is_done = False
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> ResponsesAPIStreamingResponse:
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
transformed_response = (
|
||||
self.responses_api_provider_config.transform_response_api_response(
|
||||
model=self.model,
|
||||
raw_response=self.raw_http_response,
|
||||
logging_obj=self.logging_obj,
|
||||
)
|
||||
)
|
||||
return ResponseCompletedEvent(
|
||||
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
|
||||
response=transformed_response,
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> ResponsesAPIStreamingResponse:
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
transformed_response = (
|
||||
self.responses_api_provider_config.transform_response_api_response(
|
||||
model=self.model,
|
||||
raw_response=self.raw_http_response,
|
||||
logging_obj=self.logging_obj,
|
||||
)
|
||||
)
|
||||
return ResponseCompletedEvent(
|
||||
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
|
||||
response=transformed_response,
|
||||
)
|
||||
|
|
|
@ -1975,6 +1975,39 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
|
|||
)
|
||||
|
||||
|
||||
def supports_native_streaming(model: str, custom_llm_provider: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if the given model supports native streaming and return a boolean value.
|
||||
|
||||
Parameters:
|
||||
model (str): The model name to be checked.
|
||||
custom_llm_provider (str): The provider to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the model supports native streaming, False otherwise.
|
||||
|
||||
Raises:
|
||||
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||
"""
|
||||
try:
|
||||
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
model_info = _get_model_info_helper(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
supports_native_streaming = model_info.get("supports_native_streaming", True)
|
||||
if supports_native_streaming is None:
|
||||
supports_native_streaming = True
|
||||
return supports_native_streaming
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Model not found or error in checking supports_native_streaming support. You passed model={model}, custom_llm_provider={custom_llm_provider}. Error: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def supports_response_schema(
|
||||
model: str, custom_llm_provider: Optional[str] = None
|
||||
) -> bool:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# LITELLM PROXY DEPENDENCIES #
|
||||
anyio==4.4.0 # openai + http req.
|
||||
httpx==0.27.0 # Pin Httpx dependency
|
||||
openai==1.66.1 # openai req.
|
||||
openai==1.67.0 # openai req.
|
||||
fastapi==0.115.5 # server dep
|
||||
backoff==2.2.1 # server dep
|
||||
pyyaml==6.0.2 # server dep
|
||||
|
|
77
tests/litellm/llms/custom_httpx/test_llm_http_handler.py
Normal file
77
tests/litellm/llms/custom_httpx/test_llm_http_handler.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
import io
|
||||
import os
|
||||
import pathlib
|
||||
import ssl
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
|
||||
|
||||
def test_prepare_fake_stream_request():
|
||||
# Initialize the BaseLLMHTTPHandler
|
||||
handler = BaseLLMHTTPHandler()
|
||||
|
||||
# Test case 1: fake_stream is True
|
||||
stream = True
|
||||
data = {
|
||||
"stream": True,
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
fake_stream = True
|
||||
|
||||
result_stream, result_data = handler._prepare_fake_stream_request(
|
||||
stream=stream, data=data, fake_stream=fake_stream
|
||||
)
|
||||
|
||||
# Verify that stream is set to False
|
||||
assert result_stream is False
|
||||
# Verify that "stream" key is removed from data
|
||||
assert "stream" not in result_data
|
||||
# Verify other data remains unchanged
|
||||
assert result_data["model"] == "gpt-4"
|
||||
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
# Test case 2: fake_stream is False
|
||||
stream = True
|
||||
data = {
|
||||
"stream": True,
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
fake_stream = False
|
||||
|
||||
result_stream, result_data = handler._prepare_fake_stream_request(
|
||||
stream=stream, data=data, fake_stream=fake_stream
|
||||
)
|
||||
|
||||
# Verify that stream remains True
|
||||
assert result_stream is True
|
||||
# Verify that data remains unchanged
|
||||
assert "stream" in result_data
|
||||
assert result_data["stream"] is True
|
||||
assert result_data["model"] == "gpt-4"
|
||||
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
# Test case 3: data doesn't have stream key but fake_stream is True
|
||||
stream = True
|
||||
data = {"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}
|
||||
fake_stream = True
|
||||
|
||||
result_stream, result_data = handler._prepare_fake_stream_request(
|
||||
stream=stream, data=data, fake_stream=fake_stream
|
||||
)
|
||||
|
||||
# Verify that stream is set to False
|
||||
assert result_stream is False
|
||||
# Verify that data remains unchanged (since there was no stream key to remove)
|
||||
assert "stream" not in result_data
|
||||
assert result_data["model"] == "gpt-4"
|
||||
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]
|
|
@ -94,7 +94,7 @@ def validate_responses_api_response(response, final_chunk: bool = False):
|
|||
@pytest.mark.asyncio
|
||||
async def test_basic_openai_responses_api(sync_mode):
|
||||
litellm._turn_on_debug()
|
||||
|
||||
litellm.set_verbose = True
|
||||
if sync_mode:
|
||||
response = litellm.responses(
|
||||
model="gpt-4o", input="Basic ping", max_output_tokens=20
|
||||
|
@ -826,3 +826,219 @@ async def test_async_bad_request_bad_param_error():
|
|||
print(f"Exception details: {e.__dict__}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Unexpected exception raised: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_openai_o1_pro_response_api(sync_mode):
|
||||
"""
|
||||
Test that LiteLLM correctly handles an incomplete response from OpenAI's o1-pro model
|
||||
due to reaching max_output_tokens limit.
|
||||
"""
|
||||
# Mock response from o1-pro
|
||||
mock_response = {
|
||||
"id": "resp_67dc3dd77b388190822443a85252da5a0e13d8bdc0e28d88",
|
||||
"object": "response",
|
||||
"created_at": 1742486999,
|
||||
"status": "incomplete",
|
||||
"error": None,
|
||||
"incomplete_details": {"reason": "max_output_tokens"},
|
||||
"instructions": None,
|
||||
"max_output_tokens": 20,
|
||||
"model": "o1-pro-2025-03-19",
|
||||
"output": [
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "rs_67dc3de50f64819097450ed50a33d5f90e13d8bdc0e28d88",
|
||||
"summary": [],
|
||||
}
|
||||
],
|
||||
"parallel_tool_calls": True,
|
||||
"previous_response_id": None,
|
||||
"reasoning": {"effort": "medium", "generate_summary": None},
|
||||
"store": True,
|
||||
"temperature": 1.0,
|
||||
"text": {"format": {"type": "text"}},
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"top_p": 1.0,
|
||||
"truncation": "disabled",
|
||||
"usage": {
|
||||
"input_tokens": 73,
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens": 20,
|
||||
"output_tokens_details": {"reasoning_tokens": 0},
|
||||
"total_tokens": 93,
|
||||
},
|
||||
"user": None,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json_data, status_code):
|
||||
self._json_data = json_data
|
||||
self.status_code = status_code
|
||||
self.text = json.dumps(json_data)
|
||||
|
||||
def json(self): # Changed from async to sync
|
||||
return self._json_data
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_post:
|
||||
# Configure the mock to return our response
|
||||
mock_post.return_value = MockResponse(mock_response, 200)
|
||||
|
||||
litellm._turn_on_debug()
|
||||
litellm.set_verbose = True
|
||||
|
||||
# Call o1-pro with max_output_tokens=20
|
||||
response = await litellm.aresponses(
|
||||
model="openai/o1-pro",
|
||||
input="Write a detailed essay about artificial intelligence and its impact on society",
|
||||
max_output_tokens=20,
|
||||
)
|
||||
|
||||
# Verify the request was made correctly
|
||||
mock_post.assert_called_once()
|
||||
request_body = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert request_body["model"] == "o1-pro"
|
||||
assert request_body["max_output_tokens"] == 20
|
||||
|
||||
# Validate the response
|
||||
print("Response:", json.dumps(response, indent=4, default=str))
|
||||
|
||||
# Check that the response has the expected structure
|
||||
assert response["id"] == mock_response["id"]
|
||||
assert response["status"] == "incomplete"
|
||||
assert response["incomplete_details"].reason == "max_output_tokens"
|
||||
assert response["max_output_tokens"] == 20
|
||||
|
||||
# Validate usage information
|
||||
assert response["usage"]["input_tokens"] == 73
|
||||
assert response["usage"]["output_tokens"] == 20
|
||||
assert response["usage"]["total_tokens"] == 93
|
||||
|
||||
# Validate that the response is properly identified as incomplete
|
||||
validate_responses_api_response(response, final_chunk=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_openai_o1_pro_response_api_streaming(sync_mode):
|
||||
"""
|
||||
Test that LiteLLM correctly handles an incomplete response from OpenAI's o1-pro model
|
||||
due to reaching max_output_tokens limit in both sync and async streaming modes.
|
||||
"""
|
||||
# Mock response from o1-pro
|
||||
mock_response = {
|
||||
"id": "resp_67dc3dd77b388190822443a85252da5a0e13d8bdc0e28d88",
|
||||
"object": "response",
|
||||
"created_at": 1742486999,
|
||||
"status": "incomplete",
|
||||
"error": None,
|
||||
"incomplete_details": {"reason": "max_output_tokens"},
|
||||
"instructions": None,
|
||||
"max_output_tokens": 20,
|
||||
"model": "o1-pro-2025-03-19",
|
||||
"output": [
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": "rs_67dc3de50f64819097450ed50a33d5f90e13d8bdc0e28d88",
|
||||
"summary": [],
|
||||
}
|
||||
],
|
||||
"parallel_tool_calls": True,
|
||||
"previous_response_id": None,
|
||||
"reasoning": {"effort": "medium", "generate_summary": None},
|
||||
"store": True,
|
||||
"temperature": 1.0,
|
||||
"text": {"format": {"type": "text"}},
|
||||
"tool_choice": "auto",
|
||||
"tools": [],
|
||||
"top_p": 1.0,
|
||||
"truncation": "disabled",
|
||||
"usage": {
|
||||
"input_tokens": 73,
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens": 20,
|
||||
"output_tokens_details": {"reasoning_tokens": 0},
|
||||
"total_tokens": 93,
|
||||
},
|
||||
"user": None,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, json_data, status_code):
|
||||
self._json_data = json_data
|
||||
self.status_code = status_code
|
||||
self.text = json.dumps(json_data)
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_post:
|
||||
# Configure the mock to return our response
|
||||
mock_post.return_value = MockResponse(mock_response, 200)
|
||||
|
||||
litellm._turn_on_debug()
|
||||
litellm.set_verbose = True
|
||||
|
||||
# Verify the request was made correctly
|
||||
if sync_mode:
|
||||
# For sync mode, we need to patch the sync HTTP handler
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||
return_value=MockResponse(mock_response, 200),
|
||||
) as mock_sync_post:
|
||||
response = litellm.responses(
|
||||
model="openai/o1-pro",
|
||||
input="Write a detailed essay about artificial intelligence and its impact on society",
|
||||
max_output_tokens=20,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Process the sync stream
|
||||
event_count = 0
|
||||
for event in response:
|
||||
print(
|
||||
f"Sync litellm response #{event_count}:",
|
||||
json.dumps(event, indent=4, default=str),
|
||||
)
|
||||
event_count += 1
|
||||
|
||||
# Verify the sync request was made correctly
|
||||
mock_sync_post.assert_called_once()
|
||||
request_body = json.loads(mock_sync_post.call_args.kwargs["data"])
|
||||
assert request_body["model"] == "o1-pro"
|
||||
assert request_body["max_output_tokens"] == 20
|
||||
assert "stream" not in request_body
|
||||
else:
|
||||
# For async mode
|
||||
response = await litellm.aresponses(
|
||||
model="openai/o1-pro",
|
||||
input="Write a detailed essay about artificial intelligence and its impact on society",
|
||||
max_output_tokens=20,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Process the async stream
|
||||
event_count = 0
|
||||
async for event in response:
|
||||
print(
|
||||
f"Async litellm response #{event_count}:",
|
||||
json.dumps(event, indent=4, default=str),
|
||||
)
|
||||
event_count += 1
|
||||
|
||||
# Verify the async request was made correctly
|
||||
mock_post.assert_called_once()
|
||||
request_body = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert request_body["model"] == "o1-pro"
|
||||
assert request_body["max_output_tokens"] == 20
|
||||
assert "stream" not in request_body
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue