Merge pull request #3670 from lj-wego/fix-pydantic-warnings-again

Fix warnings from pydantic
This commit is contained in:
Krish Dholakia 2024-05-30 22:57:21 -07:00 committed by GitHub
commit deb87f71e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 67 additions and 91 deletions

View file

@ -1,4 +1,4 @@
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator
from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict
from dataclasses import fields
import enum
from typing import Optional, List, Union, Dict, Literal, Any
@ -119,8 +119,7 @@ class LiteLLMBase(BaseModel):
# if using pydantic v1
return self.__fields_set__
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
@ -349,7 +348,8 @@ class LiteLLMPromptInjectionParams(LiteLLMBase):
description="Return rejected request error message as a string to the user. Default behaviour is to raise an exception.",
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_llm_api_params(cls, values):
llm_api_check = values.get("llm_api_check")
if llm_api_check is True:
@ -407,8 +407,7 @@ class ProxyChatCompletionRequest(LiteLLMBase):
deployment_id: Optional[str] = None
request_timeout: Optional[int] = None
class Config:
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
model_config = ConfigDict(extra="allow") # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelInfoDelete(LiteLLMBase):
@ -435,11 +434,10 @@ class ModelInfo(LiteLLMBase):
]
]
class Config:
extra = Extra.allow # Allow extra fields
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=(), extra="allow")
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
if values.get("id") is None:
values.update({"id": str(uuid.uuid4())})
@ -470,10 +468,10 @@ class ModelParams(LiteLLMBase):
litellm_params: dict
model_info: ModelInfo
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
if values.get("model_info") is None:
values.update({"model_info": ModelInfo()})
@ -509,8 +507,7 @@ class GenerateKeyRequest(GenerateRequestBase):
{}
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class GenerateKeyResponse(GenerateKeyRequest):
@ -520,7 +517,8 @@ class GenerateKeyResponse(GenerateKeyRequest):
user_id: Optional[str] = None
token_id: Optional[str] = None
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
if values.get("token") is not None:
values.update({"key": values.get("token")})
@ -560,8 +558,7 @@ class LiteLLM_ModelTable(LiteLLMBase):
created_by: str
updated_by: str
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class NewUserRequest(GenerateKeyRequest):
@ -607,7 +604,8 @@ class UpdateUserRequest(GenerateRequestBase):
] = None
max_budget: Optional[float] = None
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided")
@ -631,7 +629,8 @@ class NewCustomerRequest(LiteLLMBase):
None # if no equivalent model in allowed region - default all requests to this model
)
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_user_info(cls, values):
if values.get("max_budget") is not None and values.get("budget_id") is not None:
raise ValueError("Set either 'max_budget' or 'budget_id', not both.")
@ -671,7 +670,8 @@ class Member(LiteLLMBase):
user_id: Optional[str] = None
user_email: Optional[str] = None
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided")
@ -700,8 +700,7 @@ class TeamBase(LiteLLMBase):
class NewTeamRequest(TeamBase):
model_aliases: Optional[dict] = None
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class GlobalEndUsersSpend(LiteLLMBase):
@ -721,7 +720,8 @@ class TeamMemberDeleteRequest(LiteLLMBase):
user_id: Optional[str] = None
user_email: Optional[str] = None
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_user_info(cls, values):
if values.get("user_id") is None and values.get("user_email") is None:
raise ValueError("Either user id or user email must be provided")
@ -787,10 +787,10 @@ class LiteLLM_TeamTable(TeamBase):
budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
dict_fields = [
"metadata",
@ -826,8 +826,7 @@ class LiteLLM_BudgetTable(LiteLLMBase):
model_max_budget: Optional[dict] = None
budget_duration: Optional[str] = None
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
@ -840,8 +839,7 @@ class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable):
team_id: Optional[str] = None
budget_id: Optional[str] = None
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class NewOrganizationRequest(LiteLLM_BudgetTable):
@ -920,8 +918,7 @@ class KeyManagementSettings(LiteLLMBase):
class TeamDefaultSettings(LiteLLMBase):
team_id: str
class Config:
extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs)
model_config = ConfigDict(extra="allow") # allow params not defined here, these fall in litellm.completion(**kwargs)
class DynamoDBArgs(LiteLLMBase):
@ -1083,8 +1080,7 @@ class ConfigYAML(LiteLLMBase):
description="litellm router object settings. See router.py __init__ for all, example router.num_retries=5, router.timeout=5, router.max_retries=5, router.retry_after=5",
)
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_VerificationToken(LiteLLMBase):
@ -1114,9 +1110,7 @@ class LiteLLM_VerificationToken(LiteLLMBase):
org_id: Optional[str] = None # org id for a given key
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
"""
@ -1161,7 +1155,8 @@ class UserAPIKeyAuth(
] = None
allowed_model_region: Optional[Literal["eu"]] = None
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_api_key(cls, values):
if values.get("api_key") is not None:
values.update({"token": hash_token(values.get("api_key"))})
@ -1188,7 +1183,8 @@ class LiteLLM_UserTable(LiteLLMBase):
tpm_limit: Optional[int] = None
rpm_limit: Optional[int] = None
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
@ -1196,8 +1192,7 @@ class LiteLLM_UserTable(LiteLLMBase):
values.update({"models": []})
return values
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_EndUserTable(LiteLLMBase):
@ -1209,14 +1204,14 @@ class LiteLLM_EndUserTable(LiteLLMBase):
default_model: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def set_model_info(cls, values):
if values.get("spend") is None:
values.update({"spend": 0.0})
return values
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class LiteLLM_SpendLogs(LiteLLMBase):

View file

@ -13,7 +13,7 @@ sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the, system path
import pytest, litellm
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from litellm.proxy.proxy_server import ProxyConfig
from litellm.proxy.utils import encrypt_value, ProxyLogging, DualCache
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
@ -26,8 +26,7 @@ class DBModel(BaseModel):
model_info: dict
litellm_params: dict
class Config:
protected_namespaces = ()
config_dict: ConfigDict = ConfigDict(protected_namespaces=())
@pytest.mark.asyncio

View file

@ -1,6 +1,6 @@
from typing import List, Optional, Union, Iterable
from pydantic import BaseModel, validator
from pydantic import BaseModel, ConfigDict, validator
from typing_extensions import Literal, Required, TypedDict
@ -191,6 +191,4 @@ class CompletionRequest(BaseModel):
api_key: Optional[str] = None
model_list: Optional[List[str]] = None
class Config:
extra = "allow"
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=(), extra="allow")

View file

@ -1,6 +1,6 @@
from typing import List, Optional, Union
from pydantic import BaseModel, validator
from pydantic import BaseModel, ConfigDict
class EmbeddingRequest(BaseModel):
@ -18,6 +18,4 @@ class EmbeddingRequest(BaseModel):
litellm_logging_obj: Optional[dict] = None
logger_fn: Optional[str] = None
class Config:
# allow kwargs
extra = "allow"
model_config = ConfigDict(extra="allow")

View file

@ -1,12 +1,12 @@
"""
litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
"""
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
import uuid
import enum
import httpx
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
import datetime
from .completion import CompletionRequest
from .embedding import EmbeddingRequest
@ -18,8 +18,7 @@ class ModelConfig(BaseModel):
tpm: int
rpm: int
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class RouterConfig(BaseModel):
@ -50,8 +49,7 @@ class RouterConfig(BaseModel):
"latency-based-routing",
] = "simple-shuffle"
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class UpdateRouterConfig(BaseModel):
@ -71,17 +69,14 @@ class UpdateRouterConfig(BaseModel):
fallbacks: Optional[List[dict]] = None
context_window_fallbacks: Optional[List[dict]] = None
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class ModelInfo(BaseModel):
id: Optional[
str
] # Allow id to be optional on input, but it will always be present as a str in the model instance
db_model: bool = (
False # used for proxy - to separate models which are stored in the db vs. config.
)
db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config.
updated_at: Optional[datetime.datetime] = None
updated_by: Optional[str] = None
@ -99,8 +94,7 @@ class ModelInfo(BaseModel):
id = str(id)
super().__init__(id=id, **params)
class Config:
extra = "allow"
model_config = ConfigDict(extra="allow")
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -155,6 +149,8 @@ class GenericLiteLLMParams(BaseModel):
input_cost_per_second: Optional[float] = None
output_cost_per_second: Optional[float] = None
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
def __init__(
self,
custom_llm_provider: Optional[str] = None,
@ -184,7 +180,7 @@ class GenericLiteLLMParams(BaseModel):
output_cost_per_token: Optional[float] = None,
input_cost_per_second: Optional[float] = None,
output_cost_per_second: Optional[float] = None,
**params
**params,
):
args = locals()
args.pop("max_retries", None)
@ -195,10 +191,6 @@ class GenericLiteLLMParams(BaseModel):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params)
class Config:
extra = "allow"
arbitrary_types_allowed = True
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
@ -222,6 +214,7 @@ class LiteLLM_Params(GenericLiteLLMParams):
"""
model: str
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
def __init__(
self,
@ -245,7 +238,7 @@ class LiteLLM_Params(GenericLiteLLMParams):
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
**params
**params,
):
args = locals()
args.pop("max_retries", None)
@ -256,10 +249,6 @@ class LiteLLM_Params(GenericLiteLLMParams):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params)
class Config:
extra = "allow"
arbitrary_types_allowed = True
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)
@ -288,8 +277,7 @@ class updateDeployment(BaseModel):
litellm_params: Optional[updateLiteLLMParams] = None
model_info: Optional[ModelInfo] = None
class Config:
protected_namespaces = ()
model_config = ConfigDict(protected_namespaces=())
class LiteLLMParamsTypedDict(TypedDict, total=False):
@ -338,12 +326,14 @@ class Deployment(BaseModel):
litellm_params: LiteLLM_Params
model_info: ModelInfo
model_config = ConfigDict(extra="allow", protected_namespaces=())
def __init__(
self,
model_name: str,
litellm_params: LiteLLM_Params,
model_info: Optional[Union[ModelInfo, dict]] = None,
**params
**params,
):
if model_info is None:
model_info = ModelInfo()
@ -353,7 +343,7 @@ class Deployment(BaseModel):
model_info=model_info,
model_name=model_name,
litellm_params=litellm_params,
**params
**params,
)
def to_json(self, **kwargs):
@ -363,10 +353,6 @@ class Deployment(BaseModel):
# if using pydantic v1
return self.dict(**kwargs)
class Config:
extra = "allow"
protected_namespaces = ()
def __contains__(self, key):
# Define custom behavior for the 'in' operator
return hasattr(self, key)

View file

@ -18,7 +18,7 @@ from functools import wraps, lru_cache
import datetime, time
import tiktoken
import uuid
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
import aiohttp
import textwrap
import logging
@ -337,9 +337,7 @@ class HiddenParams(OpenAIObject):
model_id: Optional[str] = None # used in Router for individual deployments
api_base: Optional[str] = None # returns api base used for making completion call
class Config:
extra = "allow"
protected_namespaces = ()
model_config = ConfigDict(extra="allow", protected_namespaces=())
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist

View file

@ -84,3 +84,5 @@ version_files = [
"pyproject.toml:^version"
]
[tool.mypy]
plugins = "pydantic.mypy"