mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 10:14:26 +00:00
(feat) router: init client for OpenAI compatible providers
This commit is contained in:
parent
b9ae6275ca
commit
afd20098be
1 changed files with 54 additions and 40 deletions
|
@ -73,6 +73,7 @@ class Router:
|
|||
context_window_fallbacks: List = [],
|
||||
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
|
||||
|
||||
self.set_verbose = set_verbose
|
||||
if model_list:
|
||||
self.set_model_list(model_list)
|
||||
self.healthy_deployments: List = self.model_list
|
||||
|
@ -83,7 +84,6 @@ class Router:
|
|||
self.allowed_fails = allowed_fails or litellm.allowed_fails
|
||||
self.failed_calls = InMemoryCache() # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown
|
||||
self.num_retries = num_retries or litellm.num_retries or 0
|
||||
self.set_verbose = set_verbose
|
||||
self.timeout = timeout or litellm.request_timeout
|
||||
self.routing_strategy = routing_strategy
|
||||
self.fallbacks = fallbacks or litellm.fallbacks
|
||||
|
@ -818,47 +818,61 @@ class Router:
|
|||
for model in self.model_list:
|
||||
litellm_params = model.get("litellm_params", {})
|
||||
model_name = litellm_params.get("model")
|
||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider is None:
|
||||
custom_llm_provider = model_name.split("/",1)[0]
|
||||
if (
|
||||
model_name in litellm.open_ai_chat_completion_models
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
or "ft:gpt-3.5-turbo" in model_name
|
||||
):
|
||||
# glorified / complicated reading of configs
|
||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||
api_key = litellm_params.get("api_key")
|
||||
if api_key and api_key.startswith("os.environ/"):
|
||||
api_key_env_name = api_key.replace("os.environ/", "")
|
||||
api_key = os.getenv(api_key_env_name)
|
||||
|
||||
# glorified / complicated reading of configs
|
||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||
api_key = litellm_params.get("api_key")
|
||||
if api_key and api_key.startswith("os.environ/"):
|
||||
api_key_env_name = api_key.replace("os.environ/", "")
|
||||
api_key = os.getenv(api_key_env_name)
|
||||
api_base = litellm_params.get("api_base")
|
||||
if api_base and api_base.startswith("os.environ/"):
|
||||
api_base_env_name = api_base.replace("os.environ/", "")
|
||||
api_base = os.getenv(api_base_env_name)
|
||||
|
||||
api_base = litellm_params.get("api_base")
|
||||
if api_base and api_base.startswith("os.environ/"):
|
||||
api_base_env_name = api_base.replace("os.environ/", "")
|
||||
api_base = os.getenv(api_base_env_name)
|
||||
|
||||
api_version = litellm_params.get("api_version")
|
||||
if api_version and api_version.startswith("os.environ/"):
|
||||
api_version_env_name = api_version.replace("os.environ/", "")
|
||||
api_version = os.getenv(api_version_env_name)
|
||||
if api_version is None:
|
||||
api_version = "2023-07-01-preview"
|
||||
|
||||
if "azure" in model_name:
|
||||
model["async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version
|
||||
)
|
||||
model["client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version
|
||||
)
|
||||
elif model_name in litellm.open_ai_chat_completion_models:
|
||||
model["async_client"] = openai.AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
)
|
||||
model["client"] = openai.OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
)
|
||||
api_version = litellm_params.get("api_version")
|
||||
if api_version and api_version.startswith("os.environ/"):
|
||||
api_version_env_name = api_version.replace("os.environ/", "")
|
||||
api_version = os.getenv(api_version_env_name)
|
||||
self.print_verbose(f"Initializing OpenAI Client for {model_name}, {str(api_base)}")
|
||||
if "azure" in model_name:
|
||||
if api_version is None:
|
||||
api_version = "2023-07-01-preview"
|
||||
model["async_client"] = openai.AsyncAzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version
|
||||
)
|
||||
model["client"] = openai.AzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=api_base,
|
||||
api_version=api_version
|
||||
)
|
||||
else:
|
||||
model["async_client"] = openai.AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
)
|
||||
model["client"] = openai.OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
)
|
||||
############ End of initializing Clients for OpenAI/Azure ###################
|
||||
model_id = ""
|
||||
for key in model["litellm_params"]:
|
||||
model_id+= str(model["litellm_params"][key])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue