fix(proxy_server.py): fix /model/new adding new model issue

This commit is contained in:
Krrish Dholakia 2023-12-09 22:38:06 -08:00
parent 3c8603f148
commit 22f04e3b33
2 changed files with 54 additions and 43 deletions

View file

@ -2,8 +2,21 @@ from pydantic import BaseModel, Extra, Field, root_validator
from typing import Optional, List, Union, Dict, Literal from typing import Optional, List, Union, Dict, Literal
from datetime import datetime from datetime import datetime
import uuid, json import uuid, json
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
######### Request Class Definition ###### ######### Request Class Definition ######
class ProxyChatCompletionRequest(BaseModel): class ProxyChatCompletionRequest(LiteLLMBase):
model: str model: str
messages: List[Dict[str, str]] messages: List[Dict[str, str]]
temperature: Optional[float] = None temperature: Optional[float] = None
@ -38,16 +51,16 @@ class ProxyChatCompletionRequest(BaseModel):
class Config: class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs) extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelInfoDelete(BaseModel): class ModelInfoDelete(LiteLLMBase):
id: Optional[str] id: Optional[str]
class ModelInfo(BaseModel): class ModelInfo(LiteLLMBase):
id: Optional[str] id: Optional[str]
mode: Optional[Literal['embedding', 'chat', 'completion']] mode: Optional[Literal['embedding', 'chat', 'completion']]
input_cost_per_token: Optional[float] input_cost_per_token: Optional[float] = 0.0
output_cost_per_token: Optional[float] output_cost_per_token: Optional[float] = 0.0
max_tokens: Optional[int] max_tokens: Optional[int] = 2048 # assume 2048 if not set
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model # for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
# we look up the base model in model_prices_and_context_window.json # we look up the base model in model_prices_and_context_window.json
@ -66,37 +79,40 @@ class ModelInfo(BaseModel):
extra = Extra.allow # Allow extra fields extra = Extra.allow # Allow extra fields
protected_namespaces = () protected_namespaces = ()
# @root_validator(pre=True)
# def set_model_info(cls, values): @root_validator(pre=True)
# if values.get("id") is None: def set_model_info(cls, values):
# values.update({"id": str(uuid.uuid4())}) if values.get("id") is None:
# if values.get("mode") is None: values.update({"id": str(uuid.uuid4())})
# values.update({"mode": str(uuid.uuid4())}) if values.get("mode") is None:
# return values values.update({"mode": None})
if values.get("input_cost_per_token") is None:
values.update({"input_cost_per_token": None})
if values.get("output_cost_per_token") is None:
values.update({"output_cost_per_token": None})
if values.get("max_tokens") is None:
values.update({"max_tokens": None})
if values.get("base_model") is None:
values.update({"base_model": None})
return values
class ModelParams(BaseModel): class ModelParams(LiteLLMBase):
model_name: str model_name: str
litellm_params: dict litellm_params: dict
model_info: Optional[ModelInfo]=None model_info: ModelInfo
# def __init__(self, model_name: str, litellm_params: dict, model_info: Optional[ModelInfo] = None):
# self.model_name = model_name
# self.litellm_params = litellm_params
# self.model_info = model_info if model_info else ModelInfo()
# super.__init__(model_name=self.model_name, litellm_params=self.litellm_params, model_info=self.model_info)
class Config: class Config:
protected_namespaces = () protected_namespaces = ()
# @root_validator(pre=True) @root_validator(pre=True)
# def set_model_info(cls, values): def set_model_info(cls, values):
# if values.get("model_info") is None: if values.get("model_info") is None:
# values.update({"model_info": ModelInfo()}) values.update({"model_info": ModelInfo()})
# return values return values
class GenerateKeyRequest(BaseModel): class GenerateKeyRequest(LiteLLMBase):
duration: Optional[str] = "1h" duration: Optional[str] = "1h"
models: Optional[list] = [] models: Optional[list] = []
aliases: Optional[dict] = {} aliases: Optional[dict] = {}
@ -105,26 +121,19 @@ class GenerateKeyRequest(BaseModel):
user_id: Optional[str] = None user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
def json(self, **kwargs): class GenerateKeyResponse(LiteLLMBase):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class GenerateKeyResponse(BaseModel):
key: str key: str
expires: datetime expires: datetime
user_id: str user_id: str
class _DeleteKeyObject(BaseModel): class _DeleteKeyObject(LiteLLMBase):
key: str key: str
class DeleteKeyRequest(BaseModel): class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject] keys: List[_DeleteKeyObject]
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
""" """
Return the row in the db Return the row in the db
""" """
@ -137,7 +146,7 @@ class UserAPIKeyAuth(BaseModel): # the expected response object for user api key
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
duration: str = "1h" duration: str = "1h"
class ConfigGeneralSettings(BaseModel): class ConfigGeneralSettings(LiteLLMBase):
""" """
Documents all the fields supported by `general_settings` in config.yaml Documents all the fields supported by `general_settings` in config.yaml
""" """
@ -153,7 +162,7 @@ class ConfigGeneralSettings(BaseModel):
health_check_interval: int = Field(300, description="background health check interval in seconds") health_check_interval: int = Field(300, description="background health check interval in seconds")
class ConfigYAML(BaseModel): class ConfigYAML(LiteLLMBase):
""" """
Documents all the fields supported by the config.yaml Documents all the fields supported by the config.yaml
""" """

View file

@ -1207,10 +1207,12 @@ async def add_new_model(model_params: ModelParams):
print_verbose(f"Loaded config: {config}") print_verbose(f"Loaded config: {config}")
# Add the new model to the config # Add the new model to the config
model_info = model_params.model_info.json()
model_info = {k: v for k, v in model_info.items() if v is not None}
config['model_list'].append({ config['model_list'].append({
'model_name': model_params.model_name, 'model_name': model_params.model_name,
'litellm_params': model_params.litellm_params, 'litellm_params': model_params.litellm_params,
'model_info': model_params.model_info 'model_info': model_info
}) })
# Save the updated config # Save the updated config
@ -1227,7 +1229,7 @@ async def add_new_model(model_params: ModelParams):
traceback.print_exc() traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use v1/model/info #### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info
@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)]) @router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
async def model_info_v1(request: Request): async def model_info_v1(request: Request):
global llm_model_list, general_settings, user_config_file_path global llm_model_list, general_settings, user_config_file_path
@ -1256,7 +1258,7 @@ async def model_info_v1(request: Request):
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933 #### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933
@router.get("v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)]) @router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
async def model_info(request: Request): async def model_info(request: Request):
global llm_model_list, general_settings, user_config_file_path global llm_model_list, general_settings, user_config_file_path
# Load existing config # Load existing config