forked from phoenix/litellm-mirror
Merge pull request #4701 from BerriAI/litellm_rpm_support_passthrough
Support key-rpm limits on pass-through endpoints
This commit is contained in:
commit
bc58e44d8f
11 changed files with 179 additions and 29 deletions
|
@ -156,6 +156,8 @@ POST /api/public/ingestion HTTP/1.1" 207 Multi-Status
|
||||||
|
|
||||||
Use this if you want the pass through endpoint to honour LiteLLM keys/authentication
|
Use this if you want the pass through endpoint to honour LiteLLM keys/authentication
|
||||||
|
|
||||||
|
This also enforces the key's rpm limits on pass-through endpoints.
|
||||||
|
|
||||||
Usage - set `auth: true` on the config
|
Usage - set `auth: true` on the config
|
||||||
```yaml
|
```yaml
|
||||||
general_settings:
|
general_settings:
|
||||||
|
@ -361,4 +363,5 @@ curl --location 'http://0.0.0.0:4000/v1/messages' \
|
||||||
{"role": "user", "content": "Hello, world"}
|
{"role": "user", "content": "Hello, world"}
|
||||||
]
|
]
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -99,6 +99,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
"image_generation",
|
"image_generation",
|
||||||
"moderation",
|
"moderation",
|
||||||
"audio_transcription",
|
"audio_transcription",
|
||||||
|
"pass_through_endpoint",
|
||||||
],
|
],
|
||||||
) -> Optional[
|
) -> Optional[
|
||||||
Union[Exception, str, dict]
|
Union[Exception, str, dict]
|
||||||
|
|
|
@ -12,11 +12,12 @@ model_list:
|
||||||
|
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
alerting: ["slack"]
|
|
||||||
alerting_threshold: 10
|
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
pass_through_endpoints:
|
pass_through_endpoints:
|
||||||
- path: "/v1/test-messages" # route you want to add to LiteLLM Proxy Server
|
- path: "/v1/rerank"
|
||||||
target: litellm.adapters.anthropic_adapter.anthropic_adapter # URL this route should forward requests to
|
target: "https://api.cohere.com/v1/rerank"
|
||||||
headers: # headers to forward to this URL
|
auth: true # 👈 Key change to use LiteLLM Auth / Keys
|
||||||
litellm_user_api_key: "x-my-test-key"
|
headers:
|
||||||
|
Authorization: "bearer os.environ/COHERE_API_KEY"
|
||||||
|
content-type: application/json
|
||||||
|
accept: application/json
|
|
@ -96,7 +96,6 @@ async def user_api_key_auth(
|
||||||
anthropic_api_key_header
|
anthropic_api_key_header
|
||||||
),
|
),
|
||||||
) -> UserAPIKeyAuth:
|
) -> UserAPIKeyAuth:
|
||||||
|
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
allowed_routes_check,
|
allowed_routes_check,
|
||||||
common_checks,
|
common_checks,
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.proxy.proxy_server import UserAPIKeyAuth, DualCache
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from typing import Optional, Literal
|
from litellm.proxy.proxy_server import DualCache, UserAPIKeyAuth
|
||||||
|
|
||||||
|
|
||||||
# This file includes the custom callbacks for LiteLLM Proxy
|
# This file includes the custom callbacks for LiteLLM Proxy
|
||||||
|
@ -27,6 +28,7 @@ class MyCustomHandler(
|
||||||
"image_generation",
|
"image_generation",
|
||||||
"moderation",
|
"moderation",
|
||||||
"audio_transcription",
|
"audio_transcription",
|
||||||
|
"pass_through_endpoint",
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -198,6 +198,7 @@ class _PROXY_DynamicRateLimitHandler(CustomLogger):
|
||||||
"image_generation",
|
"image_generation",
|
||||||
"moderation",
|
"moderation",
|
||||||
"audio_transcription",
|
"audio_transcription",
|
||||||
|
"pass_through_endpoint",
|
||||||
],
|
],
|
||||||
) -> Optional[
|
) -> Optional[
|
||||||
Union[Exception, str, dict]
|
Union[Exception, str, dict]
|
||||||
|
|
|
@ -1,12 +1,16 @@
|
||||||
from typing import Optional
|
import sys
|
||||||
import litellm, traceback, sys
|
import traceback
|
||||||
from litellm.caching import DualCache
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
|
||||||
from fastapi import HTTPException
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
|
||||||
from litellm import ModelResponse
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm import ModelResponse
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
|
||||||
|
|
||||||
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
|
|
|
@ -3,6 +3,7 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
|
@ -22,8 +23,6 @@ from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
|
from litellm.proxy._types import ProxyException, UserAPIKeyAuth
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
|
||||||
async_client = httpx.AsyncClient()
|
|
||||||
|
|
||||||
|
|
||||||
async def set_env_variables_in_header(custom_headers: dict):
|
async def set_env_variables_in_header(custom_headers: dict):
|
||||||
"""
|
"""
|
||||||
|
@ -240,14 +239,28 @@ async def chat_completion_pass_through_endpoint(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def pass_through_request(request: Request, target: str, custom_headers: dict):
|
async def pass_through_request(
|
||||||
|
request: Request,
|
||||||
|
target: str,
|
||||||
|
custom_headers: dict,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
|
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||||
|
|
||||||
url = httpx.URL(target)
|
url = httpx.URL(target)
|
||||||
headers = custom_headers
|
headers = custom_headers
|
||||||
|
|
||||||
request_body = await request.body()
|
request_body = await request.body()
|
||||||
_parsed_body = ast.literal_eval(request_body.decode("utf-8"))
|
body_str = request_body.decode()
|
||||||
|
try:
|
||||||
|
_parsed_body = ast.literal_eval(body_str)
|
||||||
|
except:
|
||||||
|
_parsed_body = json.loads(body_str)
|
||||||
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
|
"Pass through endpoint sending request to \nURL {}\nheaders: {}\nbody: {}\n".format(
|
||||||
|
@ -255,6 +268,15 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||||
|
_parsed_body = await proxy_logging_obj.pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
data=_parsed_body,
|
||||||
|
call_type="pass_through_endpoint",
|
||||||
|
)
|
||||||
|
|
||||||
|
async_client = httpx.AsyncClient()
|
||||||
|
|
||||||
response = await async_client.request(
|
response = await async_client.request(
|
||||||
method=request.method,
|
method=request.method,
|
||||||
url=url,
|
url=url,
|
||||||
|
@ -267,6 +289,47 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
|
||||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||||
|
|
||||||
content = await response.aread()
|
content = await response.aread()
|
||||||
|
|
||||||
|
## LOG SUCCESS
|
||||||
|
start_time = time.time()
|
||||||
|
end_time = time.time()
|
||||||
|
# create logging object
|
||||||
|
logging_obj = Logging(
|
||||||
|
model="unknown",
|
||||||
|
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||||
|
stream=False,
|
||||||
|
call_type="pass_through_endpoint",
|
||||||
|
start_time=start_time,
|
||||||
|
litellm_call_id=str(uuid.uuid4()),
|
||||||
|
function_id="1245",
|
||||||
|
)
|
||||||
|
# done for supporting 'parallel_request_limiter.py' with pass-through endpoints
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {
|
||||||
|
"user_api_key": user_api_key_dict.api_key,
|
||||||
|
"user_api_key_user_id": user_api_key_dict.user_id,
|
||||||
|
"user_api_key_team_id": user_api_key_dict.team_id,
|
||||||
|
"user_api_key_end_user_id": user_api_key_dict.user_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"call_type": "pass_through_endpoint",
|
||||||
|
}
|
||||||
|
logging_obj.update_environment_variables(
|
||||||
|
model="unknown",
|
||||||
|
user="unknown",
|
||||||
|
optional_params={},
|
||||||
|
litellm_params=kwargs["litellm_params"],
|
||||||
|
call_type="pass_through_endpoint",
|
||||||
|
)
|
||||||
|
|
||||||
|
await logging_obj.async_success_handler(
|
||||||
|
result="",
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
cache_hit=False,
|
||||||
|
)
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
content=content,
|
content=content,
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
|
@ -274,8 +337,8 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
"litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format(
|
"litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}\n{}".format(
|
||||||
str(e)
|
str(e), traceback.format_exc()
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(traceback.format_exc())
|
verbose_proxy_logger.debug(traceback.format_exc())
|
||||||
|
@ -296,7 +359,9 @@ async def pass_through_request(request: Request, target: str, custom_headers: di
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_pass_through_route(endpoint, target: str, custom_headers=None):
|
def create_pass_through_route(
|
||||||
|
endpoint, target: str, custom_headers: Optional[dict] = None
|
||||||
|
):
|
||||||
# check if target is an adapter.py or a url
|
# check if target is an adapter.py or a url
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
@ -325,8 +390,17 @@ def create_pass_through_route(endpoint, target: str, custom_headers=None):
|
||||||
except Exception:
|
except Exception:
|
||||||
verbose_proxy_logger.warning("Defaulting to target being a url.")
|
verbose_proxy_logger.warning("Defaulting to target being a url.")
|
||||||
|
|
||||||
async def endpoint_func(request: Request): # type: ignore
|
async def endpoint_func(
|
||||||
return await pass_through_request(request, target, custom_headers)
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
return await pass_through_request(
|
||||||
|
request=request,
|
||||||
|
target=target,
|
||||||
|
custom_headers=custom_headers or {},
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
)
|
||||||
|
|
||||||
return endpoint_func
|
return endpoint_func
|
||||||
|
|
||||||
|
@ -349,7 +423,9 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
|
||||||
if _auth is not None and str(_auth).lower() == "true":
|
if _auth is not None and str(_auth).lower() == "true":
|
||||||
if premium_user is not True:
|
if premium_user is not True:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Error Setting Authentication on Pass Through Endpoint: {CommonProxyErrors.not_premium_user}"
|
"Error Setting Authentication on Pass Through Endpoint: {}".format(
|
||||||
|
CommonProxyErrors.not_premium_user.value
|
||||||
|
)
|
||||||
)
|
)
|
||||||
_dependencies = [Depends(user_api_key_auth)]
|
_dependencies = [Depends(user_api_key_auth)]
|
||||||
LiteLLMRoutes.openai_routes.value.append(_path)
|
LiteLLMRoutes.openai_routes.value.append(_path)
|
||||||
|
|
|
@ -299,6 +299,7 @@ class ProxyLogging:
|
||||||
"image_generation",
|
"image_generation",
|
||||||
"moderation",
|
"moderation",
|
||||||
"audio_transcription",
|
"audio_transcription",
|
||||||
|
"pass_through_endpoint",
|
||||||
],
|
],
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -85,6 +85,67 @@ async def test_pass_through_endpoint_rerank(client):
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"auth, rpm_limit, expected_error_code",
|
||||||
|
[(True, 0, 429), (True, 1, 200), (False, 0, 401)],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pass_through_endpoint_rpm_limit(auth, expected_error_code, rpm_limit):
|
||||||
|
client = TestClient(app)
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache
|
||||||
|
|
||||||
|
mock_api_key = "sk-my-test-key"
|
||||||
|
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
|
||||||
|
|
||||||
|
_cohere_api_key = os.environ.get("COHERE_API_KEY")
|
||||||
|
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
|
||||||
|
|
||||||
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||||
|
proxy_logging_obj._init_litellm_callbacks()
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
|
||||||
|
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
|
||||||
|
|
||||||
|
# Define a pass-through endpoint
|
||||||
|
pass_through_endpoints = [
|
||||||
|
{
|
||||||
|
"path": "/v1/rerank",
|
||||||
|
"target": "https://api.cohere.com/v1/rerank",
|
||||||
|
"auth": auth,
|
||||||
|
"headers": {"Authorization": f"bearer {_cohere_api_key}"},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Initialize the pass-through endpoint
|
||||||
|
await initialize_pass_through_endpoints(pass_through_endpoints)
|
||||||
|
|
||||||
|
_json_data = {
|
||||||
|
"model": "rerank-english-v3.0",
|
||||||
|
"query": "What is the capital of the United States?",
|
||||||
|
"top_n": 3,
|
||||||
|
"documents": [
|
||||||
|
"Carson City is the capital city of the American state of Nevada."
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make a request to the pass-through endpoint
|
||||||
|
response = client.post(
|
||||||
|
"/v1/rerank",
|
||||||
|
json=_json_data,
|
||||||
|
headers={"Authorization": "Bearer {}".format(mock_api_key)},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("JSON response: ", _json_data)
|
||||||
|
|
||||||
|
# Assert the response
|
||||||
|
assert response.status_code == expected_error_code
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pass_through_endpoint_anthropic(client):
|
async def test_pass_through_endpoint_anthropic(client):
|
||||||
import litellm
|
import litellm
|
||||||
|
|
|
@ -66,6 +66,7 @@ class testLogger(CustomLogger):
|
||||||
"image_generation",
|
"image_generation",
|
||||||
"moderation",
|
"moderation",
|
||||||
"audio_transcription",
|
"audio_transcription",
|
||||||
|
"pass_through_endpoint",
|
||||||
],
|
],
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue