Add pyright to ci/cd + Fix remaining type-checking errors (#6082)

* fix: fix type-checking errors

* fix: fix additional type-checking errors

* fix: additional type-checking error fixes

* fix: fix additional type-checking errors

* fix: additional type-check fixes

* fix: fix all type-checking errors + add pyright to ci/cd

* fix: fix incorrect import

* ci(config.yml): use mypy on ci/cd

* fix: fix type-checking errors in utils.py

* fix: fix all type-checking errors on main.py

* fix: fix mypy linting errors

* fix(anthropic/cost_calculator.py): fix linting errors

* fix: fix mypy linting errors

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2024-10-05 17:04:00 -04:00 committed by GitHub
parent f7ce1173f3
commit fac3b2ee42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 619 additions and 522 deletions

View file

@ -7,6 +7,7 @@ import httpx
import openai
import litellm
from litellm import get_secret, get_secret_str
from litellm._logging import verbose_router_logger
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
from litellm.secret_managers.get_azure_ad_token_provider import (
@ -111,17 +112,17 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
api_key = litellm_params.get("api_key") or default_api_key
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name)
api_key = get_secret_str(api_key_env_name)
litellm_params["api_key"] = api_key
api_base = litellm_params.get("api_base")
base_url = litellm_params.get("base_url")
base_url: Optional[str] = litellm_params.get("base_url")
api_base = (
api_base or base_url or default_api_base
) # allow users to pass in `api_base` or `base_url` for azure
if api_base and api_base.startswith("os.environ/"):
api_base_env_name = api_base.replace("os.environ/", "")
api_base = litellm.get_secret(api_base_env_name)
api_base = get_secret_str(api_base_env_name)
litellm_params["api_base"] = api_base
## AZURE AI STUDIO MISTRAL CHECK ##
@ -147,33 +148,37 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
api_version = litellm_params.get("api_version")
if api_version and api_version.startswith("os.environ/"):
api_version_env_name = api_version.replace("os.environ/", "")
api_version = litellm.get_secret(api_version_env_name)
api_version = get_secret_str(api_version_env_name)
litellm_params["api_version"] = api_version
timeout = litellm_params.pop("timeout", None) or litellm.request_timeout
timeout: Optional[float] = (
litellm_params.pop("timeout", None) or litellm.request_timeout
)
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
timeout_env_name = timeout.replace("os.environ/", "")
timeout = litellm.get_secret(timeout_env_name)
timeout = get_secret(timeout_env_name) # type: ignore
litellm_params["timeout"] = timeout
stream_timeout = litellm_params.pop(
stream_timeout: Optional[float] = litellm_params.pop(
"stream_timeout", timeout
) # if no stream_timeout is set, default to timeout
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
stream_timeout = litellm.get_secret(stream_timeout_env_name)
stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
litellm_params["stream_timeout"] = stream_timeout
max_retries = litellm_params.pop("max_retries", 0) # router handles retry logic
max_retries: Optional[int] = litellm_params.pop(
"max_retries", 0
) # router handles retry logic
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
max_retries_env_name = max_retries.replace("os.environ/", "")
max_retries = litellm.get_secret(max_retries_env_name)
max_retries = get_secret(max_retries_env_name) # type: ignore
litellm_params["max_retries"] = max_retries
organization = litellm_params.get("organization", None)
if isinstance(organization, str) and organization.startswith("os.environ/"):
organization_env_name = organization.replace("os.environ/", "")
organization = litellm.get_secret(organization_env_name)
organization = get_secret_str(organization_env_name)
litellm_params["organization"] = organization
azure_ad_token_provider: Optional[Callable[[], str]] = None
if litellm_params.get("tenant_id"):
@ -227,8 +232,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -253,8 +258,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -276,8 +281,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -302,8 +307,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
azure_ad_token_provider=azure_ad_token_provider,
base_url=api_base,
api_version=api_version,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -350,8 +355,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -371,8 +376,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -391,8 +396,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.AsyncClient(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -413,8 +418,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
http_client=httpx.Client(
limits=httpx.Limits(
max_connections=1000, max_keepalive_connections=100
@ -441,8 +446,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.AsyncClient(
limits=httpx.Limits(
@ -465,8 +470,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=timeout,
max_retries=max_retries,
timeout=timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.Client(
limits=httpx.Limits(
@ -487,8 +492,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.AsyncOpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.AsyncClient(
limits=httpx.Limits(
@ -512,8 +517,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
_client = openai.OpenAI( # type: ignore
api_key=api_key,
base_url=api_base,
timeout=stream_timeout,
max_retries=max_retries,
timeout=stream_timeout, # type: ignore
max_retries=max_retries, # type: ignore
organization=organization,
http_client=httpx.Client(
limits=httpx.Limits(
@ -542,20 +547,29 @@ def get_azure_ad_token_from_entrata_id(
verbose_router_logger.debug("Getting Azure AD Token from Entrata ID")
if tenant_id.startswith("os.environ/"):
tenant_id = litellm.get_secret(tenant_id)
_tenant_id = get_secret_str(tenant_id)
else:
_tenant_id = tenant_id
if client_id.startswith("os.environ/"):
client_id = litellm.get_secret(client_id)
_client_id = get_secret_str(client_id)
else:
_client_id = client_id
if client_secret.startswith("os.environ/"):
client_secret = litellm.get_secret(client_secret)
_client_secret = get_secret_str(client_secret)
else:
_client_secret = client_secret
verbose_router_logger.debug(
"tenant_id %s, client_id %s, client_secret %s",
tenant_id,
client_id,
client_secret,
_tenant_id,
_client_id,
_client_secret,
)
credential = ClientSecretCredential(tenant_id, client_id, client_secret)
if _tenant_id is None or _client_id is None or _client_secret is None:
raise ValueError("tenant_id, client_id, and client_secret must be provided")
credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret)
verbose_router_logger.debug("credential %s", credential)