feat(azure.py): support dynamic api versions

Closes https://github.com/BerriAI/litellm/issues/5228
This commit is contained in:
Krrish Dholakia 2024-08-19 12:17:43 -07:00
parent 417547b6f9
commit 49416e121c
5 changed files with 176 additions and 32 deletions

View file

@ -403,6 +403,27 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
return azure_ad_token_access_token return azure_ad_token_access_token
def _check_dynamic_azure_params(
azure_client_params: dict,
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]],
) -> bool:
"""
Returns True if user passed in client params != initialized azure client
Currently only implemented for api version
"""
if azure_client is None:
return True
dynamic_params = ["api_version"]
for k, v in azure_client_params.items():
if k in dynamic_params and k == "api_version":
if v is not None and v != azure_client._custom_query["api-version"]:
return True
return False
class AzureChatCompletion(BaseLLM): class AzureChatCompletion(BaseLLM):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -462,6 +483,28 @@ class AzureChatCompletion(BaseLLM):
return azure_client return azure_client
def make_sync_azure_openai_chat_completion_request(
self,
azure_client: AzureOpenAI,
data: dict,
timeout: Union[float, httpx.Timeout],
):
"""
Helper to:
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
- call chat.completions.create by default
"""
try:
raw_response = azure_client.chat.completions.with_raw_response.create(
**data, timeout=timeout
)
headers = dict(raw_response.headers)
response = raw_response.parse()
return headers, response
except Exception as e:
raise e
async def make_azure_openai_chat_completion_request( async def make_azure_openai_chat_completion_request(
self, self,
azure_client: AsyncAzureOpenAI, azure_client: AsyncAzureOpenAI,
@ -494,6 +537,7 @@ class AzureChatCompletion(BaseLLM):
api_version: str, api_version: str,
api_type: str, api_type: str,
azure_ad_token: str, azure_ad_token: str,
dynamic_params: bool,
print_verbose: Callable, print_verbose: Callable,
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
@ -558,6 +602,7 @@ class AzureChatCompletion(BaseLLM):
return self.async_streaming( return self.async_streaming(
logging_obj=logging_obj, logging_obj=logging_obj,
api_base=api_base, api_base=api_base,
dynamic_params=dynamic_params,
data=data, data=data,
model=model, model=model,
api_key=api_key, api_key=api_key,
@ -575,6 +620,7 @@ class AzureChatCompletion(BaseLLM):
api_version=api_version, api_version=api_version,
model=model, model=model,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
dynamic_params=dynamic_params,
timeout=timeout, timeout=timeout,
client=client, client=client,
logging_obj=logging_obj, logging_obj=logging_obj,
@ -583,6 +629,7 @@ class AzureChatCompletion(BaseLLM):
return self.streaming( return self.streaming(
logging_obj=logging_obj, logging_obj=logging_obj,
api_base=api_base, api_base=api_base,
dynamic_params=dynamic_params,
data=data, data=data,
model=model, model=model,
api_key=api_key, api_key=api_key,
@ -628,7 +675,8 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
if client is None or dynamic_params:
azure_client = AzureOpenAI(**azure_client_params) azure_client = AzureOpenAI(**azure_client_params)
else: else:
azure_client = client azure_client = client
@ -640,7 +688,9 @@ class AzureChatCompletion(BaseLLM):
"api-version", api_version "api-version", api_version
) )
response = azure_client.chat.completions.create(**data, timeout=timeout) # type: ignore headers, response = self.make_sync_azure_openai_chat_completion_request(
azure_client=azure_client, data=data, timeout=timeout
)
stringified_response = response.model_dump() stringified_response = response.model_dump()
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
@ -674,6 +724,7 @@ class AzureChatCompletion(BaseLLM):
api_base: str, api_base: str,
data: dict, data: dict,
timeout: Any, timeout: Any,
dynamic_params: bool,
model_response: ModelResponse, model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj, logging_obj: LiteLLMLoggingObj,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
@ -707,15 +758,11 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
# setting Azure client # setting Azure client
if client is None: if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params) azure_client = AsyncAzureOpenAI(**azure_client_params)
else: else:
azure_client = client azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data["messages"], input=data["messages"],
@ -786,6 +833,7 @@ class AzureChatCompletion(BaseLLM):
api_base: str, api_base: str,
api_key: str, api_key: str,
api_version: str, api_version: str,
dynamic_params: bool,
data: dict, data: dict,
model: str, model: str,
timeout: Any, timeout: Any,
@ -815,13 +863,11 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None:
if client is None or dynamic_params:
azure_client = AzureOpenAI(**azure_client_params) azure_client = AzureOpenAI(**azure_client_params)
else: else:
azure_client = client azure_client = client
if api_version is not None and isinstance(azure_client._custom_query, dict):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data["messages"], input=data["messages"],
@ -833,7 +879,9 @@ class AzureChatCompletion(BaseLLM):
"complete_input_dict": data, "complete_input_dict": data,
}, },
) )
response = azure_client.chat.completions.create(**data, timeout=timeout) headers, response = self.make_sync_azure_openai_chat_completion_request(
azure_client=azure_client, data=data, timeout=timeout
)
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(
completion_stream=response, completion_stream=response,
model=model, model=model,
@ -848,6 +896,7 @@ class AzureChatCompletion(BaseLLM):
api_base: str, api_base: str,
api_key: str, api_key: str,
api_version: str, api_version: str,
dynamic_params: bool,
data: dict, data: dict,
model: str, model: str,
timeout: Any, timeout: Any,
@ -873,15 +922,10 @@ class AzureChatCompletion(BaseLLM):
if azure_ad_token.startswith("oidc/"): if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if client is None: if client is None or dynamic_params:
azure_client = AsyncAzureOpenAI(**azure_client_params) azure_client = AsyncAzureOpenAI(**azure_client_params)
else: else:
azure_client = client azure_client = client
if api_version is not None and isinstance(
azure_client._custom_query, dict
):
# set api_version to version passed by user
azure_client._custom_query.setdefault("api-version", api_version)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data["messages"], input=data["messages"],

View file

@ -104,7 +104,7 @@ from .llms import (
) )
from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion from .llms.anthropic_text import AnthropicTextCompletion
from .llms.azure import AzureChatCompletion from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.azure_text import AzureTextCompletion from .llms.azure_text import AzureTextCompletion
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
from .llms.custom_llm import CustomLLM, custom_chat_llm_router from .llms.custom_llm import CustomLLM, custom_chat_llm_router
@ -967,6 +967,17 @@ def completion(
if custom_llm_provider == "azure": if custom_llm_provider == "azure":
# azure configs # azure configs
## check dynamic params ##
dynamic_params = False
if client is not None and (
isinstance(client, openai.AzureOpenAI)
or isinstance(client, openai.AsyncAzureOpenAI)
):
dynamic_params = _check_dynamic_azure_params(
azure_client_params={"api_version": api_version},
azure_client=client,
)
api_type = get_secret("AZURE_API_TYPE") or "azure" api_type = get_secret("AZURE_API_TYPE") or "azure"
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
@ -1006,6 +1017,7 @@ def completion(
api_base=api_base, api_base=api_base,
api_version=api_version, api_version=api_version,
api_type=api_type, api_type=api_type,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token, azure_ad_token=azure_ad_token,
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,

View file

@ -1,15 +1,6 @@
model_list: model_list:
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: gpt-3.5-turbo model: azure/chatgpt-v-2
api_key: os.environ/AZURE_API_KEY
litellm_settings: api_base: os.environ/AZURE_API_BASE
cache: True # set cache responses to True
cache_params: # set cache params for s3
type: s3
s3_bucket_name: litellm-proxy # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3

View file

@ -136,6 +136,17 @@ async def add_litellm_data_to_request(
"body": copy.copy(data), # use copy instead of deepcopy "body": copy.copy(data), # use copy instead of deepcopy
} }
## Dynamic api version (Azure OpenAI endpoints) ##
query_params = request.query_params
# Convert query parameters to a dictionary (optional)
query_dict = dict(query_params)
## check for api version in query params
dynamic_api_version: Optional[str] = query_dict.get("api-version")
data["api_version"] = dynamic_api_version
## Forward any LLM API Provider specific headers in extra_headers ## Forward any LLM API Provider specific headers in extra_headers
add_provider_specific_headers_to_request(data=data, headers=_headers) add_provider_specific_headers_to_request(data=data, headers=_headers)

View file

@ -4338,3 +4338,89 @@ def test_moderation():
output = response.results[0] output = response.results[0]
print(output) print(output)
return output return output
@pytest.mark.parametrize("stream", [False, True])
@pytest.mark.parametrize("sync_mode", [False, True])
@pytest.mark.asyncio
async def test_dynamic_azure_params(stream, sync_mode):
"""
If dynamic params are given, which are different from the initialized client, use a new client
"""
from openai import AsyncAzureOpenAI, AzureOpenAI
if sync_mode:
client = AzureOpenAI(
api_key="my-test-key",
base_url="my-test-base",
api_version="my-test-version",
)
mock_client = MagicMock(return_value="Hello world!")
else:
client = AsyncAzureOpenAI(
api_key="my-test-key",
base_url="my-test-base",
api_version="my-test-version",
)
mock_client = AsyncMock(return_value="Hello world!")
## CHECK IF CLIENT IS USED (NO PARAM CHANGE)
with patch.object(
client.chat.completions.with_raw_response, "create", new=mock_client
) as mock_client:
try:
# client.chat.completions.with_raw_response.create = mock_client
if sync_mode:
_ = completion(
model="azure/chatgpt-v2",
messages=[{"role": "user", "content": "Hello world"}],
client=client,
stream=stream,
)
else:
_ = await litellm.acompletion(
model="azure/chatgpt-v2",
messages=[{"role": "user", "content": "Hello world"}],
client=client,
stream=stream,
)
except Exception:
pass
mock_client.assert_called()
## recreate mock client
if sync_mode:
mock_client = MagicMock(return_value="Hello world!")
else:
mock_client = AsyncMock(return_value="Hello world!")
## CHECK IF NEW CLIENT IS USED (PARAM CHANGE)
with patch.object(
client.chat.completions.with_raw_response, "create", new=mock_client
) as mock_client:
try:
if sync_mode:
_ = completion(
model="azure/chatgpt-v2",
messages=[{"role": "user", "content": "Hello world"}],
client=client,
api_version="my-new-version",
stream=stream,
)
else:
_ = await litellm.acompletion(
model="azure/chatgpt-v2",
messages=[{"role": "user", "content": "Hello world"}],
client=client,
api_version="my-new-version",
stream=stream,
)
except Exception:
pass
try:
mock_client.assert_not_called()
except Exception as e:
traceback.print_stack()
raise e