diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 68709f34d..7ab0d5fe5 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2,8 +2,21 @@ from pydantic import BaseModel, Extra, Field, root_validator from typing import Optional, List, Union, Dict, Literal from datetime import datetime import uuid, json + +class LiteLLMBase(BaseModel): + """ + Implements default functions, all pydantic objects should have. + """ + def json(self, **kwargs): + try: + return self.model_dump() # noqa + except: + # if using pydantic v1 + return self.dict() + + ######### Request Class Definition ###### -class ProxyChatCompletionRequest(BaseModel): +class ProxyChatCompletionRequest(LiteLLMBase): model: str messages: List[Dict[str, str]] temperature: Optional[float] = None @@ -38,16 +51,16 @@ class ProxyChatCompletionRequest(BaseModel): class Config: extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs) -class ModelInfoDelete(BaseModel): +class ModelInfoDelete(LiteLLMBase): id: Optional[str] -class ModelInfo(BaseModel): +class ModelInfo(LiteLLMBase): id: Optional[str] mode: Optional[Literal['embedding', 'chat', 'completion']] - input_cost_per_token: Optional[float] - output_cost_per_token: Optional[float] - max_tokens: Optional[int] + input_cost_per_token: Optional[float] = 0.0 + output_cost_per_token: Optional[float] = 0.0 + max_tokens: Optional[int] = 2048 # assume 2048 if not set # for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model # we look up the base model in model_prices_and_context_window.json @@ -65,38 +78,41 @@ class ModelInfo(BaseModel): class Config: extra = Extra.allow # Allow extra fields protected_namespaces = () + - # @root_validator(pre=True) - # def set_model_info(cls, values): - # if values.get("id") is None: - # values.update({"id": str(uuid.uuid4())}) - # if values.get("mode") is None: - # values.update({"mode": str(uuid.uuid4())}) - # return values + @root_validator(pre=True) + def set_model_info(cls, values): + if values.get("id") is None: + values.update({"id": str(uuid.uuid4())}) + if values.get("mode") is None: + values.update({"mode": None}) + if values.get("input_cost_per_token") is None: + values.update({"input_cost_per_token": None}) + if values.get("output_cost_per_token") is None: + values.update({"output_cost_per_token": None}) + if values.get("max_tokens") is None: + values.update({"max_tokens": None}) + if values.get("base_model") is None: + values.update({"base_model": None}) + return values -class ModelParams(BaseModel): +class ModelParams(LiteLLMBase): model_name: str litellm_params: dict - model_info: Optional[ModelInfo]=None + model_info: ModelInfo - # def __init__(self, model_name: str, litellm_params: dict, model_info: Optional[ModelInfo] = None): - # self.model_name = model_name - # self.litellm_params = litellm_params - # self.model_info = model_info if model_info else ModelInfo() - # super.__init__(model_name=self.model_name, litellm_params=self.litellm_params, model_info=self.model_info) - class Config: protected_namespaces = () - # @root_validator(pre=True) - # def set_model_info(cls, values): - # if values.get("model_info") is None: - # values.update({"model_info": ModelInfo()}) - # return values + @root_validator(pre=True) + def set_model_info(cls, values): + if values.get("model_info") is None: + values.update({"model_info": ModelInfo()}) + return values -class GenerateKeyRequest(BaseModel): +class GenerateKeyRequest(LiteLLMBase): duration: Optional[str] = "1h" models: Optional[list] = [] aliases: Optional[dict] = {} @@ -105,26 +121,19 @@ class GenerateKeyRequest(BaseModel): user_id: Optional[str] = None max_parallel_requests: Optional[int] = None - def json(self, **kwargs): - try: - return self.model_dump() # noqa - except: - # if using pydantic v1 - return self.dict() - -class GenerateKeyResponse(BaseModel): +class GenerateKeyResponse(LiteLLMBase): key: str expires: datetime user_id: str -class _DeleteKeyObject(BaseModel): +class _DeleteKeyObject(LiteLLMBase): key: str -class DeleteKeyRequest(BaseModel): +class DeleteKeyRequest(LiteLLMBase): keys: List[_DeleteKeyObject] -class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth +class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth """ Return the row in the db """ @@ -137,7 +146,7 @@ class UserAPIKeyAuth(BaseModel): # the expected response object for user api key max_parallel_requests: Optional[int] = None duration: str = "1h" -class ConfigGeneralSettings(BaseModel): +class ConfigGeneralSettings(LiteLLMBase): """ Documents all the fields supported by `general_settings` in config.yaml """ @@ -153,7 +162,7 @@ class ConfigGeneralSettings(BaseModel): health_check_interval: int = Field(300, description="background health check interval in seconds") -class ConfigYAML(BaseModel): +class ConfigYAML(LiteLLMBase): """ Documents all the fields supported by the config.yaml """ diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 259bf23f2..259ead7c9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1207,10 +1207,12 @@ async def add_new_model(model_params: ModelParams): print_verbose(f"Loaded config: {config}") # Add the new model to the config + model_info = model_params.model_info.json() + model_info = {k: v for k, v in model_info.items() if v is not None} config['model_list'].append({ 'model_name': model_params.model_name, 'litellm_params': model_params.litellm_params, - 'model_info': model_params.model_info + 'model_info': model_info }) # Save the updated config @@ -1227,7 +1229,7 @@ async def add_new_model(model_params: ModelParams): traceback.print_exc() raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") -#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use v1/model/info +#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info @router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)]) async def model_info_v1(request: Request): global llm_model_list, general_settings, user_config_file_path @@ -1256,7 +1258,7 @@ async def model_info_v1(request: Request): #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933 -@router.get("v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)]) +@router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)]) async def model_info(request: Request): global llm_model_list, general_settings, user_config_file_path # Load existing config