diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 876d3b899..a1203e6f1 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -403,6 +403,27 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str): return azure_ad_token_access_token +def _check_dynamic_azure_params( + azure_client_params: dict, + azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]], +) -> bool: + """ + Returns True if user passed in client params != initialized azure client + + Currently only implemented for api version + """ + if azure_client is None: + return True + + dynamic_params = ["api_version"] + for k, v in azure_client_params.items(): + if k in dynamic_params and k == "api_version": + if v is not None and v != azure_client._custom_query["api-version"]: + return True + + return False + + class AzureChatCompletion(BaseLLM): def __init__(self) -> None: super().__init__() @@ -462,6 +483,28 @@ class AzureChatCompletion(BaseLLM): return azure_client + def make_sync_azure_openai_chat_completion_request( + self, + azure_client: AzureOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + ): + """ + Helper to: + - call chat.completions.create.with_raw_response when litellm.return_response_headers is True + - call chat.completions.create by default + """ + try: + raw_response = azure_client.chat.completions.with_raw_response.create( + **data, timeout=timeout + ) + + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + except Exception as e: + raise e + async def make_azure_openai_chat_completion_request( self, azure_client: AsyncAzureOpenAI, @@ -494,6 +537,7 @@ class AzureChatCompletion(BaseLLM): api_version: str, api_type: str, azure_ad_token: str, + dynamic_params: bool, print_verbose: Callable, timeout: Union[float, httpx.Timeout], logging_obj: LiteLLMLoggingObj, @@ -558,6 +602,7 @@ class AzureChatCompletion(BaseLLM): return self.async_streaming( logging_obj=logging_obj, api_base=api_base, + dynamic_params=dynamic_params, data=data, model=model, api_key=api_key, @@ -575,6 +620,7 @@ class AzureChatCompletion(BaseLLM): api_version=api_version, model=model, azure_ad_token=azure_ad_token, + dynamic_params=dynamic_params, timeout=timeout, client=client, logging_obj=logging_obj, @@ -583,6 +629,7 @@ class AzureChatCompletion(BaseLLM): return self.streaming( logging_obj=logging_obj, api_base=api_base, + dynamic_params=dynamic_params, data=data, model=model, api_key=api_key, @@ -628,7 +675,8 @@ class AzureChatCompletion(BaseLLM): if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token - if client is None: + + if client is None or dynamic_params: azure_client = AzureOpenAI(**azure_client_params) else: azure_client = client @@ -640,7 +688,9 @@ class AzureChatCompletion(BaseLLM): "api-version", api_version ) - response = azure_client.chat.completions.create(**data, timeout=timeout) # type: ignore + headers, response = self.make_sync_azure_openai_chat_completion_request( + azure_client=azure_client, data=data, timeout=timeout + ) stringified_response = response.model_dump() ## LOGGING logging_obj.post_call( @@ -674,6 +724,7 @@ class AzureChatCompletion(BaseLLM): api_base: str, data: dict, timeout: Any, + dynamic_params: bool, model_response: ModelResponse, logging_obj: LiteLLMLoggingObj, azure_ad_token: Optional[str] = None, @@ -707,15 +758,11 @@ class AzureChatCompletion(BaseLLM): azure_client_params["azure_ad_token"] = azure_ad_token # setting Azure client - if client is None: + if client is None or dynamic_params: azure_client = AsyncAzureOpenAI(**azure_client_params) else: azure_client = client - if api_version is not None and isinstance( - azure_client._custom_query, dict - ): - # set api_version to version passed by user - azure_client._custom_query.setdefault("api-version", api_version) + ## LOGGING logging_obj.pre_call( input=data["messages"], @@ -786,6 +833,7 @@ class AzureChatCompletion(BaseLLM): api_base: str, api_key: str, api_version: str, + dynamic_params: bool, data: dict, model: str, timeout: Any, @@ -815,13 +863,11 @@ class AzureChatCompletion(BaseLLM): if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token - if client is None: + + if client is None or dynamic_params: azure_client = AzureOpenAI(**azure_client_params) else: azure_client = client - if api_version is not None and isinstance(azure_client._custom_query, dict): - # set api_version to version passed by user - azure_client._custom_query.setdefault("api-version", api_version) ## LOGGING logging_obj.pre_call( input=data["messages"], @@ -833,7 +879,9 @@ class AzureChatCompletion(BaseLLM): "complete_input_dict": data, }, ) - response = azure_client.chat.completions.create(**data, timeout=timeout) + headers, response = self.make_sync_azure_openai_chat_completion_request( + azure_client=azure_client, data=data, timeout=timeout + ) streamwrapper = CustomStreamWrapper( completion_stream=response, model=model, @@ -848,6 +896,7 @@ class AzureChatCompletion(BaseLLM): api_base: str, api_key: str, api_version: str, + dynamic_params: bool, data: dict, model: str, timeout: Any, @@ -873,15 +922,10 @@ class AzureChatCompletion(BaseLLM): if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_client_params["azure_ad_token"] = azure_ad_token - if client is None: + if client is None or dynamic_params: azure_client = AsyncAzureOpenAI(**azure_client_params) else: azure_client = client - if api_version is not None and isinstance( - azure_client._custom_query, dict - ): - # set api_version to version passed by user - azure_client._custom_query.setdefault("api-version", api_version) ## LOGGING logging_obj.pre_call( input=data["messages"], diff --git a/litellm/main.py b/litellm/main.py index 24ae12631..12f8cceb5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -104,7 +104,7 @@ from .llms import ( ) from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic_text import AnthropicTextCompletion -from .llms.azure import AzureChatCompletion +from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params from .llms.azure_text import AzureTextCompletion from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM from .llms.custom_llm import CustomLLM, custom_chat_llm_router @@ -967,6 +967,17 @@ def completion( if custom_llm_provider == "azure": # azure configs + ## check dynamic params ## + dynamic_params = False + if client is not None and ( + isinstance(client, openai.AzureOpenAI) + or isinstance(client, openai.AsyncAzureOpenAI) + ): + dynamic_params = _check_dynamic_azure_params( + azure_client_params={"api_version": api_version}, + azure_client=client, + ) + api_type = get_secret("AZURE_API_TYPE") or "azure" api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") @@ -1006,6 +1017,7 @@ def completion( api_base=api_base, api_version=api_version, api_type=api_type, + dynamic_params=dynamic_params, azure_ad_token=azure_ad_token, model_response=model_response, print_verbose=print_verbose, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 62faafe0e..521a034d6 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,15 +1,6 @@ model_list: - model_name: gpt-3.5-turbo litellm_params: - model: gpt-3.5-turbo - -litellm_settings: - cache: True # set cache responses to True - cache_params: # set cache params for s3 - type: s3 - s3_bucket_name: litellm-proxy # AWS Bucket Name for S3 - s3_region_name: us-west-2 # AWS Region Name for S3 - s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3 - s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3 - - + model: azure/chatgpt-v-2 + api_key: os.environ/AZURE_API_KEY + api_base: os.environ/AZURE_API_BASE \ No newline at end of file diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index dd39efd6b..3c5d37ee6 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -136,6 +136,17 @@ async def add_litellm_data_to_request( "body": copy.copy(data), # use copy instead of deepcopy } + ## Dynamic api version (Azure OpenAI endpoints) ## + query_params = request.query_params + + # Convert query parameters to a dictionary (optional) + query_dict = dict(query_params) + + ## check for api version in query params + dynamic_api_version: Optional[str] = query_dict.get("api-version") + + data["api_version"] = dynamic_api_version + ## Forward any LLM API Provider specific headers in extra_headers add_provider_specific_headers_to_request(data=data, headers=_headers) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 0a7037428..31bfa9332 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -4338,3 +4338,89 @@ def test_moderation(): output = response.results[0] print(output) return output + + +@pytest.mark.parametrize("stream", [False, True]) +@pytest.mark.parametrize("sync_mode", [False, True]) +@pytest.mark.asyncio +async def test_dynamic_azure_params(stream, sync_mode): + """ + If dynamic params are given, which are different from the initialized client, use a new client + """ + from openai import AsyncAzureOpenAI, AzureOpenAI + + if sync_mode: + client = AzureOpenAI( + api_key="my-test-key", + base_url="my-test-base", + api_version="my-test-version", + ) + mock_client = MagicMock(return_value="Hello world!") + else: + client = AsyncAzureOpenAI( + api_key="my-test-key", + base_url="my-test-base", + api_version="my-test-version", + ) + mock_client = AsyncMock(return_value="Hello world!") + + ## CHECK IF CLIENT IS USED (NO PARAM CHANGE) + with patch.object( + client.chat.completions.with_raw_response, "create", new=mock_client + ) as mock_client: + try: + # client.chat.completions.with_raw_response.create = mock_client + if sync_mode: + _ = completion( + model="azure/chatgpt-v2", + messages=[{"role": "user", "content": "Hello world"}], + client=client, + stream=stream, + ) + else: + _ = await litellm.acompletion( + model="azure/chatgpt-v2", + messages=[{"role": "user", "content": "Hello world"}], + client=client, + stream=stream, + ) + except Exception: + pass + + mock_client.assert_called() + + ## recreate mock client + if sync_mode: + mock_client = MagicMock(return_value="Hello world!") + else: + mock_client = AsyncMock(return_value="Hello world!") + + ## CHECK IF NEW CLIENT IS USED (PARAM CHANGE) + with patch.object( + client.chat.completions.with_raw_response, "create", new=mock_client + ) as mock_client: + try: + if sync_mode: + _ = completion( + model="azure/chatgpt-v2", + messages=[{"role": "user", "content": "Hello world"}], + client=client, + api_version="my-new-version", + stream=stream, + ) + else: + _ = await litellm.acompletion( + model="azure/chatgpt-v2", + messages=[{"role": "user", "content": "Hello world"}], + client=client, + api_version="my-new-version", + stream=stream, + ) + except Exception: + pass + + try: + mock_client.assert_not_called() + except Exception as e: + traceback.print_stack() + raise e