mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
[Feat] Unified Responses API - Add Azure Responses API support (#10116)
* initial commit for azure responses api support * update get complete url * fixes for responses API * working azure responses API * working responses API * test suite for responses API * azure responses API test suite * fix test with complete url * fix test refactor * test fix metadata checks * fix code quality check
This commit is contained in:
parent
8be8022914
commit
d3e04eac7f
11 changed files with 428 additions and 191 deletions
|
@ -128,19 +128,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False
|
|||
require_auth_for_metrics_endpoint: Optional[bool] = False
|
||||
argilla_batch_size: Optional[int] = None
|
||||
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
||||
gcs_pub_sub_use_v1: Optional[
|
||||
bool
|
||||
] = False # if you want to use v1 gcs pubsub logged payload
|
||||
gcs_pub_sub_use_v1: Optional[bool] = (
|
||||
False # if you want to use v1 gcs pubsub logged payload
|
||||
)
|
||||
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
||||
_async_input_callback: List[
|
||||
Union[str, Callable, CustomLogger]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[
|
||||
Union[str, Callable, CustomLogger]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[
|
||||
Union[str, Callable, CustomLogger]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
pre_call_rules: List[Callable] = []
|
||||
post_call_rules: List[Callable] = []
|
||||
turn_off_message_logging: Optional[bool] = False
|
||||
|
@ -148,18 +148,18 @@ log_raw_request_response: bool = False
|
|||
redact_messages_in_exceptions: Optional[bool] = False
|
||||
redact_user_api_key_info: Optional[bool] = False
|
||||
filter_invalid_headers: Optional[bool] = False
|
||||
add_user_information_to_llm_headers: Optional[
|
||||
bool
|
||||
] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||
add_user_information_to_llm_headers: Optional[bool] = (
|
||||
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||
)
|
||||
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
||||
### end of callbacks #############
|
||||
|
||||
email: Optional[
|
||||
str
|
||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
token: Optional[
|
||||
str
|
||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
email: Optional[str] = (
|
||||
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
token: Optional[str] = (
|
||||
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
telemetry = True
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
|
||||
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
||||
|
@ -235,20 +235,24 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
|
|||
enable_caching_on_provider_specific_optional_params: bool = (
|
||||
False # feature-flag for caching on optional params - e.g. 'top_k'
|
||||
)
|
||||
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
cache: Optional[
|
||||
Cache
|
||||
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||
caching: bool = (
|
||||
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
caching_with_models: bool = (
|
||||
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
cache: Optional[Cache] = (
|
||||
None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||
)
|
||||
default_in_memory_ttl: Optional[float] = None
|
||||
default_redis_ttl: Optional[float] = None
|
||||
default_redis_batch_cache_expiry: Optional[float] = None
|
||||
model_alias_map: Dict[str, str] = {}
|
||||
model_group_alias_map: Dict[str, str] = {}
|
||||
max_budget: float = 0.0 # set the max budget across all providers
|
||||
budget_duration: Optional[
|
||||
str
|
||||
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
budget_duration: Optional[str] = (
|
||||
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
)
|
||||
default_soft_budget: float = (
|
||||
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
|
||||
)
|
||||
|
@ -257,11 +261,15 @@ forward_traceparent_to_llm_provider: bool = False
|
|||
|
||||
_current_cost = 0.0 # private variable, used if max budget is set
|
||||
error_logs: Dict = {}
|
||||
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
||||
add_function_to_prompt: bool = (
|
||||
False # if function calling not supported by api, append function call details to system prompt
|
||||
)
|
||||
client_session: Optional[httpx.Client] = None
|
||||
aclient_session: Optional[httpx.AsyncClient] = None
|
||||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
model_cost_map_url: str = (
|
||||
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
)
|
||||
suppress_debug_info = False
|
||||
dynamodb_table_name: Optional[str] = None
|
||||
s3_callback_params: Optional[Dict] = None
|
||||
|
@ -284,7 +292,9 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
|
|||
custom_prometheus_metadata_labels: List[str] = []
|
||||
#### REQUEST PRIORITIZATION ####
|
||||
priority_reservation: Optional[Dict[str, float]] = None
|
||||
force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
||||
force_ipv4: bool = (
|
||||
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
||||
)
|
||||
module_level_aclient = AsyncHTTPHandler(
|
||||
timeout=request_timeout, client_alias="module level aclient"
|
||||
)
|
||||
|
@ -298,13 +308,13 @@ fallbacks: Optional[List] = None
|
|||
context_window_fallbacks: Optional[List] = None
|
||||
content_policy_fallbacks: Optional[List] = None
|
||||
allowed_fails: int = 3
|
||||
num_retries_per_request: Optional[
|
||||
int
|
||||
] = None # for the request overall (incl. fallbacks + model retries)
|
||||
num_retries_per_request: Optional[int] = (
|
||||
None # for the request overall (incl. fallbacks + model retries)
|
||||
)
|
||||
####### SECRET MANAGERS #####################
|
||||
secret_manager_client: Optional[
|
||||
Any
|
||||
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||
secret_manager_client: Optional[Any] = (
|
||||
None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||
)
|
||||
_google_kms_resource_name: Optional[str] = None
|
||||
_key_management_system: Optional[KeyManagementSystem] = None
|
||||
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
||||
|
@ -939,6 +949,7 @@ from .llms.voyage.embedding.transformation import VoyageEmbeddingConfig
|
|||
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||
from .llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from .llms.azure.responses.transformation import AzureOpenAIResponsesAPIConfig
|
||||
from .llms.openai.chat.o_series_transformation import (
|
||||
OpenAIOSeriesConfig as OpenAIO1Config, # maintain backwards compatibility
|
||||
OpenAIOSeriesConfig,
|
||||
|
@ -1055,10 +1066,10 @@ from .types.llms.custom_llm import CustomLLMItem
|
|||
from .types.utils import GenericStreamingChunk
|
||||
|
||||
custom_provider_map: List[CustomLLMItem] = []
|
||||
_custom_providers: List[
|
||||
str
|
||||
] = [] # internal helper util, used to track names of custom providers
|
||||
disable_hf_tokenizer_download: Optional[
|
||||
bool
|
||||
] = None # disable huggingface tokenizer download. Defaults to openai clk100
|
||||
_custom_providers: List[str] = (
|
||||
[]
|
||||
) # internal helper util, used to track names of custom providers
|
||||
disable_hf_tokenizer_download: Optional[bool] = (
|
||||
None # disable huggingface tokenizer download. Defaults to openai clk100
|
||||
)
|
||||
global_disable_no_log_param: bool = False
|
||||
|
|
94
litellm/llms/azure/responses/transformation.py
Normal file
94
litellm/llms/azure/responses/transformation.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import *
|
||||
from litellm.utils import _add_path_to_api_base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a complete URL for the API request.
|
||||
|
||||
Args:
|
||||
- api_base: Base URL, e.g.,
|
||||
"https://litellm8397336933.openai.azure.com"
|
||||
OR
|
||||
"https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview"
|
||||
- model: Model name.
|
||||
- optional_params: Additional query parameters, including "api_version".
|
||||
- stream: If streaming is required (optional).
|
||||
|
||||
Returns:
|
||||
- A complete URL string, e.g.,
|
||||
"https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview"
|
||||
"""
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
|
||||
)
|
||||
original_url = httpx.URL(api_base)
|
||||
|
||||
# Extract api_version or use default
|
||||
api_version = cast(Optional[str], litellm_params.get("api_version"))
|
||||
|
||||
# Create a new dictionary with existing params
|
||||
query_params = dict(original_url.params)
|
||||
|
||||
# Add api_version if needed
|
||||
if "api-version" not in query_params and api_version:
|
||||
query_params["api-version"] = api_version
|
||||
|
||||
# Add the path to the base URL
|
||||
if "/openai/responses" not in api_base:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base, ending_path="/openai/responses"
|
||||
)
|
||||
else:
|
||||
new_url = api_base
|
||||
|
||||
# Use the new query_params dictionary
|
||||
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
||||
|
||||
return str(final_url)
|
|
@ -73,7 +73,10 @@ class BaseResponsesAPIConfig(ABC):
|
|||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
|
@ -462,7 +462,7 @@ class BaseLLMHTTPHandler:
|
|||
)
|
||||
|
||||
if fake_stream is True:
|
||||
model_response: (ModelResponse) = provider_config.transform_response(
|
||||
model_response: ModelResponse = provider_config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
|
@ -595,7 +595,7 @@ class BaseLLMHTTPHandler:
|
|||
)
|
||||
|
||||
if fake_stream is True:
|
||||
model_response: (ModelResponse) = provider_config.transform_response(
|
||||
model_response: ModelResponse = provider_config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
|
@ -1055,9 +1055,16 @@ class BaseLLMHTTPHandler:
|
|||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# Check if streaming is requested
|
||||
stream = response_api_optional_request_params.get("stream", False)
|
||||
|
||||
api_base = responses_api_provider_config.get_complete_url(
|
||||
api_base=litellm_params.api_base,
|
||||
api_key=litellm_params.api_key,
|
||||
model=model,
|
||||
optional_params=response_api_optional_request_params,
|
||||
litellm_params=dict(litellm_params),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
data = responses_api_provider_config.transform_responses_api_request(
|
||||
|
@ -1079,9 +1086,6 @@ class BaseLLMHTTPHandler:
|
|||
},
|
||||
)
|
||||
|
||||
# Check if streaming is requested
|
||||
stream = response_api_optional_request_params.get("stream", False)
|
||||
|
||||
try:
|
||||
if stream:
|
||||
# For streaming, use stream=True in the request
|
||||
|
@ -1170,9 +1174,16 @@ class BaseLLMHTTPHandler:
|
|||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# Check if streaming is requested
|
||||
stream = response_api_optional_request_params.get("stream", False)
|
||||
|
||||
api_base = responses_api_provider_config.get_complete_url(
|
||||
api_base=litellm_params.api_base,
|
||||
api_key=litellm_params.api_key,
|
||||
model=model,
|
||||
optional_params=response_api_optional_request_params,
|
||||
litellm_params=dict(litellm_params),
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
data = responses_api_provider_config.transform_responses_api_request(
|
||||
|
@ -1193,8 +1204,6 @@ class BaseLLMHTTPHandler:
|
|||
"headers": headers,
|
||||
},
|
||||
)
|
||||
# Check if streaming is requested
|
||||
stream = response_api_optional_request_params.get("stream", False)
|
||||
|
||||
try:
|
||||
if stream:
|
||||
|
|
|
@ -110,7 +110,10 @@ class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
|
|||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
|
|
|
@ -60,7 +60,7 @@ class ResponsesAPIRequestUtils:
|
|||
|
||||
@staticmethod
|
||||
def get_requested_response_api_optional_param(
|
||||
params: Dict[str, Any]
|
||||
params: Dict[str, Any],
|
||||
) -> ResponsesAPIOptionalRequestParams:
|
||||
"""
|
||||
Filter parameters to only include those defined in ResponsesAPIOptionalRequestParams.
|
||||
|
@ -72,7 +72,9 @@ class ResponsesAPIRequestUtils:
|
|||
ResponsesAPIOptionalRequestParams instance with only the valid parameters
|
||||
"""
|
||||
valid_keys = get_type_hints(ResponsesAPIOptionalRequestParams).keys()
|
||||
filtered_params = {k: v for k, v in params.items() if k in valid_keys}
|
||||
filtered_params = {
|
||||
k: v for k, v in params.items() if k in valid_keys and v is not None
|
||||
}
|
||||
return cast(ResponsesAPIOptionalRequestParams, filtered_params)
|
||||
|
||||
|
||||
|
@ -88,7 +90,7 @@ class ResponseAPILoggingUtils:
|
|||
|
||||
@staticmethod
|
||||
def _transform_response_api_usage_to_chat_usage(
|
||||
usage: Union[dict, ResponseAPIUsage]
|
||||
usage: Union[dict, ResponseAPIUsage],
|
||||
) -> Usage:
|
||||
"""Tranforms the ResponseAPIUsage object to a Usage object"""
|
||||
response_api_usage: ResponseAPIUsage = (
|
||||
|
|
|
@ -516,9 +516,9 @@ def function_setup( # noqa: PLR0915
|
|||
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
## DYNAMIC CALLBACKS ##
|
||||
dynamic_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = kwargs.pop("callbacks", None)
|
||||
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = (
|
||||
kwargs.pop("callbacks", None)
|
||||
)
|
||||
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
|
||||
|
||||
if len(all_callbacks) > 0:
|
||||
|
@ -1202,9 +1202,9 @@ def client(original_function): # noqa: PLR0915
|
|||
exception=e,
|
||||
retry_policy=kwargs.get("retry_policy"),
|
||||
)
|
||||
kwargs[
|
||||
"retry_policy"
|
||||
] = reset_retry_policy() # prevent infinite loops
|
||||
kwargs["retry_policy"] = (
|
||||
reset_retry_policy()
|
||||
) # prevent infinite loops
|
||||
litellm.num_retries = (
|
||||
None # set retries to None to prevent infinite loops
|
||||
)
|
||||
|
@ -3013,16 +3013,16 @@ def get_optional_params( # noqa: PLR0915
|
|||
True # so that main.py adds the function call to the prompt
|
||||
)
|
||||
if "tools" in non_default_params:
|
||||
optional_params[
|
||||
"functions_unsupported_model"
|
||||
] = non_default_params.pop("tools")
|
||||
optional_params["functions_unsupported_model"] = (
|
||||
non_default_params.pop("tools")
|
||||
)
|
||||
non_default_params.pop(
|
||||
"tool_choice", None
|
||||
) # causes ollama requests to hang
|
||||
elif "functions" in non_default_params:
|
||||
optional_params[
|
||||
"functions_unsupported_model"
|
||||
] = non_default_params.pop("functions")
|
||||
optional_params["functions_unsupported_model"] = (
|
||||
non_default_params.pop("functions")
|
||||
)
|
||||
elif (
|
||||
litellm.add_function_to_prompt
|
||||
): # if user opts to add it to prompt instead
|
||||
|
@ -3045,10 +3045,10 @@ def get_optional_params( # noqa: PLR0915
|
|||
|
||||
if "response_format" in non_default_params:
|
||||
if provider_config is not None:
|
||||
non_default_params[
|
||||
"response_format"
|
||||
] = provider_config.get_json_schema_from_pydantic_object(
|
||||
response_format=non_default_params["response_format"]
|
||||
non_default_params["response_format"] = (
|
||||
provider_config.get_json_schema_from_pydantic_object(
|
||||
response_format=non_default_params["response_format"]
|
||||
)
|
||||
)
|
||||
else:
|
||||
non_default_params["response_format"] = type_to_response_format_param(
|
||||
|
@ -4064,9 +4064,9 @@ def _count_characters(text: str) -> int:
|
|||
|
||||
|
||||
def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str:
|
||||
_choices: Union[
|
||||
List[Union[Choices, StreamingChoices]], List[StreamingChoices]
|
||||
] = response_obj.choices
|
||||
_choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = (
|
||||
response_obj.choices
|
||||
)
|
||||
|
||||
response_str = ""
|
||||
for choice in _choices:
|
||||
|
@ -6602,6 +6602,8 @@ class ProviderConfigManager:
|
|||
) -> Optional[BaseResponsesAPIConfig]:
|
||||
if litellm.LlmProviders.OPENAI == provider:
|
||||
return litellm.OpenAIResponsesAPIConfig()
|
||||
elif litellm.LlmProviders.AZURE == provider:
|
||||
return litellm.AzureOpenAIResponsesAPIConfig()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -201,13 +201,25 @@ class TestOpenAIResponsesAPIConfig:
|
|||
# Test with provided API base
|
||||
api_base = "https://custom-openai.example.com/v1"
|
||||
|
||||
result = self.config.get_complete_url(api_base=api_base, model=self.model)
|
||||
result = self.config.get_complete_url(
|
||||
api_base=api_base,
|
||||
model=self.model,
|
||||
api_key="test_api_key",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
assert result == "https://custom-openai.example.com/v1/responses"
|
||||
|
||||
# Test with litellm.api_base
|
||||
with patch("litellm.api_base", "https://litellm-api-base.example.com/v1"):
|
||||
result = self.config.get_complete_url(api_base=None, model=self.model)
|
||||
result = self.config.get_complete_url(
|
||||
api_base=None,
|
||||
model=self.model,
|
||||
api_key="test_api_key",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
assert result == "https://litellm-api-base.example.com/v1/responses"
|
||||
|
||||
|
@ -217,7 +229,13 @@ class TestOpenAIResponsesAPIConfig:
|
|||
"litellm.llms.openai.responses.transformation.get_secret_str",
|
||||
return_value="https://env-api-base.example.com/v1",
|
||||
):
|
||||
result = self.config.get_complete_url(api_base=None, model=self.model)
|
||||
result = self.config.get_complete_url(
|
||||
api_base=None,
|
||||
model=self.model,
|
||||
api_key="test_api_key",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
assert result == "https://env-api-base.example.com/v1/responses"
|
||||
|
||||
|
@ -227,13 +245,25 @@ class TestOpenAIResponsesAPIConfig:
|
|||
"litellm.llms.openai.responses.transformation.get_secret_str",
|
||||
return_value=None,
|
||||
):
|
||||
result = self.config.get_complete_url(api_base=None, model=self.model)
|
||||
result = self.config.get_complete_url(
|
||||
api_base=None,
|
||||
model=self.model,
|
||||
api_key="test_api_key",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
assert result == "https://api.openai.com/v1/responses"
|
||||
|
||||
# Test with trailing slash in API base
|
||||
api_base = "https://custom-openai.example.com/v1/"
|
||||
|
||||
result = self.config.get_complete_url(api_base=api_base, model=self.model)
|
||||
result = self.config.get_complete_url(
|
||||
api_base=api_base,
|
||||
model=self.model,
|
||||
api_key="test_api_key",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
assert result == "https://custom-openai.example.com/v1/responses"
|
||||
|
|
158
tests/llm_responses_api_testing/base_responses_api.py
Normal file
158
tests/llm_responses_api_testing/base_responses_api.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
|
||||
import httpx
|
||||
import json
|
||||
import pytest
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
import base64
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import json
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.types.llms.openai import (
|
||||
ResponseCompletedEvent,
|
||||
ResponsesAPIResponse,
|
||||
ResponseTextConfig,
|
||||
ResponseAPIUsage,
|
||||
IncompleteDetails,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
|
||||
def validate_responses_api_response(response, final_chunk: bool = False):
|
||||
"""
|
||||
Validate that a response from litellm.responses() or litellm.aresponses()
|
||||
conforms to the expected ResponsesAPIResponse structure.
|
||||
|
||||
Args:
|
||||
response: The response object to validate
|
||||
|
||||
Raises:
|
||||
AssertionError: If the response doesn't match the expected structure
|
||||
"""
|
||||
# Validate response structure
|
||||
print("response=", json.dumps(response, indent=4, default=str))
|
||||
assert isinstance(
|
||||
response, ResponsesAPIResponse
|
||||
), "Response should be an instance of ResponsesAPIResponse"
|
||||
|
||||
# Required fields
|
||||
assert "id" in response and isinstance(
|
||||
response["id"], str
|
||||
), "Response should have a string 'id' field"
|
||||
assert "created_at" in response and isinstance(
|
||||
response["created_at"], (int, float)
|
||||
), "Response should have a numeric 'created_at' field"
|
||||
assert "output" in response and isinstance(
|
||||
response["output"], list
|
||||
), "Response should have a list 'output' field"
|
||||
assert "parallel_tool_calls" in response and isinstance(
|
||||
response["parallel_tool_calls"], bool
|
||||
), "Response should have a boolean 'parallel_tool_calls' field"
|
||||
|
||||
# Optional fields with their expected types
|
||||
optional_fields = {
|
||||
"error": (dict, type(None)), # error can be dict or None
|
||||
"incomplete_details": (IncompleteDetails, type(None)),
|
||||
"instructions": (str, type(None)),
|
||||
"metadata": dict,
|
||||
"model": str,
|
||||
"object": str,
|
||||
"temperature": (int, float),
|
||||
"tool_choice": (dict, str),
|
||||
"tools": list,
|
||||
"top_p": (int, float),
|
||||
"max_output_tokens": (int, type(None)),
|
||||
"previous_response_id": (str, type(None)),
|
||||
"reasoning": dict,
|
||||
"status": str,
|
||||
"text": ResponseTextConfig,
|
||||
"truncation": str,
|
||||
"usage": ResponseAPIUsage,
|
||||
"user": (str, type(None)),
|
||||
}
|
||||
if final_chunk is False:
|
||||
optional_fields["usage"] = type(None)
|
||||
|
||||
for field, expected_type in optional_fields.items():
|
||||
if field in response:
|
||||
assert isinstance(
|
||||
response[field], expected_type
|
||||
), f"Field '{field}' should be of type {expected_type}, but got {type(response[field])}"
|
||||
|
||||
# Check if output has at least one item
|
||||
if final_chunk is True:
|
||||
assert (
|
||||
len(response["output"]) > 0
|
||||
), "Response 'output' field should have at least one item"
|
||||
|
||||
return True # Return True if validation passes
|
||||
|
||||
|
||||
|
||||
class BaseResponsesAPITest(ABC):
|
||||
"""
|
||||
Abstract base test class that enforces a common test across all test classes.
|
||||
"""
|
||||
@abstractmethod
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
"""Must return the base completion call args"""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_openai_responses_api(self, sync_mode):
|
||||
litellm._turn_on_debug()
|
||||
litellm.set_verbose = True
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
if sync_mode:
|
||||
response = litellm.responses(
|
||||
input="Basic ping", max_output_tokens=20,
|
||||
**base_completion_call_args
|
||||
)
|
||||
else:
|
||||
response = await litellm.aresponses(
|
||||
input="Basic ping", max_output_tokens=20,
|
||||
**base_completion_call_args
|
||||
)
|
||||
|
||||
print("litellm response=", json.dumps(response, indent=4, default=str))
|
||||
|
||||
# Use the helper function to validate the response
|
||||
validate_responses_api_response(response, final_chunk=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_openai_responses_api_streaming(self, sync_mode):
|
||||
litellm._turn_on_debug()
|
||||
base_completion_call_args = self.get_base_completion_call_args()
|
||||
if sync_mode:
|
||||
response = litellm.responses(
|
||||
input="Basic ping",
|
||||
stream=True,
|
||||
**base_completion_call_args
|
||||
)
|
||||
for event in response:
|
||||
print("litellm response=", json.dumps(event, indent=4, default=str))
|
||||
else:
|
||||
response = await litellm.aresponses(
|
||||
input="Basic ping",
|
||||
stream=True,
|
||||
**base_completion_call_args
|
||||
)
|
||||
async for event in response:
|
||||
print("litellm response=", json.dumps(event, indent=4, default=str))
|
||||
|
||||
|
31
tests/llm_responses_api_testing/test_azure_responses_api.py
Normal file
31
tests/llm_responses_api_testing/test_azure_responses_api.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
from unittest.mock import patch, AsyncMock
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import json
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
from litellm.types.llms.openai import (
|
||||
ResponseCompletedEvent,
|
||||
ResponsesAPIResponse,
|
||||
ResponseTextConfig,
|
||||
ResponseAPIUsage,
|
||||
IncompleteDetails,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from base_responses_api import BaseResponsesAPITest
|
||||
|
||||
class TestAzureResponsesAPITest(BaseResponsesAPITest):
|
||||
def get_base_completion_call_args(self):
|
||||
return {
|
||||
"model": "azure/computer-use-preview",
|
||||
"truncation": "auto",
|
||||
"api_base": os.getenv("AZURE_RESPONSES_OPENAI_ENDPOINT"),
|
||||
"api_key": os.getenv("AZURE_RESPONSES_OPENAI_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_RESPONSES_OPENAI_API_VERSION"),
|
||||
}
|
|
@ -18,119 +18,13 @@ from litellm.types.llms.openai import (
|
|||
IncompleteDetails,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from base_responses_api import BaseResponsesAPITest, validate_responses_api_response
|
||||
|
||||
|
||||
def validate_responses_api_response(response, final_chunk: bool = False):
|
||||
"""
|
||||
Validate that a response from litellm.responses() or litellm.aresponses()
|
||||
conforms to the expected ResponsesAPIResponse structure.
|
||||
|
||||
Args:
|
||||
response: The response object to validate
|
||||
|
||||
Raises:
|
||||
AssertionError: If the response doesn't match the expected structure
|
||||
"""
|
||||
# Validate response structure
|
||||
print("response=", json.dumps(response, indent=4, default=str))
|
||||
assert isinstance(
|
||||
response, ResponsesAPIResponse
|
||||
), "Response should be an instance of ResponsesAPIResponse"
|
||||
|
||||
# Required fields
|
||||
assert "id" in response and isinstance(
|
||||
response["id"], str
|
||||
), "Response should have a string 'id' field"
|
||||
assert "created_at" in response and isinstance(
|
||||
response["created_at"], (int, float)
|
||||
), "Response should have a numeric 'created_at' field"
|
||||
assert "output" in response and isinstance(
|
||||
response["output"], list
|
||||
), "Response should have a list 'output' field"
|
||||
assert "parallel_tool_calls" in response and isinstance(
|
||||
response["parallel_tool_calls"], bool
|
||||
), "Response should have a boolean 'parallel_tool_calls' field"
|
||||
|
||||
# Optional fields with their expected types
|
||||
optional_fields = {
|
||||
"error": (dict, type(None)), # error can be dict or None
|
||||
"incomplete_details": (IncompleteDetails, type(None)),
|
||||
"instructions": (str, type(None)),
|
||||
"metadata": dict,
|
||||
"model": str,
|
||||
"object": str,
|
||||
"temperature": (int, float),
|
||||
"tool_choice": (dict, str),
|
||||
"tools": list,
|
||||
"top_p": (int, float),
|
||||
"max_output_tokens": (int, type(None)),
|
||||
"previous_response_id": (str, type(None)),
|
||||
"reasoning": dict,
|
||||
"status": str,
|
||||
"text": ResponseTextConfig,
|
||||
"truncation": str,
|
||||
"usage": ResponseAPIUsage,
|
||||
"user": (str, type(None)),
|
||||
}
|
||||
if final_chunk is False:
|
||||
optional_fields["usage"] = type(None)
|
||||
|
||||
for field, expected_type in optional_fields.items():
|
||||
if field in response:
|
||||
assert isinstance(
|
||||
response[field], expected_type
|
||||
), f"Field '{field}' should be of type {expected_type}, but got {type(response[field])}"
|
||||
|
||||
# Check if output has at least one item
|
||||
if final_chunk is True:
|
||||
assert (
|
||||
len(response["output"]) > 0
|
||||
), "Response 'output' field should have at least one item"
|
||||
|
||||
return True # Return True if validation passes
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_openai_responses_api(sync_mode):
|
||||
litellm._turn_on_debug()
|
||||
litellm.set_verbose = True
|
||||
if sync_mode:
|
||||
response = litellm.responses(
|
||||
model="gpt-4o", input="Basic ping", max_output_tokens=20
|
||||
)
|
||||
else:
|
||||
response = await litellm.aresponses(
|
||||
model="gpt-4o", input="Basic ping", max_output_tokens=20
|
||||
)
|
||||
|
||||
print("litellm response=", json.dumps(response, indent=4, default=str))
|
||||
|
||||
# Use the helper function to validate the response
|
||||
validate_responses_api_response(response, final_chunk=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_openai_responses_api_streaming(sync_mode):
|
||||
litellm._turn_on_debug()
|
||||
|
||||
if sync_mode:
|
||||
response = litellm.responses(
|
||||
model="gpt-4o",
|
||||
input="Basic ping",
|
||||
stream=True,
|
||||
)
|
||||
for event in response:
|
||||
print("litellm response=", json.dumps(event, indent=4, default=str))
|
||||
else:
|
||||
response = await litellm.aresponses(
|
||||
model="gpt-4o",
|
||||
input="Basic ping",
|
||||
stream=True,
|
||||
)
|
||||
async for event in response:
|
||||
print("litellm response=", json.dumps(event, indent=4, default=str))
|
||||
class TestOpenAIResponsesAPITest(BaseResponsesAPITest):
|
||||
def get_base_completion_call_args(self):
|
||||
return {
|
||||
"model": "openai/gpt-4o",
|
||||
}
|
||||
|
||||
|
||||
class TestCustomLogger(CustomLogger):
|
||||
|
@ -693,7 +587,7 @@ async def test_openai_responses_litellm_router_no_metadata():
|
|||
|
||||
# Assert metadata is not in the request
|
||||
assert (
|
||||
loaded_request_body["metadata"] == None
|
||||
"metadata" not in loaded_request_body
|
||||
), "metadata should not be in the request body"
|
||||
mock_post.assert_called_once()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue