forked from phoenix/litellm-mirror
fix(proxy_server.py): fix linting issues
This commit is contained in:
parent
b5751bd040
commit
7ed8f8dac8
1 changed files with 26 additions and 11 deletions
|
@ -31,6 +31,7 @@ except ImportError:
|
||||||
"appdirs",
|
"appdirs",
|
||||||
"tomli-w",
|
"tomli-w",
|
||||||
"backoff",
|
"backoff",
|
||||||
|
"pyyaml"
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
@ -125,7 +126,7 @@ user_config_path = os.getenv(
|
||||||
#### GLOBAL VARIABLES ####
|
#### GLOBAL VARIABLES ####
|
||||||
llm_router: Optional[litellm.Router] = None
|
llm_router: Optional[litellm.Router] = None
|
||||||
llm_model_list: Optional[list] = None
|
llm_model_list: Optional[list] = None
|
||||||
server_settings: Optional[dict] = None
|
server_settings: dict = {}
|
||||||
log_file = "api_log.json"
|
log_file = "api_log.json"
|
||||||
|
|
||||||
|
|
||||||
|
@ -197,7 +198,7 @@ def save_params_to_config(data: dict):
|
||||||
tomli_w.dump(config, f)
|
tomli_w.dump(config, f)
|
||||||
|
|
||||||
|
|
||||||
def load_router_config(router: Optional[litellm.Router], config_file_path: Optional[str]):
|
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
config = {}
|
config = {}
|
||||||
server_settings = {}
|
server_settings = {}
|
||||||
try:
|
try:
|
||||||
|
@ -210,9 +211,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: Optio
|
||||||
pass
|
pass
|
||||||
|
|
||||||
## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')
|
## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')
|
||||||
server_settings = config.get("server_settings", None)
|
_server_settings = config.get("server_settings", None)
|
||||||
if server_settings:
|
if _server_settings:
|
||||||
server_settings = server_settings
|
server_settings = _server_settings
|
||||||
|
|
||||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||||
litellm_settings = config.get('litellm_settings', None)
|
litellm_settings = config.get('litellm_settings', None)
|
||||||
|
@ -543,8 +544,10 @@ def litellm_completion(*args, **kwargs):
|
||||||
@router.get("/v1/models")
|
@router.get("/v1/models")
|
||||||
@router.get("/models") # if project requires model list
|
@router.get("/models") # if project requires model list
|
||||||
def model_list():
|
def model_list():
|
||||||
global llm_model_list
|
global llm_model_list, server_settings
|
||||||
all_models = litellm.utils.get_valid_models()
|
all_models = []
|
||||||
|
if server_settings.get("infer_model_from_keys", False):
|
||||||
|
all_models = litellm.utils.get_valid_models()
|
||||||
if llm_model_list:
|
if llm_model_list:
|
||||||
all_models += llm_model_list
|
all_models += llm_model_list
|
||||||
if user_model is not None:
|
if user_model is not None:
|
||||||
|
@ -573,13 +576,19 @@ def model_list():
|
||||||
@router.post("/v1/completions")
|
@router.post("/v1/completions")
|
||||||
@router.post("/completions")
|
@router.post("/completions")
|
||||||
@router.post("/engines/{model:path}/completions")
|
@router.post("/engines/{model:path}/completions")
|
||||||
async def completion(request: Request):
|
async def completion(request: Request, model: Optional[str] = None):
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
try:
|
try:
|
||||||
data = ast.literal_eval(body_str)
|
data = ast.literal_eval(body_str)
|
||||||
except:
|
except:
|
||||||
data = json.loads(body_str)
|
data = json.loads(body_str)
|
||||||
|
data["model"] = (
|
||||||
|
server_settings.get("completion_model", None) # server default
|
||||||
|
or user_model # model name passed via cli args
|
||||||
|
or model # for azure deployments
|
||||||
|
or data["model"] # default passed in http request
|
||||||
|
)
|
||||||
if user_model:
|
if user_model:
|
||||||
data["model"] = user_model
|
data["model"] = user_model
|
||||||
data["call_type"] = "text_completion"
|
data["call_type"] = "text_completion"
|
||||||
|
@ -590,15 +599,21 @@ async def completion(request: Request):
|
||||||
|
|
||||||
@router.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
@router.post("/chat/completions")
|
@router.post("/chat/completions")
|
||||||
async def chat_completion(request: Request):
|
@router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint
|
||||||
|
async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
|
global server_settings
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
try:
|
try:
|
||||||
data = ast.literal_eval(body_str)
|
data = ast.literal_eval(body_str)
|
||||||
except:
|
except:
|
||||||
data = json.loads(body_str)
|
data = json.loads(body_str)
|
||||||
if user_model:
|
data["model"] = (
|
||||||
data["model"] = user_model
|
server_settings.get("completion_model", None) # server default
|
||||||
|
or user_model # model name passed via cli args
|
||||||
|
or model # for azure deployments
|
||||||
|
or data["model"] # default passed in http request
|
||||||
|
)
|
||||||
data["call_type"] = "chat_completion"
|
data["call_type"] = "chat_completion"
|
||||||
return litellm_completion(
|
return litellm_completion(
|
||||||
**data
|
**data
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue