mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
refactor(azure.py): working client init logic in azure image generation
This commit is contained in:
parent
152bc67d22
commit
2c2404dac9
2 changed files with 10 additions and 18 deletions
|
@ -1152,6 +1152,7 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
azure_ad_token_provider: Optional[Callable] = None,
|
azure_ad_token_provider: Optional[Callable] = None,
|
||||||
client=None,
|
client=None,
|
||||||
aimg_generation=None,
|
aimg_generation=None,
|
||||||
|
litellm_params: Optional[dict] = None,
|
||||||
) -> ImageResponse:
|
) -> ImageResponse:
|
||||||
try:
|
try:
|
||||||
if model and len(model) > 0:
|
if model and len(model) > 0:
|
||||||
|
@ -1176,25 +1177,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
# init AzureOpenAI Client
|
# init AzureOpenAI Client
|
||||||
azure_client_params: Dict[str, Any] = {
|
azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client(
|
||||||
"api_version": api_version,
|
litellm_params=litellm_params or {},
|
||||||
"azure_endpoint": api_base,
|
api_key=api_key,
|
||||||
"azure_deployment": model,
|
model_name=model or "",
|
||||||
"max_retries": max_retries,
|
api_version=api_version,
|
||||||
"timeout": timeout,
|
api_base=api_base,
|
||||||
}
|
|
||||||
azure_client_params = select_azure_base_url_or_endpoint(
|
|
||||||
azure_client_params=azure_client_params
|
|
||||||
)
|
)
|
||||||
if api_key is not None:
|
|
||||||
azure_client_params["api_key"] = api_key
|
|
||||||
elif azure_ad_token is not None:
|
|
||||||
if azure_ad_token.startswith("oidc/"):
|
|
||||||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
|
||||||
elif azure_ad_token_provider is not None:
|
|
||||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
|
||||||
|
|
||||||
if aimg_generation is True:
|
if aimg_generation is True:
|
||||||
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
|
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -4544,6 +4544,8 @@ def image_generation( # noqa: PLR0915
|
||||||
**non_default_params,
|
**non_default_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
litellm_params_dict = get_litellm_params(**kwargs)
|
||||||
|
|
||||||
logging: Logging = litellm_logging_obj
|
logging: Logging = litellm_logging_obj
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -4614,6 +4616,7 @@ def image_generation( # noqa: PLR0915
|
||||||
aimg_generation=aimg_generation,
|
aimg_generation=aimg_generation,
|
||||||
client=client,
|
client=client,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
litellm_params=litellm_params_dict,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider == "openai"
|
custom_llm_provider == "openai"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue