mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Merge pull request #4065 from BerriAI/litellm_use_common_func
[Refactor] - Refactor proxy_server.py to use common function for `add_litellm_data_to_request`
This commit is contained in:
commit
d9dacc1f43
8 changed files with 264 additions and 806 deletions
|
@ -1,14 +1,15 @@
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
- repo: https://github.com/psf/black
|
||||||
# Ruff version.
|
rev: 24.2.0
|
||||||
rev: v0.4.8
|
|
||||||
hooks:
|
hooks:
|
||||||
# Run the linter.
|
- id: black
|
||||||
- id: ruff
|
- repo: https://github.com/pycqa/flake8
|
||||||
exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/
|
rev: 7.0.0 # The version of flake8 to use
|
||||||
# Run the formatter.
|
hooks:
|
||||||
- id: ruff-format
|
- id: flake8
|
||||||
exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/
|
exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/
|
||||||
|
additional_dependencies: [flake8-print]
|
||||||
|
files: litellm/.*\.py
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-files-match
|
- id: check-files-match
|
||||||
|
|
|
@ -165,7 +165,8 @@ class OpenTelemetry(CustomLogger):
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
|
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
|
||||||
headers = proxy_server_request.get("headers", {}) or {}
|
headers = proxy_server_request.get("headers", {}) or {}
|
||||||
traceparent = headers.get("traceparent", None)
|
traceparent = headers.get("traceparent", None)
|
||||||
parent_otel_span = litellm_params.get("litellm_parent_otel_span", None)
|
_metadata = litellm_params.get("metadata", {})
|
||||||
|
parent_otel_span = _metadata.get("litellm_parent_otel_span", None)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Two way to use parents in opentelemetry
|
Two way to use parents in opentelemetry
|
||||||
|
|
|
@ -607,7 +607,6 @@ def completion(
|
||||||
client = kwargs.get("client", None)
|
client = kwargs.get("client", None)
|
||||||
### Admin Controls ###
|
### Admin Controls ###
|
||||||
no_log = kwargs.get("no-log", False)
|
no_log = kwargs.get("no-log", False)
|
||||||
litellm_parent_otel_span = kwargs.get("litellm_parent_otel_span", None)
|
|
||||||
######## end of unpacking kwargs ###########
|
######## end of unpacking kwargs ###########
|
||||||
openai_params = [
|
openai_params = [
|
||||||
"functions",
|
"functions",
|
||||||
|
@ -697,7 +696,6 @@ def completion(
|
||||||
"allowed_model_region",
|
"allowed_model_region",
|
||||||
"model_config",
|
"model_config",
|
||||||
"fastest_response",
|
"fastest_response",
|
||||||
"litellm_parent_otel_span",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
|
@ -882,7 +880,6 @@ def completion(
|
||||||
input_cost_per_token=input_cost_per_token,
|
input_cost_per_token=input_cost_per_token,
|
||||||
output_cost_per_second=output_cost_per_second,
|
output_cost_per_second=output_cost_per_second,
|
||||||
output_cost_per_token=output_cost_per_token,
|
output_cost_per_token=output_cost_per_token,
|
||||||
litellm_parent_otel_span=litellm_parent_otel_span,
|
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
130
litellm/proxy/litellm_pre_call_utils.py
Normal file
130
litellm/proxy/litellm_pre_call_utils.py
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
import copy
|
||||||
|
from fastapi import Request
|
||||||
|
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm._logging import verbose_proxy_logger, verbose_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig
|
||||||
|
|
||||||
|
ProxyConfig = _ProxyConfig
|
||||||
|
else:
|
||||||
|
ProxyConfig = Any
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cache_control(cache_control):
|
||||||
|
cache_dict = {}
|
||||||
|
directives = cache_control.split(", ")
|
||||||
|
|
||||||
|
for directive in directives:
|
||||||
|
if "=" in directive:
|
||||||
|
key, value = directive.split("=")
|
||||||
|
cache_dict[key] = value
|
||||||
|
else:
|
||||||
|
cache_dict[directive] = True
|
||||||
|
|
||||||
|
return cache_dict
|
||||||
|
|
||||||
|
|
||||||
|
async def add_litellm_data_to_request(
|
||||||
|
data: dict,
|
||||||
|
request: Request,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
proxy_config: ProxyConfig,
|
||||||
|
general_settings: Optional[Dict[str, Any]] = None,
|
||||||
|
version: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Adds LiteLLM-specific data to the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (dict): The data dictionary to be modified.
|
||||||
|
request (Request): The incoming request.
|
||||||
|
user_api_key_dict (UserAPIKeyAuth): The user API key dictionary.
|
||||||
|
general_settings (Optional[Dict[str, Any]], optional): General settings. Defaults to None.
|
||||||
|
version (Optional[str], optional): Version. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The modified data dictionary.
|
||||||
|
|
||||||
|
"""
|
||||||
|
query_params = dict(request.query_params)
|
||||||
|
if "api-version" in query_params:
|
||||||
|
data["api_version"] = query_params["api-version"]
|
||||||
|
|
||||||
|
# Include original request and headers in the data
|
||||||
|
data["proxy_server_request"] = {
|
||||||
|
"url": str(request.url),
|
||||||
|
"method": request.method,
|
||||||
|
"headers": dict(request.headers),
|
||||||
|
"body": copy.copy(data), # use copy instead of deepcopy
|
||||||
|
}
|
||||||
|
|
||||||
|
## Cache Controls
|
||||||
|
headers = request.headers
|
||||||
|
verbose_proxy_logger.debug("Request Headers: %s", headers)
|
||||||
|
cache_control_header = headers.get("Cache-Control", None)
|
||||||
|
if cache_control_header:
|
||||||
|
cache_dict = parse_cache_control(cache_control_header)
|
||||||
|
data["ttl"] = cache_dict.get("s-maxage")
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("receiving data: %s", data)
|
||||||
|
# users can pass in 'user' param to /chat/completions. Don't override it
|
||||||
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||||
|
# if users are using user_api_key_auth, set `user` in `data`
|
||||||
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
|
if "metadata" not in data:
|
||||||
|
data["metadata"] = {}
|
||||||
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_alias"] = getattr(
|
||||||
|
user_api_key_dict, "key_alias", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_end_user_max_budget"] = getattr(
|
||||||
|
user_api_key_dict, "end_user_max_budget", None
|
||||||
|
)
|
||||||
|
data["metadata"]["litellm_api_version"] = version
|
||||||
|
|
||||||
|
if general_settings is not None:
|
||||||
|
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
|
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id
|
||||||
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||||
|
user_api_key_dict, "team_id", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||||
|
user_api_key_dict, "team_alias", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
|
_headers = dict(request.headers)
|
||||||
|
_headers.pop(
|
||||||
|
"authorization", None
|
||||||
|
) # do not store the original `sk-..` api key in the db
|
||||||
|
data["metadata"]["headers"] = _headers
|
||||||
|
data["metadata"]["endpoint"] = str(request.url)
|
||||||
|
# Add the OTEL Parent Trace before sending it LiteLLM
|
||||||
|
data["metadata"]["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span
|
||||||
|
|
||||||
|
### END-USER SPECIFIC PARAMS ###
|
||||||
|
if user_api_key_dict.allowed_model_region is not None:
|
||||||
|
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
|
||||||
|
|
||||||
|
### TEAM-SPECIFIC PARAMS ###
|
||||||
|
if user_api_key_dict.team_id is not None:
|
||||||
|
team_config = await proxy_config.load_team_config(
|
||||||
|
team_id=user_api_key_dict.team_id
|
||||||
|
)
|
||||||
|
if len(team_config) == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
team_id = team_config.pop("team_id", None)
|
||||||
|
data["metadata"]["team_id"] = team_id
|
||||||
|
data = {
|
||||||
|
**team_config,
|
||||||
|
**data,
|
||||||
|
} # add the team-specific configs to the completion call
|
||||||
|
|
||||||
|
return data
|
File diff suppressed because it is too large
Load diff
|
@ -73,7 +73,8 @@ def print_verbose(print_statement):
|
||||||
def safe_deep_copy(data):
|
def safe_deep_copy(data):
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
# remove litellm_parent_otel_span since this is not picklable
|
# remove litellm_parent_otel_span since this is not picklable
|
||||||
data.pop("litellm_parent_otel_span", None)
|
if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]:
|
||||||
|
data["metadata"].pop("litellm_parent_otel_span")
|
||||||
new_data = copy.deepcopy(data)
|
new_data = copy.deepcopy(data)
|
||||||
return new_data
|
return new_data
|
||||||
|
|
||||||
|
|
|
@ -152,7 +152,6 @@ def test_chat_completion(mock_acompletion, client_no_auth):
|
||||||
specific_deployment=True,
|
specific_deployment=True,
|
||||||
metadata=mock.ANY,
|
metadata=mock.ANY,
|
||||||
proxy_server_request=mock.ANY,
|
proxy_server_request=mock.ANY,
|
||||||
litellm_parent_otel_span=mock.ANY,
|
|
||||||
)
|
)
|
||||||
print(f"response - {response.text}")
|
print(f"response - {response.text}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
|
@ -4927,7 +4927,6 @@ def get_litellm_params(
|
||||||
input_cost_per_token=None,
|
input_cost_per_token=None,
|
||||||
output_cost_per_token=None,
|
output_cost_per_token=None,
|
||||||
output_cost_per_second=None,
|
output_cost_per_second=None,
|
||||||
litellm_parent_otel_span=None,
|
|
||||||
):
|
):
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
|
@ -4950,7 +4949,6 @@ def get_litellm_params(
|
||||||
"input_cost_per_second": input_cost_per_second,
|
"input_cost_per_second": input_cost_per_second,
|
||||||
"output_cost_per_token": output_cost_per_token,
|
"output_cost_per_token": output_cost_per_token,
|
||||||
"output_cost_per_second": output_cost_per_second,
|
"output_cost_per_second": output_cost_per_second,
|
||||||
"litellm_parent_otel_span": litellm_parent_otel_span,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue