forked from phoenix/litellm-mirror
feat(main.py): add support for azure-openai via cloudflare ai gateway
This commit is contained in:
parent
be8bdb580a
commit
4f07c8565a
3 changed files with 43 additions and 2 deletions
|
@ -111,7 +111,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if model is None or messages is None:
|
if (model is None or messages is None) and client is None:
|
||||||
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages")
|
raise AzureOpenAIError(status_code=422, message=f"Missing model or messages")
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
|
|
|
@ -494,6 +494,31 @@ def completion(
|
||||||
if k not in optional_params: # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
|
if k not in optional_params: # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
|
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
||||||
|
### if so - set the model as part of the base url
|
||||||
|
if "gateway.ai.cloudflare.com" in api_base and client is None:
|
||||||
|
## build base url - assume api base includes resource name
|
||||||
|
if not api_base.endswith("/"):
|
||||||
|
api_base += "/"
|
||||||
|
api_base += f"{model}"
|
||||||
|
|
||||||
|
azure_client_params = {
|
||||||
|
"api_version": api_version,
|
||||||
|
"base_url": f"{api_base}",
|
||||||
|
"http_client": litellm.client_session,
|
||||||
|
"max_retries": max_retries,
|
||||||
|
"timeout": timeout
|
||||||
|
}
|
||||||
|
if api_key is not None:
|
||||||
|
azure_client_params["api_key"] = api_key
|
||||||
|
elif azure_ad_token is not None:
|
||||||
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
|
if acompletion is True:
|
||||||
|
client = openai.AsyncAzureOpenAI(**azure_client_params)
|
||||||
|
else:
|
||||||
|
client = openai.AzureOpenAI(**azure_client_params)
|
||||||
|
model = None
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = azure_chat_completions.completion(
|
response = azure_chat_completions.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -933,7 +933,7 @@ def test_replicate_custom_prompt_dict():
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
litellm.custom_prompt_dict = {} # reset
|
litellm.custom_prompt_dict = {} # reset
|
||||||
|
|
||||||
test_replicate_custom_prompt_dict()
|
# test_replicate_custom_prompt_dict()
|
||||||
|
|
||||||
# commenthing this out since we won't be always testing a custom replicate deployment
|
# commenthing this out since we won't be always testing a custom replicate deployment
|
||||||
# def test_completion_replicate_deployments():
|
# def test_completion_replicate_deployments():
|
||||||
|
@ -1323,6 +1323,22 @@ def test_completion_anyscale_api():
|
||||||
|
|
||||||
# test_completion_anyscale_api()
|
# test_completion_anyscale_api()
|
||||||
|
|
||||||
|
def test_azure_cloudflare_api():
|
||||||
|
try:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "How do I output all files in a directory using Python?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
response = completion(model="azure/gpt-turbo", messages=messages, base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"), api_key=os.getenv("AZURE_FRANCE_API_KEY"))
|
||||||
|
print(f"response: {response}")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
|
||||||
|
test_azure_cloudflare_api()
|
||||||
|
|
||||||
def test_completion_anyscale_2():
|
def test_completion_anyscale_2():
|
||||||
try:
|
try:
|
||||||
# litellm.set_verbose=True
|
# litellm.set_verbose=True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue