fix(router.py): support cloudflare ai gateway for azure models on router

This commit is contained in:
Krrish Dholakia 2023-11-30 14:08:52 -08:00
parent 936c27c9ee
commit 032f71adb2
3 changed files with 50 additions and 32 deletions

View file

@ -118,8 +118,9 @@ class AzureChatCompletion(BaseLLM):
### CHECK IF CLOUDFLARE AI GATEWAY ### ### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url ### if so - set the model as part of the base url
if "gateway.ai.cloudflare.com" in api_base and client is None: if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name ## build base url - assume api base includes resource name
if client is None:
if not api_base.endswith("/"): if not api_base.endswith("/"):
api_base += "/" api_base += "/"
api_base += f"{model}" api_base += f"{model}"
@ -162,7 +163,7 @@ class AzureChatCompletion(BaseLLM):
"azure_ad_token": azure_ad_token "azure_ad_token": azure_ad_token
}, },
"api_version": api_version, "api_version": api_version,
"api_base": api_base, "api_base": client.base_url,
"complete_input_dict": data, "complete_input_dict": data,
}, },
) )

View file

@ -856,6 +856,22 @@ class Router:
if "azure" in model_name: if "azure" in model_name:
if api_version is None: if api_version is None:
api_version = "2023-07-01-preview" api_version = "2023-07-01-preview"
if "gateway.ai.cloudflare.com" in api_base:
if not api_base.endswith("/"):
api_base += "/"
azure_model = model_name.replace("azure/", "")
api_base += f"{azure_model}"
model["async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version
)
model["client"] = openai.AzureOpenAI(
api_key=api_key,
base_url=api_base,
api_version=api_version
)
else:
model["async_client"] = openai.AsyncAzureOpenAI( model["async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key, api_key=api_key,
azure_endpoint=api_base, azure_endpoint=api_base,

View file

@ -65,7 +65,7 @@ def test_async_response_azure():
user_message = "What do you know?" user_message = "What do you know?"
messages = [{"content": user_message, "role": "user"}] messages = [{"content": user_message, "role": "user"}]
try: try:
response = await acompletion(model="azure/chatgpt-v-2", messages=messages, timeout=5) response = await acompletion(model="azure/gpt-turbo", messages=messages, base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), api_key=os.getenv("AZURE_FRANCE_API_KEY"))
print(f"response: {response}") print(f"response: {response}")
except litellm.Timeout as e: except litellm.Timeout as e:
pass pass
@ -76,6 +76,7 @@ def test_async_response_azure():
# test_async_response_azure() # test_async_response_azure()
def test_async_anyscale_response(): def test_async_anyscale_response():
import asyncio import asyncio
litellm.set_verbose = True litellm.set_verbose = True