Merge pull request #3600 from msabramo/msabramo/fix-pydantic-warnings

Update pydantic code to fix warnings
This commit is contained in:
Krish Dholakia 2024-05-13 22:00:39 -07:00 committed by GitHub
commit 2c867ea9a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 129 additions and 82 deletions

10
.git-blame-ignore-revs Normal file
View file

@ -0,0 +1,10 @@
# Add the commit hash of any commit you want to ignore in `git blame` here.
# One commit hash per line.
#
# The GitHub Blame UI will use this file automatically!
#
# Run this command to always ignore formatting commits in `git blame`
# git config blame.ignoreRevsFile .git-blame-ignore-revs
# Update pydantic code to fix warnings (GH-3600)
876840e9957bc7e9f7d6a2b58c4d7c53dad16481

View file

@ -1,11 +1,20 @@
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator
from dataclasses import fields
from pydantic import ConfigDict, BaseModel, Field, root_validator, Json
import enum
from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime
import uuid, json, sys, os
import uuid
import json
from litellm.types.router import UpdateRouterConfig
try:
from pydantic import model_validator # pydantic v2
except ImportError:
from pydantic import root_validator # pydantic v1
def model_validator(mode):
pre = mode == "before"
return root_validator(pre=pre)
def hash_token(token: str):
import hashlib
@ -35,8 +44,9 @@ class LiteLLMBase(BaseModel):
# if using pydantic v1
return self.__fields_set__
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
@ -229,7 +239,7 @@ class LiteLLMPromptInjectionParams(LiteLLMBase):
llm_api_system_prompt: Optional[str] = None
llm_api_fail_call_string: Optional[str] = None
@root_validator(pre=True)
@model_validator(mode="before")
def check_llm_api_params(cls, values):
llm_api_check = values.get("llm_api_check")
if llm_api_check is True:
@ -287,8 +297,9 @@ class ProxyChatCompletionRequest(LiteLLMBase):
deployment_id: Optional[str] = None
request_timeout: Optional[int] = None
class Config:
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
model_config = ConfigDict(
extra = "allow", # allow params not defined here, these fall in litellm.completion(**kwargs)
)
class ModelInfoDelete(LiteLLMBase):
@ -315,11 +326,12 @@ class ModelInfo(LiteLLMBase):
]
]
class Config:
extra = Extra.allow # Allow extra fields
protected_namespaces = ()
model_config = ConfigDict(
extra = "allow", # Allow extra fields
protected_namespaces = (),
)
@root_validator(pre=True)
@model_validator(mode="before")
def set_model_info(cls, values):
if values.get("id") is None:
values.update({"id": str(uuid.uuid4())})
@ -345,10 +357,11 @@ class ModelParams(LiteLLMBase):
litellm_params: dict
model_info: ModelInfo
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
@root_validator(pre=True)
@model_validator(mode="before")
def set_model_info(cls, values):
if values.get("model_info") is None:
values.update({"model_info": ModelInfo()})
@ -384,8 +397,9 @@ class GenerateKeyRequest(GenerateRequestBase):
{}
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class GenerateKeyResponse(GenerateKeyRequest):
@ -395,7 +409,7 @@ class GenerateKeyResponse(GenerateKeyRequest):
user_id: Optional[str] = None
token_id: Optional[str] = None
@root_validator(pre=True)
@model_validator(mode="before")
def set_model_info(cls, values):
if values.get("token") is not None:
values.update({"key": values.get("token")})
@ -435,8 +449,9 @@ class LiteLLM_ModelTable(LiteLLMBase):
created_by: str
updated_by: str
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class NewUserRequest(GenerateKeyRequest):
@ -464,7 +479,7 @@ class UpdateUserRequest(GenerateRequestBase):
user_role: Optional[str] = None
max_budget: Optional[float] = None
@root_validator(pre=True)
@model_validator(mode="before")
def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided")
@ -484,7 +499,7 @@ class NewEndUserRequest(LiteLLMBase):
None # if no equivalent model in allowed region - default all requests to this model
)
@root_validator(pre=True)
@model_validator(mode="before")
def check_user_info(cls, values):
if values.get("max_budget") is not None and values.get("budget_id") is not None:
raise ValueError("Set either 'max_budget' or 'budget_id', not both.")
@ -497,7 +512,7 @@ class Member(LiteLLMBase):
user_id: Optional[str] = None
user_email: Optional[str] = None
@root_validator(pre=True)
@model_validator(mode="before")
def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided")
@ -522,8 +537,9 @@ class TeamBase(LiteLLMBase):
class NewTeamRequest(TeamBase):
model_aliases: Optional[dict] = None
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class GlobalEndUsersSpend(LiteLLMBase):
@ -542,7 +558,7 @@ class TeamMemberDeleteRequest(LiteLLMBase):
user_id: Optional[str] = None
user_email: Optional[str] = None
@root_validator(pre=True)
@model_validator(mode="before")
def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided")
@ -576,10 +592,11 @@ class LiteLLM_TeamTable(TeamBase):
budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
@root_validator(pre=True)
@model_validator(mode="before")
def set_model_info(cls, values):
dict_fields = [
"metadata",
@ -615,8 +632,9 @@ class LiteLLM_BudgetTable(LiteLLMBase):
model_max_budget: Optional[dict] = None
budget_duration: Optional[str] = None
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class NewOrganizationRequest(LiteLLM_BudgetTable):
@ -666,8 +684,9 @@ class KeyManagementSettings(LiteLLMBase):
class TeamDefaultSettings(LiteLLMBase):
team_id: str
class Config:
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
model_config = ConfigDict(
extra = "allow", # allow params not defined here, these fall in litellm.completion(**kwargs)
)
class DynamoDBArgs(LiteLLMBase):
@ -808,8 +827,9 @@ class ConfigYAML(LiteLLMBase):
description="litellm router object settings. See router.py __init__ for all, example router.num_retries=5, router.timeout=5, router.max_retries=5, router.retry_after=5",
)
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class LiteLLM_VerificationToken(LiteLLMBase):
@ -843,8 +863,9 @@ class LiteLLM_VerificationToken(LiteLLMBase):
user_id_rate_limits: Optional[dict] = None
team_id_rate_limits: Optional[dict] = None
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
@ -874,7 +895,7 @@ class UserAPIKeyAuth(
user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None
allowed_model_region: Optional[Literal["eu"]] = None
@root_validator(pre=True)
@model_validator(mode="before")
def check_api_key(cls, values):
if values.get("api_key") is not None:
values.update({"token": hash_token(values.get("api_key"))})
@ -901,7 +922,7 @@ class LiteLLM_UserTable(LiteLLMBase):
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
@root_validator(pre=True)
@model_validator(mode="before")
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
@ -909,8 +930,9 @@ class LiteLLM_UserTable(LiteLLMBase):
values.update({"models": []})
return values
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class LiteLLM_EndUserTable(LiteLLMBase):
@ -922,14 +944,15 @@ class LiteLLM_EndUserTable(LiteLLMBase):
default_model: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
@root_validator(pre=True)
@model_validator(mode="before")
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
return values
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class LiteLLM_SpendLogs(LiteLLMBase):

View file

@ -5,6 +5,7 @@
import sys, os
import traceback
from dotenv import load_dotenv
from pydantic import ConfigDict
load_dotenv()
import os, io
@ -25,9 +26,7 @@ class DBModel(BaseModel):
model_name: str
model_info: dict
litellm_params: dict
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
@pytest.mark.asyncio

View file

@ -1,6 +1,6 @@
from typing import List, Optional, Union, Iterable
from pydantic import BaseModel, validator
from pydantic import ConfigDict, BaseModel, validator
from typing_extensions import Literal, Required, TypedDict
@ -190,7 +190,4 @@ class CompletionRequest(BaseModel):
api_version: Optional[str] = None
api_key: Optional[str] = None
model_list: Optional[List[str]] = None
class Config:
extra = "allow"
protected_namespaces = ()
model_config = ConfigDict(extra="allow", protected_namespaces=())

View file

@ -1,6 +1,6 @@
from typing import List, Optional, Union
from pydantic import BaseModel, validator
from pydantic import ConfigDict, BaseModel, validator
class EmbeddingRequest(BaseModel):
@ -17,7 +17,4 @@ class EmbeddingRequest(BaseModel):
litellm_call_id: Optional[str] = None
litellm_logging_obj: Optional[dict] = None
logger_fn: Optional[str] = None
class Config:
# allow kwargs
extra = "allow"
model_config = ConfigDict(extra="allow")

View file

@ -1,6 +1,6 @@
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
import httpx
from pydantic import BaseModel, validator, Field
from pydantic import ConfigDict, BaseModel, validator, Field, __version__ as pydantic_version
from .completion import CompletionRequest
from .embedding import EmbeddingRequest
import uuid, enum
@ -12,8 +12,9 @@ class ModelConfig(BaseModel):
tpm: int
rpm: int
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class RouterConfig(BaseModel):
@ -44,8 +45,9 @@ class RouterConfig(BaseModel):
"latency-based-routing",
] = "simple-shuffle"
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class UpdateRouterConfig(BaseModel):
@ -65,8 +67,9 @@ class UpdateRouterConfig(BaseModel):
fallbacks: Optional[List[dict]] = None
context_window_fallbacks: Optional[List[dict]] = None
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class ModelInfo(BaseModel):
@ -84,8 +87,9 @@ class ModelInfo(BaseModel):
id = str(id)
super().__init__(id=id, **params)
class Config:
extra = "allow"
model_config = ConfigDict(
extra = "allow",
)
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -180,8 +184,17 @@ class GenericLiteLLMParams(BaseModel):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params)
model_config = ConfigDict(
extra = "allow",
arbitrary_types_allowed = True,
)
if pydantic_version.startswith("1"):
# pydantic v2 warns about using a Config class.
# But without this, pydantic v1 will raise an error:
# RuntimeError: no validator found for <class 'openai.Timeout'>,
# see `arbitrary_types_allowed` in Config
# Putting arbitrary_types_allowed = True in the ConfigDict doesn't work in pydantic v1.
class Config:
extra = "allow"
arbitrary_types_allowed = True
def __contains__(self, key):
@ -241,8 +254,17 @@ class LiteLLM_Params(GenericLiteLLMParams):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params)
model_config = ConfigDict(
extra = "allow",
arbitrary_types_allowed = True,
)
if pydantic_version.startswith("1"):
# pydantic v2 warns about using a Config class.
# But without this, pydantic v1 will raise an error:
# RuntimeError: no validator found for <class 'openai.Timeout'>,
# see `arbitrary_types_allowed` in Config
# Putting arbitrary_types_allowed = True in the ConfigDict doesn't work in pydantic v1.
class Config:
extra = "allow"
arbitrary_types_allowed = True
def __contains__(self, key):
@ -273,8 +295,9 @@ class updateDeployment(BaseModel):
litellm_params: Optional[updateLiteLLMParams] = None
model_info: Optional[ModelInfo] = None
class Config:
protected_namespaces = ()
model_config = ConfigDict(
protected_namespaces = (),
)
class LiteLLMParamsTypedDict(TypedDict, total=False):
@ -348,9 +371,10 @@ class Deployment(BaseModel):
# if using pydantic v1
return self.dict(**kwargs)
class Config:
extra = "allow"
protected_namespaces = ()
model_config = ConfigDict(
extra = "allow",
protected_namespaces = (),
)
def __contains__(self, key):
# Define custom behavior for the 'in' operator

View file

@ -20,7 +20,7 @@ from functools import wraps
import datetime, time
import tiktoken
import uuid
from pydantic import BaseModel
from pydantic import ConfigDict, BaseModel
import aiohttp
import textwrap
import logging
@ -328,10 +328,7 @@ class HiddenParams(OpenAIObject):
original_response: Optional[str] = None
model_id: Optional[str] = None # used in Router for individual deployments
api_base: Optional[str] = None # returns api base used for making completion call
class Config:
extra = "allow"
protected_namespaces = ()
model_config = ConfigDict(extra="allow", protected_namespaces=())
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist