diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index d8cf83f84..6df6b4fe4 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Extra, Field, root_validator, Json, validator +from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict from dataclasses import fields import enum from typing import Optional, List, Union, Dict, Literal, Any @@ -119,8 +119,7 @@ class LiteLLMBase(BaseModel): # if using pydantic v1 return self.__fields_set__ - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase): @@ -349,7 +348,8 @@ class LiteLLMPromptInjectionParams(LiteLLMBase): description="Return rejected request error message as a string to the user. Default behaviour is to raise an exception.", ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def check_llm_api_params(cls, values): llm_api_check = values.get("llm_api_check") if llm_api_check is True: @@ -407,8 +407,7 @@ class ProxyChatCompletionRequest(LiteLLMBase): deployment_id: Optional[str] = None request_timeout: Optional[int] = None - class Config: - extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs) + model_config = ConfigDict(extra="allow") # allow params not defined here, these fall in litellm.completion(**kwargs) class ModelInfoDelete(LiteLLMBase): @@ -435,11 +434,10 @@ class ModelInfo(LiteLLMBase): ] ] - class Config: - extra = Extra.allow # Allow extra fields - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=(), extra="allow") - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def set_model_info(cls, values): if values.get("id") is None: values.update({"id": str(uuid.uuid4())}) @@ -470,10 +468,10 @@ class ModelParams(LiteLLMBase): litellm_params: dict model_info: ModelInfo - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def set_model_info(cls, values): if values.get("model_info") is None: values.update({"model_info": ModelInfo()}) @@ -509,8 +507,7 @@ class GenerateKeyRequest(GenerateRequestBase): {} ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class GenerateKeyResponse(GenerateKeyRequest): @@ -520,7 +517,8 @@ class GenerateKeyResponse(GenerateKeyRequest): user_id: Optional[str] = None token_id: Optional[str] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def set_model_info(cls, values): if values.get("token") is not None: values.update({"key": values.get("token")}) @@ -560,8 +558,7 @@ class LiteLLM_ModelTable(LiteLLMBase): created_by: str updated_by: str - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class NewUserRequest(GenerateKeyRequest): @@ -607,7 +604,8 @@ class UpdateUserRequest(GenerateRequestBase): ] = None max_budget: Optional[float] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def check_user_info(cls, values): if values.get("user_id") is None and values.get("user_email") is None: raise ValueError("Either user id or user email must be provided") @@ -631,7 +629,8 @@ class NewCustomerRequest(LiteLLMBase): None # if no equivalent model in allowed region - default all requests to this model ) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def check_user_info(cls, values): if values.get("max_budget") is not None and values.get("budget_id") is not None: raise ValueError("Set either 'max_budget' or 'budget_id', not both.") @@ -671,7 +670,8 @@ class Member(LiteLLMBase): user_id: Optional[str] = None user_email: Optional[str] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def check_user_info(cls, values): if values.get("user_id") is None and values.get("user_email") is None: raise ValueError("Either user id or user email must be provided") @@ -700,8 +700,7 @@ class TeamBase(LiteLLMBase): class NewTeamRequest(TeamBase): model_aliases: Optional[dict] = None - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class GlobalEndUsersSpend(LiteLLMBase): @@ -721,7 +720,8 @@ class TeamMemberDeleteRequest(LiteLLMBase): user_id: Optional[str] = None user_email: Optional[str] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def check_user_info(cls, values): if values.get("user_id") is None and values.get("user_email") is None: raise ValueError("Either user id or user email must be provided") @@ -787,10 +787,10 @@ class LiteLLM_TeamTable(TeamBase): budget_reset_at: Optional[datetime] = None model_id: Optional[int] = None - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def set_model_info(cls, values): dict_fields = [ "metadata", @@ -826,8 +826,7 @@ class LiteLLM_BudgetTable(LiteLLMBase): model_max_budget: Optional[dict] = None budget_duration: Optional[str] = None - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable): @@ -840,8 +839,7 @@ class LiteLLM_TeamMemberTable(LiteLLM_BudgetTable): team_id: Optional[str] = None budget_id: Optional[str] = None - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class NewOrganizationRequest(LiteLLM_BudgetTable): @@ -920,8 +918,7 @@ class KeyManagementSettings(LiteLLMBase): class TeamDefaultSettings(LiteLLMBase): team_id: str - class Config: - extra = "allow" # allow params not defined here, these fall in litellm.completion(**kwargs) + model_config = ConfigDict(extra="allow") # allow params not defined here, these fall in litellm.completion(**kwargs) class DynamoDBArgs(LiteLLMBase): @@ -1083,8 +1080,7 @@ class ConfigYAML(LiteLLMBase): description="litellm router object settings. See router.py __init__ for all, example router.num_retries=5, router.timeout=5, router.max_retries=5, router.retry_after=5", ) - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class LiteLLM_VerificationToken(LiteLLMBase): @@ -1114,9 +1110,7 @@ class LiteLLM_VerificationToken(LiteLLMBase): org_id: Optional[str] = None # org id for a given key - class Config: - protected_namespaces = () - + model_config = ConfigDict(protected_namespaces=()) class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): """ @@ -1161,7 +1155,8 @@ class UserAPIKeyAuth( ] = None allowed_model_region: Optional[Literal["eu"]] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def check_api_key(cls, values): if values.get("api_key") is not None: values.update({"token": hash_token(values.get("api_key"))}) @@ -1188,7 +1183,8 @@ class LiteLLM_UserTable(LiteLLMBase): tpm_limit: Optional[int] = None rpm_limit: Optional[int] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def set_model_info(cls, values): if values.get("spend") is None: values.update({"spend": 0.0}) @@ -1196,8 +1192,7 @@ class LiteLLM_UserTable(LiteLLMBase): values.update({"models": []}) return values - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class LiteLLM_EndUserTable(LiteLLMBase): @@ -1209,14 +1204,14 @@ class LiteLLM_EndUserTable(LiteLLMBase): default_model: Optional[str] = None litellm_budget_table: Optional[LiteLLM_BudgetTable] = None - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def set_model_info(cls, values): if values.get("spend") is None: values.update({"spend": 0.0}) return values - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class LiteLLM_SpendLogs(LiteLLMBase): diff --git a/litellm/tests/test_config.py b/litellm/tests/test_config.py index 47f632b96..b62d20422 100644 --- a/litellm/tests/test_config.py +++ b/litellm/tests/test_config.py @@ -13,7 +13,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the, system path import pytest, litellm -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from litellm.proxy.proxy_server import ProxyConfig from litellm.proxy.utils import encrypt_value, ProxyLogging, DualCache from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo @@ -26,8 +26,7 @@ class DBModel(BaseModel): model_info: dict litellm_params: dict - class Config: - protected_namespaces = () + config_dict: ConfigDict = ConfigDict(protected_namespaces=()) @pytest.mark.asyncio diff --git a/litellm/types/completion.py b/litellm/types/completion.py index 78af7667b..c8ddc7449 100644 --- a/litellm/types/completion.py +++ b/litellm/types/completion.py @@ -1,6 +1,6 @@ from typing import List, Optional, Union, Iterable -from pydantic import BaseModel, validator +from pydantic import BaseModel, ConfigDict, validator from typing_extensions import Literal, Required, TypedDict @@ -191,6 +191,4 @@ class CompletionRequest(BaseModel): api_key: Optional[str] = None model_list: Optional[List[str]] = None - class Config: - extra = "allow" - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=(), extra="allow") diff --git a/litellm/types/embedding.py b/litellm/types/embedding.py index 9db0ef290..f8fdebc53 100644 --- a/litellm/types/embedding.py +++ b/litellm/types/embedding.py @@ -1,6 +1,6 @@ from typing import List, Optional, Union -from pydantic import BaseModel, validator +from pydantic import BaseModel, ConfigDict class EmbeddingRequest(BaseModel): @@ -18,6 +18,4 @@ class EmbeddingRequest(BaseModel): litellm_logging_obj: Optional[dict] = None logger_fn: Optional[str] = None - class Config: - # allow kwargs - extra = "allow" + model_config = ConfigDict(extra="allow") diff --git a/litellm/types/router.py b/litellm/types/router.py index 75e792f4c..a35e7a77d 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -1,12 +1,12 @@ """ - litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc +litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc """ from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict import uuid import enum import httpx -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field import datetime from .completion import CompletionRequest from .embedding import EmbeddingRequest @@ -18,8 +18,7 @@ class ModelConfig(BaseModel): tpm: int rpm: int - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class RouterConfig(BaseModel): @@ -50,8 +49,7 @@ class RouterConfig(BaseModel): "latency-based-routing", ] = "simple-shuffle" - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class UpdateRouterConfig(BaseModel): @@ -71,17 +69,14 @@ class UpdateRouterConfig(BaseModel): fallbacks: Optional[List[dict]] = None context_window_fallbacks: Optional[List[dict]] = None - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class ModelInfo(BaseModel): id: Optional[ str ] # Allow id to be optional on input, but it will always be present as a str in the model instance - db_model: bool = ( - False # used for proxy - to separate models which are stored in the db vs. config. - ) + db_model: bool = False # used for proxy - to separate models which are stored in the db vs. config. updated_at: Optional[datetime.datetime] = None updated_by: Optional[str] = None @@ -99,8 +94,7 @@ class ModelInfo(BaseModel): id = str(id) super().__init__(id=id, **params) - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") def __contains__(self, key): # Define custom behavior for the 'in' operator @@ -155,6 +149,8 @@ class GenericLiteLLMParams(BaseModel): input_cost_per_second: Optional[float] = None output_cost_per_second: Optional[float] = None + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + def __init__( self, custom_llm_provider: Optional[str] = None, @@ -184,7 +180,7 @@ class GenericLiteLLMParams(BaseModel): output_cost_per_token: Optional[float] = None, input_cost_per_second: Optional[float] = None, output_cost_per_second: Optional[float] = None, - **params + **params, ): args = locals() args.pop("max_retries", None) @@ -195,10 +191,6 @@ class GenericLiteLLMParams(BaseModel): max_retries = int(max_retries) # cast to int super().__init__(max_retries=max_retries, **args, **params) - class Config: - extra = "allow" - arbitrary_types_allowed = True - def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -222,6 +214,7 @@ class LiteLLM_Params(GenericLiteLLMParams): """ model: str + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) def __init__( self, @@ -245,7 +238,7 @@ class LiteLLM_Params(GenericLiteLLMParams): aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_region_name: Optional[str] = None, - **params + **params, ): args = locals() args.pop("max_retries", None) @@ -256,10 +249,6 @@ class LiteLLM_Params(GenericLiteLLMParams): max_retries = int(max_retries) # cast to int super().__init__(max_retries=max_retries, **args, **params) - class Config: - extra = "allow" - arbitrary_types_allowed = True - def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) @@ -288,8 +277,7 @@ class updateDeployment(BaseModel): litellm_params: Optional[updateLiteLLMParams] = None model_info: Optional[ModelInfo] = None - class Config: - protected_namespaces = () + model_config = ConfigDict(protected_namespaces=()) class LiteLLMParamsTypedDict(TypedDict, total=False): @@ -338,12 +326,14 @@ class Deployment(BaseModel): litellm_params: LiteLLM_Params model_info: ModelInfo + model_config = ConfigDict(extra="allow", protected_namespaces=()) + def __init__( self, model_name: str, litellm_params: LiteLLM_Params, model_info: Optional[Union[ModelInfo, dict]] = None, - **params + **params, ): if model_info is None: model_info = ModelInfo() @@ -353,7 +343,7 @@ class Deployment(BaseModel): model_info=model_info, model_name=model_name, litellm_params=litellm_params, - **params + **params, ) def to_json(self, **kwargs): @@ -363,10 +353,6 @@ class Deployment(BaseModel): # if using pydantic v1 return self.dict(**kwargs) - class Config: - extra = "allow" - protected_namespaces = () - def __contains__(self, key): # Define custom behavior for the 'in' operator return hasattr(self, key) diff --git a/litellm/utils.py b/litellm/utils.py index 9d2fcaec2..3e9fdccd9 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -18,7 +18,7 @@ from functools import wraps, lru_cache import datetime, time import tiktoken import uuid -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict import aiohttp import textwrap import logging @@ -337,9 +337,7 @@ class HiddenParams(OpenAIObject): model_id: Optional[str] = None # used in Router for individual deployments api_base: Optional[str] = None # returns api base used for making completion call - class Config: - extra = "allow" - protected_namespaces = () + model_config = ConfigDict(extra="allow", protected_namespaces=()) def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist diff --git a/pyproject.toml b/pyproject.toml index 49ca81db6..b970eb9fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,3 +84,5 @@ version_files = [ "pyproject.toml:^version" ] +[tool.mypy] +plugins = "pydantic.mypy"