Merge pull request #4065 from BerriAI/litellm_use_common_func

[Refactor] - Refactor proxy_server.py to use common function for `add_litellm_data_to_request`
This commit is contained in:
Ishaan Jaff 2024-06-07 14:02:17 -07:00 committed by GitHub
commit d9dacc1f43
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 264 additions and 806 deletions

View file

@ -1,14 +1,15 @@
repos: repos:
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/psf/black
# Ruff version. rev: 24.2.0
rev: v0.4.8
hooks: hooks:
# Run the linter. - id: black
- id: ruff - repo: https://github.com/pycqa/flake8
exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/ rev: 7.0.0 # The version of flake8 to use
# Run the formatter. hooks:
- id: ruff-format - id: flake8
exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/ exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/
additional_dependencies: [flake8-print]
files: litellm/.*\.py
- repo: local - repo: local
hooks: hooks:
- id: check-files-match - id: check-files-match

View file

@ -165,7 +165,8 @@ class OpenTelemetry(CustomLogger):
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {} proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
headers = proxy_server_request.get("headers", {}) or {} headers = proxy_server_request.get("headers", {}) or {}
traceparent = headers.get("traceparent", None) traceparent = headers.get("traceparent", None)
parent_otel_span = litellm_params.get("litellm_parent_otel_span", None) _metadata = litellm_params.get("metadata", {})
parent_otel_span = _metadata.get("litellm_parent_otel_span", None)
""" """
Two way to use parents in opentelemetry Two way to use parents in opentelemetry

View file

@ -607,7 +607,6 @@ def completion(
client = kwargs.get("client", None) client = kwargs.get("client", None)
### Admin Controls ### ### Admin Controls ###
no_log = kwargs.get("no-log", False) no_log = kwargs.get("no-log", False)
litellm_parent_otel_span = kwargs.get("litellm_parent_otel_span", None)
######## end of unpacking kwargs ########### ######## end of unpacking kwargs ###########
openai_params = [ openai_params = [
"functions", "functions",
@ -697,7 +696,6 @@ def completion(
"allowed_model_region", "allowed_model_region",
"model_config", "model_config",
"fastest_response", "fastest_response",
"litellm_parent_otel_span",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
@ -882,7 +880,6 @@ def completion(
input_cost_per_token=input_cost_per_token, input_cost_per_token=input_cost_per_token,
output_cost_per_second=output_cost_per_second, output_cost_per_second=output_cost_per_second,
output_cost_per_token=output_cost_per_token, output_cost_per_token=output_cost_per_token,
litellm_parent_otel_span=litellm_parent_otel_span,
) )
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,

View file

@ -0,0 +1,130 @@
import copy
from fastapi import Request
from typing import Any, Dict, Optional, TYPE_CHECKING
from litellm.proxy._types import UserAPIKeyAuth
from litellm._logging import verbose_proxy_logger, verbose_logger
if TYPE_CHECKING:
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
ProxyConfig = _ProxyConfig
else:
ProxyConfig = Any
def parse_cache_control(cache_control):
cache_dict = {}
directives = cache_control.split(", ")
for directive in directives:
if "=" in directive:
key, value = directive.split("=")
cache_dict[key] = value
else:
cache_dict[directive] = True
return cache_dict
async def add_litellm_data_to_request(
data: dict,
request: Request,
user_api_key_dict: UserAPIKeyAuth,
proxy_config: ProxyConfig,
general_settings: Optional[Dict[str, Any]] = None,
version: Optional[str] = None,
):
"""
Adds LiteLLM-specific data to the request.
Args:
data (dict): The data dictionary to be modified.
request (Request): The incoming request.
user_api_key_dict (UserAPIKeyAuth): The user API key dictionary.
general_settings (Optional[Dict[str, Any]], optional): General settings. Defaults to None.
version (Optional[str], optional): Version. Defaults to None.
Returns:
dict: The modified data dictionary.
"""
query_params = dict(request.query_params)
if "api-version" in query_params:
data["api_version"] = query_params["api-version"]
# Include original request and headers in the data
data["proxy_server_request"] = {
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
## Cache Controls
headers = request.headers
verbose_proxy_logger.debug("Request Headers: %s", headers)
cache_control_header = headers.get("Cache-Control", None)
if cache_control_header:
cache_dict = parse_cache_control(cache_control_header)
data["ttl"] = cache_dict.get("s-maxage")
verbose_proxy_logger.debug("receiving data: %s", data)
# users can pass in 'user' param to /chat/completions. Don't override it
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
# if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.user_id
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_end_user_max_budget"] = getattr(
user_api_key_dict, "end_user_max_budget", None
)
data["metadata"]["litellm_api_version"] = version
if general_settings is not None:
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["endpoint"] = str(request.url)
# Add the OTEL Parent Trace before sending it LiteLLM
data["metadata"]["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span
### END-USER SPECIFIC PARAMS ###
if user_api_key_dict.allowed_model_region is not None:
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
return data

File diff suppressed because it is too large Load diff

View file

@ -73,7 +73,8 @@ def print_verbose(print_statement):
def safe_deep_copy(data): def safe_deep_copy(data):
if isinstance(data, dict): if isinstance(data, dict):
# remove litellm_parent_otel_span since this is not picklable # remove litellm_parent_otel_span since this is not picklable
data.pop("litellm_parent_otel_span", None) if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]:
data["metadata"].pop("litellm_parent_otel_span")
new_data = copy.deepcopy(data) new_data = copy.deepcopy(data)
return new_data return new_data

View file

@ -152,7 +152,6 @@ def test_chat_completion(mock_acompletion, client_no_auth):
specific_deployment=True, specific_deployment=True,
metadata=mock.ANY, metadata=mock.ANY,
proxy_server_request=mock.ANY, proxy_server_request=mock.ANY,
litellm_parent_otel_span=mock.ANY,
) )
print(f"response - {response.text}") print(f"response - {response.text}")
assert response.status_code == 200 assert response.status_code == 200

View file

@ -4927,7 +4927,6 @@ def get_litellm_params(
input_cost_per_token=None, input_cost_per_token=None,
output_cost_per_token=None, output_cost_per_token=None,
output_cost_per_second=None, output_cost_per_second=None,
litellm_parent_otel_span=None,
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
@ -4950,7 +4949,6 @@ def get_litellm_params(
"input_cost_per_second": input_cost_per_second, "input_cost_per_second": input_cost_per_second,
"output_cost_per_token": output_cost_per_token, "output_cost_per_token": output_cost_per_token,
"output_cost_per_second": output_cost_per_second, "output_cost_per_second": output_cost_per_second,
"litellm_parent_otel_span": litellm_parent_otel_span,
} }
return litellm_params return litellm_params