forked from phoenix/litellm-mirror
Add pyright to ci/cd + Fix remaining type-checking errors (#6082)
* fix: fix type-checking errors * fix: fix additional type-checking errors * fix: additional type-checking error fixes * fix: fix additional type-checking errors * fix: additional type-check fixes * fix: fix all type-checking errors + add pyright to ci/cd * fix: fix incorrect import * ci(config.yml): use mypy on ci/cd * fix: fix type-checking errors in utils.py * fix: fix all type-checking errors on main.py * fix: fix mypy linting errors * fix(anthropic/cost_calculator.py): fix linting errors * fix: fix mypy linting errors * fix: fix linting errors
This commit is contained in:
parent
f7ce1173f3
commit
fac3b2ee42
65 changed files with 619 additions and 522 deletions
|
@ -8,7 +8,7 @@ import requests
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
|
||||
from typing import Literal, Union
|
||||
from typing import Literal, Union, Optional
|
||||
|
||||
import traceback
|
||||
|
||||
|
@ -26,9 +26,9 @@ from litellm._logging import print_verbose, verbose_logger
|
|||
|
||||
class GenericAPILogger:
|
||||
# Class variables or attributes
|
||||
def __init__(self, endpoint=None, headers=None):
|
||||
def __init__(self, endpoint: Optional[str] = None, headers: Optional[dict] = None):
|
||||
try:
|
||||
if endpoint == None:
|
||||
if endpoint is None:
|
||||
# check env for "GENERIC_LOGGER_ENDPOINT"
|
||||
if os.getenv("GENERIC_LOGGER_ENDPOINT"):
|
||||
# Do something with the endpoint
|
||||
|
@ -36,9 +36,15 @@ class GenericAPILogger:
|
|||
else:
|
||||
# Handle the case when the endpoint is not found in the environment variables
|
||||
raise ValueError(
|
||||
f"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
|
||||
"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
|
||||
)
|
||||
headers = headers or litellm.generic_logger_headers
|
||||
|
||||
if endpoint is None:
|
||||
raise ValueError("endpoint not set for GenericAPILogger")
|
||||
if headers is None:
|
||||
raise ValueError("headers not set for GenericAPILogger")
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.headers = headers
|
||||
|
||||
|
|
|
@ -48,8 +48,6 @@ class AporiaGuardrail(CustomGuardrail):
|
|||
)
|
||||
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
|
||||
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
|
||||
self.event_hook: GuardrailEventHooks
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
|
|
|
@ -84,7 +84,7 @@ class _ENTERPRISE_BlockedUserList(CustomLogger):
|
|||
)
|
||||
|
||||
cache_key = f"litellm:end_user_id:{user}"
|
||||
end_user_cache_obj: LiteLLM_EndUserTable = cache.get_cache(
|
||||
end_user_cache_obj: Optional[LiteLLM_EndUserTable] = cache.get_cache( # type: ignore
|
||||
key=cache_key
|
||||
)
|
||||
if end_user_cache_obj is None and self.prisma_client is not None:
|
||||
|
|
|
@ -48,7 +48,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
|||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
try:
|
||||
from google.cloud import language_v1
|
||||
from google.cloud import language_v1 # type: ignore
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
|
||||
|
@ -57,8 +57,8 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
|||
# Instantiates a client
|
||||
self.client = language_v1.LanguageServiceClient()
|
||||
self.moderate_text_request = language_v1.ModerateTextRequest
|
||||
self.language_document = language_v1.types.Document
|
||||
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT
|
||||
self.language_document = language_v1.types.Document # type: ignore
|
||||
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT # type: ignore
|
||||
|
||||
default_confidence_threshold = (
|
||||
litellm.google_moderation_confidence_threshold or 0.8
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import sys, os
|
||||
from collections.abc import Iterable
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -19,11 +20,12 @@ from litellm.proxy._types import UserAPIKeyAuth
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.utils import (
|
||||
from litellm.types.utils import (
|
||||
ModelResponse,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
StreamingChoices,
|
||||
Choices,
|
||||
)
|
||||
from datetime import datetime
|
||||
import aiohttp, asyncio
|
||||
|
@ -34,7 +36,10 @@ litellm.set_verbose = True
|
|||
class _ENTERPRISE_LlamaGuard(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self, model_name: Optional[str] = None):
|
||||
self.model = model_name or litellm.llamaguard_model_name
|
||||
_model = model_name or litellm.llamaguard_model_name
|
||||
if _model is None:
|
||||
raise ValueError("model_name not set for LlamaGuard")
|
||||
self.model = _model
|
||||
file_path = litellm.llamaguard_unsafe_content_categories
|
||||
data = None
|
||||
|
||||
|
@ -124,7 +129,13 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
|
|||
hf_model_name="meta-llama/LlamaGuard-7b",
|
||||
)
|
||||
|
||||
if "unsafe" in response.choices[0].message.content:
|
||||
if (
|
||||
isinstance(response, ModelResponse)
|
||||
and isinstance(response.choices[0], Choices)
|
||||
and response.choices[0].message.content is not None
|
||||
and isinstance(response.choices[0].message.content, Iterable)
|
||||
and "unsafe" in response.choices[0].message.content
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "Violated content safety policy"}
|
||||
)
|
||||
|
|
|
@ -8,7 +8,11 @@
|
|||
## This provides an LLM Guard Integration for content moderation on the proxy
|
||||
|
||||
from typing import Optional, Literal, Union
|
||||
import litellm, traceback, sys, uuid, os
|
||||
import litellm
|
||||
import traceback
|
||||
import sys
|
||||
import uuid
|
||||
import os
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -21,8 +25,10 @@ from litellm.utils import (
|
|||
StreamingChoices,
|
||||
)
|
||||
from datetime import datetime
|
||||
import aiohttp, asyncio
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from litellm.utils import get_formatted_prompt
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
|
@ -38,7 +44,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
self.llm_guard_mode = litellm.llm_guard_mode
|
||||
if mock_testing == True: # for testing purposes only
|
||||
return
|
||||
self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None)
|
||||
self.llm_guard_api_base = get_secret_str("LLM_GUARD_API_BASE", None)
|
||||
if self.llm_guard_api_base is None:
|
||||
raise Exception("Missing `LLM_GUARD_API_BASE` from environment")
|
||||
elif not self.llm_guard_api_base.endswith("/"):
|
||||
|
|
|
@ -51,8 +51,8 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
|
|||
"audio_transcription",
|
||||
],
|
||||
):
|
||||
text = ""
|
||||
if "messages" in data and isinstance(data["messages"], list):
|
||||
text = ""
|
||||
for m in data["messages"]: # assume messages is a list
|
||||
if "content" in m and isinstance(m["content"], str):
|
||||
text += m["content"]
|
||||
|
@ -67,7 +67,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
|
|||
)
|
||||
|
||||
verbose_proxy_logger.debug("Moderation response: %s", moderation_response)
|
||||
if moderation_response.results[0].flagged == True:
|
||||
if moderation_response.results[0].flagged is True:
|
||||
raise HTTPException(
|
||||
status_code=403, detail={"error": "Violated content safety policy"}
|
||||
)
|
||||
|
|
|
@ -6,7 +6,9 @@ import collections
|
|||
from datetime import datetime
|
||||
|
||||
|
||||
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
|
||||
async def get_spend_by_tags(
|
||||
prisma_client: PrismaClient, start_date=None, end_date=None
|
||||
):
|
||||
response = await prisma_client.db.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue