use get_azure_openai_client

This commit is contained in:
Ishaan Jaff 2025-03-18 10:28:39 -07:00
parent a0c5fb81b8
commit e34be5a3b6

View file

@ -76,33 +76,25 @@ class AzureTextCompletion(BaseAzureLLM):
model_name=model, model_name=model,
api_version=api_version, api_version=api_version,
api_base=api_base, api_base=api_base,
is_async=False,
) )
### CHECK IF CLOUDFLARE AI GATEWAY ### ### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url ### if so - set the model as part of the base url
if "gateway.ai.cloudflare.com" in api_base: if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name ## build base url - assume api base includes resource name
if client is None: client = self._init_azure_client_for_cloudflare_ai_gateway(
if not api_base.endswith("/"): api_key=api_key,
api_base += "/" api_version=api_version,
api_base += f"{model}" api_base=api_base,
model=model,
azure_client_params = { client=client,
"api_version": api_version, max_retries=max_retries,
"base_url": f"{api_base}", timeout=timeout,
"http_client": litellm.client_session, azure_ad_token=azure_ad_token,
"max_retries": max_retries, azure_ad_token_provider=azure_ad_token_provider,
"timeout": timeout, acompletion=acompletion,
} )
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if acompletion is True:
client = AsyncAzureOpenAI(**azure_client_params)
else:
client = AzureOpenAI(**azure_client_params)
data = {"model": None, "prompt": prompt, **optional_params} data = {"model": None, "prompt": prompt, **optional_params}
else: else:
@ -174,17 +166,21 @@ class AzureTextCompletion(BaseAzureLLM):
status_code=422, message="max retries must be an int" status_code=422, message="max retries must be an int"
) )
# init AzureOpenAI Client # init AzureOpenAI Client
if client is None: azure_client = self.get_azure_openai_client(
azure_client = AzureOpenAI(**azure_client_params) api_key=api_key,
else: api_base=api_base,
azure_client = client api_version=api_version,
if api_version is not None and isinstance( client=client,
azure_client._custom_query, dict litellm_params=litellm_params,
): _is_async=False,
# set api_version to version passed by user model=model,
azure_client._custom_query.setdefault( )
"api-version", api_version
) if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AzureOpenAI",
)
raw_response = azure_client.completions.with_raw_response.create( raw_response = azure_client.completions.with_raw_response.create(
**data, timeout=timeout **data, timeout=timeout
@ -234,20 +230,27 @@ class AzureTextCompletion(BaseAzureLLM):
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
client=None, # this is the AsyncAzureOpenAI client=None, # this is the AsyncAzureOpenAI
azure_client_params: dict = {}, azure_client_params: dict = {},
litellm_params: dict = {},
): ):
response = None response = None
try: try:
# init AzureOpenAI Client # init AzureOpenAI Client
# setting Azure client # setting Azure client
if client is None: azure_client = self.get_azure_openai_client(
azure_client = AsyncAzureOpenAI(**azure_client_params) api_version=api_version,
else: api_base=api_base,
azure_client = client api_key=api_key,
if api_version is not None and isinstance( model=model,
azure_client._custom_query, dict _is_async=True,
): client=client,
# set api_version to version passed by user litellm_params=litellm_params,
azure_client._custom_query.setdefault("api-version", api_version) )
if not isinstance(azure_client, AsyncAzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AsyncAzureOpenAI",
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data["prompt"], input=data["prompt"],
@ -291,6 +294,7 @@ class AzureTextCompletion(BaseAzureLLM):
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
client=None, client=None,
azure_client_params: dict = {}, azure_client_params: dict = {},
litellm_params: dict = {},
): ):
max_retries = data.pop("max_retries", 2) max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
@ -298,13 +302,21 @@ class AzureTextCompletion(BaseAzureLLM):
status_code=422, message="max retries must be an int" status_code=422, message="max retries must be an int"
) )
# init AzureOpenAI Client # init AzureOpenAI Client
if client is None: azure_client = self.get_azure_openai_client(
azure_client = AzureOpenAI(**azure_client_params) api_version=api_version,
else: api_base=api_base,
azure_client = client api_key=api_key,
if api_version is not None and isinstance(azure_client._custom_query, dict): model=model,
# set api_version to version passed by user _is_async=False,
azure_client._custom_query.setdefault("api-version", api_version) client=client,
litellm_params=litellm_params,
)
if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AzureOpenAI",
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data["prompt"], input=data["prompt"],
@ -340,18 +352,24 @@ class AzureTextCompletion(BaseAzureLLM):
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
client=None, client=None,
azure_client_params: dict = {}, azure_client_params: dict = {},
litellm_params: dict = {},
): ):
try: try:
# init AzureOpenAI Client # init AzureOpenAI Client
if client is None: azure_client = self.get_azure_openai_client(
azure_client = AsyncAzureOpenAI(**azure_client_params) api_version=api_version,
else: api_base=api_base,
azure_client = client api_key=api_key,
if api_version is not None and isinstance( model=model,
azure_client._custom_query, dict _is_async=True,
): client=client,
# set api_version to version passed by user litellm_params=litellm_params,
azure_client._custom_query.setdefault("api-version", api_version) )
if not isinstance(azure_client, AsyncAzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AsyncAzureOpenAI",
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=data["prompt"], input=data["prompt"],