mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
use get_azure_openai_client
This commit is contained in:
parent
a0c5fb81b8
commit
e34be5a3b6
1 changed files with 75 additions and 57 deletions
|
@ -76,33 +76,25 @@ class AzureTextCompletion(BaseAzureLLM):
|
||||||
model_name=model,
|
model_name=model,
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
is_async=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
||||||
### if so - set the model as part of the base url
|
### if so - set the model as part of the base url
|
||||||
if "gateway.ai.cloudflare.com" in api_base:
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
## build base url - assume api base includes resource name
|
## build base url - assume api base includes resource name
|
||||||
if client is None:
|
client = self._init_azure_client_for_cloudflare_ai_gateway(
|
||||||
if not api_base.endswith("/"):
|
api_key=api_key,
|
||||||
api_base += "/"
|
api_version=api_version,
|
||||||
api_base += f"{model}"
|
api_base=api_base,
|
||||||
|
model=model,
|
||||||
azure_client_params = {
|
client=client,
|
||||||
"api_version": api_version,
|
max_retries=max_retries,
|
||||||
"base_url": f"{api_base}",
|
timeout=timeout,
|
||||||
"http_client": litellm.client_session,
|
azure_ad_token=azure_ad_token,
|
||||||
"max_retries": max_retries,
|
azure_ad_token_provider=azure_ad_token_provider,
|
||||||
"timeout": timeout,
|
acompletion=acompletion,
|
||||||
}
|
)
|
||||||
if api_key is not None:
|
|
||||||
azure_client_params["api_key"] = api_key
|
|
||||||
elif azure_ad_token is not None:
|
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
|
||||||
|
|
||||||
if acompletion is True:
|
|
||||||
client = AsyncAzureOpenAI(**azure_client_params)
|
|
||||||
else:
|
|
||||||
client = AzureOpenAI(**azure_client_params)
|
|
||||||
|
|
||||||
data = {"model": None, "prompt": prompt, **optional_params}
|
data = {"model": None, "prompt": prompt, **optional_params}
|
||||||
else:
|
else:
|
||||||
|
@ -174,17 +166,21 @@ class AzureTextCompletion(BaseAzureLLM):
|
||||||
status_code=422, message="max retries must be an int"
|
status_code=422, message="max retries must be an int"
|
||||||
)
|
)
|
||||||
# init AzureOpenAI Client
|
# init AzureOpenAI Client
|
||||||
if client is None:
|
azure_client = self.get_azure_openai_client(
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
api_key=api_key,
|
||||||
else:
|
api_base=api_base,
|
||||||
azure_client = client
|
api_version=api_version,
|
||||||
if api_version is not None and isinstance(
|
client=client,
|
||||||
azure_client._custom_query, dict
|
litellm_params=litellm_params,
|
||||||
):
|
_is_async=False,
|
||||||
# set api_version to version passed by user
|
model=model,
|
||||||
azure_client._custom_query.setdefault(
|
)
|
||||||
"api-version", api_version
|
|
||||||
)
|
if not isinstance(azure_client, AzureOpenAI):
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=500,
|
||||||
|
message="azure_client is not an instance of AzureOpenAI",
|
||||||
|
)
|
||||||
|
|
||||||
raw_response = azure_client.completions.with_raw_response.create(
|
raw_response = azure_client.completions.with_raw_response.create(
|
||||||
**data, timeout=timeout
|
**data, timeout=timeout
|
||||||
|
@ -234,20 +230,27 @@ class AzureTextCompletion(BaseAzureLLM):
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
client=None, # this is the AsyncAzureOpenAI
|
client=None, # this is the AsyncAzureOpenAI
|
||||||
azure_client_params: dict = {},
|
azure_client_params: dict = {},
|
||||||
|
litellm_params: dict = {},
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
# init AzureOpenAI Client
|
# init AzureOpenAI Client
|
||||||
# setting Azure client
|
# setting Azure client
|
||||||
if client is None:
|
azure_client = self.get_azure_openai_client(
|
||||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
api_version=api_version,
|
||||||
else:
|
api_base=api_base,
|
||||||
azure_client = client
|
api_key=api_key,
|
||||||
if api_version is not None and isinstance(
|
model=model,
|
||||||
azure_client._custom_query, dict
|
_is_async=True,
|
||||||
):
|
client=client,
|
||||||
# set api_version to version passed by user
|
litellm_params=litellm_params,
|
||||||
azure_client._custom_query.setdefault("api-version", api_version)
|
)
|
||||||
|
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=500,
|
||||||
|
message="azure_client is not an instance of AsyncAzureOpenAI",
|
||||||
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["prompt"],
|
input=data["prompt"],
|
||||||
|
@ -291,6 +294,7 @@ class AzureTextCompletion(BaseAzureLLM):
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
azure_client_params: dict = {},
|
azure_client_params: dict = {},
|
||||||
|
litellm_params: dict = {},
|
||||||
):
|
):
|
||||||
max_retries = data.pop("max_retries", 2)
|
max_retries = data.pop("max_retries", 2)
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
|
@ -298,13 +302,21 @@ class AzureTextCompletion(BaseAzureLLM):
|
||||||
status_code=422, message="max retries must be an int"
|
status_code=422, message="max retries must be an int"
|
||||||
)
|
)
|
||||||
# init AzureOpenAI Client
|
# init AzureOpenAI Client
|
||||||
if client is None:
|
azure_client = self.get_azure_openai_client(
|
||||||
azure_client = AzureOpenAI(**azure_client_params)
|
api_version=api_version,
|
||||||
else:
|
api_base=api_base,
|
||||||
azure_client = client
|
api_key=api_key,
|
||||||
if api_version is not None and isinstance(azure_client._custom_query, dict):
|
model=model,
|
||||||
# set api_version to version passed by user
|
_is_async=False,
|
||||||
azure_client._custom_query.setdefault("api-version", api_version)
|
client=client,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
)
|
||||||
|
if not isinstance(azure_client, AzureOpenAI):
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=500,
|
||||||
|
message="azure_client is not an instance of AzureOpenAI",
|
||||||
|
)
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["prompt"],
|
input=data["prompt"],
|
||||||
|
@ -340,18 +352,24 @@ class AzureTextCompletion(BaseAzureLLM):
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
client=None,
|
client=None,
|
||||||
azure_client_params: dict = {},
|
azure_client_params: dict = {},
|
||||||
|
litellm_params: dict = {},
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
# init AzureOpenAI Client
|
# init AzureOpenAI Client
|
||||||
if client is None:
|
azure_client = self.get_azure_openai_client(
|
||||||
azure_client = AsyncAzureOpenAI(**azure_client_params)
|
api_version=api_version,
|
||||||
else:
|
api_base=api_base,
|
||||||
azure_client = client
|
api_key=api_key,
|
||||||
if api_version is not None and isinstance(
|
model=model,
|
||||||
azure_client._custom_query, dict
|
_is_async=True,
|
||||||
):
|
client=client,
|
||||||
# set api_version to version passed by user
|
litellm_params=litellm_params,
|
||||||
azure_client._custom_query.setdefault("api-version", api_version)
|
)
|
||||||
|
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||||
|
raise AzureOpenAIError(
|
||||||
|
status_code=500,
|
||||||
|
message="azure_client is not an instance of AsyncAzureOpenAI",
|
||||||
|
)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=data["prompt"],
|
input=data["prompt"],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue