refactor(azure.py): working client init logic in azure image generation

This commit is contained in:
Krrish Dholakia 2025-03-11 14:22:25 -07:00
parent 152bc67d22
commit 2c2404dac9
2 changed files with 10 additions and 18 deletions

View file

@ -1152,6 +1152,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
azure_ad_token_provider: Optional[Callable] = None,
client=None,
aimg_generation=None,
litellm_params: Optional[dict] = None,
) -> ImageResponse:
try:
if model and len(model) > 0:
@ -1176,25 +1177,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
)
# init AzureOpenAI Client
azure_client_params: Dict[str, Any] = {
"api_version": api_version,
"azure_endpoint": api_base,
"azure_deployment": model,
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
model_name=model or "",
api_version=api_version,
api_base=api_base,
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if aimg_generation is True:
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore

View file

@ -4544,6 +4544,8 @@ def image_generation( # noqa: PLR0915
**non_default_params,
)
litellm_params_dict = get_litellm_params(**kwargs)
logging: Logging = litellm_logging_obj
logging.update_environment_variables(
model=model,
@ -4614,6 +4616,7 @@ def image_generation( # noqa: PLR0915
aimg_generation=aimg_generation,
client=client,
headers=headers,
litellm_params=litellm_params_dict,
)
elif (
custom_llm_provider == "openai"