mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(azure.py): support dynamic api versions
Closes https://github.com/BerriAI/litellm/issues/5228
This commit is contained in:
parent
417547b6f9
commit
49416e121c
5 changed files with 176 additions and 32 deletions
|
@ -403,6 +403,27 @@ def get_azure_ad_token_from_oidc(azure_ad_token: str):
|
||||||
return azure_ad_token_access_token
|
return azure_ad_token_access_token
|
||||||
|
|
||||||
|
|
||||||
|
def _check_dynamic_azure_params(
|
||||||
|
azure_client_params: dict,
|
||||||
|
azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]],
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if user passed in client params != initialized azure client
|
||||||
|
|
||||||
|
Currently only implemented for api version
|
||||||
|
"""
|
||||||
|
if azure_client is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
dynamic_params = ["api_version"]
|
||||||
|
for k, v in azure_client_params.items():
|
||||||
|
if k in dynamic_params and k == "api_version":
|
||||||
|
if v is not None and v != azure_client._custom_query["api-version"]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class AzureChatCompletion(BaseLLM):
|
class AzureChatCompletion(BaseLLM):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -462,6 +483,28 @@ class AzureChatCompletion(BaseLLM):
|
||||||
|
|
||||||
return azure_client
|
return azure_client
|
||||||
|
|
||||||
|
def make_sync_azure_openai_chat_completion_request(
|
||||||
|
self,
|
||||||
|
azure_client: AzureOpenAI,
|
||||||
|
data: dict,
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to:
|
||||||
|
- call chat.completions.create.with_raw_response when litellm.return_response_headers is True
|
||||||
|
- call chat.completions.create by default
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raw_response = azure_client.chat.completions.with_raw_response.create(
|
||||||
|
**data, timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = dict(raw_response.headers)
|
||||||
|
response = raw_response.parse()
|
||||||
|
return headers, response
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
async def make_azure_openai_chat_completion_request(
|
async def make_azure_openai_chat_completion_request(
|
||||||
self,
|
self,
|
||||||
azure_client: AsyncAzureOpenAI,
|
azure_client: AsyncAzureOpenAI,
|
||||||
|
@ -494,6 +537,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_version: str,
|
api_version: str,
|
||||||
api_type: str,
|
api_type: str,
|
||||||
azure_ad_token: str,
|
azure_ad_token: str,
|
||||||
|
dynamic_params: bool,
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
@ -558,6 +602,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
return self.async_streaming(
|
return self.async_streaming(
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
dynamic_params=dynamic_params,
|
||||||
data=data,
|
data=data,
|
||||||
model=model,
|
model=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -575,6 +620,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
model=model,
|
model=model,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
|
dynamic_params=dynamic_params,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
|
@ -583,6 +629,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
return self.streaming(
|
return self.streaming(
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
dynamic_params=dynamic_params,
|
||||||
data=data,
|
data=data,
|
||||||
model=model,
|
model=model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
@ -628,7 +675,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if client is None:
|
|
||||||
|
if client is None or dynamic_params:
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
azure_client = AzureOpenAI(**azure_client_params)
|
||||||
else:
|
else:
|
||||||
azure_client = client
|
azure_client = client
|
||||||
|
@ -640,7 +688,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"api-version", api_version
|
"api-version", api_version
|
||||||
)
|
)
|
||||||
|
|
||||||
response = azure_client.chat.completions.create(**data, timeout=timeout) # type: ignore
|
headers, response = self.make_sync_azure_openai_chat_completion_request(
|
||||||
|
azure_client=azure_client, data=data, timeout=timeout
|
||||||
|
)
|
||||||
stringified_response = response.model_dump()
|
stringified_response = response.model_dump()
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -674,6 +724,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_base: str,
|
api_base: str,
|
||||||
data: dict,
|
data: dict,
|
||||||
timeout: Any,
|
timeout: Any,
|
||||||
|
dynamic_params: bool,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
|
@ -707,15 +758,11 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
# setting Azure client
|
# setting Azure client
|
||||||
if client is None:
|
if client is None or dynamic_params:
|
||||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
else:
|
else:
|
||||||
azure_client = client
|
azure_client = client
|
||||||
if api_version is not None and isinstance(
|
|
||||||
azure_client._custom_query, dict
|
|
||||||
):
|
|
||||||
# set api_version to version passed by user
|
|
||||||
azure_client._custom_query.setdefault("api-version", api_version)
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["messages"],
|
input=data["messages"],
|
||||||
|
@ -786,6 +833,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_base: str,
|
api_base: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_version: str,
|
api_version: str,
|
||||||
|
dynamic_params: bool,
|
||||||
data: dict,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
timeout: Any,
|
timeout: Any,
|
||||||
|
@ -815,13 +863,11 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if client is None:
|
|
||||||
|
if client is None or dynamic_params:
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
azure_client = AzureOpenAI(**azure_client_params)
|
||||||
else:
|
else:
|
||||||
azure_client = client
|
azure_client = client
|
||||||
if api_version is not None and isinstance(azure_client._custom_query, dict):
|
|
||||||
# set api_version to version passed by user
|
|
||||||
azure_client._custom_query.setdefault("api-version", api_version)
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["messages"],
|
input=data["messages"],
|
||||||
|
@ -833,7 +879,9 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response = azure_client.chat.completions.create(**data, timeout=timeout)
|
headers, response = self.make_sync_azure_openai_chat_completion_request(
|
||||||
|
azure_client=azure_client, data=data, timeout=timeout
|
||||||
|
)
|
||||||
streamwrapper = CustomStreamWrapper(
|
streamwrapper = CustomStreamWrapper(
|
||||||
completion_stream=response,
|
completion_stream=response,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -848,6 +896,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
api_base: str,
|
api_base: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_version: str,
|
api_version: str,
|
||||||
|
dynamic_params: bool,
|
||||||
data: dict,
|
data: dict,
|
||||||
model: str,
|
model: str,
|
||||||
timeout: Any,
|
timeout: Any,
|
||||||
|
@ -873,15 +922,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
if azure_ad_token.startswith("oidc/"):
|
if azure_ad_token.startswith("oidc/"):
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if client is None:
|
if client is None or dynamic_params:
|
||||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
else:
|
else:
|
||||||
azure_client = client
|
azure_client = client
|
||||||
if api_version is not None and isinstance(
|
|
||||||
azure_client._custom_query, dict
|
|
||||||
):
|
|
||||||
# set api_version to version passed by user
|
|
||||||
azure_client._custom_query.setdefault("api-version", api_version)
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["messages"],
|
input=data["messages"],
|
||||||
|
|
|
@ -104,7 +104,7 @@ from .llms import (
|
||||||
)
|
)
|
||||||
from .llms.anthropic import AnthropicChatCompletion
|
from .llms.anthropic import AnthropicChatCompletion
|
||||||
from .llms.anthropic_text import AnthropicTextCompletion
|
from .llms.anthropic_text import AnthropicTextCompletion
|
||||||
from .llms.azure import AzureChatCompletion
|
from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params
|
||||||
from .llms.azure_text import AzureTextCompletion
|
from .llms.azure_text import AzureTextCompletion
|
||||||
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
|
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
|
||||||
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||||
|
@ -967,6 +967,17 @@ def completion(
|
||||||
|
|
||||||
if custom_llm_provider == "azure":
|
if custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
|
## check dynamic params ##
|
||||||
|
dynamic_params = False
|
||||||
|
if client is not None and (
|
||||||
|
isinstance(client, openai.AzureOpenAI)
|
||||||
|
or isinstance(client, openai.AsyncAzureOpenAI)
|
||||||
|
):
|
||||||
|
dynamic_params = _check_dynamic_azure_params(
|
||||||
|
azure_client_params={"api_version": api_version},
|
||||||
|
azure_client=client,
|
||||||
|
)
|
||||||
|
|
||||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||||
|
|
||||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||||
|
@ -1006,6 +1017,7 @@ def completion(
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
api_type=api_type,
|
api_type=api_type,
|
||||||
|
dynamic_params=dynamic_params,
|
||||||
azure_ad_token=azure_ad_token,
|
azure_ad_token=azure_ad_token,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
|
|
|
@ -1,15 +1,6 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: gpt-3.5-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: azure/chatgpt-v-2
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
litellm_settings:
|
api_base: os.environ/AZURE_API_BASE
|
||||||
cache: True # set cache responses to True
|
|
||||||
cache_params: # set cache params for s3
|
|
||||||
type: s3
|
|
||||||
s3_bucket_name: litellm-proxy # AWS Bucket Name for S3
|
|
||||||
s3_region_name: us-west-2 # AWS Region Name for S3
|
|
||||||
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
|
||||||
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
|
||||||
|
|
||||||
|
|
|
@ -136,6 +136,17 @@ async def add_litellm_data_to_request(
|
||||||
"body": copy.copy(data), # use copy instead of deepcopy
|
"body": copy.copy(data), # use copy instead of deepcopy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
## Dynamic api version (Azure OpenAI endpoints) ##
|
||||||
|
query_params = request.query_params
|
||||||
|
|
||||||
|
# Convert query parameters to a dictionary (optional)
|
||||||
|
query_dict = dict(query_params)
|
||||||
|
|
||||||
|
## check for api version in query params
|
||||||
|
dynamic_api_version: Optional[str] = query_dict.get("api-version")
|
||||||
|
|
||||||
|
data["api_version"] = dynamic_api_version
|
||||||
|
|
||||||
## Forward any LLM API Provider specific headers in extra_headers
|
## Forward any LLM API Provider specific headers in extra_headers
|
||||||
add_provider_specific_headers_to_request(data=data, headers=_headers)
|
add_provider_specific_headers_to_request(data=data, headers=_headers)
|
||||||
|
|
||||||
|
|
|
@ -4338,3 +4338,89 @@ def test_moderation():
|
||||||
output = response.results[0]
|
output = response.results[0]
|
||||||
print(output)
|
print(output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream", [False, True])
|
||||||
|
@pytest.mark.parametrize("sync_mode", [False, True])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dynamic_azure_params(stream, sync_mode):
|
||||||
|
"""
|
||||||
|
If dynamic params are given, which are different from the initialized client, use a new client
|
||||||
|
"""
|
||||||
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
client = AzureOpenAI(
|
||||||
|
api_key="my-test-key",
|
||||||
|
base_url="my-test-base",
|
||||||
|
api_version="my-test-version",
|
||||||
|
)
|
||||||
|
mock_client = MagicMock(return_value="Hello world!")
|
||||||
|
else:
|
||||||
|
client = AsyncAzureOpenAI(
|
||||||
|
api_key="my-test-key",
|
||||||
|
base_url="my-test-base",
|
||||||
|
api_version="my-test-version",
|
||||||
|
)
|
||||||
|
mock_client = AsyncMock(return_value="Hello world!")
|
||||||
|
|
||||||
|
## CHECK IF CLIENT IS USED (NO PARAM CHANGE)
|
||||||
|
with patch.object(
|
||||||
|
client.chat.completions.with_raw_response, "create", new=mock_client
|
||||||
|
) as mock_client:
|
||||||
|
try:
|
||||||
|
# client.chat.completions.with_raw_response.create = mock_client
|
||||||
|
if sync_mode:
|
||||||
|
_ = completion(
|
||||||
|
model="azure/chatgpt-v2",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
client=client,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_ = await litellm.acompletion(
|
||||||
|
model="azure/chatgpt-v2",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
client=client,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_client.assert_called()
|
||||||
|
|
||||||
|
## recreate mock client
|
||||||
|
if sync_mode:
|
||||||
|
mock_client = MagicMock(return_value="Hello world!")
|
||||||
|
else:
|
||||||
|
mock_client = AsyncMock(return_value="Hello world!")
|
||||||
|
|
||||||
|
## CHECK IF NEW CLIENT IS USED (PARAM CHANGE)
|
||||||
|
with patch.object(
|
||||||
|
client.chat.completions.with_raw_response, "create", new=mock_client
|
||||||
|
) as mock_client:
|
||||||
|
try:
|
||||||
|
if sync_mode:
|
||||||
|
_ = completion(
|
||||||
|
model="azure/chatgpt-v2",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
client=client,
|
||||||
|
api_version="my-new-version",
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_ = await litellm.acompletion(
|
||||||
|
model="azure/chatgpt-v2",
|
||||||
|
messages=[{"role": "user", "content": "Hello world"}],
|
||||||
|
client=client,
|
||||||
|
api_version="my-new-version",
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
mock_client.assert_not_called()
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_stack()
|
||||||
|
raise e
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue