(feat) router: init client for OpenAI compatible providers

This commit is contained in:
ishaan-jaff 2023-11-28 17:49:33 -08:00
parent b9ae6275ca
commit afd20098be

View file

@ -73,6 +73,7 @@ class Router:
context_window_fallbacks: List = [],
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
self.set_verbose = set_verbose
if model_list:
self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list
@ -83,7 +84,6 @@ class Router:
self.allowed_fails = allowed_fails or litellm.allowed_fails
self.failed_calls = InMemoryCache() # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
self.num_retries = num_retries or litellm.num_retries or 0
self.set_verbose = set_verbose
self.timeout = timeout or litellm.request_timeout
self.routing_strategy = routing_strategy
self.fallbacks = fallbacks or litellm.fallbacks
@ -818,47 +818,61 @@ class Router:
for model in self.model_list:
litellm_params = model.get("litellm_params", {})
model_name = litellm_params.get("model")
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
custom_llm_provider = litellm_params.get("custom_llm_provider")
if custom_llm_provider is None:
custom_llm_provider = model_name.split("/",1)[0]
if (
model_name in litellm.open_ai_chat_completion_models
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "deepinfra"
or custom_llm_provider == "perplexity"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or "ft:gpt-3.5-turbo" in model_name
):
# glorified / complicated reading of configs
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key")
if api_key and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = os.getenv(api_key_env_name)
# glorified / complicated reading of configs
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key")
if api_key and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = os.getenv(api_key_env_name)
api_base = litellm_params.get("api_base")
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = os.getenv(api_base_env_name)
api_base = litellm_params.get("api_base")
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = os.getenv(api_base_env_name)
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = os.getenv(api_version_env_name)
if api_version is None:
api_version = "2023-07-01-preview"
if "azure" in model_name:
model["async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version
)
model["client"] = openai.AzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version
)
elif model_name in litellm.open_ai_chat_completion_models:
model["async_client"] = openai.AsyncOpenAI(
api_key=api_key,
base_url=api_base,
)
model["client"] = openai.OpenAI(
api_key=api_key,
base_url=api_base,
)
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = os.getenv(api_version_env_name)
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
if "azure" in model_name:
if api_version is None:
api_version = "2023-07-01-preview"
model["async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version
)
model["client"] = openai.AzureOpenAI(
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version
)
else:
model["async_client"] = openai.AsyncOpenAI(
api_key=api_key,
base_url=api_base,
)
model["client"] = openai.OpenAI(
api_key=api_key,
base_url=api_base,
)
############ End of initializing Clients for OpenAI/Azure ###################
model_id = ""
for key in model["litellm_params"]:
model_id+= str(model["litellm_params"][key])