From 58eb352ddb1c8b76389c53cd864d380ccad1b2d4 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 12:18:53 -0700 Subject: [PATCH] feat - refactor /chat/completions to have a common helper --- .pre-commit-config.yaml | 19 ++--- litellm/proxy/litellm_pre_call_utils.py | 93 +++++++++++++++++++++ litellm/proxy/proxy_server.py | 102 ++---------------------- 3 files changed, 111 insertions(+), 103 deletions(-) create mode 100644 litellm/proxy/litellm_pre_call_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 41bff6d84..cc41d85f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,14 +1,15 @@ repos: -- repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.4.8 +- repo: https://github.com/psf/black + rev: 24.2.0 hooks: - # Run the linter. - - id: ruff - exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/ - # Run the formatter. - - id: ruff-format - exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/ + - id: black +- repo: https://github.com/pycqa/flake8 + rev: 7.0.0 # The version of flake8 to use + hooks: + - id: flake8 + exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/ + additional_dependencies: [flake8-print] + files: litellm/.*\.py - repo: local hooks: - id: check-files-match diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py new file mode 100644 index 000000000..d4736f933 --- /dev/null +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -0,0 +1,93 @@ +import copy +from fastapi import Request +from typing import Any, Dict, Optional +from litellm.proxy._types import UserAPIKeyAuth +from litellm._logging import verbose_proxy_logger, verbose_logger + + +def parse_cache_control(cache_control): + cache_dict = {} + directives = cache_control.split(", ") + + for directive in directives: + if "=" in directive: + key, value = directive.split("=") + cache_dict[key] = value + else: + cache_dict[directive] = True + + return cache_dict + + +async def add_litellm_data_to_request( + data: dict, + request: Request, + user_api_key_dict: UserAPIKeyAuth, + general_settings: Optional[Dict[str, Any]] = None, + version: Optional[str] = None, +): + # Azure OpenAI only: check if user passed api-version + query_params = dict(request.query_params) + if "api-version" in query_params: + data["api_version"] = query_params["api-version"] + + # Include original request and headers in the data + data["proxy_server_request"] = { + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + ## Cache Controls + headers = request.headers + verbose_proxy_logger.debug("Request Headers: %s", headers) + cache_control_header = headers.get("Cache-Control", None) + if cache_control_header: + cache_dict = parse_cache_control(cache_control_header) + data["ttl"] = cache_dict.get("s-maxage") + + verbose_proxy_logger.debug("receiving data: %s", data) + # users can pass in 'user' param to /chat/completions. Don't override it + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + # if users are using user_api_key_auth, set `user` in `data` + data["user"] = user_api_key_dict.user_id + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["metadata"]["user_api_end_user_max_budget"] = getattr( + user_api_key_dict, "end_user_max_budget", None + ) + data["metadata"]["litellm_api_version"] = version + + if general_settings is not None: + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id + data["metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["endpoint"] = str(request.url) + # Add the OTEL Parent Trace before sending it LiteLLM + data["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span + + ### END-USER SPECIFIC PARAMS ### + if user_api_key_dict.allowed_model_region is not None: + data["allowed_model_region"] = user_api_key_dict.allowed_model_region + return data diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 25e46269e..15846ad4d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -89,6 +89,7 @@ import litellm from litellm.types.llms.openai import ( HttpxBinaryResponseContent, ) +from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request from litellm.proxy.utils import ( PrismaClient, DBClient, @@ -3827,20 +3828,6 @@ def get_litellm_model_info(model: dict = {}): return {} -def parse_cache_control(cache_control): - cache_dict = {} - directives = cache_control.split(", ") - - for directive in directives: - if "=" in directive: - key, value = directive.split("=") - cache_dict[key] = value - else: - cache_dict[directive] = True - - return cache_dict - - def on_backoff(details): # The 'tries' key in the details dictionary contains the number of completed tries verbose_proxy_logger.debug("Backing off... this was attempt # %s", details["tries"]) @@ -4153,28 +4140,14 @@ async def chat_completion( except: data = json.loads(body_str) - # Azure OpenAI only: check if user passed api-version - query_params = dict(request.query_params) - if "api-version" in query_params: - data["api_version"] = query_params["api-version"] + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + ) - # Include original request and headers in the data - data["proxy_server_request"] = { - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - ## Cache Controls - headers = request.headers - verbose_proxy_logger.debug("Request Headers: %s", headers) - cache_control_header = headers.get("Cache-Control", None) - if cache_control_header: - cache_dict = parse_cache_control(cache_control_header) - data["ttl"] = cache_dict.get("s-maxage") - - verbose_proxy_logger.debug("receiving data: %s", data) data["model"] = ( general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args @@ -4182,65 +4155,6 @@ async def chat_completion( or data["model"] # default passed in http request ) - # users can pass in 'user' param to /chat/completions. Don't override it - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - # if users are using user_api_key_auth, set `user` in `data` - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["user_api_end_user_max_budget"] = getattr( - user_api_key_dict, "end_user_max_budget", None - ) - data["metadata"]["litellm_api_version"] = version - - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["endpoint"] = str(request.url) - # Add the OTEL Parent Trace before sending it LiteLLM - data["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - _is_valid_team_configs( - team_id=team_id, team_config=team_config, request_data=data - ) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call - - ### END-USER SPECIFIC PARAMS ### - if user_api_key_dict.allowed_model_region is not None: - data["allowed_model_region"] = user_api_key_dict.allowed_model_region - global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli if user_temperature: