mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat(proxy_server.py): add model router to proxy
This commit is contained in:
parent
c8f8686d7c
commit
3a8c8f56d6
3 changed files with 46 additions and 7 deletions
|
@ -101,6 +101,7 @@ user_telemetry = True
|
|||
user_config = None
|
||||
user_headers = None
|
||||
local_logging = True # writes logs to a local api_log.json file for debugging
|
||||
model_router = litellm.Router()
|
||||
config_filename = "litellm.secrets.toml"
|
||||
config_dir = os.getcwd()
|
||||
config_dir = appdirs.user_config_dir("litellm")
|
||||
|
@ -213,6 +214,12 @@ def load_config():
|
|||
if "model" in user_config:
|
||||
if user_model in user_config["model"]:
|
||||
model_config = user_config["model"][user_model]
|
||||
model_list = []
|
||||
for model in user_config["model"]:
|
||||
if "model_list" in user_config["model"][model]:
|
||||
model_list.extend(user_config["model"][model]["model_list"])
|
||||
if len(model_list) > 0:
|
||||
model_router.set_model_list(model_list=model_list)
|
||||
|
||||
print_verbose(f"user_config: {user_config}")
|
||||
print_verbose(f"model_config: {model_config}")
|
||||
|
@ -423,7 +430,7 @@ async def completion(request: Request):
|
|||
data = await request.json()
|
||||
return litellm_completion(data=data, type="completion", user_model=user_model, user_temperature=user_temperature,
|
||||
user_max_tokens=user_max_tokens, user_api_base=user_api_base, user_headers=user_headers,
|
||||
user_debug=user_debug)
|
||||
user_debug=user_debug, model_router=model_router)
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
|
@ -433,7 +440,7 @@ async def chat_completion(request: Request):
|
|||
print_verbose(f"data passed in: {data}")
|
||||
return litellm_completion(data, type="chat_completion", user_model=user_model,
|
||||
user_temperature=user_temperature, user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug)
|
||||
user_api_base=user_api_base, user_headers=user_headers, user_debug=user_debug, model_router=model_router)
|
||||
|
||||
|
||||
def print_cost_logs():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue