Merge pull request #3600 from msabramo/msabramo/fix-pydantic-warnings

Update pydantic code to fix warnings
This commit is contained in:
Krish Dholakia 2024-05-13 22:00:39 -07:00 committed by GitHub
commit 2c867ea9a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 129 additions and 82 deletions

10
.git-blame-ignore-revs Normal file
View file

@ -0,0 +1,10 @@
# Add the commit hash of any commit you want to ignore in `git blame` here.
# One commit hash per line.
#
# The GitHub Blame UI will use this file automatically!
#
# Run this command to always ignore formatting commits in `git blame`
# git config blame.ignoreRevsFile .git-blame-ignore-revs
# Update pydantic code to fix warnings (GH-3600)
876840e9957bc7e9f7d6a2b58c4d7c53dad16481

View file

@ -1,11 +1,20 @@
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator from pydantic import ConfigDict, BaseModel, Field, root_validator, Json
from dataclasses import fields
import enum import enum
from typing import Optional, List, Union, Dict, Literal, Any from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime from datetime import datetime
import uuid, json, sys, os import uuid
import json
from litellm.types.router import UpdateRouterConfig from litellm.types.router import UpdateRouterConfig
try:
from pydantic import model_validator # pydantic v2
except ImportError:
from pydantic import root_validator # pydantic v1
def model_validator(mode):
pre = mode == "before"
return root_validator(pre=pre)
def hash_token(token: str): def hash_token(token: str):
import hashlib import hashlib
@ -35,8 +44,9 @@ class LiteLLMBase(BaseModel):
# if using pydantic v1 # if using pydantic v1
return self.__fields_set__ return self.__fields_set__
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase): class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
@ -229,7 +239,7 @@ class LiteLLMPromptInjectionParams(LiteLLMBase):
llm_api_system_prompt: Optional[str] = None llm_api_system_prompt: Optional[str] = None
llm_api_fail_call_string: Optional[str] = None llm_api_fail_call_string: Optional[str] = None
@root_validator(pre=True) @model_validator(mode="before")
def check_llm_api_params(cls, values): def check_llm_api_params(cls, values):
llm_api_check = values.get("llm_api_check") llm_api_check = values.get("llm_api_check")
if llm_api_check is True: if llm_api_check is True:
@ -287,8 +297,9 @@ class ProxyChatCompletionRequest(LiteLLMBase):
deployment_id: Optional[str] = None deployment_id: Optional[str] = None
request_timeout: Optional[int] = None request_timeout: Optional[int] = None
class Config: model_config = ConfigDict(
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs) extra = "allow", # allow params not defined here, these fall in litellm.completion(**kwargs)
)
class ModelInfoDelete(LiteLLMBase): class ModelInfoDelete(LiteLLMBase):
@ -315,11 +326,12 @@ class ModelInfo(LiteLLMBase):
] ]
] ]
class Config: model_config = ConfigDict(
extra = Extra.allow # Allow extra fields extra = "allow", # Allow extra fields
protected_namespaces = () protected_namespaces = (),
)
@root_validator(pre=True) @model_validator(mode="before")
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("id") is None: if values.get("id") is None:
values.update({"id": str(uuid.uuid4())}) values.update({"id": str(uuid.uuid4())})
@ -345,10 +357,11 @@ class ModelParams(LiteLLMBase):
litellm_params: dict litellm_params: dict
model_info: ModelInfo model_info: ModelInfo
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
@root_validator(pre=True) @model_validator(mode="before")
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("model_info") is None: if values.get("model_info") is None:
values.update({"model_info": ModelInfo()}) values.update({"model_info": ModelInfo()})
@ -384,8 +397,9 @@ class GenerateKeyRequest(GenerateRequestBase):
{} {}
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class GenerateKeyResponse(GenerateKeyRequest): class GenerateKeyResponse(GenerateKeyRequest):
@ -395,7 +409,7 @@ class GenerateKeyResponse(GenerateKeyRequest):
user_id: Optional[str] = None user_id: Optional[str] = None
token_id: Optional[str] = None token_id: Optional[str] = None
@root_validator(pre=True) @model_validator(mode="before")
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("token") is not None: if values.get("token") is not None:
values.update({"key": values.get("token")}) values.update({"key": values.get("token")})
@ -435,8 +449,9 @@ class LiteLLM_ModelTable(LiteLLMBase):
created_by: str created_by: str
updated_by: str updated_by: str
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class NewUserRequest(GenerateKeyRequest): class NewUserRequest(GenerateKeyRequest):
@ -464,7 +479,7 @@ class UpdateUserRequest(GenerateRequestBase):
user_role: Optional[str] = None user_role: Optional[str] = None
max_budget: Optional[float] = None max_budget: Optional[float] = None
@root_validator(pre=True) @model_validator(mode="before")
def check_user_info(cls, values): def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None: if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided") raise ValueError("Either user id or user email must be provided")
@ -484,7 +499,7 @@ class NewEndUserRequest(LiteLLMBase):
None # if no equivalent model in allowed region - default all requests to this model None # if no equivalent model in allowed region - default all requests to this model
) )
@root_validator(pre=True) @model_validator(mode="before")
def check_user_info(cls, values): def check_user_info(cls, values):
if values.get("max_budget") is not None and values.get("budget_id") is not None: if values.get("max_budget") is not None and values.get("budget_id") is not None:
raise ValueError("Set either 'max_budget' or 'budget_id', not both.") raise ValueError("Set either 'max_budget' or 'budget_id', not both.")
@ -497,7 +512,7 @@ class Member(LiteLLMBase):
user_id: Optional[str] = None user_id: Optional[str] = None
user_email: Optional[str] = None user_email: Optional[str] = None
@root_validator(pre=True) @model_validator(mode="before")
def check_user_info(cls, values): def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None: if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided") raise ValueError("Either user id or user email must be provided")
@ -522,8 +537,9 @@ class TeamBase(LiteLLMBase):
class NewTeamRequest(TeamBase): class NewTeamRequest(TeamBase):
model_aliases: Optional[dict] = None model_aliases: Optional[dict] = None
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class GlobalEndUsersSpend(LiteLLMBase): class GlobalEndUsersSpend(LiteLLMBase):
@ -542,7 +558,7 @@ class TeamMemberDeleteRequest(LiteLLMBase):
user_id: Optional[str] = None user_id: Optional[str] = None
user_email: Optional[str] = None user_email: Optional[str] = None
@root_validator(pre=True) @model_validator(mode="before")
def check_user_info(cls, values): def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None: if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided") raise ValueError("Either user id or user email must be provided")
@ -576,10 +592,11 @@ class LiteLLM_TeamTable(TeamBase):
budget_reset_at: Optional[datetime] = None budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None model_id: Optional[int] = None
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
@root_validator(pre=True) @model_validator(mode="before")
def set_model_info(cls, values): def set_model_info(cls, values):
dict_fields = [ dict_fields = [
"metadata", "metadata",
@ -615,8 +632,9 @@ class LiteLLM_BudgetTable(LiteLLMBase):
model_max_budget: Optional[dict] = None model_max_budget: Optional[dict] = None
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class NewOrganizationRequest(LiteLLM_BudgetTable): class NewOrganizationRequest(LiteLLM_BudgetTable):
@ -666,8 +684,9 @@ class KeyManagementSettings(LiteLLMBase):
class TeamDefaultSettings(LiteLLMBase): class TeamDefaultSettings(LiteLLMBase):
team_id: str team_id: str
class Config: model_config = ConfigDict(
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs) extra = "allow", # allow params not defined here, these fall in litellm.completion(**kwargs)
)
class DynamoDBArgs(LiteLLMBase): class DynamoDBArgs(LiteLLMBase):
@ -808,8 +827,9 @@ class ConfigYAML(LiteLLMBase):
description="litellm router object settings. See router.py __init__ for all, example router.num_retries=5, router.timeout=5, router.max_retries=5, router.retry_after=5", description="litellm router object settings. See router.py __init__ for all, example router.num_retries=5, router.timeout=5, router.max_retries=5, router.retry_after=5",
) )
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class LiteLLM_VerificationToken(LiteLLMBase): class LiteLLM_VerificationToken(LiteLLMBase):
@ -843,8 +863,9 @@ class LiteLLM_VerificationToken(LiteLLMBase):
user_id_rate_limits: Optional[dict] = None user_id_rate_limits: Optional[dict] = None
team_id_rate_limits: Optional[dict] = None team_id_rate_limits: Optional[dict] = None
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
@ -874,7 +895,7 @@ class UserAPIKeyAuth(
user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None
allowed_model_region: Optional[Literal["eu"]] = None allowed_model_region: Optional[Literal["eu"]] = None
@root_validator(pre=True) @model_validator(mode="before")
def check_api_key(cls, values): def check_api_key(cls, values):
if values.get("api_key") is not None: if values.get("api_key") is not None:
values.update({"token": hash_token(values.get("api_key"))}) values.update({"token": hash_token(values.get("api_key"))})
@ -901,7 +922,7 @@ class LiteLLM_UserTable(LiteLLMBase):
tpm_limit: Optional[int] = None tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None rpm_limit: Optional[int] = None
@root_validator(pre=True) @model_validator(mode="before")
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("spend") is None: if values.get("spend") is None:
values.update({"spend": 0.0}) values.update({"spend": 0.0})
@ -909,8 +930,9 @@ class LiteLLM_UserTable(LiteLLMBase):
values.update({"models": []}) values.update({"models": []})
return values return values
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class LiteLLM_EndUserTable(LiteLLMBase): class LiteLLM_EndUserTable(LiteLLMBase):
@ -922,14 +944,15 @@ class LiteLLM_EndUserTable(LiteLLMBase):
default_model: Optional[str] = None default_model: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
@root_validator(pre=True) @model_validator(mode="before")
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("spend") is None: if values.get("spend") is None:
values.update({"spend": 0.0}) values.update({"spend": 0.0})
return values return values
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class LiteLLM_SpendLogs(LiteLLMBase): class LiteLLM_SpendLogs(LiteLLMBase):

View file

@ -5,6 +5,7 @@
import sys, os import sys, os
import traceback import traceback
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import ConfigDict
load_dotenv() load_dotenv()
import os, io import os, io
@ -25,9 +26,7 @@ class DBModel(BaseModel):
model_name: str model_name: str
model_info: dict model_info: dict
litellm_params: dict litellm_params: dict
model_config = ConfigDict(protected_namespaces=())
class Config:
protected_namespaces = ()
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -1,6 +1,6 @@
from typing import List, Optional, Union, Iterable from typing import List, Optional, Union, Iterable
from pydantic import BaseModel, validator from pydantic import ConfigDict, BaseModel, validator
from typing_extensions import Literal, Required, TypedDict from typing_extensions import Literal, Required, TypedDict
@ -190,7 +190,4 @@ class CompletionRequest(BaseModel):
api_version: Optional[str] = None api_version: Optional[str] = None
api_key: Optional[str] = None api_key: Optional[str] = None
model_list: Optional[List[str]] = None model_list: Optional[List[str]] = None
model_config = ConfigDict(extra="allow", protected_namespaces=())
class Config:
extra = "allow"
protected_namespaces = ()

View file

@ -1,6 +1,6 @@
from typing import List, Optional, Union from typing import List, Optional, Union
from pydantic import BaseModel, validator from pydantic import ConfigDict, BaseModel, validator
class EmbeddingRequest(BaseModel): class EmbeddingRequest(BaseModel):
@ -17,7 +17,4 @@ class EmbeddingRequest(BaseModel):
litellm_call_id: Optional[str] = None litellm_call_id: Optional[str] = None
litellm_logging_obj: Optional[dict] = None litellm_logging_obj: Optional[dict] = None
logger_fn: Optional[str] = None logger_fn: Optional[str] = None
model_config = ConfigDict(extra="allow")
class Config:
# allow kwargs
extra = "allow"

View file

@ -1,6 +1,6 @@
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
import httpx import httpx
from pydantic import BaseModel, validator, Field from pydantic import ConfigDict, BaseModel, validator, Field, __version__ as pydantic_version
from .completion import CompletionRequest from .completion import CompletionRequest
from .embedding import EmbeddingRequest from .embedding import EmbeddingRequest
import uuid, enum import uuid, enum
@ -12,8 +12,9 @@ class ModelConfig(BaseModel):
tpm: int tpm: int
rpm: int rpm: int
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class RouterConfig(BaseModel): class RouterConfig(BaseModel):
@ -44,8 +45,9 @@ class RouterConfig(BaseModel):
"latency-based-routing", "latency-based-routing",
] = "simple-shuffle" ] = "simple-shuffle"
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class UpdateRouterConfig(BaseModel): class UpdateRouterConfig(BaseModel):
@ -65,8 +67,9 @@ class UpdateRouterConfig(BaseModel):
fallbacks: Optional[List[dict]] = None fallbacks: Optional[List[dict]] = None
context_window_fallbacks: Optional[List[dict]] = None context_window_fallbacks: Optional[List[dict]] = None
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
@ -84,8 +87,9 @@ class ModelInfo(BaseModel):
id = str(id) id = str(id)
super().__init__(id=id, **params) super().__init__(id=id, **params)
class Config: model_config = ConfigDict(
extra = "allow" extra = "allow",
)
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator
@ -180,8 +184,17 @@ class GenericLiteLLMParams(BaseModel):
max_retries = int(max_retries) # cast to int max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params) super().__init__(max_retries=max_retries, **args, **params)
model_config = ConfigDict(
extra = "allow",
arbitrary_types_allowed = True,
)
if pydantic_version.startswith("1"):
# pydantic v2 warns about using a Config class.
# But without this, pydantic v1 will raise an error:
# RuntimeError: no validator found for <class 'openai.Timeout'>,
# see `arbitrary_types_allowed` in Config
# Putting arbitrary_types_allowed = True in the ConfigDict doesn't work in pydantic v1.
class Config: class Config:
extra = "allow"
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __contains__(self, key): def __contains__(self, key):
@ -241,8 +254,17 @@ class LiteLLM_Params(GenericLiteLLMParams):
max_retries = int(max_retries) # cast to int max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params) super().__init__(max_retries=max_retries, **args, **params)
model_config = ConfigDict(
extra = "allow",
arbitrary_types_allowed = True,
)
if pydantic_version.startswith("1"):
# pydantic v2 warns about using a Config class.
# But without this, pydantic v1 will raise an error:
# RuntimeError: no validator found for <class 'openai.Timeout'>,
# see `arbitrary_types_allowed` in Config
# Putting arbitrary_types_allowed = True in the ConfigDict doesn't work in pydantic v1.
class Config: class Config:
extra = "allow"
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __contains__(self, key): def __contains__(self, key):
@ -273,8 +295,9 @@ class updateDeployment(BaseModel):
litellm_params: Optional[updateLiteLLMParams] = None litellm_params: Optional[updateLiteLLMParams] = None
model_info: Optional[ModelInfo] = None model_info: Optional[ModelInfo] = None
class Config: model_config = ConfigDict(
protected_namespaces = () protected_namespaces = (),
)
class LiteLLMParamsTypedDict(TypedDict, total=False): class LiteLLMParamsTypedDict(TypedDict, total=False):
@ -348,9 +371,10 @@ class Deployment(BaseModel):
# if using pydantic v1 # if using pydantic v1
return self.dict(**kwargs) return self.dict(**kwargs)
class Config: model_config = ConfigDict(
extra = "allow" extra = "allow",
protected_namespaces = () protected_namespaces = (),
)
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator

View file

@ -20,7 +20,7 @@ from functools import wraps
import datetime, time import datetime, time
import tiktoken import tiktoken
import uuid import uuid
from pydantic import BaseModel from pydantic import ConfigDict, BaseModel
import aiohttp import aiohttp
import textwrap import textwrap
import logging import logging
@ -328,10 +328,7 @@ class HiddenParams(OpenAIObject):
original_response: Optional[str] = None original_response: Optional[str] = None
model_id: Optional[str] = None # used in Router for individual deployments model_id: Optional[str] = None # used in Router for individual deployments
api_base: Optional[str] = None # returns api base used for making completion call api_base: Optional[str] = None # returns api base used for making completion call
model_config = ConfigDict(extra="allow", protected_namespaces=())
class Config:
extra = "allow"
protected_namespaces = ()
def get(self, key, default=None): def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist # Custom .get() method to access attributes with a default value if the attribute doesn't exist