mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Move chat_completions before completions
so that the `chat_completions` route is defined before the `completions` route. This is necessary because the `chat_completions` route is more specific than the `completions` route, and the order of route definitions matters in FastAPI. Without this, doing a request to `/openai/deployments/{model_in_url}/chat/completions` might trigger `completions` being called (with `model` set to `{model_in_url}/chat` instead of `chat_completions` getting called, which is the correct function. Fixes: GH-3372
This commit is contained in:
parent
285a3733a9
commit
dd166680d1
1 changed files with 166 additions and 166 deletions
|
@ -3371,172 +3371,6 @@ def model_list(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/v1/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]
|
|
||||||
)
|
|
||||||
@router.post(
|
|
||||||
"/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]
|
|
||||||
)
|
|
||||||
@router.post(
|
|
||||||
"/engines/{model:path}/completions",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["completions"],
|
|
||||||
)
|
|
||||||
@router.post(
|
|
||||||
"/openai/deployments/{model:path}/completions",
|
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
|
||||||
tags=["completions"],
|
|
||||||
)
|
|
||||||
async def completion(
|
|
||||||
request: Request,
|
|
||||||
fastapi_response: Response,
|
|
||||||
model: Optional[str] = None,
|
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
|
||||||
):
|
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
|
||||||
try:
|
|
||||||
body = await request.body()
|
|
||||||
body_str = body.decode()
|
|
||||||
try:
|
|
||||||
data = ast.literal_eval(body_str)
|
|
||||||
except:
|
|
||||||
data = json.loads(body_str)
|
|
||||||
|
|
||||||
data["user"] = data.get("user", user_api_key_dict.user_id)
|
|
||||||
data["model"] = (
|
|
||||||
general_settings.get("completion_model", None) # server default
|
|
||||||
or user_model # model name passed via cli args
|
|
||||||
or model # for azure deployments
|
|
||||||
or data["model"] # default passed in http request
|
|
||||||
)
|
|
||||||
if user_model:
|
|
||||||
data["model"] = user_model
|
|
||||||
if "metadata" not in data:
|
|
||||||
data["metadata"] = {}
|
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
|
||||||
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
|
||||||
data["metadata"]["user_api_key_alias"] = getattr(
|
|
||||||
user_api_key_dict, "key_alias", None
|
|
||||||
)
|
|
||||||
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
|
||||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
|
||||||
user_api_key_dict, "team_id", None
|
|
||||||
)
|
|
||||||
data["metadata"]["user_api_key_team_alias"] = getattr(
|
|
||||||
user_api_key_dict, "team_alias", None
|
|
||||||
)
|
|
||||||
_headers = dict(request.headers)
|
|
||||||
_headers.pop(
|
|
||||||
"authorization", None
|
|
||||||
) # do not store the original `sk-..` api key in the db
|
|
||||||
data["metadata"]["headers"] = _headers
|
|
||||||
data["metadata"]["endpoint"] = str(request.url)
|
|
||||||
|
|
||||||
# override with user settings, these are params passed via cli
|
|
||||||
if user_temperature:
|
|
||||||
data["temperature"] = user_temperature
|
|
||||||
if user_request_timeout:
|
|
||||||
data["request_timeout"] = user_request_timeout
|
|
||||||
if user_max_tokens:
|
|
||||||
data["max_tokens"] = user_max_tokens
|
|
||||||
if user_api_base:
|
|
||||||
data["api_base"] = user_api_base
|
|
||||||
|
|
||||||
### MODEL ALIAS MAPPING ###
|
|
||||||
# check if model name in model alias map
|
|
||||||
# get the actual model name
|
|
||||||
if data["model"] in litellm.model_alias_map:
|
|
||||||
data["model"] = litellm.model_alias_map[data["model"]]
|
|
||||||
|
|
||||||
### CALL HOOKS ### - modify incoming data before calling the model
|
|
||||||
data = await proxy_logging_obj.pre_call_hook(
|
|
||||||
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
|
||||||
)
|
|
||||||
|
|
||||||
### ROUTE THE REQUESTs ###
|
|
||||||
router_model_names = llm_router.model_names if llm_router is not None else []
|
|
||||||
# skip router if user passed their key
|
|
||||||
if "api_key" in data:
|
|
||||||
response = await litellm.atext_completion(**data)
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in router_model_names
|
|
||||||
): # model in router model list
|
|
||||||
response = await llm_router.atext_completion(**data)
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and llm_router.model_group_alias is not None
|
|
||||||
and data["model"] in llm_router.model_group_alias
|
|
||||||
): # model set in model_group_alias
|
|
||||||
response = await llm_router.atext_completion(**data)
|
|
||||||
elif (
|
|
||||||
llm_router is not None and data["model"] in llm_router.deployment_names
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
response = await llm_router.atext_completion(
|
|
||||||
**data, specific_deployment=True
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
llm_router is not None
|
|
||||||
and data["model"] not in router_model_names
|
|
||||||
and llm_router.default_deployment is not None
|
|
||||||
): # model in router deployments, calling a specific deployment on the router
|
|
||||||
response = await llm_router.atext_completion(**data)
|
|
||||||
elif user_model is not None: # `litellm --model <your-model-name>`
|
|
||||||
response = await litellm.atext_completion(**data)
|
|
||||||
else:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
"error": "Invalid model name passed in model="
|
|
||||||
+ data.get("model", "")
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if hasattr(response, "_hidden_params"):
|
|
||||||
model_id = response._hidden_params.get("model_id", None) or ""
|
|
||||||
original_response = (
|
|
||||||
response._hidden_params.get("original_response", None) or ""
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_id = ""
|
|
||||||
original_response = ""
|
|
||||||
|
|
||||||
verbose_proxy_logger.debug("final response: %s", response)
|
|
||||||
if (
|
|
||||||
"stream" in data and data["stream"] == True
|
|
||||||
): # use generate_responses to stream responses
|
|
||||||
custom_headers = {
|
|
||||||
"x-litellm-model-id": model_id,
|
|
||||||
}
|
|
||||||
selected_data_generator = select_data_generator(
|
|
||||||
response=response, user_api_key_dict=user_api_key_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
selected_data_generator,
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers=custom_headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
fastapi_response.headers["x-litellm-model-id"] = model_id
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
data["litellm_status"] = "fail" # used for alerting
|
|
||||||
verbose_proxy_logger.debug("EXCEPTION RAISED IN PROXY MAIN.PY")
|
|
||||||
verbose_proxy_logger.debug(
|
|
||||||
"\033[1;31mAn error occurred: %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
|
|
||||||
e,
|
|
||||||
)
|
|
||||||
traceback.print_exc()
|
|
||||||
error_traceback = traceback.format_exc()
|
|
||||||
error_msg = f"{str(e)}"
|
|
||||||
raise ProxyException(
|
|
||||||
message=getattr(e, "message", error_msg),
|
|
||||||
type=getattr(e, "type", "None"),
|
|
||||||
param=getattr(e, "param", "None"),
|
|
||||||
code=getattr(e, "status_code", 500),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
@ -3809,6 +3643,172 @@ async def chat_completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]
|
||||||
|
)
|
||||||
|
@router.post(
|
||||||
|
"/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]
|
||||||
|
)
|
||||||
|
@router.post(
|
||||||
|
"/engines/{model:path}/completions",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["completions"],
|
||||||
|
)
|
||||||
|
@router.post(
|
||||||
|
"/openai/deployments/{model:path}/completions",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["completions"],
|
||||||
|
)
|
||||||
|
async def completion(
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
|
try:
|
||||||
|
body = await request.body()
|
||||||
|
body_str = body.decode()
|
||||||
|
try:
|
||||||
|
data = ast.literal_eval(body_str)
|
||||||
|
except:
|
||||||
|
data = json.loads(body_str)
|
||||||
|
|
||||||
|
data["user"] = data.get("user", user_api_key_dict.user_id)
|
||||||
|
data["model"] = (
|
||||||
|
general_settings.get("completion_model", None) # server default
|
||||||
|
or user_model # model name passed via cli args
|
||||||
|
or model # for azure deployments
|
||||||
|
or data["model"] # default passed in http request
|
||||||
|
)
|
||||||
|
if user_model:
|
||||||
|
data["model"] = user_model
|
||||||
|
if "metadata" not in data:
|
||||||
|
data["metadata"] = {}
|
||||||
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
|
data["metadata"]["user_api_key_alias"] = getattr(
|
||||||
|
user_api_key_dict, "key_alias", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||||
|
user_api_key_dict, "team_id", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||||
|
user_api_key_dict, "team_alias", None
|
||||||
|
)
|
||||||
|
_headers = dict(request.headers)
|
||||||
|
_headers.pop(
|
||||||
|
"authorization", None
|
||||||
|
) # do not store the original `sk-..` api key in the db
|
||||||
|
data["metadata"]["headers"] = _headers
|
||||||
|
data["metadata"]["endpoint"] = str(request.url)
|
||||||
|
|
||||||
|
# override with user settings, these are params passed via cli
|
||||||
|
if user_temperature:
|
||||||
|
data["temperature"] = user_temperature
|
||||||
|
if user_request_timeout:
|
||||||
|
data["request_timeout"] = user_request_timeout
|
||||||
|
if user_max_tokens:
|
||||||
|
data["max_tokens"] = user_max_tokens
|
||||||
|
if user_api_base:
|
||||||
|
data["api_base"] = user_api_base
|
||||||
|
|
||||||
|
### MODEL ALIAS MAPPING ###
|
||||||
|
# check if model name in model alias map
|
||||||
|
# get the actual model name
|
||||||
|
if data["model"] in litellm.model_alias_map:
|
||||||
|
data["model"] = litellm.model_alias_map[data["model"]]
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify incoming data before calling the model
|
||||||
|
data = await proxy_logging_obj.pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, data=data, call_type="completion"
|
||||||
|
)
|
||||||
|
|
||||||
|
### ROUTE THE REQUESTs ###
|
||||||
|
router_model_names = llm_router.model_names if llm_router is not None else []
|
||||||
|
# skip router if user passed their key
|
||||||
|
if "api_key" in data:
|
||||||
|
response = await litellm.atext_completion(**data)
|
||||||
|
elif (
|
||||||
|
llm_router is not None and data["model"] in router_model_names
|
||||||
|
): # model in router model list
|
||||||
|
response = await llm_router.atext_completion(**data)
|
||||||
|
elif (
|
||||||
|
llm_router is not None
|
||||||
|
and llm_router.model_group_alias is not None
|
||||||
|
and data["model"] in llm_router.model_group_alias
|
||||||
|
): # model set in model_group_alias
|
||||||
|
response = await llm_router.atext_completion(**data)
|
||||||
|
elif (
|
||||||
|
llm_router is not None and data["model"] in llm_router.deployment_names
|
||||||
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
|
response = await llm_router.atext_completion(
|
||||||
|
**data, specific_deployment=True
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
llm_router is not None
|
||||||
|
and data["model"] not in router_model_names
|
||||||
|
and llm_router.default_deployment is not None
|
||||||
|
): # model in router deployments, calling a specific deployment on the router
|
||||||
|
response = await llm_router.atext_completion(**data)
|
||||||
|
elif user_model is not None: # `litellm --model <your-model-name>`
|
||||||
|
response = await litellm.atext_completion(**data)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={
|
||||||
|
"error": "Invalid model name passed in model="
|
||||||
|
+ data.get("model", "")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(response, "_hidden_params"):
|
||||||
|
model_id = response._hidden_params.get("model_id", None) or ""
|
||||||
|
original_response = (
|
||||||
|
response._hidden_params.get("original_response", None) or ""
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_id = ""
|
||||||
|
original_response = ""
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("final response: %s", response)
|
||||||
|
if (
|
||||||
|
"stream" in data and data["stream"] == True
|
||||||
|
): # use generate_responses to stream responses
|
||||||
|
custom_headers = {
|
||||||
|
"x-litellm-model-id": model_id,
|
||||||
|
}
|
||||||
|
selected_data_generator = select_data_generator(
|
||||||
|
response=response, user_api_key_dict=user_api_key_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
selected_data_generator,
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers=custom_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
fastapi_response.headers["x-litellm-model-id"] = model_id
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
|
verbose_proxy_logger.debug("EXCEPTION RAISED IN PROXY MAIN.PY")
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
"\033[1;31mAn error occurred: %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}"
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", error_msg),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", 500),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/v1/embeddings",
|
"/v1/embeddings",
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue