fix(proxy_server.py): fix linting issues

This commit is contained in:
Krrish Dholakia 2023-11-03 13:44:24 -07:00
parent b5751bd040
commit 7ed8f8dac8

View file

@ -31,6 +31,7 @@ except ImportError:
"appdirs", "appdirs",
"tomli-w", "tomli-w",
"backoff", "backoff",
"pyyaml"
] ]
) )
import uvicorn import uvicorn
@ -125,7 +126,7 @@ user_config_path = os.getenv(
#### GLOBAL VARIABLES #### #### GLOBAL VARIABLES ####
llm_router: Optional[litellm.Router] = None llm_router: Optional[litellm.Router] = None
llm_model_list: Optional[list] = None llm_model_list: Optional[list] = None
server_settings: Optional[dict] = None server_settings: dict = {}
log_file = "api_log.json" log_file = "api_log.json"
@ -197,7 +198,7 @@ def save_params_to_config(data: dict):
tomli_w.dump(config, f) tomli_w.dump(config, f)
def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]): def load_router_config(router: Optional[litellm.Router], config_file_path: str):
config = {} config = {}
server_settings = {} server_settings = {}
try: try:
@ -210,9 +211,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: Optio
pass pass
## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral') ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')
server_settings = config.get("server_settings", None) _server_settings = config.get("server_settings", None)
if server_settings: if _server_settings:
server_settings = server_settings server_settings = _server_settings
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get('litellm_settings', None) litellm_settings = config.get('litellm_settings', None)
@ -543,7 +544,9 @@ def litellm_completion(*args, **kwargs):
@router.get("/v1/models") @router.get("/v1/models")
@router.get("/models") # if project requires model list @router.get("/models") # if project requires model list
def model_list(): def model_list():
global llm_model_list global llm_model_list, server_settings
all_models = []
if server_settings.get("infer_model_from_keys", False):
all_models = litellm.utils.get_valid_models() all_models = litellm.utils.get_valid_models()
if llm_model_list: if llm_model_list:
all_models += llm_model_list all_models += llm_model_list
@ -573,13 +576,19 @@ def model_list():
@router.post("/v1/completions") @router.post("/v1/completions")
@router.post("/completions") @router.post("/completions")
@router.post("/engines/{model:path}/completions") @router.post("/engines/{model:path}/completions")
async def completion(request: Request): async def completion(request: Request, model: Optional[str] = None):
body = await request.body() body = await request.body()
body_str = body.decode() body_str = body.decode()
try: try:
data = ast.literal_eval(body_str) data = ast.literal_eval(body_str)
except: except:
data = json.loads(body_str) data = json.loads(body_str)
data["model"] = (
server_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or model # for azure deployments
or data["model"] # default passed in http request
)
if user_model: if user_model:
data["model"] = user_model data["model"] = user_model
data["call_type"] = "text_completion" data["call_type"] = "text_completion"
@ -590,15 +599,21 @@ async def completion(request: Request):
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions")
@router.post("/chat/completions") @router.post("/chat/completions")
async def chat_completion(request: Request): @router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None):
global server_settings
body = await request.body() body = await request.body()
body_str = body.decode() body_str = body.decode()
try: try:
data = ast.literal_eval(body_str) data = ast.literal_eval(body_str)
except: except:
data = json.loads(body_str) data = json.loads(body_str)
if user_model: data["model"] = (
data["model"] = user_model server_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
or model # for azure deployments
or data["model"] # default passed in http request
)
data["call_type"] = "chat_completion" data["call_type"] = "chat_completion"
return litellm_completion( return litellm_completion(
**data **data