From d3e04eac7f961bedc99ee4007f4bfb053c1944e2 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 17 Apr 2025 16:47:59 -0700 Subject: [PATCH] [Feat] Unified Responses API - Add Azure Responses API support (#10116) * initial commit for azure responses api support * update get complete url * fixes for responses API * working azure responses API * working responses API * test suite for responses API * azure responses API test suite * fix test with complete url * fix test refactor * test fix metadata checks * fix code quality check --- litellm/__init__.py | 99 ++++++----- .../llms/azure/responses/transformation.py | 94 +++++++++++ .../llms/base_llm/responses/transformation.py | 3 + litellm/llms/custom_httpx/llm_http_handler.py | 23 ++- .../llms/openai/responses/transformation.py | 3 + litellm/responses/utils.py | 8 +- litellm/utils.py | 40 ++--- .../test_openai_responses_transformation.py | 40 ++++- .../base_responses_api.py | 158 ++++++++++++++++++ .../test_azure_responses_api.py | 31 ++++ .../test_openai_responses_api.py | 120 +------------ 11 files changed, 428 insertions(+), 191 deletions(-) create mode 100644 litellm/llms/azure/responses/transformation.py create mode 100644 tests/llm_responses_api_testing/base_responses_api.py create mode 100644 tests/llm_responses_api_testing/test_azure_responses_api.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 7327c0e0ac..e9dadbfaf6 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -128,19 +128,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False require_auth_for_metrics_endpoint: Optional[bool] = False argilla_batch_size: Optional[int] = None datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload -gcs_pub_sub_use_v1: Optional[ - bool -] = False # if you want to use v1 gcs pubsub logged payload +gcs_pub_sub_use_v1: Optional[bool] = ( + False # if you want to use v1 gcs pubsub logged payload +) argilla_transformation_object: Optional[Dict[str, Any]] = None -_async_input_callback: List[ - Union[str, Callable, CustomLogger] -] = [] # internal variable - async custom callbacks are routed here. -_async_success_callback: List[ - Union[str, Callable, CustomLogger] -] = [] # internal variable - async custom callbacks are routed here. -_async_failure_callback: List[ - Union[str, Callable, CustomLogger] -] = [] # internal variable - async custom callbacks are routed here. +_async_input_callback: List[Union[str, Callable, CustomLogger]] = ( + [] +) # internal variable - async custom callbacks are routed here. +_async_success_callback: List[Union[str, Callable, CustomLogger]] = ( + [] +) # internal variable - async custom callbacks are routed here. +_async_failure_callback: List[Union[str, Callable, CustomLogger]] = ( + [] +) # internal variable - async custom callbacks are routed here. pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] turn_off_message_logging: Optional[bool] = False @@ -148,18 +148,18 @@ log_raw_request_response: bool = False redact_messages_in_exceptions: Optional[bool] = False redact_user_api_key_info: Optional[bool] = False filter_invalid_headers: Optional[bool] = False -add_user_information_to_llm_headers: Optional[ - bool -] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers +add_user_information_to_llm_headers: Optional[bool] = ( + None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers +) store_audit_logs = False # Enterprise feature, allow users to see audit logs ### end of callbacks ############# -email: Optional[ - str -] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -token: Optional[ - str -] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +email: Optional[str] = ( + None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +token: Optional[str] = ( + None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) telemetry = True max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False)) @@ -235,20 +235,24 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None enable_caching_on_provider_specific_optional_params: bool = ( False # feature-flag for caching on optional params - e.g. 'top_k' ) -caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -cache: Optional[ - Cache -] = None # cache object <- use this - https://docs.litellm.ai/docs/caching +caching: bool = ( + False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +caching_with_models: bool = ( + False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +) +cache: Optional[Cache] = ( + None # cache object <- use this - https://docs.litellm.ai/docs/caching +) default_in_memory_ttl: Optional[float] = None default_redis_ttl: Optional[float] = None default_redis_batch_cache_expiry: Optional[float] = None model_alias_map: Dict[str, str] = {} model_group_alias_map: Dict[str, str] = {} max_budget: float = 0.0 # set the max budget across all providers -budget_duration: Optional[ - str -] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). +budget_duration: Optional[str] = ( + None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). +) default_soft_budget: float = ( DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0 ) @@ -257,11 +261,15 @@ forward_traceparent_to_llm_provider: bool = False _current_cost = 0.0 # private variable, used if max budget is set error_logs: Dict = {} -add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt +add_function_to_prompt: bool = ( + False # if function calling not supported by api, append function call details to system prompt +) client_session: Optional[httpx.Client] = None aclient_session: Optional[httpx.AsyncClient] = None model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' -model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +model_cost_map_url: str = ( + "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +) suppress_debug_info = False dynamodb_table_name: Optional[str] = None s3_callback_params: Optional[Dict] = None @@ -284,7 +292,9 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None custom_prometheus_metadata_labels: List[str] = [] #### REQUEST PRIORITIZATION #### priority_reservation: Optional[Dict[str, float]] = None -force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. +force_ipv4: bool = ( + False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6. +) module_level_aclient = AsyncHTTPHandler( timeout=request_timeout, client_alias="module level aclient" ) @@ -298,13 +308,13 @@ fallbacks: Optional[List] = None context_window_fallbacks: Optional[List] = None content_policy_fallbacks: Optional[List] = None allowed_fails: int = 3 -num_retries_per_request: Optional[ - int -] = None # for the request overall (incl. fallbacks + model retries) +num_retries_per_request: Optional[int] = ( + None # for the request overall (incl. fallbacks + model retries) +) ####### SECRET MANAGERS ##################### -secret_manager_client: Optional[ - Any -] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +secret_manager_client: Optional[Any] = ( + None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +) _google_kms_resource_name: Optional[str] = None _key_management_system: Optional[KeyManagementSystem] = None _key_management_settings: KeyManagementSettings = KeyManagementSettings() @@ -939,6 +949,7 @@ from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig from .llms.azure_ai.chat.transformation import AzureAIStudioConfig from .llms.mistral.mistral_chat_transformation import MistralConfig from .llms.openai.responses.transformation import OpenAIResponsesAPIConfig +from .llms.azure.responses.transformation import AzureOpenAIResponsesAPIConfig from .llms.openai.chat.o_series_transformation import ( OpenAIOSeriesConfig as OpenAIO1Config, # maintain backwards compatibility OpenAIOSeriesConfig, @@ -1055,10 +1066,10 @@ from .types.llms.custom_llm import CustomLLMItem from .types.utils import GenericStreamingChunk custom_provider_map: List[CustomLLMItem] = [] -_custom_providers: List[ - str -] = [] # internal helper util, used to track names of custom providers -disable_hf_tokenizer_download: Optional[ - bool -] = None # disable huggingface tokenizer download. Defaults to openai clk100 +_custom_providers: List[str] = ( + [] +) # internal helper util, used to track names of custom providers +disable_hf_tokenizer_download: Optional[bool] = ( + None # disable huggingface tokenizer download. Defaults to openai clk100 +) global_disable_no_log_param: bool = False diff --git a/litellm/llms/azure/responses/transformation.py b/litellm/llms/azure/responses/transformation.py new file mode 100644 index 0000000000..a85ba73bec --- /dev/null +++ b/litellm/llms/azure/responses/transformation.py @@ -0,0 +1,94 @@ +from typing import TYPE_CHECKING, Any, Optional, cast + +import httpx + +import litellm +from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import * +from litellm.utils import _add_path_to_api_base + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig): + def validate_environment( + self, + headers: dict, + model: str, + api_key: Optional[str] = None, + ) -> dict: + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret_str("AZURE_OPENAI_API_KEY") + or get_secret_str("AZURE_API_KEY") + ) + + headers.update( + { + "Authorization": f"Bearer {api_key}", + } + ) + return headers + + def get_complete_url( + self, + api_base: Optional[str], + api_key: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + """ + Constructs a complete URL for the API request. + + Args: + - api_base: Base URL, e.g., + "https://litellm8397336933.openai.azure.com" + OR + "https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview" + - model: Model name. + - optional_params: Additional query parameters, including "api_version". + - stream: If streaming is required (optional). + + Returns: + - A complete URL string, e.g., + "https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview" + """ + api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") + if api_base is None: + raise ValueError( + f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`" + ) + original_url = httpx.URL(api_base) + + # Extract api_version or use default + api_version = cast(Optional[str], litellm_params.get("api_version")) + + # Create a new dictionary with existing params + query_params = dict(original_url.params) + + # Add api_version if needed + if "api-version" not in query_params and api_version: + query_params["api-version"] = api_version + + # Add the path to the base URL + if "/openai/responses" not in api_base: + new_url = _add_path_to_api_base( + api_base=api_base, ending_path="/openai/responses" + ) + else: + new_url = api_base + + # Use the new query_params dictionary + final_url = httpx.URL(new_url).copy_with(params=query_params) + + return str(final_url) diff --git a/litellm/llms/base_llm/responses/transformation.py b/litellm/llms/base_llm/responses/transformation.py index e98a579845..649b91226f 100644 --- a/litellm/llms/base_llm/responses/transformation.py +++ b/litellm/llms/base_llm/responses/transformation.py @@ -73,7 +73,10 @@ class BaseResponsesAPIConfig(ABC): def get_complete_url( self, api_base: Optional[str], + api_key: Optional[str], model: str, + optional_params: dict, + litellm_params: dict, stream: Optional[bool] = None, ) -> str: """ diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 1ab8a94adf..6dd47cc223 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -462,7 +462,7 @@ class BaseLLMHTTPHandler: ) if fake_stream is True: - model_response: (ModelResponse) = provider_config.transform_response( + model_response: ModelResponse = provider_config.transform_response( model=model, raw_response=response, model_response=litellm.ModelResponse(), @@ -595,7 +595,7 @@ class BaseLLMHTTPHandler: ) if fake_stream is True: - model_response: (ModelResponse) = provider_config.transform_response( + model_response: ModelResponse = provider_config.transform_response( model=model, raw_response=response, model_response=litellm.ModelResponse(), @@ -1055,9 +1055,16 @@ class BaseLLMHTTPHandler: if extra_headers: headers.update(extra_headers) + # Check if streaming is requested + stream = response_api_optional_request_params.get("stream", False) + api_base = responses_api_provider_config.get_complete_url( api_base=litellm_params.api_base, + api_key=litellm_params.api_key, model=model, + optional_params=response_api_optional_request_params, + litellm_params=dict(litellm_params), + stream=stream, ) data = responses_api_provider_config.transform_responses_api_request( @@ -1079,9 +1086,6 @@ class BaseLLMHTTPHandler: }, ) - # Check if streaming is requested - stream = response_api_optional_request_params.get("stream", False) - try: if stream: # For streaming, use stream=True in the request @@ -1170,9 +1174,16 @@ class BaseLLMHTTPHandler: if extra_headers: headers.update(extra_headers) + # Check if streaming is requested + stream = response_api_optional_request_params.get("stream", False) + api_base = responses_api_provider_config.get_complete_url( api_base=litellm_params.api_base, + api_key=litellm_params.api_key, model=model, + optional_params=response_api_optional_request_params, + litellm_params=dict(litellm_params), + stream=stream, ) data = responses_api_provider_config.transform_responses_api_request( @@ -1193,8 +1204,6 @@ class BaseLLMHTTPHandler: "headers": headers, }, ) - # Check if streaming is requested - stream = response_api_optional_request_params.get("stream", False) try: if stream: diff --git a/litellm/llms/openai/responses/transformation.py b/litellm/llms/openai/responses/transformation.py index e062c0c9fa..047572657c 100644 --- a/litellm/llms/openai/responses/transformation.py +++ b/litellm/llms/openai/responses/transformation.py @@ -110,7 +110,10 @@ class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig): def get_complete_url( self, api_base: Optional[str], + api_key: Optional[str], model: str, + optional_params: dict, + litellm_params: dict, stream: Optional[bool] = None, ) -> str: """ diff --git a/litellm/responses/utils.py b/litellm/responses/utils.py index 679b9e16c6..fbb3527d22 100644 --- a/litellm/responses/utils.py +++ b/litellm/responses/utils.py @@ -60,7 +60,7 @@ class ResponsesAPIRequestUtils: @staticmethod def get_requested_response_api_optional_param( - params: Dict[str, Any] + params: Dict[str, Any], ) -> ResponsesAPIOptionalRequestParams: """ Filter parameters to only include those defined in ResponsesAPIOptionalRequestParams. @@ -72,7 +72,9 @@ class ResponsesAPIRequestUtils: ResponsesAPIOptionalRequestParams instance with only the valid parameters """ valid_keys = get_type_hints(ResponsesAPIOptionalRequestParams).keys() - filtered_params = {k: v for k, v in params.items() if k in valid_keys} + filtered_params = { + k: v for k, v in params.items() if k in valid_keys and v is not None + } return cast(ResponsesAPIOptionalRequestParams, filtered_params) @@ -88,7 +90,7 @@ class ResponseAPILoggingUtils: @staticmethod def _transform_response_api_usage_to_chat_usage( - usage: Union[dict, ResponseAPIUsage] + usage: Union[dict, ResponseAPIUsage], ) -> Usage: """Tranforms the ResponseAPIUsage object to a Usage object""" response_api_usage: ResponseAPIUsage = ( diff --git a/litellm/utils.py b/litellm/utils.py index 70a48a9360..141eadf624 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -516,9 +516,9 @@ def function_setup( # noqa: PLR0915 function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None ## DYNAMIC CALLBACKS ## - dynamic_callbacks: Optional[ - List[Union[str, Callable, CustomLogger]] - ] = kwargs.pop("callbacks", None) + dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = ( + kwargs.pop("callbacks", None) + ) all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks) if len(all_callbacks) > 0: @@ -1202,9 +1202,9 @@ def client(original_function): # noqa: PLR0915 exception=e, retry_policy=kwargs.get("retry_policy"), ) - kwargs[ - "retry_policy" - ] = reset_retry_policy() # prevent infinite loops + kwargs["retry_policy"] = ( + reset_retry_policy() + ) # prevent infinite loops litellm.num_retries = ( None # set retries to None to prevent infinite loops ) @@ -3013,16 +3013,16 @@ def get_optional_params( # noqa: PLR0915 True # so that main.py adds the function call to the prompt ) if "tools" in non_default_params: - optional_params[ - "functions_unsupported_model" - ] = non_default_params.pop("tools") + optional_params["functions_unsupported_model"] = ( + non_default_params.pop("tools") + ) non_default_params.pop( "tool_choice", None ) # causes ollama requests to hang elif "functions" in non_default_params: - optional_params[ - "functions_unsupported_model" - ] = non_default_params.pop("functions") + optional_params["functions_unsupported_model"] = ( + non_default_params.pop("functions") + ) elif ( litellm.add_function_to_prompt ): # if user opts to add it to prompt instead @@ -3045,10 +3045,10 @@ def get_optional_params( # noqa: PLR0915 if "response_format" in non_default_params: if provider_config is not None: - non_default_params[ - "response_format" - ] = provider_config.get_json_schema_from_pydantic_object( - response_format=non_default_params["response_format"] + non_default_params["response_format"] = ( + provider_config.get_json_schema_from_pydantic_object( + response_format=non_default_params["response_format"] + ) ) else: non_default_params["response_format"] = type_to_response_format_param( @@ -4064,9 +4064,9 @@ def _count_characters(text: str) -> int: def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str: - _choices: Union[ - List[Union[Choices, StreamingChoices]], List[StreamingChoices] - ] = response_obj.choices + _choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = ( + response_obj.choices + ) response_str = "" for choice in _choices: @@ -6602,6 +6602,8 @@ class ProviderConfigManager: ) -> Optional[BaseResponsesAPIConfig]: if litellm.LlmProviders.OPENAI == provider: return litellm.OpenAIResponsesAPIConfig() + elif litellm.LlmProviders.AZURE == provider: + return litellm.AzureOpenAIResponsesAPIConfig() return None @staticmethod diff --git a/tests/litellm/llms/openai/responses/test_openai_responses_transformation.py b/tests/litellm/llms/openai/responses/test_openai_responses_transformation.py index b4a6cd974e..202d0aea23 100644 --- a/tests/litellm/llms/openai/responses/test_openai_responses_transformation.py +++ b/tests/litellm/llms/openai/responses/test_openai_responses_transformation.py @@ -201,13 +201,25 @@ class TestOpenAIResponsesAPIConfig: # Test with provided API base api_base = "https://custom-openai.example.com/v1" - result = self.config.get_complete_url(api_base=api_base, model=self.model) + result = self.config.get_complete_url( + api_base=api_base, + model=self.model, + api_key="test_api_key", + optional_params={}, + litellm_params={}, + ) assert result == "https://custom-openai.example.com/v1/responses" # Test with litellm.api_base with patch("litellm.api_base", "https://litellm-api-base.example.com/v1"): - result = self.config.get_complete_url(api_base=None, model=self.model) + result = self.config.get_complete_url( + api_base=None, + model=self.model, + api_key="test_api_key", + optional_params={}, + litellm_params={}, + ) assert result == "https://litellm-api-base.example.com/v1/responses" @@ -217,7 +229,13 @@ class TestOpenAIResponsesAPIConfig: "litellm.llms.openai.responses.transformation.get_secret_str", return_value="https://env-api-base.example.com/v1", ): - result = self.config.get_complete_url(api_base=None, model=self.model) + result = self.config.get_complete_url( + api_base=None, + model=self.model, + api_key="test_api_key", + optional_params={}, + litellm_params={}, + ) assert result == "https://env-api-base.example.com/v1/responses" @@ -227,13 +245,25 @@ class TestOpenAIResponsesAPIConfig: "litellm.llms.openai.responses.transformation.get_secret_str", return_value=None, ): - result = self.config.get_complete_url(api_base=None, model=self.model) + result = self.config.get_complete_url( + api_base=None, + model=self.model, + api_key="test_api_key", + optional_params={}, + litellm_params={}, + ) assert result == "https://api.openai.com/v1/responses" # Test with trailing slash in API base api_base = "https://custom-openai.example.com/v1/" - result = self.config.get_complete_url(api_base=api_base, model=self.model) + result = self.config.get_complete_url( + api_base=api_base, + model=self.model, + api_key="test_api_key", + optional_params={}, + litellm_params={}, + ) assert result == "https://custom-openai.example.com/v1/responses" diff --git a/tests/llm_responses_api_testing/base_responses_api.py b/tests/llm_responses_api_testing/base_responses_api.py new file mode 100644 index 0000000000..356fe5e78e --- /dev/null +++ b/tests/llm_responses_api_testing/base_responses_api.py @@ -0,0 +1,158 @@ + +import httpx +import json +import pytest +import sys +from typing import Any, Dict, List +from unittest.mock import MagicMock, Mock, patch +import os +import uuid +import time +import base64 + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from abc import ABC, abstractmethod + +from litellm.integrations.custom_logger import CustomLogger +import json +from litellm.types.utils import StandardLoggingPayload +from litellm.types.llms.openai import ( + ResponseCompletedEvent, + ResponsesAPIResponse, + ResponseTextConfig, + ResponseAPIUsage, + IncompleteDetails, +) +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler + + +def validate_responses_api_response(response, final_chunk: bool = False): + """ + Validate that a response from litellm.responses() or litellm.aresponses() + conforms to the expected ResponsesAPIResponse structure. + + Args: + response: The response object to validate + + Raises: + AssertionError: If the response doesn't match the expected structure + """ + # Validate response structure + print("response=", json.dumps(response, indent=4, default=str)) + assert isinstance( + response, ResponsesAPIResponse + ), "Response should be an instance of ResponsesAPIResponse" + + # Required fields + assert "id" in response and isinstance( + response["id"], str + ), "Response should have a string 'id' field" + assert "created_at" in response and isinstance( + response["created_at"], (int, float) + ), "Response should have a numeric 'created_at' field" + assert "output" in response and isinstance( + response["output"], list + ), "Response should have a list 'output' field" + assert "parallel_tool_calls" in response and isinstance( + response["parallel_tool_calls"], bool + ), "Response should have a boolean 'parallel_tool_calls' field" + + # Optional fields with their expected types + optional_fields = { + "error": (dict, type(None)), # error can be dict or None + "incomplete_details": (IncompleteDetails, type(None)), + "instructions": (str, type(None)), + "metadata": dict, + "model": str, + "object": str, + "temperature": (int, float), + "tool_choice": (dict, str), + "tools": list, + "top_p": (int, float), + "max_output_tokens": (int, type(None)), + "previous_response_id": (str, type(None)), + "reasoning": dict, + "status": str, + "text": ResponseTextConfig, + "truncation": str, + "usage": ResponseAPIUsage, + "user": (str, type(None)), + } + if final_chunk is False: + optional_fields["usage"] = type(None) + + for field, expected_type in optional_fields.items(): + if field in response: + assert isinstance( + response[field], expected_type + ), f"Field '{field}' should be of type {expected_type}, but got {type(response[field])}" + + # Check if output has at least one item + if final_chunk is True: + assert ( + len(response["output"]) > 0 + ), "Response 'output' field should have at least one item" + + return True # Return True if validation passes + + + +class BaseResponsesAPITest(ABC): + """ + Abstract base test class that enforces a common test across all test classes. + """ + @abstractmethod + def get_base_completion_call_args(self) -> dict: + """Must return the base completion call args""" + pass + + + @pytest.mark.parametrize("sync_mode", [True, False]) + @pytest.mark.asyncio + async def test_basic_openai_responses_api(self, sync_mode): + litellm._turn_on_debug() + litellm.set_verbose = True + base_completion_call_args = self.get_base_completion_call_args() + if sync_mode: + response = litellm.responses( + input="Basic ping", max_output_tokens=20, + **base_completion_call_args + ) + else: + response = await litellm.aresponses( + input="Basic ping", max_output_tokens=20, + **base_completion_call_args + ) + + print("litellm response=", json.dumps(response, indent=4, default=str)) + + # Use the helper function to validate the response + validate_responses_api_response(response, final_chunk=True) + + + @pytest.mark.parametrize("sync_mode", [True]) + @pytest.mark.asyncio + async def test_basic_openai_responses_api_streaming(self, sync_mode): + litellm._turn_on_debug() + base_completion_call_args = self.get_base_completion_call_args() + if sync_mode: + response = litellm.responses( + input="Basic ping", + stream=True, + **base_completion_call_args + ) + for event in response: + print("litellm response=", json.dumps(event, indent=4, default=str)) + else: + response = await litellm.aresponses( + input="Basic ping", + stream=True, + **base_completion_call_args + ) + async for event in response: + print("litellm response=", json.dumps(event, indent=4, default=str)) + + diff --git a/tests/llm_responses_api_testing/test_azure_responses_api.py b/tests/llm_responses_api_testing/test_azure_responses_api.py new file mode 100644 index 0000000000..725a6efd5a --- /dev/null +++ b/tests/llm_responses_api_testing/test_azure_responses_api.py @@ -0,0 +1,31 @@ +import os +import sys +import pytest +import asyncio +from typing import Optional +from unittest.mock import patch, AsyncMock + +sys.path.insert(0, os.path.abspath("../..")) +import litellm +from litellm.integrations.custom_logger import CustomLogger +import json +from litellm.types.utils import StandardLoggingPayload +from litellm.types.llms.openai import ( + ResponseCompletedEvent, + ResponsesAPIResponse, + ResponseTextConfig, + ResponseAPIUsage, + IncompleteDetails, +) +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from base_responses_api import BaseResponsesAPITest + +class TestAzureResponsesAPITest(BaseResponsesAPITest): + def get_base_completion_call_args(self): + return { + "model": "azure/computer-use-preview", + "truncation": "auto", + "api_base": os.getenv("AZURE_RESPONSES_OPENAI_ENDPOINT"), + "api_key": os.getenv("AZURE_RESPONSES_OPENAI_API_KEY"), + "api_version": os.getenv("AZURE_RESPONSES_OPENAI_API_VERSION"), + } diff --git a/tests/llm_responses_api_testing/test_openai_responses_api.py b/tests/llm_responses_api_testing/test_openai_responses_api.py index 677e13b08a..a1b636c657 100644 --- a/tests/llm_responses_api_testing/test_openai_responses_api.py +++ b/tests/llm_responses_api_testing/test_openai_responses_api.py @@ -18,119 +18,13 @@ from litellm.types.llms.openai import ( IncompleteDetails, ) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from base_responses_api import BaseResponsesAPITest, validate_responses_api_response - -def validate_responses_api_response(response, final_chunk: bool = False): - """ - Validate that a response from litellm.responses() or litellm.aresponses() - conforms to the expected ResponsesAPIResponse structure. - - Args: - response: The response object to validate - - Raises: - AssertionError: If the response doesn't match the expected structure - """ - # Validate response structure - print("response=", json.dumps(response, indent=4, default=str)) - assert isinstance( - response, ResponsesAPIResponse - ), "Response should be an instance of ResponsesAPIResponse" - - # Required fields - assert "id" in response and isinstance( - response["id"], str - ), "Response should have a string 'id' field" - assert "created_at" in response and isinstance( - response["created_at"], (int, float) - ), "Response should have a numeric 'created_at' field" - assert "output" in response and isinstance( - response["output"], list - ), "Response should have a list 'output' field" - assert "parallel_tool_calls" in response and isinstance( - response["parallel_tool_calls"], bool - ), "Response should have a boolean 'parallel_tool_calls' field" - - # Optional fields with their expected types - optional_fields = { - "error": (dict, type(None)), # error can be dict or None - "incomplete_details": (IncompleteDetails, type(None)), - "instructions": (str, type(None)), - "metadata": dict, - "model": str, - "object": str, - "temperature": (int, float), - "tool_choice": (dict, str), - "tools": list, - "top_p": (int, float), - "max_output_tokens": (int, type(None)), - "previous_response_id": (str, type(None)), - "reasoning": dict, - "status": str, - "text": ResponseTextConfig, - "truncation": str, - "usage": ResponseAPIUsage, - "user": (str, type(None)), - } - if final_chunk is False: - optional_fields["usage"] = type(None) - - for field, expected_type in optional_fields.items(): - if field in response: - assert isinstance( - response[field], expected_type - ), f"Field '{field}' should be of type {expected_type}, but got {type(response[field])}" - - # Check if output has at least one item - if final_chunk is True: - assert ( - len(response["output"]) > 0 - ), "Response 'output' field should have at least one item" - - return True # Return True if validation passes - - -@pytest.mark.parametrize("sync_mode", [True, False]) -@pytest.mark.asyncio -async def test_basic_openai_responses_api(sync_mode): - litellm._turn_on_debug() - litellm.set_verbose = True - if sync_mode: - response = litellm.responses( - model="gpt-4o", input="Basic ping", max_output_tokens=20 - ) - else: - response = await litellm.aresponses( - model="gpt-4o", input="Basic ping", max_output_tokens=20 - ) - - print("litellm response=", json.dumps(response, indent=4, default=str)) - - # Use the helper function to validate the response - validate_responses_api_response(response, final_chunk=True) - - -@pytest.mark.parametrize("sync_mode", [True]) -@pytest.mark.asyncio -async def test_basic_openai_responses_api_streaming(sync_mode): - litellm._turn_on_debug() - - if sync_mode: - response = litellm.responses( - model="gpt-4o", - input="Basic ping", - stream=True, - ) - for event in response: - print("litellm response=", json.dumps(event, indent=4, default=str)) - else: - response = await litellm.aresponses( - model="gpt-4o", - input="Basic ping", - stream=True, - ) - async for event in response: - print("litellm response=", json.dumps(event, indent=4, default=str)) +class TestOpenAIResponsesAPITest(BaseResponsesAPITest): + def get_base_completion_call_args(self): + return { + "model": "openai/gpt-4o", + } class TestCustomLogger(CustomLogger): @@ -693,7 +587,7 @@ async def test_openai_responses_litellm_router_no_metadata(): # Assert metadata is not in the request assert ( - loaded_request_body["metadata"] == None + "metadata" not in loaded_request_body ), "metadata should not be in the request body" mock_post.assert_called_once()