_init_azure_client_for_cloudflare_ai_gateway

This commit is contained in:
Ishaan Jaff 2025-03-18 10:11:54 -07:00
parent 860d96a01e
commit cb55069a70
2 changed files with 54 additions and 31 deletions

View file

@ -238,37 +238,18 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
### CHECK IF CLOUDFLARE AI GATEWAY ### ### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url ### if so - set the model as part of the base url
if "gateway.ai.cloudflare.com" in api_base: if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name client = self._init_azure_client_for_cloudflare_ai_gateway(
if client is None: api_base=api_base,
if not api_base.endswith("/"): model=model,
api_base += "/" api_version=api_version,
api_base += f"{model}" max_retries=max_retries,
timeout=timeout,
azure_client_params = { api_key=api_key,
"api_version": api_version, azure_ad_token=azure_ad_token,
"base_url": f"{api_base}", azure_ad_token_provider=azure_ad_token_provider,
"http_client": litellm.client_session, acompletion=acompletion,
"max_retries": max_retries, client=client,
"timeout": timeout, )
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token
)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = (
azure_ad_token_provider
)
if acompletion is True:
client = AsyncAzureOpenAI(**azure_client_params)
else:
client = AzureOpenAI(**azure_client_params)
data = {"model": None, "messages": messages, **optional_params} data = {"model": None, "messages": messages, **optional_params}
else: else:

View file

@ -357,3 +357,45 @@ class BaseAzureLLM:
) )
return azure_client_params return azure_client_params
def _init_azure_client_for_cloudflare_ai_gateway(
self,
api_base: str,
model: str,
api_version: str,
max_retries: int,
timeout: Union[float, httpx.Timeout],
api_key: Optional[str],
azure_ad_token: Optional[str],
azure_ad_token_provider: Optional[Callable[[], str]],
acompletion: bool,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[AzureOpenAI, AsyncAzureOpenAI]:
## build base url - assume api base includes resource name
if client is None:
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
azure_client_params = {
"api_version": api_version,
"base_url": f"{api_base}",
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
elif azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if acompletion is True:
client = AsyncAzureOpenAI(**azure_client_params)
else:
client = AzureOpenAI(**azure_client_params)
return client