Add pyright to ci/cd + Fix remaining type-checking errors (#6082)

* fix: fix type-checking errors

* fix: fix additional type-checking errors

* fix: additional type-checking error fixes

* fix: fix additional type-checking errors

* fix: additional type-check fixes

* fix: fix all type-checking errors + add pyright to ci/cd

* fix: fix incorrect import

* ci(config.yml): use mypy on ci/cd

* fix: fix type-checking errors in utils.py

* fix: fix all type-checking errors on main.py

* fix: fix mypy linting errors

* fix(anthropic/cost_calculator.py): fix linting errors

* fix: fix mypy linting errors

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2024-10-05 17:04:00 -04:00 committed by GitHub
parent f7ce1173f3
commit fac3b2ee42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
65 changed files with 619 additions and 522 deletions

View file

@ -8,7 +8,7 @@ import requests
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union
from typing import Literal, Union, Optional
import traceback
@ -26,9 +26,9 @@ from litellm._logging import print_verbose, verbose_logger
class GenericAPILogger:
# Class variables or attributes
def __init__(self, endpoint=None, headers=None):
def __init__(self, endpoint: Optional[str] = None, headers: Optional[dict] = None):
try:
if endpoint == None:
if endpoint is None:
# check env for "GENERIC_LOGGER_ENDPOINT"
if os.getenv("GENERIC_LOGGER_ENDPOINT"):
# Do something with the endpoint
@ -36,9 +36,15 @@ class GenericAPILogger:
else:
# Handle the case when the endpoint is not found in the environment variables
raise ValueError(
f"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
)
headers = headers or litellm.generic_logger_headers
if endpoint is None:
raise ValueError("endpoint not set for GenericAPILogger")
if headers is None:
raise ValueError("headers not set for GenericAPILogger")
self.endpoint = endpoint
self.headers = headers

View file

@ -48,8 +48,6 @@ class AporiaGuardrail(CustomGuardrail):
)
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
self.event_hook: GuardrailEventHooks
super().__init__(**kwargs)
#### CALL HOOKS - proxy only ####

View file

@ -84,7 +84,7 @@ class _ENTERPRISE_BlockedUserList(CustomLogger):
)
cache_key = f"litellm:end_user_id:{user}"
end_user_cache_obj: LiteLLM_EndUserTable = cache.get_cache(
end_user_cache_obj: Optional[LiteLLM_EndUserTable] = cache.get_cache( # type: ignore
key=cache_key
)
if end_user_cache_obj is None and self.prisma_client is not None:

View file

@ -48,7 +48,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
# Class variables or attributes
def __init__(self):
try:
from google.cloud import language_v1
from google.cloud import language_v1 # type: ignore
except Exception:
raise Exception(
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
@ -57,8 +57,8 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
# Instantiates a client
self.client = language_v1.LanguageServiceClient()
self.moderate_text_request = language_v1.ModerateTextRequest
self.language_document = language_v1.types.Document
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT
self.language_document = language_v1.types.Document # type: ignore
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT # type: ignore
default_confidence_threshold = (
litellm.google_moderation_confidence_threshold or 0.8

View file

@ -8,6 +8,7 @@
# Thank you users! We ❤️ you! - Krrish & Ishaan
import sys, os
from collections.abc import Iterable
sys.path.insert(
0, os.path.abspath("../..")
@ -19,11 +20,12 @@ from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.utils import (
from litellm.types.utils import (
ModelResponse,
EmbeddingResponse,
ImageResponse,
StreamingChoices,
Choices,
)
from datetime import datetime
import aiohttp, asyncio
@ -34,7 +36,10 @@ litellm.set_verbose = True
class _ENTERPRISE_LlamaGuard(CustomLogger):
# Class variables or attributes
def __init__(self, model_name: Optional[str] = None):
self.model = model_name or litellm.llamaguard_model_name
_model = model_name or litellm.llamaguard_model_name
if _model is None:
raise ValueError("model_name not set for LlamaGuard")
self.model = _model
file_path = litellm.llamaguard_unsafe_content_categories
data = None
@ -124,7 +129,13 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
hf_model_name="meta-llama/LlamaGuard-7b",
)
if "unsafe" in response.choices[0].message.content:
if (
isinstance(response, ModelResponse)
and isinstance(response.choices[0], Choices)
and response.choices[0].message.content is not None
and isinstance(response.choices[0].message.content, Iterable)
and "unsafe" in response.choices[0].message.content
):
raise HTTPException(
status_code=400, detail={"error": "Violated content safety policy"}
)

View file

@ -8,7 +8,11 @@
## This provides an LLM Guard Integration for content moderation on the proxy
from typing import Optional, Literal, Union
import litellm, traceback, sys, uuid, os
import litellm
import traceback
import sys
import uuid
import os
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.integrations.custom_logger import CustomLogger
@ -21,8 +25,10 @@ from litellm.utils import (
StreamingChoices,
)
from datetime import datetime
import aiohttp, asyncio
import aiohttp
import asyncio
from litellm.utils import get_formatted_prompt
from litellm.secret_managers.main import get_secret_str
litellm.set_verbose = True
@ -38,7 +44,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
self.llm_guard_mode = litellm.llm_guard_mode
if mock_testing == True: # for testing purposes only
return
self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None)
self.llm_guard_api_base = get_secret_str("LLM_GUARD_API_BASE", None)
if self.llm_guard_api_base is None:
raise Exception("Missing `LLM_GUARD_API_BASE` from environment")
elif not self.llm_guard_api_base.endswith("/"):

View file

@ -51,8 +51,8 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
"audio_transcription",
],
):
text = ""
if "messages" in data and isinstance(data["messages"], list):
text = ""
for m in data["messages"]: # assume messages is a list
if "content" in m and isinstance(m["content"], str):
text += m["content"]
@ -67,7 +67,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
)
verbose_proxy_logger.debug("Moderation response: %s", moderation_response)
if moderation_response.results[0].flagged == True:
if moderation_response.results[0].flagged is True:
raise HTTPException(
status_code=403, detail={"error": "Violated content safety policy"}
)

View file

@ -6,7 +6,9 @@ import collections
from datetime import datetime
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
async def get_spend_by_tags(
prisma_client: PrismaClient, start_date=None, end_date=None
):
response = await prisma_client.db.query_raw(
"""
SELECT