Merge pull request #3670 from lj-wego/fix-pydantic-warnings-again

Fix warnings from pydantic
This commit is contained in:
Krish Dholakia 2024-05-30 22:57:21 -07:00 committed by GitHub
commit deb87f71e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 67 additions and 91 deletions

View file

@ -1,4 +1,4 @@
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict
from dataclasses import fields from dataclasses import fields
import enum import enum
from typing import Optional, List, Union, Dict, Literal, Any from typing import Optional, List, Union, Dict, Literal, Any
@ -119,8 +119,7 @@ class LiteLLMBase(BaseModel):
# if using pydantic v1 # if using pydantic v1
return self.__fields_set__ return self.__fields_set__
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase): class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
@ -349,7 +348,8 @@ class LiteLLMPromptInjectionParams(LiteLLMBase):
description="Return rejected request error message as a string to the user. Default behaviour is to raise an exception.", description="Return rejected request error message as a string to the user. Default behaviour is to raise an exception.",
) )
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def check_llm_api_params(cls, values): def check_llm_api_params(cls, values):
llm_api_check = values.get("llm_api_check") llm_api_check = values.get("llm_api_check")
if llm_api_check is True: if llm_api_check is True:
@ -407,8 +407,7 @@ class ProxyChatCompletionRequest(LiteLLMBase):
deployment_id: Optional[str] = None deployment_id: Optional[str] = None
request_timeout: Optional[int] = None request_timeout: Optional[int] = None
class Config: model_config = ConfigDict(extra="allow") # allow params not defined here, these fall in litellm.completion(**kwargs)
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelInfoDelete(LiteLLMBase): class ModelInfoDelete(LiteLLMBase):
@ -435,11 +434,10 @@ class ModelInfo(LiteLLMBase):
] ]
] ]
class Config: model_config = ConfigDict(protected_namespaces=(), extra="allow")
extra = Extra.allow # Allow extra fields
protected_namespaces = ()
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("id") is None: if values.get("id") is None:
values.update({"id": str(uuid.uuid4())}) values.update({"id": str(uuid.uuid4())})
@ -470,10 +468,10 @@ class ModelParams(LiteLLMBase):
litellm_params: dict litellm_params: dict
model_info: ModelInfo model_info: ModelInfo
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("model_info") is None: if values.get("model_info") is None:
values.update({"model_info": ModelInfo()}) values.update({"model_info": ModelInfo()})
@ -509,8 +507,7 @@ class GenerateKeyRequest(GenerateRequestBase):
{} {}
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class GenerateKeyResponse(GenerateKeyRequest): class GenerateKeyResponse(GenerateKeyRequest):
@ -520,7 +517,8 @@ class GenerateKeyResponse(GenerateKeyRequest):
user_id: Optional[str] = None user_id: Optional[str] = None
token_id: Optional[str] = None token_id: Optional[str] = None
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("token") is not None: if values.get("token") is not None:
values.update({"key": values.get("token")}) values.update({"key": values.get("token")})
@ -560,8 +558,7 @@ class LiteLLM_ModelTable(LiteLLMBase):
created_by: str created_by: str
updated_by: str updated_by: str
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class NewUserRequest(GenerateKeyRequest): class NewUserRequest(GenerateKeyRequest):
@ -607,7 +604,8 @@ class UpdateUserRequest(GenerateRequestBase):
] = None ] = None
max_budget: Optional[float] = None max_budget: Optional[float] = None
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def check_user_info(cls, values): def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None: if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided") raise ValueError("Either user id or user email must be provided")
@ -631,7 +629,8 @@ class NewCustomerRequest(LiteLLMBase):
None # if no equivalent model in allowed region - default all requests to this model None # if no equivalent model in allowed region - default all requests to this model
) )
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def check_user_info(cls, values): def check_user_info(cls, values):
if values.get("max_budget") is not None and values.get("budget_id") is not None: if values.get("max_budget") is not None and values.get("budget_id") is not None:
raise ValueError("Set either 'max_budget' or 'budget_id', not both.") raise ValueError("Set either 'max_budget' or 'budget_id', not both.")
@ -671,7 +670,8 @@ class Member(LiteLLMBase):
user_id: Optional[str] = None user_id: Optional[str] = None
user_email: Optional[str] = None user_email: Optional[str] = None
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def check_user_info(cls, values): def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None: if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided") raise ValueError("Either user id or user email must be provided")
@ -700,8 +700,7 @@ class TeamBase(LiteLLMBase):
class NewTeamRequest(TeamBase): class NewTeamRequest(TeamBase):
model_aliases: Optional[dict] = None model_aliases: Optional[dict] = None
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class GlobalEndUsersSpend(LiteLLMBase): class GlobalEndUsersSpend(LiteLLMBase):
@ -721,7 +720,8 @@ class TeamMemberDeleteRequest(LiteLLMBase):
user_id: Optional[str] = None user_id: Optional[str] = None
user_email: Optional[str] = None user_email: Optional[str] = None
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def check_user_info(cls, values): def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None: if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided") raise ValueError("Either user id or user email must be provided")
@ -787,10 +787,10 @@ class LiteLLM_TeamTable(TeamBase):
budget_reset_at: Optional[datetime] = None budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None model_id: Optional[int] = None
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def set_model_info(cls, values): def set_model_info(cls, values):
dict_fields = [ dict_fields = [
"metadata", "metadata",
@ -826,8 +826,7 @@ class LiteLLM_BudgetTable(LiteLLMBase):
model_max_budget: Optional[dict] = None model_max_budget: Optional[dict] = None
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable): class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
@ -840,8 +839,7 @@ class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
team_id: Optional[str] = None team_id: Optional[str] = None
budget_id: Optional[str] = None budget_id: Optional[str] = None
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class NewOrganizationRequest(LiteLLM_BudgetTable): class NewOrganizationRequest(LiteLLM_BudgetTable):
@ -920,8 +918,7 @@ class KeyManagementSettings(LiteLLMBase):
class TeamDefaultSettings(LiteLLMBase): class TeamDefaultSettings(LiteLLMBase):
team_id: str team_id: str
class Config: model_config = ConfigDict(extra="allow") # allow params not defined here, these fall in litellm.completion(**kwargs)
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
class DynamoDBArgs(LiteLLMBase): class DynamoDBArgs(LiteLLMBase):
@ -1083,8 +1080,7 @@ class ConfigYAML(LiteLLMBase):
description="litellm router object settings. See router.py __init__ for all, example router.num_retries=5, router.timeout=5, router.max_retries=5, router.retry_after=5", description="litellm router object settings. See router.py __init__ for all, example router.num_retries=5, router.timeout=5, router.max_retries=5, router.retry_after=5",
) )
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class LiteLLM_VerificationToken(LiteLLMBase): class LiteLLM_VerificationToken(LiteLLMBase):
@ -1114,9 +1110,7 @@ class LiteLLM_VerificationToken(LiteLLMBase):
org_id: Optional[str] = None # org id for a given key org_id: Optional[str] = None # org id for a given key
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
""" """
@ -1161,7 +1155,8 @@ class UserAPIKeyAuth(
] = None ] = None
allowed_model_region: Optional[Literal["eu"]] = None allowed_model_region: Optional[Literal["eu"]] = None
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def check_api_key(cls, values): def check_api_key(cls, values):
if values.get("api_key") is not None: if values.get("api_key") is not None:
values.update({"token": hash_token(values.get("api_key"))}) values.update({"token": hash_token(values.get("api_key"))})
@ -1188,7 +1183,8 @@ class LiteLLM_UserTable(LiteLLMBase):
tpm_limit: Optional[int] = None tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None rpm_limit: Optional[int] = None
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("spend") is None: if values.get("spend") is None:
values.update({"spend": 0.0}) values.update({"spend": 0.0})
@ -1196,8 +1192,7 @@ class LiteLLM_UserTable(LiteLLMBase):
values.update({"models": []}) values.update({"models": []})
return values return values
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class LiteLLM_EndUserTable(LiteLLMBase): class LiteLLM_EndUserTable(LiteLLMBase):
@ -1209,14 +1204,14 @@ class LiteLLM_EndUserTable(LiteLLMBase):
default_model: Optional[str] = None default_model: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
@root_validator(pre=True) @model_validator(mode="before")
@classmethod
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("spend") is None: if values.get("spend") is None:
values.update({"spend": 0.0}) values.update({"spend": 0.0})
return values return values
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class LiteLLM_SpendLogs(LiteLLMBase): class LiteLLM_SpendLogs(LiteLLMBase):

View file

@ -13,7 +13,7 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the, system path ) # Adds the parent directory to the, system path
import pytest, litellm import pytest, litellm
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from litellm.proxy.proxy_server import ProxyConfig from litellm.proxy.proxy_server import ProxyConfig
from litellm.proxy.utils import encrypt_value, ProxyLogging, DualCache from litellm.proxy.utils import encrypt_value, ProxyLogging, DualCache
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
@ -26,8 +26,7 @@ class DBModel(BaseModel):
model_info: dict model_info: dict
litellm_params: dict litellm_params: dict
class Config: config_dict: ConfigDict = ConfigDict(protected_namespaces=())
protected_namespaces = ()
@pytest.mark.asyncio @pytest.mark.asyncio

View file

@ -1,6 +1,6 @@
from typing import List, Optional, Union, Iterable from typing import List, Optional, Union, Iterable
from pydantic import BaseModel, validator from pydantic import BaseModel, ConfigDict, validator
from typing_extensions import Literal, Required, TypedDict from typing_extensions import Literal, Required, TypedDict
@ -191,6 +191,4 @@ class CompletionRequest(BaseModel):
api_key: Optional[str] = None api_key: Optional[str] = None
model_list: Optional[List[str]] = None model_list: Optional[List[str]] = None
class Config: model_config = ConfigDict(protected_namespaces=(), extra="allow")
extra = "allow"
protected_namespaces = ()

View file

@ -1,6 +1,6 @@
from typing import List, Optional, Union from typing import List, Optional, Union
from pydantic import BaseModel, validator from pydantic import BaseModel, ConfigDict
class EmbeddingRequest(BaseModel): class EmbeddingRequest(BaseModel):
@ -18,6 +18,4 @@ class EmbeddingRequest(BaseModel):
litellm_logging_obj: Optional[dict] = None litellm_logging_obj: Optional[dict] = None
logger_fn: Optional[str] = None logger_fn: Optional[str] = None
class Config: model_config = ConfigDict(extra="allow")
# allow kwargs
extra = "allow"

View file

@ -1,12 +1,12 @@
""" """
litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
""" """
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
import uuid import uuid
import enum import enum
import httpx import httpx
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
import datetime import datetime
from .completion import CompletionRequest from .completion import CompletionRequest
from .embedding import EmbeddingRequest from .embedding import EmbeddingRequest
@ -18,8 +18,7 @@ class ModelConfig(BaseModel):
tpm: int tpm: int
rpm: int rpm: int
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class RouterConfig(BaseModel): class RouterConfig(BaseModel):
@ -50,8 +49,7 @@ class RouterConfig(BaseModel):
"latency-based-routing", "latency-based-routing",
] = "simple-shuffle" ] = "simple-shuffle"
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class UpdateRouterConfig(BaseModel): class UpdateRouterConfig(BaseModel):
@ -71,17 +69,14 @@ class UpdateRouterConfig(BaseModel):
fallbacks: Optional[List[dict]] = None fallbacks: Optional[List[dict]] = None
context_window_fallbacks: Optional[List[dict]] = None context_window_fallbacks: Optional[List[dict]] = None
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
id: Optional[ id: Optional[
str str
] # Allow id to be optional on input, but it will always be present as a str in the model instance ] # Allow id to be optional on input, but it will always be present as a str in the model instance
db_model: bool = ( db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config.
False # used for proxy - to separate models which are stored in the db vs. config.
)
updated_at: Optional[datetime.datetime] = None updated_at: Optional[datetime.datetime] = None
updated_by: Optional[str] = None updated_by: Optional[str] = None
@ -99,8 +94,7 @@ class ModelInfo(BaseModel):
id = str(id) id = str(id)
super().__init__(id=id, **params) super().__init__(id=id, **params)
class Config: model_config = ConfigDict(extra="allow")
extra = "allow"
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator
@ -155,6 +149,8 @@ class GenericLiteLLMParams(BaseModel):
input_cost_per_second: Optional[float] = None input_cost_per_second: Optional[float] = None
output_cost_per_second: Optional[float] = None output_cost_per_second: Optional[float] = None
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
def __init__( def __init__(
self, self,
custom_llm_provider: Optional[str] = None, custom_llm_provider: Optional[str] = None,
@ -184,7 +180,7 @@ class GenericLiteLLMParams(BaseModel):
output_cost_per_token: Optional[float] = None, output_cost_per_token: Optional[float] = None,
input_cost_per_second: Optional[float] = None, input_cost_per_second: Optional[float] = None,
output_cost_per_second: Optional[float] = None, output_cost_per_second: Optional[float] = None,
**params **params,
): ):
args = locals() args = locals()
args.pop("max_retries", None) args.pop("max_retries", None)
@ -195,10 +191,6 @@ class GenericLiteLLMParams(BaseModel):
max_retries = int(max_retries) # cast to int max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params) super().__init__(max_retries=max_retries, **args, **params)
class Config:
extra = "allow"
arbitrary_types_allowed = True
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator
return hasattr(self, key) return hasattr(self, key)
@ -222,6 +214,7 @@ class LiteLLM_Params(GenericLiteLLMParams):
""" """
model: str model: str
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
def __init__( def __init__(
self, self,
@ -245,7 +238,7 @@ class LiteLLM_Params(GenericLiteLLMParams):
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None, aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None, aws_region_name: Optional[str] = None,
**params **params,
): ):
args = locals() args = locals()
args.pop("max_retries", None) args.pop("max_retries", None)
@ -256,10 +249,6 @@ class LiteLLM_Params(GenericLiteLLMParams):
max_retries = int(max_retries) # cast to int max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params) super().__init__(max_retries=max_retries, **args, **params)
class Config:
extra = "allow"
arbitrary_types_allowed = True
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator
return hasattr(self, key) return hasattr(self, key)
@ -288,8 +277,7 @@ class updateDeployment(BaseModel):
litellm_params: Optional[updateLiteLLMParams] = None litellm_params: Optional[updateLiteLLMParams] = None
model_info: Optional[ModelInfo] = None model_info: Optional[ModelInfo] = None
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class LiteLLMParamsTypedDict(TypedDict, total=False): class LiteLLMParamsTypedDict(TypedDict, total=False):
@ -338,12 +326,14 @@ class Deployment(BaseModel):
litellm_params: LiteLLM_Params litellm_params: LiteLLM_Params
model_info: ModelInfo model_info: ModelInfo
model_config = ConfigDict(extra="allow", protected_namespaces=())
def __init__( def __init__(
self, self,
model_name: str, model_name: str,
litellm_params: LiteLLM_Params, litellm_params: LiteLLM_Params,
model_info: Optional[Union[ModelInfo, dict]] = None, model_info: Optional[Union[ModelInfo, dict]] = None,
**params **params,
): ):
if model_info is None: if model_info is None:
model_info = ModelInfo() model_info = ModelInfo()
@ -353,7 +343,7 @@ class Deployment(BaseModel):
model_info=model_info, model_info=model_info,
model_name=model_name, model_name=model_name,
litellm_params=litellm_params, litellm_params=litellm_params,
**params **params,
) )
def to_json(self, **kwargs): def to_json(self, **kwargs):
@ -363,10 +353,6 @@ class Deployment(BaseModel):
# if using pydantic v1 # if using pydantic v1
return self.dict(**kwargs) return self.dict(**kwargs)
class Config:
extra = "allow"
protected_namespaces = ()
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator
return hasattr(self, key) return hasattr(self, key)

View file

@ -18,7 +18,7 @@ from functools import wraps, lru_cache
import datetime, time import datetime, time
import tiktoken import tiktoken
import uuid import uuid
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
import aiohttp import aiohttp
import textwrap import textwrap
import logging import logging
@ -337,9 +337,7 @@ class HiddenParams(OpenAIObject):
model_id: Optional[str] = None # used in Router for individual deployments model_id: Optional[str] = None # used in Router for individual deployments
api_base: Optional[str] = None # returns api base used for making completion call api_base: Optional[str] = None # returns api base used for making completion call
class Config: model_config = ConfigDict(extra="allow", protected_namespaces=())
extra = "allow"
protected_namespaces = ()
def get(self, key, default=None): def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist # Custom .get() method to access attributes with a default value if the attribute doesn't exist

View file

@ -84,3 +84,5 @@ version_files = [
"pyproject.toml:^version" "pyproject.toml:^version"
] ]
[tool.mypy]
plugins = "pydantic.mypy"