mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix - correctly re-use azure openai client
This commit is contained in:
parent
d991b3c398
commit
8b54873e9f
2 changed files with 203 additions and 34 deletions
|
@ -149,8 +149,8 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
azure_ad_token: Optional[str],
|
azure_ad_token: Optional[str],
|
||||||
azure_ad_token_provider: Optional[Callable],
|
azure_ad_token_provider: Optional[Callable],
|
||||||
model: str,
|
model: str,
|
||||||
max_retries: int,
|
max_retries: Optional[int],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
client: Optional[Any],
|
client: Optional[Any],
|
||||||
client_type: Literal["sync", "async"],
|
client_type: Literal["sync", "async"],
|
||||||
litellm_params: Optional[dict] = None,
|
litellm_params: Optional[dict] = None,
|
||||||
|
@ -366,6 +366,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -387,21 +388,19 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
status_code=422, message="max retries must be an int"
|
status_code=422, message="max retries must be an int"
|
||||||
)
|
)
|
||||||
# init AzureOpenAI Client
|
# init AzureOpenAI Client
|
||||||
if (
|
azure_client = self._get_azure_openai_client(
|
||||||
client is None
|
api_version=api_version,
|
||||||
or not isinstance(client, AzureOpenAI)
|
api_base=api_base,
|
||||||
or dynamic_params
|
api_key=api_key,
|
||||||
):
|
azure_ad_token=azure_ad_token,
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
else:
|
model=model,
|
||||||
azure_client = client
|
max_retries=max_retries,
|
||||||
if api_version is not None and isinstance(
|
timeout=timeout,
|
||||||
azure_client._custom_query, dict
|
client=client,
|
||||||
):
|
client_type="sync",
|
||||||
# set api_version to version passed by user
|
litellm_params=litellm_params,
|
||||||
azure_client._custom_query.setdefault(
|
)
|
||||||
"api-version", api_version
|
|
||||||
)
|
|
||||||
if not isinstance(azure_client, AzureOpenAI):
|
if not isinstance(azure_client, AzureOpenAI):
|
||||||
raise AzureOpenAIError(
|
raise AzureOpenAIError(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
|
@ -567,6 +566,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
azure_ad_token_provider: Optional[Callable] = None,
|
azure_ad_token_provider: Optional[Callable] = None,
|
||||||
client=None,
|
client=None,
|
||||||
|
litellm_params: Optional[dict] = {},
|
||||||
):
|
):
|
||||||
# init AzureOpenAI Client
|
# init AzureOpenAI Client
|
||||||
azure_client_params = {
|
azure_client_params = {
|
||||||
|
@ -589,10 +589,24 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
elif azure_ad_token_provider is not None:
|
elif azure_ad_token_provider is not None:
|
||||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||||
|
|
||||||
if client is None or dynamic_params:
|
azure_client = self._get_azure_openai_client(
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
api_version=api_version,
|
||||||
else:
|
api_base=api_base,
|
||||||
azure_client = client
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
|
model=model,
|
||||||
|
max_retries=max_retries,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
client_type="sync",
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
if not isinstance(azure_client, AzureOpenAI):
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=500,
|
||||||
|
message="azure_client is not an instance of AzureOpenAI",
|
||||||
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["messages"],
|
input=data["messages"],
|
||||||
|
@ -638,10 +652,22 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
litellm_params: Optional[dict] = {},
|
litellm_params: Optional[dict] = {},
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if client is None or dynamic_params:
|
azure_client = self._get_azure_openai_client(
|
||||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
api_version=api_version,
|
||||||
else:
|
api_base=api_base,
|
||||||
azure_client = client
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
|
model=model,
|
||||||
|
max_retries=max_retries,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
client_type="async",
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||||
|
raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["messages"],
|
input=data["messages"],
|
||||||
|
@ -692,6 +718,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
|
|
||||||
async def aembedding(
|
async def aembedding(
|
||||||
self,
|
self,
|
||||||
|
model: str,
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: EmbeddingResponse,
|
model_response: EmbeddingResponse,
|
||||||
azure_client_params: dict,
|
azure_client_params: dict,
|
||||||
|
@ -699,15 +726,33 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
logging_obj: LiteLLMLoggingObj,
|
logging_obj: LiteLLMLoggingObj,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
client: Optional[AsyncAzureOpenAI] = None,
|
client: Optional[AsyncAzureOpenAI] = None,
|
||||||
timeout=None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
max_retries: Optional[int] = None,
|
||||||
|
api_version: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
azure_ad_token: Optional[str] = None,
|
||||||
|
azure_ad_token_provider: Optional[Callable] = None,
|
||||||
|
litellm_params: Optional[dict] = {},
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if client is None:
|
openai_aclient = self._get_azure_openai_client(
|
||||||
openai_aclient = AsyncAzureOpenAI(**azure_client_params)
|
api_version=api_version,
|
||||||
else:
|
api_base=api_base,
|
||||||
openai_aclient = client
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
|
model=model,
|
||||||
|
max_retries=max_retries,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
client_type="async",
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
if not isinstance(openai_aclient, AsyncAzureOpenAI):
|
||||||
|
raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
|
||||||
|
|
||||||
raw_response = await openai_aclient.embeddings.with_raw_response.create(
|
raw_response = await openai_aclient.embeddings.with_raw_response.create(
|
||||||
**data, timeout=timeout
|
**data, timeout=timeout
|
||||||
)
|
)
|
||||||
|
@ -799,11 +844,27 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
azure_client_params=azure_client_params,
|
azure_client_params=azure_client_params,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
client=client,
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
)
|
)
|
||||||
if client is None:
|
azure_client = self._get_azure_openai_client(
|
||||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
api_version=api_version,
|
||||||
else:
|
api_base=api_base,
|
||||||
azure_client = client
|
api_key=api_key,
|
||||||
|
azure_ad_token=azure_ad_token,
|
||||||
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
|
model=model,
|
||||||
|
max_retries=max_retries,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
client_type="sync",
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
if not isinstance(azure_client, AzureOpenAI):
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=500,
|
||||||
|
message="azure_client is not an instance of AzureOpenAI",
|
||||||
|
)
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
raw_response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore
|
raw_response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore
|
||||||
headers = dict(raw_response.headers)
|
headers = dict(raw_response.headers)
|
||||||
|
|
108
tests/code_coverage_tests/azure_client_usage_test.py
Normal file
108
tests/code_coverage_tests/azure_client_usage_test.py
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def find_azure_files(base_dir):
|
||||||
|
"""
|
||||||
|
Find all Python files in the Azure directory.
|
||||||
|
"""
|
||||||
|
azure_files = []
|
||||||
|
for root, _, files in os.walk(base_dir):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".py"):
|
||||||
|
azure_files.append(os.path.join(root, file))
|
||||||
|
return azure_files
|
||||||
|
|
||||||
|
|
||||||
|
def check_direct_instantiation(file_path):
|
||||||
|
"""
|
||||||
|
Check if a file directly instantiates AzureOpenAI or AsyncAzureOpenAI
|
||||||
|
outside of the BaseAzureLLM class methods.
|
||||||
|
"""
|
||||||
|
with open(file_path, "r") as file:
|
||||||
|
content = file.read()
|
||||||
|
|
||||||
|
# Parse the file
|
||||||
|
tree = ast.parse(content)
|
||||||
|
|
||||||
|
# Track issues found
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
# Find all class definitions
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.ClassDef):
|
||||||
|
class_name = node.name
|
||||||
|
|
||||||
|
# Skip BaseAzureLLM class since it's allowed to define the client creation methods
|
||||||
|
if class_name == "BaseAzureLLM":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check method bodies for direct instantiation
|
||||||
|
for method in node.body:
|
||||||
|
if isinstance(method, ast.FunctionDef) or isinstance(
|
||||||
|
method, ast.AsyncFunctionDef
|
||||||
|
):
|
||||||
|
method_name = method.name
|
||||||
|
|
||||||
|
# Skip methods that are specifically for client creation
|
||||||
|
if method_name in [
|
||||||
|
"get_azure_openai_client",
|
||||||
|
"initialize_azure_sdk_client",
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Look for direct instantiation in the method body
|
||||||
|
for subnode in ast.walk(method):
|
||||||
|
if isinstance(subnode, ast.Call):
|
||||||
|
if hasattr(subnode, "func") and hasattr(subnode.func, "id"):
|
||||||
|
if subnode.func.id in [
|
||||||
|
"AzureOpenAI",
|
||||||
|
"AsyncAzureOpenAI",
|
||||||
|
]:
|
||||||
|
issues.append(
|
||||||
|
f"Direct instantiation of {subnode.func.id} in {class_name}.{method_name}"
|
||||||
|
)
|
||||||
|
elif hasattr(subnode, "func") and hasattr(
|
||||||
|
subnode.func, "attr"
|
||||||
|
):
|
||||||
|
if subnode.func.attr in [
|
||||||
|
"AzureOpenAI",
|
||||||
|
"AsyncAzureOpenAI",
|
||||||
|
]:
|
||||||
|
issues.append(
|
||||||
|
f"Direct instantiation of {subnode.func.attr} in {class_name}.{method_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return issues
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
Main function to run the test.
|
||||||
|
"""
|
||||||
|
# local
|
||||||
|
base_dir = "../../litellm/llms/azure"
|
||||||
|
azure_files = find_azure_files(base_dir)
|
||||||
|
print(f"Found {len(azure_files)} Azure Python files to check")
|
||||||
|
|
||||||
|
all_issues = []
|
||||||
|
|
||||||
|
for file_path in azure_files:
|
||||||
|
issues = check_direct_instantiation(file_path)
|
||||||
|
if issues:
|
||||||
|
all_issues.extend([f"{file_path}: {issue}" for issue in issues])
|
||||||
|
|
||||||
|
if all_issues:
|
||||||
|
print("Found direct instantiations of AzureOpenAI or AsyncAzureOpenAI:")
|
||||||
|
for issue in all_issues:
|
||||||
|
print(f" - {issue}")
|
||||||
|
raise Exception(
|
||||||
|
f"Found {len(all_issues)} direct instantiations of AzureOpenAI or AsyncAzureOpenAI classes. Use get_azure_openai_client instead."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("All Azure modules are correctly using get_azure_openai_client!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Add table
Add a link
Reference in a new issue