fix(proxy_server.py): fix /model/new adding new model issue

This commit is contained in:
Krrish Dholakia 2023-12-09 22:38:06 -08:00
parent 3c8603f148
commit 22f04e3b33
2 changed files with 54 additions and 43 deletions

View file

@ -2,8 +2,21 @@ from pydantic import BaseModel, Extra, Field, root_validator
from typing import Optional, List, Union, Dict, Literal
from datetime import datetime
import uuid, json
class LiteLLMBase(BaseModel):
"""
Implements default functions, all pydantic objects should have.
"""
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
######### Request Class Definition ######
class ProxyChatCompletionRequest(BaseModel):
class ProxyChatCompletionRequest(LiteLLMBase):
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = None
@ -38,16 +51,16 @@ class ProxyChatCompletionRequest(BaseModel):
class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelInfoDelete(BaseModel):
class ModelInfoDelete(LiteLLMBase):
id: Optional[str]
class ModelInfo(BaseModel):
class ModelInfo(LiteLLMBase):
id: Optional[str]
mode: Optional[Literal['embedding', 'chat', 'completion']]
input_cost_per_token: Optional[float]
output_cost_per_token: Optional[float]
max_tokens: Optional[int]
input_cost_per_token: Optional[float] = 0.0
output_cost_per_token: Optional[float] = 0.0
max_tokens: Optional[int] = 2048 # assume 2048 if not set
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
# we look up the base model in model_prices_and_context_window.json
@ -66,37 +79,40 @@ class ModelInfo(BaseModel):
extra = Extra.allow # Allow extra fields
protected_namespaces = ()
# @root_validator(pre=True)
# def set_model_info(cls, values):
# if values.get("id") is None:
# values.update({"id": str(uuid.uuid4())})
# if values.get("mode") is None:
# values.update({"mode": str(uuid.uuid4())})
# return values
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("id") is None:
values.update({"id": str(uuid.uuid4())})
if values.get("mode") is None:
values.update({"mode": None})
if values.get("input_cost_per_token") is None:
values.update({"input_cost_per_token": None})
if values.get("output_cost_per_token") is None:
values.update({"output_cost_per_token": None})
if values.get("max_tokens") is None:
values.update({"max_tokens": None})
if values.get("base_model") is None:
values.update({"base_model": None})
return values
class ModelParams(BaseModel):
class ModelParams(LiteLLMBase):
model_name: str
litellm_params: dict
model_info: Optional[ModelInfo]=None
# def __init__(self, model_name: str, litellm_params: dict, model_info: Optional[ModelInfo] = None):
# self.model_name = model_name
# self.litellm_params = litellm_params
# self.model_info = model_info if model_info else ModelInfo()
# super.__init__(model_name=self.model_name, litellm_params=self.litellm_params, model_info=self.model_info)
model_info: ModelInfo
class Config:
protected_namespaces = ()
# @root_validator(pre=True)
# def set_model_info(cls, values):
# if values.get("model_info") is None:
# values.update({"model_info": ModelInfo()})
# return values
@root_validator(pre=True)
def set_model_info(cls, values):
if values.get("model_info") is None:
values.update({"model_info": ModelInfo()})
return values
class GenerateKeyRequest(BaseModel):
class GenerateKeyRequest(LiteLLMBase):
duration: Optional[str] = "1h"
models: Optional[list] = []
aliases: Optional[dict] = {}
@ -105,26 +121,19 @@ class GenerateKeyRequest(BaseModel):
user_id: Optional[str] = None
max_parallel_requests: Optional[int] = None
def json(self, **kwargs):
try:
return self.model_dump() # noqa
except:
# if using pydantic v1
return self.dict()
class GenerateKeyResponse(BaseModel):
class GenerateKeyResponse(LiteLLMBase):
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(BaseModel):
class _DeleteKeyObject(LiteLLMBase):
key: str
class DeleteKeyRequest(BaseModel):
class DeleteKeyRequest(LiteLLMBase):
keys: List[_DeleteKeyObject]
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
"""
Return the row in the db
"""
@ -137,7 +146,7 @@ class UserAPIKeyAuth(BaseModel): # the expected response object for user api key
max_parallel_requests: Optional[int] = None
duration: str = "1h"
class ConfigGeneralSettings(BaseModel):
class ConfigGeneralSettings(LiteLLMBase):
"""
Documents all the fields supported by `general_settings` in config.yaml
"""
@ -153,7 +162,7 @@ class ConfigGeneralSettings(BaseModel):
health_check_interval: int = Field(300, description="background health check interval in seconds")
class ConfigYAML(BaseModel):
class ConfigYAML(LiteLLMBase):
"""
Documents all the fields supported by the config.yaml
"""

View file

@ -1207,10 +1207,12 @@ async def add_new_model(model_params: ModelParams):
print_verbose(f"Loaded config: {config}")
# Add the new model to the config
model_info = model_params.model_info.json()
model_info = {k: v for k, v in model_info.items() if v is not None}
config['model_list'].append({
'model_name': model_params.model_name,
'litellm_params': model_params.litellm_params,
'model_info': model_params.model_info
'model_info': model_info
})
# Save the updated config
@ -1227,7 +1229,7 @@ async def add_new_model(model_params: ModelParams):
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use v1/model/info
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info
@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
async def model_info_v1(request: Request):
global llm_model_list, general_settings, user_config_file_path
@ -1256,7 +1258,7 @@ async def model_info_v1(request: Request):
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933
@router.get("v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
@router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
async def model_info(request: Request):
global llm_model_list, general_settings, user_config_file_path
# Load existing config