fix(router.py): support openai-compatible endpoints

This commit is contained in:
Krrish Dholakia 2023-12-15 14:47:54 -08:00
parent d2e9798de9
commit e5268fa6bc
3 changed files with 37 additions and 15 deletions

View file

@ -889,25 +889,23 @@ class Router:
model["model_info"] = model_info
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
if custom_llm_provider is None:
custom_llm_provider = model_name.split("/",1)[0]
custom_llm_provider = custom_llm_provider or model_name.split("/",1)[0] or ""
default_api_base = None
default_api_key = None
if custom_llm_provider in litellm.openai_compatible_providers:
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(model=model_name)
default_api_base = api_base
default_api_key = api_key
if (
model_name in litellm.open_ai_chat_completion_models
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral"
or custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "mistral"
or custom_llm_provider in litellm.openai_compatible_providers
or "ft:gpt-3.5-turbo" in model_name
or model_name in litellm.open_ai_embedding_models
):
# glorified / complicated reading of configs
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key")
api_key = litellm_params.get("api_key") or default_api_key
if api_key and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name)
@ -915,7 +913,7 @@ class Router:
api_base = litellm_params.get("api_base")
base_url = litellm_params.get("base_url")
api_base = api_base or base_url # allow users to pass in `api_base` or `base_url` for azure
api_base = api_base or base_url or default_api_base # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = litellm.get_secret(api_base_env_name)