mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #9419 from BerriAI/litellm_streaming_o1_pro
[Feat] OpenAI o1-pro Responses API streaming support
This commit is contained in:
commit
c44fe8bd90
11 changed files with 491 additions and 20 deletions
|
@ -49,7 +49,7 @@ jobs:
|
||||||
pip install opentelemetry-api==1.25.0
|
pip install opentelemetry-api==1.25.0
|
||||||
pip install opentelemetry-sdk==1.25.0
|
pip install opentelemetry-sdk==1.25.0
|
||||||
pip install opentelemetry-exporter-otlp==1.25.0
|
pip install opentelemetry-exporter-otlp==1.25.0
|
||||||
pip install openai==1.66.1
|
pip install openai==1.67.0
|
||||||
pip install prisma==0.11.0
|
pip install prisma==0.11.0
|
||||||
pip install "detect_secrets==1.5.0"
|
pip install "detect_secrets==1.5.0"
|
||||||
pip install "httpx==0.24.1"
|
pip install "httpx==0.24.1"
|
||||||
|
@ -168,7 +168,7 @@ jobs:
|
||||||
pip install opentelemetry-api==1.25.0
|
pip install opentelemetry-api==1.25.0
|
||||||
pip install opentelemetry-sdk==1.25.0
|
pip install opentelemetry-sdk==1.25.0
|
||||||
pip install opentelemetry-exporter-otlp==1.25.0
|
pip install opentelemetry-exporter-otlp==1.25.0
|
||||||
pip install openai==1.66.1
|
pip install openai==1.67.0
|
||||||
pip install prisma==0.11.0
|
pip install prisma==0.11.0
|
||||||
pip install "detect_secrets==1.5.0"
|
pip install "detect_secrets==1.5.0"
|
||||||
pip install "httpx==0.24.1"
|
pip install "httpx==0.24.1"
|
||||||
|
@ -268,7 +268,7 @@ jobs:
|
||||||
pip install opentelemetry-api==1.25.0
|
pip install opentelemetry-api==1.25.0
|
||||||
pip install opentelemetry-sdk==1.25.0
|
pip install opentelemetry-sdk==1.25.0
|
||||||
pip install opentelemetry-exporter-otlp==1.25.0
|
pip install opentelemetry-exporter-otlp==1.25.0
|
||||||
pip install openai==1.66.1
|
pip install openai==1.67.0
|
||||||
pip install prisma==0.11.0
|
pip install prisma==0.11.0
|
||||||
pip install "detect_secrets==1.5.0"
|
pip install "detect_secrets==1.5.0"
|
||||||
pip install "httpx==0.24.1"
|
pip install "httpx==0.24.1"
|
||||||
|
@ -513,7 +513,7 @@ jobs:
|
||||||
pip install opentelemetry-api==1.25.0
|
pip install opentelemetry-api==1.25.0
|
||||||
pip install opentelemetry-sdk==1.25.0
|
pip install opentelemetry-sdk==1.25.0
|
||||||
pip install opentelemetry-exporter-otlp==1.25.0
|
pip install opentelemetry-exporter-otlp==1.25.0
|
||||||
pip install openai==1.66.1
|
pip install openai==1.67.0
|
||||||
pip install prisma==0.11.0
|
pip install prisma==0.11.0
|
||||||
pip install "detect_secrets==1.5.0"
|
pip install "detect_secrets==1.5.0"
|
||||||
pip install "httpx==0.24.1"
|
pip install "httpx==0.24.1"
|
||||||
|
@ -1278,7 +1278,7 @@ jobs:
|
||||||
pip install "aiodynamo==23.10.1"
|
pip install "aiodynamo==23.10.1"
|
||||||
pip install "asyncio==3.4.3"
|
pip install "asyncio==3.4.3"
|
||||||
pip install "PyGithub==1.59.1"
|
pip install "PyGithub==1.59.1"
|
||||||
pip install "openai==1.66.1"
|
pip install "openai==1.67.0"
|
||||||
- run:
|
- run:
|
||||||
name: Install Grype
|
name: Install Grype
|
||||||
command: |
|
command: |
|
||||||
|
@ -1414,7 +1414,7 @@ jobs:
|
||||||
pip install "aiodynamo==23.10.1"
|
pip install "aiodynamo==23.10.1"
|
||||||
pip install "asyncio==3.4.3"
|
pip install "asyncio==3.4.3"
|
||||||
pip install "PyGithub==1.59.1"
|
pip install "PyGithub==1.59.1"
|
||||||
pip install "openai==1.66.1"
|
pip install "openai==1.67.0"
|
||||||
# Run pytest and generate JUnit XML report
|
# Run pytest and generate JUnit XML report
|
||||||
- run:
|
- run:
|
||||||
name: Build Docker image
|
name: Build Docker image
|
||||||
|
@ -1536,7 +1536,7 @@ jobs:
|
||||||
pip install "aiodynamo==23.10.1"
|
pip install "aiodynamo==23.10.1"
|
||||||
pip install "asyncio==3.4.3"
|
pip install "asyncio==3.4.3"
|
||||||
pip install "PyGithub==1.59.1"
|
pip install "PyGithub==1.59.1"
|
||||||
pip install "openai==1.66.1"
|
pip install "openai==1.67.0"
|
||||||
- run:
|
- run:
|
||||||
name: Build Docker image
|
name: Build Docker image
|
||||||
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
|
command: docker build -t my-app:latest -f ./docker/Dockerfile.database .
|
||||||
|
@ -1965,7 +1965,7 @@ jobs:
|
||||||
pip install "pytest-asyncio==0.21.1"
|
pip install "pytest-asyncio==0.21.1"
|
||||||
pip install "google-cloud-aiplatform==1.43.0"
|
pip install "google-cloud-aiplatform==1.43.0"
|
||||||
pip install aiohttp
|
pip install aiohttp
|
||||||
pip install "openai==1.66.1"
|
pip install "openai==1.67.0"
|
||||||
pip install "assemblyai==0.37.0"
|
pip install "assemblyai==0.37.0"
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install "pydantic==2.7.1"
|
pip install "pydantic==2.7.1"
|
||||||
|
@ -2241,7 +2241,7 @@ jobs:
|
||||||
pip install "pytest-retry==1.6.3"
|
pip install "pytest-retry==1.6.3"
|
||||||
pip install "pytest-asyncio==0.21.1"
|
pip install "pytest-asyncio==0.21.1"
|
||||||
pip install aiohttp
|
pip install aiohttp
|
||||||
pip install "openai==1.66.1"
|
pip install "openai==1.67.0"
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install "pydantic==2.7.1"
|
pip install "pydantic==2.7.1"
|
||||||
pip install "pytest==7.3.1"
|
pip install "pytest==7.3.1"
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# used by CI/CD testing
|
# used by CI/CD testing
|
||||||
openai==1.66.1
|
openai==1.67.0
|
||||||
python-dotenv
|
python-dotenv
|
||||||
tiktoken
|
tiktoken
|
||||||
importlib_metadata
|
importlib_metadata
|
||||||
|
|
|
@ -7,7 +7,6 @@ import httpx
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ResponseInputParam,
|
ResponseInputParam,
|
||||||
ResponsesAPIOptionalRequestParams,
|
ResponsesAPIOptionalRequestParams,
|
||||||
ResponsesAPIRequestParams,
|
|
||||||
ResponsesAPIResponse,
|
ResponsesAPIResponse,
|
||||||
ResponsesAPIStreamingResponse,
|
ResponsesAPIStreamingResponse,
|
||||||
)
|
)
|
||||||
|
@ -97,7 +96,7 @@ class BaseResponsesAPIConfig(ABC):
|
||||||
response_api_optional_request_params: Dict,
|
response_api_optional_request_params: Dict,
|
||||||
litellm_params: GenericLiteLLMParams,
|
litellm_params: GenericLiteLLMParams,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> ResponsesAPIRequestParams:
|
) -> Dict:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -131,3 +130,12 @@ class BaseResponsesAPIConfig(ABC):
|
||||||
message=error_message,
|
message=error_message,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def should_fake_stream(
|
||||||
|
self,
|
||||||
|
model: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Returns True if litellm should fake a stream for the given model and stream value"""
|
||||||
|
return False
|
||||||
|
|
|
@ -20,6 +20,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
)
|
)
|
||||||
from litellm.responses.streaming_iterator import (
|
from litellm.responses.streaming_iterator import (
|
||||||
BaseResponsesAPIStreamingIterator,
|
BaseResponsesAPIStreamingIterator,
|
||||||
|
MockResponsesAPIStreamingIterator,
|
||||||
ResponsesAPIStreamingIterator,
|
ResponsesAPIStreamingIterator,
|
||||||
SyncResponsesAPIStreamingIterator,
|
SyncResponsesAPIStreamingIterator,
|
||||||
)
|
)
|
||||||
|
@ -978,6 +979,7 @@ class BaseLLMHTTPHandler:
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
_is_async: bool = False,
|
_is_async: bool = False,
|
||||||
|
fake_stream: bool = False,
|
||||||
) -> Union[
|
) -> Union[
|
||||||
ResponsesAPIResponse,
|
ResponsesAPIResponse,
|
||||||
BaseResponsesAPIStreamingIterator,
|
BaseResponsesAPIStreamingIterator,
|
||||||
|
@ -1003,6 +1005,7 @@ class BaseLLMHTTPHandler:
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client if isinstance(client, AsyncHTTPHandler) else None,
|
client=client if isinstance(client, AsyncHTTPHandler) else None,
|
||||||
|
fake_stream=fake_stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
if client is None or not isinstance(client, HTTPHandler):
|
if client is None or not isinstance(client, HTTPHandler):
|
||||||
|
@ -1051,14 +1054,27 @@ class BaseLLMHTTPHandler:
|
||||||
try:
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
# For streaming, use stream=True in the request
|
# For streaming, use stream=True in the request
|
||||||
|
if fake_stream is True:
|
||||||
|
stream, data = self._prepare_fake_stream_request(
|
||||||
|
stream=stream,
|
||||||
|
data=data,
|
||||||
|
fake_stream=fake_stream,
|
||||||
|
)
|
||||||
response = sync_httpx_client.post(
|
response = sync_httpx_client.post(
|
||||||
url=api_base,
|
url=api_base,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
or response_api_optional_request_params.get("timeout"),
|
or response_api_optional_request_params.get("timeout"),
|
||||||
stream=True,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
if fake_stream is True:
|
||||||
|
return MockResponsesAPIStreamingIterator(
|
||||||
|
response=response,
|
||||||
|
model=model,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
responses_api_provider_config=responses_api_provider_config,
|
||||||
|
)
|
||||||
|
|
||||||
return SyncResponsesAPIStreamingIterator(
|
return SyncResponsesAPIStreamingIterator(
|
||||||
response=response,
|
response=response,
|
||||||
|
@ -1100,6 +1116,7 @@ class BaseLLMHTTPHandler:
|
||||||
extra_body: Optional[Dict[str, Any]] = None,
|
extra_body: Optional[Dict[str, Any]] = None,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||||
|
fake_stream: bool = False,
|
||||||
) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
|
) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
|
||||||
"""
|
"""
|
||||||
Async version of the responses API handler.
|
Async version of the responses API handler.
|
||||||
|
@ -1145,22 +1162,36 @@ class BaseLLMHTTPHandler:
|
||||||
"headers": headers,
|
"headers": headers,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if streaming is requested
|
# Check if streaming is requested
|
||||||
stream = response_api_optional_request_params.get("stream", False)
|
stream = response_api_optional_request_params.get("stream", False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
# For streaming, we need to use stream=True in the request
|
# For streaming, we need to use stream=True in the request
|
||||||
|
if fake_stream is True:
|
||||||
|
stream, data = self._prepare_fake_stream_request(
|
||||||
|
stream=stream,
|
||||||
|
data=data,
|
||||||
|
fake_stream=fake_stream,
|
||||||
|
)
|
||||||
|
|
||||||
response = await async_httpx_client.post(
|
response = await async_httpx_client.post(
|
||||||
url=api_base,
|
url=api_base,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
or response_api_optional_request_params.get("timeout"),
|
or response_api_optional_request_params.get("timeout"),
|
||||||
stream=True,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if fake_stream is True:
|
||||||
|
return MockResponsesAPIStreamingIterator(
|
||||||
|
response=response,
|
||||||
|
model=model,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
responses_api_provider_config=responses_api_provider_config,
|
||||||
|
)
|
||||||
|
|
||||||
# Return the streaming iterator
|
# Return the streaming iterator
|
||||||
return ResponsesAPIStreamingIterator(
|
return ResponsesAPIStreamingIterator(
|
||||||
response=response,
|
response=response,
|
||||||
|
@ -1177,6 +1208,7 @@ class BaseLLMHTTPHandler:
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
or response_api_optional_request_params.get("timeout"),
|
or response_api_optional_request_params.get("timeout"),
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise self._handle_error(
|
raise self._handle_error(
|
||||||
e=e,
|
e=e,
|
||||||
|
@ -1189,6 +1221,21 @@ class BaseLLMHTTPHandler:
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _prepare_fake_stream_request(
|
||||||
|
self,
|
||||||
|
stream: bool,
|
||||||
|
data: dict,
|
||||||
|
fake_stream: bool,
|
||||||
|
) -> Tuple[bool, dict]:
|
||||||
|
"""
|
||||||
|
Handles preparing a request when `fake_stream` is True.
|
||||||
|
"""
|
||||||
|
if fake_stream is True:
|
||||||
|
stream = False
|
||||||
|
data.pop("stream", None)
|
||||||
|
return stream, data
|
||||||
|
return stream, data
|
||||||
|
|
||||||
def _handle_error(
|
def _handle_error(
|
||||||
self,
|
self,
|
||||||
e: Exception,
|
e: Exception,
|
||||||
|
|
|
@ -65,10 +65,12 @@ class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
|
||||||
response_api_optional_request_params: Dict,
|
response_api_optional_request_params: Dict,
|
||||||
litellm_params: GenericLiteLLMParams,
|
litellm_params: GenericLiteLLMParams,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> ResponsesAPIRequestParams:
|
) -> Dict:
|
||||||
"""No transform applied since inputs are in OpenAI spec already"""
|
"""No transform applied since inputs are in OpenAI spec already"""
|
||||||
return ResponsesAPIRequestParams(
|
return dict(
|
||||||
model=model, input=input, **response_api_optional_request_params
|
ResponsesAPIRequestParams(
|
||||||
|
model=model, input=input, **response_api_optional_request_params
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def transform_response_api_response(
|
def transform_response_api_response(
|
||||||
|
@ -188,3 +190,27 @@ class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
|
||||||
raise ValueError(f"Unknown event type: {event_type}")
|
raise ValueError(f"Unknown event type: {event_type}")
|
||||||
|
|
||||||
return model_class
|
return model_class
|
||||||
|
|
||||||
|
def should_fake_stream(
|
||||||
|
self,
|
||||||
|
model: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
if stream is not True:
|
||||||
|
return False
|
||||||
|
if model is not None:
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
litellm.utils.supports_native_streaming(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Error getting model info in OpenAIResponsesAPIConfig: {e}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
|
@ -232,6 +232,9 @@ def responses(
|
||||||
timeout=timeout or request_timeout,
|
timeout=timeout or request_timeout,
|
||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
client=kwargs.get("client"),
|
client=kwargs.get("client"),
|
||||||
|
fake_stream=responses_api_provider_config.should_fake_stream(
|
||||||
|
model=model, stream=stream, custom_llm_provider=custom_llm_provider
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -11,6 +11,7 @@ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||||
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
|
ResponseCompletedEvent,
|
||||||
ResponsesAPIStreamEvents,
|
ResponsesAPIStreamEvents,
|
||||||
ResponsesAPIStreamingResponse,
|
ResponsesAPIStreamingResponse,
|
||||||
)
|
)
|
||||||
|
@ -207,3 +208,63 @@ class SyncResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
|
||||||
start_time=self.start_time,
|
start_time=self.start_time,
|
||||||
end_time=datetime.now(),
|
end_time=datetime.now(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponsesAPIStreamingIterator(BaseResponsesAPIStreamingIterator):
|
||||||
|
"""
|
||||||
|
mock iterator - some models like o1-pro do not support streaming, we need to fake a stream
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
response: httpx.Response,
|
||||||
|
model: str,
|
||||||
|
responses_api_provider_config: BaseResponsesAPIConfig,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
):
|
||||||
|
self.raw_http_response = response
|
||||||
|
super().__init__(
|
||||||
|
response=response,
|
||||||
|
model=model,
|
||||||
|
responses_api_provider_config=responses_api_provider_config,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
self.is_done = False
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> ResponsesAPIStreamingResponse:
|
||||||
|
if self.is_done:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
self.is_done = True
|
||||||
|
transformed_response = (
|
||||||
|
self.responses_api_provider_config.transform_response_api_response(
|
||||||
|
model=self.model,
|
||||||
|
raw_response=self.raw_http_response,
|
||||||
|
logging_obj=self.logging_obj,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ResponseCompletedEvent(
|
||||||
|
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
|
||||||
|
response=transformed_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self) -> ResponsesAPIStreamingResponse:
|
||||||
|
if self.is_done:
|
||||||
|
raise StopIteration
|
||||||
|
self.is_done = True
|
||||||
|
transformed_response = (
|
||||||
|
self.responses_api_provider_config.transform_response_api_response(
|
||||||
|
model=self.model,
|
||||||
|
raw_response=self.raw_http_response,
|
||||||
|
logging_obj=self.logging_obj,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return ResponseCompletedEvent(
|
||||||
|
type=ResponsesAPIStreamEvents.RESPONSE_COMPLETED,
|
||||||
|
response=transformed_response,
|
||||||
|
)
|
||||||
|
|
|
@ -1975,6 +1975,39 @@ def supports_system_messages(model: str, custom_llm_provider: Optional[str]) ->
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def supports_native_streaming(model: str, custom_llm_provider: Optional[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given model supports native streaming and return a boolean value.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
model (str): The model name to be checked.
|
||||||
|
custom_llm_provider (str): The provider to be checked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model supports native streaming, False otherwise.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the given model is not found in model_prices_and_context_window.json.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
model_info = _get_model_info_helper(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider
|
||||||
|
)
|
||||||
|
supports_native_streaming = model_info.get("supports_native_streaming", True)
|
||||||
|
if supports_native_streaming is None:
|
||||||
|
supports_native_streaming = True
|
||||||
|
return supports_native_streaming
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"Model not found or error in checking supports_native_streaming support. You passed model={model}, custom_llm_provider={custom_llm_provider}. Error: {str(e)}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def supports_response_schema(
|
def supports_response_schema(
|
||||||
model: str, custom_llm_provider: Optional[str] = None
|
model: str, custom_llm_provider: Optional[str] = None
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# LITELLM PROXY DEPENDENCIES #
|
# LITELLM PROXY DEPENDENCIES #
|
||||||
anyio==4.4.0 # openai + http req.
|
anyio==4.4.0 # openai + http req.
|
||||||
httpx==0.27.0 # Pin Httpx dependency
|
httpx==0.27.0 # Pin Httpx dependency
|
||||||
openai==1.66.1 # openai req.
|
openai==1.67.0 # openai req.
|
||||||
fastapi==0.115.5 # server dep
|
fastapi==0.115.5 # server dep
|
||||||
backoff==2.2.1 # server dep
|
backoff==2.2.1 # server dep
|
||||||
pyyaml==6.0.2 # server dep
|
pyyaml==6.0.2 # server dep
|
||||||
|
|
77
tests/litellm/llms/custom_httpx/test_llm_http_handler.py
Normal file
77
tests/litellm/llms/custom_httpx/test_llm_http_handler.py
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import ssl
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../../../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_fake_stream_request():
|
||||||
|
# Initialize the BaseLLMHTTPHandler
|
||||||
|
handler = BaseLLMHTTPHandler()
|
||||||
|
|
||||||
|
# Test case 1: fake_stream is True
|
||||||
|
stream = True
|
||||||
|
data = {
|
||||||
|
"stream": True,
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
}
|
||||||
|
fake_stream = True
|
||||||
|
|
||||||
|
result_stream, result_data = handler._prepare_fake_stream_request(
|
||||||
|
stream=stream, data=data, fake_stream=fake_stream
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that stream is set to False
|
||||||
|
assert result_stream is False
|
||||||
|
# Verify that "stream" key is removed from data
|
||||||
|
assert "stream" not in result_data
|
||||||
|
# Verify other data remains unchanged
|
||||||
|
assert result_data["model"] == "gpt-4"
|
||||||
|
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
# Test case 2: fake_stream is False
|
||||||
|
stream = True
|
||||||
|
data = {
|
||||||
|
"stream": True,
|
||||||
|
"model": "gpt-4",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
}
|
||||||
|
fake_stream = False
|
||||||
|
|
||||||
|
result_stream, result_data = handler._prepare_fake_stream_request(
|
||||||
|
stream=stream, data=data, fake_stream=fake_stream
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that stream remains True
|
||||||
|
assert result_stream is True
|
||||||
|
# Verify that data remains unchanged
|
||||||
|
assert "stream" in result_data
|
||||||
|
assert result_data["stream"] is True
|
||||||
|
assert result_data["model"] == "gpt-4"
|
||||||
|
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
# Test case 3: data doesn't have stream key but fake_stream is True
|
||||||
|
stream = True
|
||||||
|
data = {"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}
|
||||||
|
fake_stream = True
|
||||||
|
|
||||||
|
result_stream, result_data = handler._prepare_fake_stream_request(
|
||||||
|
stream=stream, data=data, fake_stream=fake_stream
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify that stream is set to False
|
||||||
|
assert result_stream is False
|
||||||
|
# Verify that data remains unchanged (since there was no stream key to remove)
|
||||||
|
assert "stream" not in result_data
|
||||||
|
assert result_data["model"] == "gpt-4"
|
||||||
|
assert result_data["messages"] == [{"role": "user", "content": "Hello"}]
|
|
@ -94,7 +94,7 @@ def validate_responses_api_response(response, final_chunk: bool = False):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_basic_openai_responses_api(sync_mode):
|
async def test_basic_openai_responses_api(sync_mode):
|
||||||
litellm._turn_on_debug()
|
litellm._turn_on_debug()
|
||||||
|
litellm.set_verbose = True
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
response = litellm.responses(
|
response = litellm.responses(
|
||||||
model="gpt-4o", input="Basic ping", max_output_tokens=20
|
model="gpt-4o", input="Basic ping", max_output_tokens=20
|
||||||
|
@ -826,3 +826,219 @@ async def test_async_bad_request_bad_param_error():
|
||||||
print(f"Exception details: {e.__dict__}")
|
print(f"Exception details: {e.__dict__}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Unexpected exception raised: {e}")
|
pytest.fail(f"Unexpected exception raised: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
async def test_openai_o1_pro_response_api(sync_mode):
|
||||||
|
"""
|
||||||
|
Test that LiteLLM correctly handles an incomplete response from OpenAI's o1-pro model
|
||||||
|
due to reaching max_output_tokens limit.
|
||||||
|
"""
|
||||||
|
# Mock response from o1-pro
|
||||||
|
mock_response = {
|
||||||
|
"id": "resp_67dc3dd77b388190822443a85252da5a0e13d8bdc0e28d88",
|
||||||
|
"object": "response",
|
||||||
|
"created_at": 1742486999,
|
||||||
|
"status": "incomplete",
|
||||||
|
"error": None,
|
||||||
|
"incomplete_details": {"reason": "max_output_tokens"},
|
||||||
|
"instructions": None,
|
||||||
|
"max_output_tokens": 20,
|
||||||
|
"model": "o1-pro-2025-03-19",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"type": "reasoning",
|
||||||
|
"id": "rs_67dc3de50f64819097450ed50a33d5f90e13d8bdc0e28d88",
|
||||||
|
"summary": [],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"parallel_tool_calls": True,
|
||||||
|
"previous_response_id": None,
|
||||||
|
"reasoning": {"effort": "medium", "generate_summary": None},
|
||||||
|
"store": True,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"text": {"format": {"type": "text"}},
|
||||||
|
"tool_choice": "auto",
|
||||||
|
"tools": [],
|
||||||
|
"top_p": 1.0,
|
||||||
|
"truncation": "disabled",
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 73,
|
||||||
|
"input_tokens_details": {"cached_tokens": 0},
|
||||||
|
"output_tokens": 20,
|
||||||
|
"output_tokens_details": {"reasoning_tokens": 0},
|
||||||
|
"total_tokens": 93,
|
||||||
|
},
|
||||||
|
"user": None,
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, json_data, status_code):
|
||||||
|
self._json_data = json_data
|
||||||
|
self.status_code = status_code
|
||||||
|
self.text = json.dumps(json_data)
|
||||||
|
|
||||||
|
def json(self): # Changed from async to sync
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_post:
|
||||||
|
# Configure the mock to return our response
|
||||||
|
mock_post.return_value = MockResponse(mock_response, 200)
|
||||||
|
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
# Call o1-pro with max_output_tokens=20
|
||||||
|
response = await litellm.aresponses(
|
||||||
|
model="openai/o1-pro",
|
||||||
|
input="Write a detailed essay about artificial intelligence and its impact on society",
|
||||||
|
max_output_tokens=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the request was made correctly
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
request_body = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
assert request_body["model"] == "o1-pro"
|
||||||
|
assert request_body["max_output_tokens"] == 20
|
||||||
|
|
||||||
|
# Validate the response
|
||||||
|
print("Response:", json.dumps(response, indent=4, default=str))
|
||||||
|
|
||||||
|
# Check that the response has the expected structure
|
||||||
|
assert response["id"] == mock_response["id"]
|
||||||
|
assert response["status"] == "incomplete"
|
||||||
|
assert response["incomplete_details"].reason == "max_output_tokens"
|
||||||
|
assert response["max_output_tokens"] == 20
|
||||||
|
|
||||||
|
# Validate usage information
|
||||||
|
assert response["usage"]["input_tokens"] == 73
|
||||||
|
assert response["usage"]["output_tokens"] == 20
|
||||||
|
assert response["usage"]["total_tokens"] == 93
|
||||||
|
|
||||||
|
# Validate that the response is properly identified as incomplete
|
||||||
|
validate_responses_api_response(response, final_chunk=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
async def test_openai_o1_pro_response_api_streaming(sync_mode):
|
||||||
|
"""
|
||||||
|
Test that LiteLLM correctly handles an incomplete response from OpenAI's o1-pro model
|
||||||
|
due to reaching max_output_tokens limit in both sync and async streaming modes.
|
||||||
|
"""
|
||||||
|
# Mock response from o1-pro
|
||||||
|
mock_response = {
|
||||||
|
"id": "resp_67dc3dd77b388190822443a85252da5a0e13d8bdc0e28d88",
|
||||||
|
"object": "response",
|
||||||
|
"created_at": 1742486999,
|
||||||
|
"status": "incomplete",
|
||||||
|
"error": None,
|
||||||
|
"incomplete_details": {"reason": "max_output_tokens"},
|
||||||
|
"instructions": None,
|
||||||
|
"max_output_tokens": 20,
|
||||||
|
"model": "o1-pro-2025-03-19",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"type": "reasoning",
|
||||||
|
"id": "rs_67dc3de50f64819097450ed50a33d5f90e13d8bdc0e28d88",
|
||||||
|
"summary": [],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"parallel_tool_calls": True,
|
||||||
|
"previous_response_id": None,
|
||||||
|
"reasoning": {"effort": "medium", "generate_summary": None},
|
||||||
|
"store": True,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"text": {"format": {"type": "text"}},
|
||||||
|
"tool_choice": "auto",
|
||||||
|
"tools": [],
|
||||||
|
"top_p": 1.0,
|
||||||
|
"truncation": "disabled",
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 73,
|
||||||
|
"input_tokens_details": {"cached_tokens": 0},
|
||||||
|
"output_tokens": 20,
|
||||||
|
"output_tokens_details": {"reasoning_tokens": 0},
|
||||||
|
"total_tokens": 93,
|
||||||
|
},
|
||||||
|
"user": None,
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, json_data, status_code):
|
||||||
|
self._json_data = json_data
|
||||||
|
self.status_code = status_code
|
||||||
|
self.text = json.dumps(json_data)
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_post:
|
||||||
|
# Configure the mock to return our response
|
||||||
|
mock_post.return_value = MockResponse(mock_response, 200)
|
||||||
|
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
# Verify the request was made correctly
|
||||||
|
if sync_mode:
|
||||||
|
# For sync mode, we need to patch the sync HTTP handler
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||||
|
return_value=MockResponse(mock_response, 200),
|
||||||
|
) as mock_sync_post:
|
||||||
|
response = litellm.responses(
|
||||||
|
model="openai/o1-pro",
|
||||||
|
input="Write a detailed essay about artificial intelligence and its impact on society",
|
||||||
|
max_output_tokens=20,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the sync stream
|
||||||
|
event_count = 0
|
||||||
|
for event in response:
|
||||||
|
print(
|
||||||
|
f"Sync litellm response #{event_count}:",
|
||||||
|
json.dumps(event, indent=4, default=str),
|
||||||
|
)
|
||||||
|
event_count += 1
|
||||||
|
|
||||||
|
# Verify the sync request was made correctly
|
||||||
|
mock_sync_post.assert_called_once()
|
||||||
|
request_body = json.loads(mock_sync_post.call_args.kwargs["data"])
|
||||||
|
assert request_body["model"] == "o1-pro"
|
||||||
|
assert request_body["max_output_tokens"] == 20
|
||||||
|
assert "stream" not in request_body
|
||||||
|
else:
|
||||||
|
# For async mode
|
||||||
|
response = await litellm.aresponses(
|
||||||
|
model="openai/o1-pro",
|
||||||
|
input="Write a detailed essay about artificial intelligence and its impact on society",
|
||||||
|
max_output_tokens=20,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the async stream
|
||||||
|
event_count = 0
|
||||||
|
async for event in response:
|
||||||
|
print(
|
||||||
|
f"Async litellm response #{event_count}:",
|
||||||
|
json.dumps(event, indent=4, default=str),
|
||||||
|
)
|
||||||
|
event_count += 1
|
||||||
|
|
||||||
|
# Verify the async request was made correctly
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
request_body = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
assert request_body["model"] == "o1-pro"
|
||||||
|
assert request_body["max_output_tokens"] == 20
|
||||||
|
assert "stream" not in request_body
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue