fix(types/router.py): fix router pydantic v1 errors

This commit is contained in:
Krrish Dholakia 2024-05-14 16:49:55 -07:00
parent 2b41f09268
commit e5a1050e8d

View file

@ -1,20 +1,42 @@
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
import httpx
from pydantic import ConfigDict, BaseModel, validator, Field, __version__ as pydantic_version
from pydantic import (
ConfigDict,
BaseModel,
validator,
Field,
__version__ as pydantic_version,
VERSION,
)
from .completion import CompletionRequest
from .embedding import EmbeddingRequest
import uuid, enum
# Function to get Pydantic version
def is_pydantic_v2() -> int:
return int(VERSION.split(".")[0])
def get_model_config(arbitrary_types_allowed: bool = False) -> ConfigDict:
# Version-specific configuration
if is_pydantic_v2() >= 2:
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=arbitrary_types_allowed, protected_namespaces=()) # type: ignore
else:
from pydantic import Extra
model_config = ConfigDict(extra=Extra.allow) # type: ignore
return model_config
class ModelConfig(BaseModel):
model_name: str
litellm_params: Union[CompletionRequest, EmbeddingRequest]
tpm: int
rpm: int
model_config = ConfigDict(
protected_namespaces = (),
)
model_config = get_model_config()
class RouterConfig(BaseModel):
@ -45,9 +67,7 @@ class RouterConfig(BaseModel):
"latency-based-routing",
] = "simple-shuffle"
model_config = ConfigDict(
protected_namespaces = (),
)
model_config = get_model_config()
class UpdateRouterConfig(BaseModel):
@ -67,9 +87,7 @@ class UpdateRouterConfig(BaseModel):
fallbacks: Optional[List[dict]] = None
context_window_fallbacks: Optional[List[dict]] = None
model_config = ConfigDict(
protected_namespaces = (),
)
model_config = get_model_config()
class ModelInfo(BaseModel):
@ -87,9 +105,7 @@ class ModelInfo(BaseModel):
id = str(id)
super().__init__(id=id, **params)
model_config = ConfigDict(
extra = "allow",
)
model_config = get_model_config()
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -184,10 +200,7 @@ class GenericLiteLLMParams(BaseModel):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params)
model_config = ConfigDict(
extra = "allow",
arbitrary_types_allowed = True,
)
model_config = get_model_config()
if pydantic_version.startswith("1"):
# pydantic v2 warns about using a Config class.
# But without this, pydantic v1 will raise an error:
@ -254,10 +267,8 @@ class LiteLLM_Params(GenericLiteLLMParams):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params)
model_config = ConfigDict(
extra = "allow",
arbitrary_types_allowed = True,
)
model_config = get_model_config(arbitrary_types_allowed=True)
if pydantic_version.startswith("1"):
# pydantic v2 warns about using a Config class.
# But without this, pydantic v1 will raise an error:
@ -295,9 +306,7 @@ class updateDeployment(BaseModel):
litellm_params: Optional[updateLiteLLMParams] = None
model_info: Optional[ModelInfo] = None
model_config = ConfigDict(
protected_namespaces = (),
)
model_config = get_model_config()
class LiteLLMParamsTypedDict(TypedDict, total=False):
@ -371,10 +380,7 @@ class Deployment(BaseModel):
# if using pydantic v1
return self.dict(**kwargs)
model_config = ConfigDict(
extra = "allow",
protected_namespaces = (),
)
model_config = get_model_config()
def __contains__(self, key):
# Define custom behavior for the 'in' operator