fix types/router.py

This commit is contained in:
Ishaan Jaff 2024-05-15 19:46:30 -07:00
parent 86223bc703
commit 1a67f244fb

View file

@ -1,42 +1,19 @@
from typing import List, Optional, Union, Dict, Tuple, Literal, TypedDict
import httpx
from pydantic import (
ConfigDict,
BaseModel,
validator,
Field,
__version__ as pydantic_version,
VERSION,
)
from pydantic import BaseModel, validator, Field
from .completion import CompletionRequest
from .embedding import EmbeddingRequest
import uuid, enum
# Function to get Pydantic version
def is_pydantic_v2() -> int:
return int(VERSION.split(".")[0])
def get_model_config(arbitrary_types_allowed: bool = False) -> ConfigDict:
# Version-specific configuration
if is_pydantic_v2() >= 2:
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=arbitrary_types_allowed, protected_namespaces=()) # type: ignore
else:
from pydantic import Extra
model_config = ConfigDict(extra=Extra.allow, arbitrary_types_allowed=arbitrary_types_allowed) # type: ignore
return model_config
class ModelConfig(BaseModel):
model_name: str
litellm_params: Union[CompletionRequest, EmbeddingRequest]
tpm: int
rpm: int
model_config = get_model_config()
class Config:
protected_namespaces = ()
class RouterConfig(BaseModel):
@ -67,7 +44,8 @@ class RouterConfig(BaseModel):
"latency-based-routing",
] = "simple-shuffle"
model_config = get_model_config()
class Config:
protected_namespaces = ()
class UpdateRouterConfig(BaseModel):
@ -87,7 +65,8 @@ class UpdateRouterConfig(BaseModel):
fallbacks: Optional[List[dict]] = None
context_window_fallbacks: Optional[List[dict]] = None
model_config = get_model_config()
class Config:
protected_namespaces = ()
class ModelInfo(BaseModel):
@ -105,7 +84,8 @@ class ModelInfo(BaseModel):
id = str(id)
super().__init__(id=id, **params)
model_config = get_model_config()
class Config:
extra = "allow"
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -200,15 +180,9 @@ class GenericLiteLLMParams(BaseModel):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params)
model_config = get_model_config(arbitrary_types_allowed=True)
if pydantic_version.startswith("1"):
# pydantic v2 warns about using a Config class.
# But without this, pydantic v1 will raise an error:
# RuntimeError: no validator found for <class 'openai.Timeout'>,
# see `arbitrary_types_allowed` in Config
# Putting arbitrary_types_allowed = True in the ConfigDict doesn't work in pydantic v1.
class Config:
arbitrary_types_allowed = True
class Config:
extra = "allow"
arbitrary_types_allowed = True
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -267,16 +241,9 @@ class LiteLLM_Params(GenericLiteLLMParams):
max_retries = int(max_retries) # cast to int
super().__init__(max_retries=max_retries, **args, **params)
model_config = get_model_config(arbitrary_types_allowed=True)
if pydantic_version.startswith("1"):
# pydantic v2 warns about using a Config class.
# But without this, pydantic v1 will raise an error:
# RuntimeError: no validator found for <class 'openai.Timeout'>,
# see `arbitrary_types_allowed` in Config
# Putting arbitrary_types_allowed = True in the ConfigDict doesn't work in pydantic v1.
class Config:
arbitrary_types_allowed = True
class Config:
extra = "allow"
arbitrary_types_allowed = True
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -306,7 +273,8 @@ class updateDeployment(BaseModel):
litellm_params: Optional[updateLiteLLMParams] = None
model_info: Optional[ModelInfo] = None
model_config = get_model_config()
class Config:
protected_namespaces = ()
class LiteLLMParamsTypedDict(TypedDict, total=False):
@ -380,7 +348,9 @@ class Deployment(BaseModel):
# if using pydantic v1
return self.dict(**kwargs)
model_config = get_model_config()
class Config:
extra = "allow"
protected_namespaces = ()
def __contains__(self, key):
# Define custom behavior for the 'in' operator