mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(router.py): support cloudflare ai gateway for azure models on router
This commit is contained in:
parent
936c27c9ee
commit
032f71adb2
3 changed files with 50 additions and 32 deletions
|
@ -118,28 +118,29 @@ class AzureChatCompletion(BaseLLM):
|
||||||
|
|
||||||
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
||||||
### if so - set the model as part of the base url
|
### if so - set the model as part of the base url
|
||||||
if "gateway.ai.cloudflare.com" in api_base and client is None:
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
## build base url - assume api base includes resource name
|
## build base url - assume api base includes resource name
|
||||||
if not api_base.endswith("/"):
|
if client is None:
|
||||||
api_base += "/"
|
if not api_base.endswith("/"):
|
||||||
api_base += f"{model}"
|
api_base += "/"
|
||||||
|
api_base += f"{model}"
|
||||||
|
|
||||||
azure_client_params = {
|
azure_client_params = {
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
"base_url": f"{api_base}",
|
"base_url": f"{api_base}",
|
||||||
"http_client": litellm.client_session,
|
"http_client": litellm.client_session,
|
||||||
"max_retries": max_retries,
|
"max_retries": max_retries,
|
||||||
"timeout": timeout
|
"timeout": timeout
|
||||||
}
|
}
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
client = AsyncAzureOpenAI(**azure_client_params)
|
client = AsyncAzureOpenAI(**azure_client_params)
|
||||||
else:
|
else:
|
||||||
client = AzureOpenAI(**azure_client_params)
|
client = AzureOpenAI(**azure_client_params)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": None,
|
"model": None,
|
||||||
|
@ -162,7 +163,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
"azure_ad_token": azure_ad_token
|
"azure_ad_token": azure_ad_token
|
||||||
},
|
},
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
"api_base": api_base,
|
"api_base": client.base_url,
|
||||||
"complete_input_dict": data,
|
"complete_input_dict": data,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -856,16 +856,32 @@ class Router:
|
||||||
if "azure" in model_name:
|
if "azure" in model_name:
|
||||||
if api_version is None:
|
if api_version is None:
|
||||||
api_version = "2023-07-01-preview"
|
api_version = "2023-07-01-preview"
|
||||||
model["async_client"] = openai.AsyncAzureOpenAI(
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
api_key=api_key,
|
if not api_base.endswith("/"):
|
||||||
azure_endpoint=api_base,
|
api_base += "/"
|
||||||
api_version=api_version
|
azure_model = model_name.replace("azure/", "")
|
||||||
)
|
api_base += f"{azure_model}"
|
||||||
model["client"] = openai.AzureOpenAI(
|
model["async_client"] = openai.AsyncAzureOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
azure_endpoint=api_base,
|
base_url=api_base,
|
||||||
api_version=api_version
|
api_version=api_version
|
||||||
)
|
)
|
||||||
|
model["client"] = openai.AzureOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
api_version=api_version
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model["async_client"] = openai.AsyncAzureOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
azure_endpoint=api_base,
|
||||||
|
api_version=api_version
|
||||||
|
)
|
||||||
|
model["client"] = openai.AzureOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
azure_endpoint=api_base,
|
||||||
|
api_version=api_version
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model["async_client"] = openai.AsyncOpenAI(
|
model["async_client"] = openai.AsyncOpenAI(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
|
|
|
@ -65,7 +65,7 @@ def test_async_response_azure():
|
||||||
user_message = "What do you know?"
|
user_message = "What do you know?"
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
try:
|
try:
|
||||||
response = await acompletion(model="azure/chatgpt-v-2", messages=messages, timeout=5)
|
response = await acompletion(model="azure/gpt-turbo", messages=messages, base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), api_key=os.getenv("AZURE_FRANCE_API_KEY"))
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
pass
|
pass
|
||||||
|
@ -76,6 +76,7 @@ def test_async_response_azure():
|
||||||
|
|
||||||
# test_async_response_azure()
|
# test_async_response_azure()
|
||||||
|
|
||||||
|
|
||||||
def test_async_anyscale_response():
|
def test_async_anyscale_response():
|
||||||
import asyncio
|
import asyncio
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue