(fix) router - set read os.environ/ values

This commit is contained in:
ishaan-jaff 2023-12-06 08:59:26 -08:00
parent 92b2cbcdc5
commit aab6be654e

View file

@ -891,6 +891,7 @@ class Router:
if api_key and api_key.startswith("os.environ/"): if api_key and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "") api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name) api_key = litellm.get_secret(api_key_env_name)
litellm_params["api_key"] = api_key
api_base = litellm_params.get("api_base") api_base = litellm_params.get("api_base")
base_url = litellm_params.get("base_url") base_url = litellm_params.get("base_url")
@ -898,31 +899,35 @@ class Router:
if api_base and api_base.startswith("os.environ/"): if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "") api_base_env_name = api_base.replace("os.environ/", "")
api_base = litellm.get_secret(api_base_env_name) api_base = litellm.get_secret(api_base_env_name)
litellm_params["api_base"] = api_base
api_version = litellm_params.get("api_version") api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"): if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "") api_version_env_name = api_version.replace("os.environ/", "")
api_version = litellm.get_secret(api_version_env_name) api_version = litellm.get_secret(api_version_env_name)
litellm_params["api_version"] = api_version
timeout = litellm_params.pop("timeout", None) timeout = litellm_params.pop("timeout", None)
if isinstance(timeout, str) and timeout.startswith("os.environ/"): if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = api_version.replace("os.environ/", "") timeout_env_name = api_version.replace("os.environ/", "")
timeout = litellm.get_secret(timeout_env_name) timeout = litellm.get_secret(timeout_env_name)
litellm_params["timeout"] = timeout
stream_timeout = litellm_params.pop("stream_timeout", timeout) # if no stream_timeout is set, default to timeout stream_timeout = litellm_params.pop("stream_timeout", timeout) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"): if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
stream_timeout_env_name = api_version.replace("os.environ/", "") stream_timeout_env_name = api_version.replace("os.environ/", "")
stream_timeout = litellm.get_secret(stream_timeout_env_name) stream_timeout = litellm.get_secret(stream_timeout_env_name)
litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop("max_retries", 2) max_retries = litellm_params.pop("max_retries", 2)
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"): if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = api_version.replace("os.environ/", "") max_retries_env_name = api_version.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name) max_retries = litellm.get_secret(max_retries_env_name)
litellm_params["max_retries"] = max_retries
if "azure" in model_name: if "azure" in model_name:
if api_base is None: if api_base is None:
raise ValueError("api_base is required for Azure OpenAI. Set it on your config") raise ValueError("api_base is required for Azure OpenAI. Set it on your config")
self.print_verbose(f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}")
if api_version is None: if api_version is None:
api_version = "2023-07-01-preview" api_version = "2023-07-01-preview"
if "gateway.ai.cloudflare.com" in api_base: if "gateway.ai.cloudflare.com" in api_base:
@ -961,6 +966,7 @@ class Router:
max_retries=max_retries max_retries=max_retries
) )
else: else:
self.print_verbose(f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}")
model["async_client"] = openai.AsyncAzureOpenAI( model["async_client"] = openai.AsyncAzureOpenAI(
api_key=api_key, api_key=api_key,
azure_endpoint=api_base, azure_endpoint=api_base,