forked from phoenix/litellm-mirror
fix(proxy_server.py): fix /model/new adding new model issue
This commit is contained in:
parent
3c8603f148
commit
22f04e3b33
2 changed files with 54 additions and 43 deletions
|
@ -2,8 +2,21 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
|||
from typing import Optional, List, Union, Dict, Literal
|
||||
from datetime import datetime
|
||||
import uuid, json
|
||||
|
||||
class LiteLLMBase(BaseModel):
|
||||
"""
|
||||
Implements default functions, all pydantic objects should have.
|
||||
"""
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
|
||||
######### Request Class Definition ######
|
||||
class ProxyChatCompletionRequest(BaseModel):
|
||||
class ProxyChatCompletionRequest(LiteLLMBase):
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
temperature: Optional[float] = None
|
||||
|
@ -38,16 +51,16 @@ class ProxyChatCompletionRequest(BaseModel):
|
|||
class Config:
|
||||
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||
|
||||
class ModelInfoDelete(BaseModel):
|
||||
class ModelInfoDelete(LiteLLMBase):
|
||||
id: Optional[str]
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
class ModelInfo(LiteLLMBase):
|
||||
id: Optional[str]
|
||||
mode: Optional[Literal['embedding', 'chat', 'completion']]
|
||||
input_cost_per_token: Optional[float]
|
||||
output_cost_per_token: Optional[float]
|
||||
max_tokens: Optional[int]
|
||||
input_cost_per_token: Optional[float] = 0.0
|
||||
output_cost_per_token: Optional[float] = 0.0
|
||||
max_tokens: Optional[int] = 2048 # assume 2048 if not set
|
||||
|
||||
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
|
||||
# we look up the base model in model_prices_and_context_window.json
|
||||
|
@ -66,37 +79,40 @@ class ModelInfo(BaseModel):
|
|||
extra = Extra.allow # Allow extra fields
|
||||
protected_namespaces = ()
|
||||
|
||||
# @root_validator(pre=True)
|
||||
# def set_model_info(cls, values):
|
||||
# if values.get("id") is None:
|
||||
# values.update({"id": str(uuid.uuid4())})
|
||||
# if values.get("mode") is None:
|
||||
# values.update({"mode": str(uuid.uuid4())})
|
||||
# return values
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("id") is None:
|
||||
values.update({"id": str(uuid.uuid4())})
|
||||
if values.get("mode") is None:
|
||||
values.update({"mode": None})
|
||||
if values.get("input_cost_per_token") is None:
|
||||
values.update({"input_cost_per_token": None})
|
||||
if values.get("output_cost_per_token") is None:
|
||||
values.update({"output_cost_per_token": None})
|
||||
if values.get("max_tokens") is None:
|
||||
values.update({"max_tokens": None})
|
||||
if values.get("base_model") is None:
|
||||
values.update({"base_model": None})
|
||||
return values
|
||||
|
||||
|
||||
|
||||
class ModelParams(BaseModel):
|
||||
class ModelParams(LiteLLMBase):
|
||||
model_name: str
|
||||
litellm_params: dict
|
||||
model_info: Optional[ModelInfo]=None
|
||||
|
||||
# def __init__(self, model_name: str, litellm_params: dict, model_info: Optional[ModelInfo] = None):
|
||||
# self.model_name = model_name
|
||||
# self.litellm_params = litellm_params
|
||||
# self.model_info = model_info if model_info else ModelInfo()
|
||||
# super.__init__(model_name=self.model_name, litellm_params=self.litellm_params, model_info=self.model_info)
|
||||
model_info: ModelInfo
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
# @root_validator(pre=True)
|
||||
# def set_model_info(cls, values):
|
||||
# if values.get("model_info") is None:
|
||||
# values.update({"model_info": ModelInfo()})
|
||||
# return values
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("model_info") is None:
|
||||
values.update({"model_info": ModelInfo()})
|
||||
return values
|
||||
|
||||
class GenerateKeyRequest(BaseModel):
|
||||
class GenerateKeyRequest(LiteLLMBase):
|
||||
duration: Optional[str] = "1h"
|
||||
models: Optional[list] = []
|
||||
aliases: Optional[dict] = {}
|
||||
|
@ -105,26 +121,19 @@ class GenerateKeyRequest(BaseModel):
|
|||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
class GenerateKeyResponse(BaseModel):
|
||||
class GenerateKeyResponse(LiteLLMBase):
|
||||
key: str
|
||||
expires: datetime
|
||||
user_id: str
|
||||
|
||||
class _DeleteKeyObject(BaseModel):
|
||||
class _DeleteKeyObject(LiteLLMBase):
|
||||
key: str
|
||||
|
||||
class DeleteKeyRequest(BaseModel):
|
||||
class DeleteKeyRequest(LiteLLMBase):
|
||||
keys: List[_DeleteKeyObject]
|
||||
|
||||
|
||||
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth
|
||||
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||
"""
|
||||
Return the row in the db
|
||||
"""
|
||||
|
@ -137,7 +146,7 @@ class UserAPIKeyAuth(BaseModel): # the expected response object for user api key
|
|||
max_parallel_requests: Optional[int] = None
|
||||
duration: str = "1h"
|
||||
|
||||
class ConfigGeneralSettings(BaseModel):
|
||||
class ConfigGeneralSettings(LiteLLMBase):
|
||||
"""
|
||||
Documents all the fields supported by `general_settings` in config.yaml
|
||||
"""
|
||||
|
@ -153,7 +162,7 @@ class ConfigGeneralSettings(BaseModel):
|
|||
health_check_interval: int = Field(300, description="background health check interval in seconds")
|
||||
|
||||
|
||||
class ConfigYAML(BaseModel):
|
||||
class ConfigYAML(LiteLLMBase):
|
||||
"""
|
||||
Documents all the fields supported by the config.yaml
|
||||
"""
|
||||
|
|
|
@ -1207,10 +1207,12 @@ async def add_new_model(model_params: ModelParams):
|
|||
|
||||
print_verbose(f"Loaded config: {config}")
|
||||
# Add the new model to the config
|
||||
model_info = model_params.model_info.json()
|
||||
model_info = {k: v for k, v in model_info.items() if v is not None}
|
||||
config['model_list'].append({
|
||||
'model_name': model_params.model_name,
|
||||
'litellm_params': model_params.litellm_params,
|
||||
'model_info': model_params.model_info
|
||||
'model_info': model_info
|
||||
})
|
||||
|
||||
# Save the updated config
|
||||
|
@ -1227,7 +1229,7 @@ async def add_new_model(model_params: ModelParams):
|
|||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use v1/model/info
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info
|
||||
@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def model_info_v1(request: Request):
|
||||
global llm_model_list, general_settings, user_config_file_path
|
||||
|
@ -1256,7 +1258,7 @@ async def model_info_v1(request: Request):
|
|||
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933
|
||||
@router.get("v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||
@router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def model_info(request: Request):
|
||||
global llm_model_list, general_settings, user_config_file_path
|
||||
# Load existing config
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue