diff --git a/litellm/litellm_core_utils/get_litellm_params.py b/litellm/litellm_core_utils/get_litellm_params.py index cf62375f33..fcf83d17a2 100644 --- a/litellm/litellm_core_utils/get_litellm_params.py +++ b/litellm/litellm_core_utils/get_litellm_params.py @@ -58,6 +58,7 @@ def get_litellm_params( async_call: Optional[bool] = None, ssl_verify: Optional[bool] = None, merge_reasoning_content_in_choices: Optional[bool] = None, + api_version: Optional[str] = None, **kwargs, ) -> dict: litellm_params = { @@ -99,5 +100,6 @@ def get_litellm_params( "async_call": async_call, "ssl_verify": ssl_verify, "merge_reasoning_content_in_choices": merge_reasoning_content_in_choices, + "api_version": api_version, } return litellm_params diff --git a/litellm/llms/azure_ai/chat/transformation.py b/litellm/llms/azure_ai/chat/transformation.py index 2ef5285ac6..154f345537 100644 --- a/litellm/llms/azure_ai/chat/transformation.py +++ b/litellm/llms/azure_ai/chat/transformation.py @@ -67,6 +67,7 @@ class AzureAIStudioConfig(OpenAIConfig): api_base: Optional[str], model: str, optional_params: dict, + litellm_params: dict, stream: Optional[bool] = None, ) -> str: """ @@ -92,12 +93,14 @@ class AzureAIStudioConfig(OpenAIConfig): original_url = httpx.URL(api_base) # Extract api_version or use default - api_version = cast(Optional[str], optional_params.get("api_version")) + api_version = cast(Optional[str], litellm_params.get("api_version")) - # Check if 'api-version' is already present - if "api-version" not in original_url.params and api_version: - # Add api_version to optional_params - original_url.params["api-version"] = api_version + # Create a new dictionary with existing params + query_params = dict(original_url.params) + + # Add api_version if needed + if "api-version" not in query_params and api_version: + query_params["api-version"] = api_version # Add the path to the base URL if "services.ai.azure.com" in api_base: @@ -109,8 +112,7 @@ class AzureAIStudioConfig(OpenAIConfig): api_base=api_base, ending_path="/chat/completions" ) - # Convert optional_params to query parameters - query_params = original_url.params + # Use the new query_params dictionary final_url = httpx.URL(new_url).copy_with(params=query_params) return str(final_url) diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py index 8327a10464..1b5a6bc58e 100644 --- a/litellm/llms/base_llm/chat/transformation.py +++ b/litellm/llms/base_llm/chat/transformation.py @@ -270,6 +270,7 @@ class BaseConfig(ABC): api_base: Optional[str], model: str, optional_params: dict, + litellm_params: dict, stream: Optional[bool] = None, ) -> str: """ diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 9d67fd1a85..f3600923c6 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -234,6 +234,7 @@ class BaseLLMHTTPHandler: model=model, optional_params=optional_params, stream=stream, + litellm_params=litellm_params, ) data = provider_config.transform_request( diff --git a/litellm/main.py b/litellm/main.py index b90030a6bb..a6a1a7c7d1 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1162,6 +1162,7 @@ def completion( # type: ignore # noqa: PLR0915 merge_reasoning_content_in_choices=kwargs.get( "merge_reasoning_content_in_choices", None ), + api_version=api_version, ) logging.update_environment_variables( model=model, diff --git a/tests/llm_translation/test_azure_ai.py b/tests/llm_translation/test_azure_ai.py index c22c9edafa..6d4284cd86 100644 --- a/tests/llm_translation/test_azure_ai.py +++ b/tests/llm_translation/test_azure_ai.py @@ -159,6 +159,32 @@ def test_azure_ai_services_handler(api_base, expected_url): assert mock_client.call_args.kwargs["url"] == expected_url +def test_azure_ai_services_with_api_version(): + from litellm.llms.custom_httpx.http_handler import HTTPHandler, AsyncHTTPHandler + + client = HTTPHandler() + + with patch.object(client, "post") as mock_client: + try: + response = litellm.completion( + model="azure_ai/Meta-Llama-3.1-70B-Instruct", + messages=[{"role": "user", "content": "Hello, how are you?"}], + api_key="my-fake-api-key", + api_version="2024-05-01-preview", + api_base="https://litellm8397336933.services.ai.azure.com/models", + client=client, + ) + except Exception as e: + print(f"Error: {e}") + + mock_client.assert_called_once() + assert mock_client.call_args.kwargs["headers"]["api-key"] == "my-fake-api-key" + assert ( + mock_client.call_args.kwargs["url"] + == "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" + ) + + def test_completion_azure_ai_command_r(): try: import os