forked from phoenix/litellm-mirror
fix(proxy_server.py): fix /model/new adding new model issue
This commit is contained in:
parent
3c8603f148
commit
22f04e3b33
2 changed files with 54 additions and 43 deletions
|
@ -2,8 +2,21 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
from typing import Optional, List, Union, Dict, Literal
|
from typing import Optional, List, Union, Dict, Literal
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import uuid, json
|
import uuid, json
|
||||||
|
|
||||||
|
class LiteLLMBase(BaseModel):
|
||||||
|
"""
|
||||||
|
Implements default functions, all pydantic objects should have.
|
||||||
|
"""
|
||||||
|
def json(self, **kwargs):
|
||||||
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
|
|
||||||
######### Request Class Definition ######
|
######### Request Class Definition ######
|
||||||
class ProxyChatCompletionRequest(BaseModel):
|
class ProxyChatCompletionRequest(LiteLLMBase):
|
||||||
model: str
|
model: str
|
||||||
messages: List[Dict[str, str]]
|
messages: List[Dict[str, str]]
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
|
@ -38,16 +51,16 @@ class ProxyChatCompletionRequest(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||||
|
|
||||||
class ModelInfoDelete(BaseModel):
|
class ModelInfoDelete(LiteLLMBase):
|
||||||
id: Optional[str]
|
id: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(LiteLLMBase):
|
||||||
id: Optional[str]
|
id: Optional[str]
|
||||||
mode: Optional[Literal['embedding', 'chat', 'completion']]
|
mode: Optional[Literal['embedding', 'chat', 'completion']]
|
||||||
input_cost_per_token: Optional[float]
|
input_cost_per_token: Optional[float] = 0.0
|
||||||
output_cost_per_token: Optional[float]
|
output_cost_per_token: Optional[float] = 0.0
|
||||||
max_tokens: Optional[int]
|
max_tokens: Optional[int] = 2048 # assume 2048 if not set
|
||||||
|
|
||||||
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
|
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
|
||||||
# we look up the base model in model_prices_and_context_window.json
|
# we look up the base model in model_prices_and_context_window.json
|
||||||
|
@ -65,38 +78,41 @@ class ModelInfo(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
extra = Extra.allow # Allow extra fields
|
extra = Extra.allow # Allow extra fields
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
# @root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
# def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
# if values.get("id") is None:
|
if values.get("id") is None:
|
||||||
# values.update({"id": str(uuid.uuid4())})
|
values.update({"id": str(uuid.uuid4())})
|
||||||
# if values.get("mode") is None:
|
if values.get("mode") is None:
|
||||||
# values.update({"mode": str(uuid.uuid4())})
|
values.update({"mode": None})
|
||||||
# return values
|
if values.get("input_cost_per_token") is None:
|
||||||
|
values.update({"input_cost_per_token": None})
|
||||||
|
if values.get("output_cost_per_token") is None:
|
||||||
|
values.update({"output_cost_per_token": None})
|
||||||
|
if values.get("max_tokens") is None:
|
||||||
|
values.update({"max_tokens": None})
|
||||||
|
if values.get("base_model") is None:
|
||||||
|
values.update({"base_model": None})
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelParams(BaseModel):
|
class ModelParams(LiteLLMBase):
|
||||||
model_name: str
|
model_name: str
|
||||||
litellm_params: dict
|
litellm_params: dict
|
||||||
model_info: Optional[ModelInfo]=None
|
model_info: ModelInfo
|
||||||
|
|
||||||
# def __init__(self, model_name: str, litellm_params: dict, model_info: Optional[ModelInfo] = None):
|
|
||||||
# self.model_name = model_name
|
|
||||||
# self.litellm_params = litellm_params
|
|
||||||
# self.model_info = model_info if model_info else ModelInfo()
|
|
||||||
# super.__init__(model_name=self.model_name, litellm_params=self.litellm_params, model_info=self.model_info)
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
# @root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
# def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
# if values.get("model_info") is None:
|
if values.get("model_info") is None:
|
||||||
# values.update({"model_info": ModelInfo()})
|
values.update({"model_info": ModelInfo()})
|
||||||
# return values
|
return values
|
||||||
|
|
||||||
class GenerateKeyRequest(BaseModel):
|
class GenerateKeyRequest(LiteLLMBase):
|
||||||
duration: Optional[str] = "1h"
|
duration: Optional[str] = "1h"
|
||||||
models: Optional[list] = []
|
models: Optional[list] = []
|
||||||
aliases: Optional[dict] = {}
|
aliases: Optional[dict] = {}
|
||||||
|
@ -105,26 +121,19 @@ class GenerateKeyRequest(BaseModel):
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
|
|
||||||
def json(self, **kwargs):
|
class GenerateKeyResponse(LiteLLMBase):
|
||||||
try:
|
|
||||||
return self.model_dump() # noqa
|
|
||||||
except:
|
|
||||||
# if using pydantic v1
|
|
||||||
return self.dict()
|
|
||||||
|
|
||||||
class GenerateKeyResponse(BaseModel):
|
|
||||||
key: str
|
key: str
|
||||||
expires: datetime
|
expires: datetime
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
class _DeleteKeyObject(BaseModel):
|
class _DeleteKeyObject(LiteLLMBase):
|
||||||
key: str
|
key: str
|
||||||
|
|
||||||
class DeleteKeyRequest(BaseModel):
|
class DeleteKeyRequest(LiteLLMBase):
|
||||||
keys: List[_DeleteKeyObject]
|
keys: List[_DeleteKeyObject]
|
||||||
|
|
||||||
|
|
||||||
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth
|
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||||
"""
|
"""
|
||||||
Return the row in the db
|
Return the row in the db
|
||||||
"""
|
"""
|
||||||
|
@ -137,7 +146,7 @@ class UserAPIKeyAuth(BaseModel): # the expected response object for user api key
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
duration: str = "1h"
|
duration: str = "1h"
|
||||||
|
|
||||||
class ConfigGeneralSettings(BaseModel):
|
class ConfigGeneralSettings(LiteLLMBase):
|
||||||
"""
|
"""
|
||||||
Documents all the fields supported by `general_settings` in config.yaml
|
Documents all the fields supported by `general_settings` in config.yaml
|
||||||
"""
|
"""
|
||||||
|
@ -153,7 +162,7 @@ class ConfigGeneralSettings(BaseModel):
|
||||||
health_check_interval: int = Field(300, description="background health check interval in seconds")
|
health_check_interval: int = Field(300, description="background health check interval in seconds")
|
||||||
|
|
||||||
|
|
||||||
class ConfigYAML(BaseModel):
|
class ConfigYAML(LiteLLMBase):
|
||||||
"""
|
"""
|
||||||
Documents all the fields supported by the config.yaml
|
Documents all the fields supported by the config.yaml
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1207,10 +1207,12 @@ async def add_new_model(model_params: ModelParams):
|
||||||
|
|
||||||
print_verbose(f"Loaded config: {config}")
|
print_verbose(f"Loaded config: {config}")
|
||||||
# Add the new model to the config
|
# Add the new model to the config
|
||||||
|
model_info = model_params.model_info.json()
|
||||||
|
model_info = {k: v for k, v in model_info.items() if v is not None}
|
||||||
config['model_list'].append({
|
config['model_list'].append({
|
||||||
'model_name': model_params.model_name,
|
'model_name': model_params.model_name,
|
||||||
'litellm_params': model_params.litellm_params,
|
'litellm_params': model_params.litellm_params,
|
||||||
'model_info': model_params.model_info
|
'model_info': model_info
|
||||||
})
|
})
|
||||||
|
|
||||||
# Save the updated config
|
# Save the updated config
|
||||||
|
@ -1227,7 +1229,7 @@ async def add_new_model(model_params: ModelParams):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
||||||
|
|
||||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use v1/model/info
|
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info
|
||||||
@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
async def model_info_v1(request: Request):
|
async def model_info_v1(request: Request):
|
||||||
global llm_model_list, general_settings, user_config_file_path
|
global llm_model_list, general_settings, user_config_file_path
|
||||||
|
@ -1256,7 +1258,7 @@ async def model_info_v1(request: Request):
|
||||||
|
|
||||||
|
|
||||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933
|
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933
|
||||||
@router.get("v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
@router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
async def model_info(request: Request):
|
async def model_info(request: Request):
|
||||||
global llm_model_list, general_settings, user_config_file_path
|
global llm_model_list, general_settings, user_config_file_path
|
||||||
# Load existing config
|
# Load existing config
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue