mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
refactor(azure.py): refactor to have client init work across all endpoints
This commit is contained in:
parent
1516240bab
commit
2f262ed9b4
10 changed files with 296 additions and 129 deletions
|
@ -35,40 +35,6 @@ class AzureOpenAIError(BaseLLMException):
|
|||
)
|
||||
|
||||
|
||||
def get_azure_openai_client(
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
_is_async: bool = False,
|
||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
received_args = locals()
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client" or k == "_is_async":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["azure_endpoint"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
if "api_version" not in data:
|
||||
data["api_version"] = litellm.AZURE_DEFAULT_API_VERSION
|
||||
|
||||
if _is_async is True:
|
||||
openai_client = AsyncAzureOpenAI(**data)
|
||||
else:
|
||||
openai_client = AzureOpenAI(**data) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
|
||||
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
||||
openai_headers = {}
|
||||
if "x-ratelimit-limit-requests" in headers:
|
||||
|
@ -277,6 +243,33 @@ def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
|||
|
||||
|
||||
class BaseAzureLLM:
|
||||
def get_azure_openai_client(
|
||||
self,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
_is_async: bool = False,
|
||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None
|
||||
if client is None:
|
||||
azure_client_params = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_name="",
|
||||
api_version=api_version,
|
||||
)
|
||||
if _is_async is True:
|
||||
openai_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
openai_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
def initialize_azure_sdk_client(
|
||||
self,
|
||||
litellm_params: dict,
|
||||
|
@ -294,6 +287,8 @@ class BaseAzureLLM:
|
|||
client_secret = litellm_params.get("client_secret")
|
||||
azure_username = litellm_params.get("azure_username")
|
||||
azure_password = litellm_params.get("azure_password")
|
||||
max_retries = litellm_params.get("max_retries")
|
||||
timeout = litellm_params.get("timeout")
|
||||
if not api_key and tenant_id and client_id and client_secret:
|
||||
verbose_logger.debug("Using Azure AD Token Provider for Azure Auth")
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entrata_id(
|
||||
|
@ -338,6 +333,10 @@ class BaseAzureLLM:
|
|||
"azure_ad_token": azure_ad_token,
|
||||
"azure_ad_token_provider": azure_ad_token_provider,
|
||||
}
|
||||
if max_retries is not None:
|
||||
azure_client_params["max_retries"] = max_retries
|
||||
if timeout is not None:
|
||||
azure_client_params["timeout"] = timeout
|
||||
|
||||
if azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue