feat - refactor /chat/completions to have a common helper

This commit is contained in:
Ishaan Jaff 2024-06-07 12:18:53 -07:00
parent 923cbed6ab
commit 58eb352ddb
3 changed files with 111 additions and 103 deletions

View file

@ -1,14 +1,15 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.8
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
# Run the linter.
- id: ruff
exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/
# Run the formatter.
- id: ruff-format
exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/
- id: black
- repo: https://github.com/pycqa/flake8
rev: 7.0.0 # The version of flake8 to use
hooks:
- id: flake8
exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/
additional_dependencies: [flake8-print]
files: litellm/.*\.py
- repo: local
hooks:
- id: check-files-match

View file

@ -0,0 +1,93 @@
import copy
from fastapi import Request
from typing import Any, Dict, Optional
from litellm.proxy._types import UserAPIKeyAuth
from litellm._logging import verbose_proxy_logger, verbose_logger
def parse_cache_control(cache_control):
cache_dict = {}
directives = cache_control.split(", ")
for directive in directives:
if "=" in directive:
key, value = directive.split("=")
cache_dict[key] = value
else:
cache_dict[directive] = True
return cache_dict
async def add_litellm_data_to_request(
data: dict,
request: Request,
user_api_key_dict: UserAPIKeyAuth,
general_settings: Optional[Dict[str, Any]] = None,
version: Optional[str] = None,
):
# Azure OpenAI only: check if user passed api-version
query_params = dict(request.query_params)
if "api-version" in query_params:
data["api_version"] = query_params["api-version"]
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
## Cache Controls
headers = request.headers
verbose_proxy_logger.debug("Request Headers: %s", headers)
cache_control_header = headers.get("Cache-Control", None)
if cache_control_header:
cache_dict = parse_cache_control(cache_control_header)
data["ttl"] = cache_dict.get("s-maxage")
verbose_proxy_logger.debug("receiving data: %s", data)
# users can pass in 'user' param to /chat/completions. Don't override it
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
# if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.user_id
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_end_user_max_budget"] = getattr(
user_api_key_dict, "end_user_max_budget", None
)
data["metadata"]["litellm_api_version"] = version
if general_settings is not None:
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["endpoint"] = str(request.url)
# Add the OTEL Parent Trace before sending it LiteLLM
data["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span
### END-USER SPECIFIC PARAMS ###
if user_api_key_dict.allowed_model_region is not None:
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
return data

View file

@ -89,6 +89,7 @@ import litellm
from litellm.types.llms.openai import (
HttpxBinaryResponseContent,
)
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.utils import (
PrismaClient,
DBClient,
@ -3827,20 +3828,6 @@ def get_litellm_model_info(model: dict = {}):
return {}
def parse_cache_control(cache_control):
cache_dict = {}
directives = cache_control.split(", ")
for directive in directives:
if "=" in directive:
key, value = directive.split("=")
cache_dict[key] = value
else:
cache_dict[directive] = True
return cache_dict
def on_backoff(details):
# The 'tries' key in the details dictionary contains the number of completed tries
verbose_proxy_logger.debug("Backing off... this was attempt # %s", details["tries"])
@ -4153,28 +4140,14 @@ async def chat_completion(
except:
data = json.loads(body_str)
# Azure OpenAI only: check if user passed api-version
query_params = dict(request.query_params)
if "api-version" in query_params:
data["api_version"] = query_params["api-version"]
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
)
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
## Cache Controls
headers = request.headers
verbose_proxy_logger.debug("Request Headers: %s", headers)
cache_control_header = headers.get("Cache-Control", None)
if cache_control_header:
cache_dict = parse_cache_control(cache_control_header)
data["ttl"] = cache_dict.get("s-maxage")
verbose_proxy_logger.debug("receiving data: %s", data)
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
@ -4182,65 +4155,6 @@ async def chat_completion(
or data["model"] # default passed in http request
)
# users can pass in 'user' param to /chat/completions. Don't override it
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
# if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.user_id
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_end_user_max_budget"] = getattr(
user_api_key_dict, "end_user_max_budget", None
)
data["metadata"]["litellm_api_version"] = version
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["endpoint"] = str(request.url)
# Add the OTEL Parent Trace before sending it LiteLLM
data["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
_is_valid_team_configs(
team_id=team_id, team_config=team_config, request_data=data
)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
### END-USER SPECIFIC PARAMS ###
if user_api_key_dict.allowed_model_region is not None:
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli
if user_temperature: