forked from phoenix/litellm-mirror
feat(azure.py): support dynamic api versions
Closes https://github.com/BerriAI/litellm/issues/5228
This commit is contained in:
parent
417547b6f9
commit
49416e121c
5 changed files with 176 additions and 32 deletions
|
@ -403,6 +403,27 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
|||
return azure_ad_token_access_token
|
||||
|
||||
|
||||
def _check_dynamic_azure_params(
|
||||
azure_client_params: dict,
|
||||
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]],
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if user passed in client params != initialized azure client
|
||||
|
||||
Currently only implemented for api version
|
||||
"""
|
||||
if azure_client is None:
|
||||
return True
|
||||
|
||||
dynamic_params = ["api_version"]
|
||||
for k, v in azure_client_params.items():
|
||||
if k in dynamic_params and k == "api_version":
|
||||
if v is not None and v != azure_client._custom_query["api-version"]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class AzureChatCompletion(BaseLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
@ -462,6 +483,28 @@ class AzureChatCompletion(BaseLLM):
|
|||
|
||||
return azure_client
|
||||
|
||||
def make_sync_azure_openai_chat_completion_request(
|
||||
self,
|
||||
azure_client: AzureOpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
):
|
||||
"""
|
||||
Helper to:
|
||||
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||
- call chat.completions.create by default
|
||||
"""
|
||||
try:
|
||||
raw_response = azure_client.chat.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
return headers, response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def make_azure_openai_chat_completion_request(
|
||||
self,
|
||||
azure_client: AsyncAzureOpenAI,
|
||||
|
@ -494,6 +537,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_version: str,
|
||||
api_type: str,
|
||||
azure_ad_token: str,
|
||||
dynamic_params: bool,
|
||||
print_verbose: Callable,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
|
@ -558,6 +602,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
return self.async_streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
dynamic_params=dynamic_params,
|
||||
data=data,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
|
@ -575,6 +620,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_version=api_version,
|
||||
model=model,
|
||||
azure_ad_token=azure_ad_token,
|
||||
dynamic_params=dynamic_params,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
logging_obj=logging_obj,
|
||||
|
@ -583,6 +629,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
dynamic_params=dynamic_params,
|
||||
data=data,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
|
@ -628,7 +675,8 @@ class AzureChatCompletion(BaseLLM):
|
|||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
|
@ -640,7 +688,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
"api-version", api_version
|
||||
)
|
||||
|
||||
response = azure_client.chat.completions.create(**data, timeout=timeout) # type: ignore
|
||||
headers, response = self.make_sync_azure_openai_chat_completion_request(
|
||||
azure_client=azure_client, data=data, timeout=timeout
|
||||
)
|
||||
stringified_response = response.model_dump()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -674,6 +724,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_base: str,
|
||||
data: dict,
|
||||
timeout: Any,
|
||||
dynamic_params: bool,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
|
@ -707,15 +758,11 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
# setting Azure client
|
||||
if client is None:
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
if api_version is not None and isinstance(
|
||||
azure_client._custom_query, dict
|
||||
):
|
||||
# set api_version to version passed by user
|
||||
azure_client._custom_query.setdefault("api-version", api_version)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["messages"],
|
||||
|
@ -786,6 +833,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_base: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
dynamic_params: bool,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
|
@ -815,13 +863,11 @@ class AzureChatCompletion(BaseLLM):
|
|||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
if api_version is not None and isinstance(azure_client._custom_query, dict):
|
||||
# set api_version to version passed by user
|
||||
azure_client._custom_query.setdefault("api-version", api_version)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["messages"],
|
||||
|
@ -833,7 +879,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
response = azure_client.chat.completions.create(**data, timeout=timeout)
|
||||
headers, response = self.make_sync_azure_openai_chat_completion_request(
|
||||
azure_client=azure_client, data=data, timeout=timeout
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
|
@ -848,6 +896,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
api_base: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
dynamic_params: bool,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
|
@ -873,15 +922,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if client is None:
|
||||
if client is None or dynamic_params:
|
||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
azure_client = client
|
||||
if api_version is not None and isinstance(
|
||||
azure_client._custom_query, dict
|
||||
):
|
||||
# set api_version to version passed by user
|
||||
azure_client._custom_query.setdefault("api-version", api_version)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["messages"],
|
||||
|
|
|
@ -104,7 +104,7 @@ from .llms import (
|
|||
)
|
||||
from .llms.anthropic import AnthropicChatCompletion
|
||||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.azure import AzureChatCompletion
|
||||
from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||
from .llms.azure_text import AzureTextCompletion
|
||||
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
|
||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||
|
@ -967,6 +967,17 @@ def completion(
|
|||
|
||||
if custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
## check dynamic params ##
|
||||
dynamic_params = False
|
||||
if client is not None and (
|
||||
isinstance(client, openai.AzureOpenAI)
|
||||
or isinstance(client, openai.AsyncAzureOpenAI)
|
||||
):
|
||||
dynamic_params = _check_dynamic_azure_params(
|
||||
azure_client_params={"api_version": api_version},
|
||||
azure_client=client,
|
||||
)
|
||||
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
|
@ -1006,6 +1017,7 @@ def completion(
|
|||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_type=api_type,
|
||||
dynamic_params=dynamic_params,
|
||||
azure_ad_token=azure_ad_token,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
|
|
|
@ -1,15 +1,6 @@
|
|||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
|
||||
litellm_settings:
|
||||
cache: True # set cache responses to True
|
||||
cache_params: # set cache params for s3
|
||||
type: s3
|
||||
s3_bucket_name: litellm-proxy # AWS Bucket Name for S3
|
||||
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
||||
|
||||
|
||||
model: azure/chatgpt-v-2
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
|
@ -136,6 +136,17 @@ async def add_litellm_data_to_request(
|
|||
"body": copy.copy(data), # use copy instead of deepcopy
|
||||
}
|
||||
|
||||
## Dynamic api version (Azure OpenAI endpoints) ##
|
||||
query_params = request.query_params
|
||||
|
||||
# Convert query parameters to a dictionary (optional)
|
||||
query_dict = dict(query_params)
|
||||
|
||||
## check for api version in query params
|
||||
dynamic_api_version: Optional[str] = query_dict.get("api-version")
|
||||
|
||||
data["api_version"] = dynamic_api_version
|
||||
|
||||
## Forward any LLM API Provider specific headers in extra_headers
|
||||
add_provider_specific_headers_to_request(data=data, headers=_headers)
|
||||
|
||||
|
|
|
@ -4338,3 +4338,89 @@ def test_moderation():
|
|||
output = response.results[0]
|
||||
print(output)
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [False, True])
|
||||
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_azure_params(stream, sync_mode):
|
||||
"""
|
||||
If dynamic params are given, which are different from the initialized client, use a new client
|
||||
"""
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
if sync_mode:
|
||||
client = AzureOpenAI(
|
||||
api_key="my-test-key",
|
||||
base_url="my-test-base",
|
||||
api_version="my-test-version",
|
||||
)
|
||||
mock_client = MagicMock(return_value="Hello world!")
|
||||
else:
|
||||
client = AsyncAzureOpenAI(
|
||||
api_key="my-test-key",
|
||||
base_url="my-test-base",
|
||||
api_version="my-test-version",
|
||||
)
|
||||
mock_client = AsyncMock(return_value="Hello world!")
|
||||
|
||||
## CHECK IF CLIENT IS USED (NO PARAM CHANGE)
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create", new=mock_client
|
||||
) as mock_client:
|
||||
try:
|
||||
# client.chat.completions.with_raw_response.create = mock_client
|
||||
if sync_mode:
|
||||
_ = completion(
|
||||
model="azure/chatgpt-v2",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
client=client,
|
||||
stream=stream,
|
||||
)
|
||||
else:
|
||||
_ = await litellm.acompletion(
|
||||
model="azure/chatgpt-v2",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
client=client,
|
||||
stream=stream,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
mock_client.assert_called()
|
||||
|
||||
## recreate mock client
|
||||
if sync_mode:
|
||||
mock_client = MagicMock(return_value="Hello world!")
|
||||
else:
|
||||
mock_client = AsyncMock(return_value="Hello world!")
|
||||
|
||||
## CHECK IF NEW CLIENT IS USED (PARAM CHANGE)
|
||||
with patch.object(
|
||||
client.chat.completions.with_raw_response, "create", new=mock_client
|
||||
) as mock_client:
|
||||
try:
|
||||
if sync_mode:
|
||||
_ = completion(
|
||||
model="azure/chatgpt-v2",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
client=client,
|
||||
api_version="my-new-version",
|
||||
stream=stream,
|
||||
)
|
||||
else:
|
||||
_ = await litellm.acompletion(
|
||||
model="azure/chatgpt-v2",
|
||||
messages=[{"role": "user", "content": "Hello world"}],
|
||||
client=client,
|
||||
api_version="my-new-version",
|
||||
stream=stream,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
mock_client.assert_not_called()
|
||||
except Exception as e:
|
||||
traceback.print_stack()
|
||||
raise e
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue