mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(user_api_key_auth.py): support calling langfuse with litellm user_api_key_auth
This commit is contained in:
parent
66d77f177f
commit
742e3cbccf
4 changed files with 160 additions and 7 deletions
|
@ -3,5 +3,11 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "*"
|
model: "*"
|
||||||
|
|
||||||
litellm_settings:
|
general_settings:
|
||||||
success_callback: ["langsmith"]
|
master_key: sk-1234
|
||||||
|
pass_through_endpoints:
|
||||||
|
- path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server
|
||||||
|
target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward
|
||||||
|
headers:
|
||||||
|
LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_PUBLIC_KEY" # your langfuse account public key
|
||||||
|
LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_SECRET_KEY" # your langfuse account secret key
|
|
@ -86,7 +86,7 @@ def _get_bearer_token(
|
||||||
if api_key.startswith("Bearer "): # ensure Bearer token passed in
|
if api_key.startswith("Bearer "): # ensure Bearer token passed in
|
||||||
api_key = api_key.replace("Bearer ", "") # extract the token
|
api_key = api_key.replace("Bearer ", "") # extract the token
|
||||||
else:
|
else:
|
||||||
api_key = ""
|
api_key = api_key
|
||||||
return api_key
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
@ -138,7 +138,6 @@ async def user_api_key_auth(
|
||||||
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
|
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
|
||||||
"pass_through_endpoints", None
|
"pass_through_endpoints", None
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
passed_in_key = api_key
|
passed_in_key = api_key
|
||||||
api_key = _get_bearer_token(api_key=api_key)
|
api_key = _get_bearer_token(api_key=api_key)
|
||||||
|
@ -367,6 +366,40 @@ async def user_api_key_auth(
|
||||||
parent_otel_span=parent_otel_span,
|
parent_otel_span=parent_otel_span,
|
||||||
)
|
)
|
||||||
#### ELSE ####
|
#### ELSE ####
|
||||||
|
|
||||||
|
## CHECK PASS-THROUGH ENDPOINTS ##
|
||||||
|
if pass_through_endpoints is not None:
|
||||||
|
for endpoint in pass_through_endpoints:
|
||||||
|
if endpoint.get("path", "") == route:
|
||||||
|
## IF AUTH DISABLED
|
||||||
|
if endpoint.get("auth") is not True:
|
||||||
|
return UserAPIKeyAuth()
|
||||||
|
## IF AUTH ENABLED
|
||||||
|
### IF CUSTOM PARSER REQUIRED
|
||||||
|
if (
|
||||||
|
endpoint.get("custom_auth_parser") is not None
|
||||||
|
and endpoint.get("custom_auth_parser") == "langfuse"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
- langfuse returns {'Authorization': 'Basic YW55dGhpbmc6YW55dGhpbmc'}
|
||||||
|
- check the langfuse public key if it contains the litellm api key
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
|
||||||
|
api_key = api_key.replace("Basic ", "").strip()
|
||||||
|
decoded_bytes = base64.b64decode(api_key)
|
||||||
|
decoded_str = decoded_bytes.decode("utf-8")
|
||||||
|
api_key = decoded_str.split(":")[0]
|
||||||
|
else:
|
||||||
|
headers = endpoint.get("headers", None)
|
||||||
|
if headers is not None:
|
||||||
|
header_key = headers.get("litellm_user_api_key", "")
|
||||||
|
if (
|
||||||
|
isinstance(request.headers, dict)
|
||||||
|
and request.headers.get(key=header_key) is not None
|
||||||
|
):
|
||||||
|
api_key = request.headers.get(key=header_key)
|
||||||
|
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
return UserAPIKeyAuth(
|
return UserAPIKeyAuth(
|
||||||
|
@ -533,7 +566,11 @@ async def user_api_key_auth(
|
||||||
if isinstance(
|
if isinstance(
|
||||||
api_key, str
|
api_key, str
|
||||||
): # if generated token, make sure it starts with sk-.
|
): # if generated token, make sure it starts with sk-.
|
||||||
assert api_key.startswith("sk-") # prevent token hashes from being used
|
assert api_key.startswith(
|
||||||
|
"sk-"
|
||||||
|
), "LiteLLM Virtual Key expected. Received={}, expected to start with 'sk-'.".format(
|
||||||
|
api_key
|
||||||
|
) # prevent token hashes from being used
|
||||||
else:
|
else:
|
||||||
verbose_logger.warning(
|
verbose_logger.warning(
|
||||||
"litellm.proxy.proxy_server.user_api_key_auth(): Warning - Key={} is not a string.".format(
|
"litellm.proxy.proxy_server.user_api_key_auth(): Warning - Key={} is not a string.".format(
|
||||||
|
|
|
@ -309,7 +309,7 @@ async def pass_through_request(
|
||||||
json=_parsed_body,
|
json=_parsed_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code >= 300:
|
||||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||||
|
|
||||||
content = await response.aread()
|
content = await response.aread()
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
@ -30,6 +31,7 @@ def client():
|
||||||
async def test_pass_through_endpoint(client, monkeypatch):
|
async def test_pass_through_endpoint(client, monkeypatch):
|
||||||
# Mock the httpx.AsyncClient.request method
|
# Mock the httpx.AsyncClient.request method
|
||||||
monkeypatch.setattr("httpx.AsyncClient.request", mock_request)
|
monkeypatch.setattr("httpx.AsyncClient.request", mock_request)
|
||||||
|
import litellm
|
||||||
|
|
||||||
# Define a pass-through endpoint
|
# Define a pass-through endpoint
|
||||||
pass_through_endpoints = [
|
pass_through_endpoints = [
|
||||||
|
@ -42,6 +44,11 @@ async def test_pass_through_endpoint(client, monkeypatch):
|
||||||
|
|
||||||
# Initialize the pass-through endpoint
|
# Initialize the pass-through endpoint
|
||||||
await initialize_pass_through_endpoints(pass_through_endpoints)
|
await initialize_pass_through_endpoints(pass_through_endpoints)
|
||||||
|
general_settings: Optional[dict] = (
|
||||||
|
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
|
||||||
|
)
|
||||||
|
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
|
||||||
|
|
||||||
# Make a request to the pass-through endpoint
|
# Make a request to the pass-through endpoint
|
||||||
response = client.post("/test-endpoint", json={"prompt": "Hello, world!"})
|
response = client.post("/test-endpoint", json={"prompt": "Hello, world!"})
|
||||||
|
@ -54,6 +61,7 @@ async def test_pass_through_endpoint(client, monkeypatch):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pass_through_endpoint_rerank(client):
|
async def test_pass_through_endpoint_rerank(client):
|
||||||
_cohere_api_key = os.environ.get("COHERE_API_KEY")
|
_cohere_api_key = os.environ.get("COHERE_API_KEY")
|
||||||
|
import litellm
|
||||||
|
|
||||||
# Define a pass-through endpoint
|
# Define a pass-through endpoint
|
||||||
pass_through_endpoints = [
|
pass_through_endpoints = [
|
||||||
|
@ -66,6 +74,11 @@ async def test_pass_through_endpoint_rerank(client):
|
||||||
|
|
||||||
# Initialize the pass-through endpoint
|
# Initialize the pass-through endpoint
|
||||||
await initialize_pass_through_endpoints(pass_through_endpoints)
|
await initialize_pass_through_endpoints(pass_through_endpoints)
|
||||||
|
general_settings: Optional[dict] = (
|
||||||
|
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
|
||||||
|
)
|
||||||
|
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
|
||||||
|
|
||||||
_json_data = {
|
_json_data = {
|
||||||
"model": "rerank-english-v3.0",
|
"model": "rerank-english-v3.0",
|
||||||
|
@ -87,7 +100,7 @@ async def test_pass_through_endpoint_rerank(client):
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"auth, rpm_limit, expected_error_code",
|
"auth, rpm_limit, expected_error_code",
|
||||||
[(True, 0, 429), (True, 1, 200), (False, 0, 401)],
|
[(True, 0, 429), (True, 1, 200), (False, 0, 200)],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pass_through_endpoint_rpm_limit(auth, expected_error_code, rpm_limit):
|
async def test_pass_through_endpoint_rpm_limit(auth, expected_error_code, rpm_limit):
|
||||||
|
@ -123,6 +136,11 @@ async def test_pass_through_endpoint_rpm_limit(auth, expected_error_code, rpm_li
|
||||||
|
|
||||||
# Initialize the pass-through endpoint
|
# Initialize the pass-through endpoint
|
||||||
await initialize_pass_through_endpoints(pass_through_endpoints)
|
await initialize_pass_through_endpoints(pass_through_endpoints)
|
||||||
|
general_settings: Optional[dict] = (
|
||||||
|
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
|
||||||
|
)
|
||||||
|
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
|
||||||
|
|
||||||
_json_data = {
|
_json_data = {
|
||||||
"model": "rerank-english-v3.0",
|
"model": "rerank-english-v3.0",
|
||||||
|
@ -146,6 +164,93 @@ async def test_pass_through_endpoint_rpm_limit(auth, expected_error_code, rpm_li
|
||||||
assert response.status_code == expected_error_code
|
assert response.status_code == expected_error_code
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"auth, rpm_limit, expected_error_code",
|
||||||
|
[(True, 0, 429), (True, 1, 207), (False, 0, 207)],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pass_through_endpoint_pass_through_keys_langfuse(
|
||||||
|
auth, expected_error_code, rpm_limit
|
||||||
|
):
|
||||||
|
client = TestClient(app)
|
||||||
|
import litellm
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.proxy.proxy_server import ProxyLogging, hash_token, user_api_key_cache
|
||||||
|
|
||||||
|
mock_api_key = "sk-my-test-key"
|
||||||
|
cache_value = UserAPIKeyAuth(token=hash_token(mock_api_key), rpm_limit=rpm_limit)
|
||||||
|
|
||||||
|
_cohere_api_key = os.environ.get("COHERE_API_KEY")
|
||||||
|
|
||||||
|
user_api_key_cache.set_cache(key=hash_token(mock_api_key), value=cache_value)
|
||||||
|
|
||||||
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||||
|
proxy_logging_obj._init_litellm_callbacks()
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "user_api_key_cache", user_api_key_cache)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", "FAKE-VAR")
|
||||||
|
setattr(litellm.proxy.proxy_server, "proxy_logging_obj", proxy_logging_obj)
|
||||||
|
|
||||||
|
# Define a pass-through endpoint
|
||||||
|
pass_through_endpoints = [
|
||||||
|
{
|
||||||
|
"path": "/api/public/ingestion",
|
||||||
|
"target": "https://us.cloud.langfuse.com/api/public/ingestion",
|
||||||
|
"auth": auth,
|
||||||
|
"custom_auth_parser": "langfuse",
|
||||||
|
"headers": {
|
||||||
|
"LANGFUSE_PUBLIC_KEY": "os.environ/LANGFUSE_PUBLIC_KEY",
|
||||||
|
"LANGFUSE_SECRET_KEY": "os.environ/LANGFUSE_SECRET_KEY",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Initialize the pass-through endpoint
|
||||||
|
await initialize_pass_through_endpoints(pass_through_endpoints)
|
||||||
|
general_settings: Optional[dict] = (
|
||||||
|
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
|
||||||
|
)
|
||||||
|
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
|
||||||
|
|
||||||
|
_json_data = {
|
||||||
|
"batch": [
|
||||||
|
{
|
||||||
|
"id": "80e2141f-0ca6-47b7-9c06-dde5e97de690",
|
||||||
|
"type": "trace-create",
|
||||||
|
"body": {
|
||||||
|
"id": "0687af7b-4a75-4de8-a4f6-cba1cdc00865",
|
||||||
|
"timestamp": "2024-08-14T02:38:56.092950Z",
|
||||||
|
"name": "test-trace-litellm-proxy-passthrough",
|
||||||
|
},
|
||||||
|
"timestamp": "2024-08-14T02:38:56.093352Z",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"batch_size": 1,
|
||||||
|
"sdk_integration": "default",
|
||||||
|
"sdk_name": "python",
|
||||||
|
"sdk_version": "2.27.0",
|
||||||
|
"public_key": "anything",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Make a request to the pass-through endpoint
|
||||||
|
response = client.post(
|
||||||
|
"/api/public/ingestion",
|
||||||
|
json=_json_data,
|
||||||
|
headers={"Authorization": "Basic c2stbXktdGVzdC1rZXk6YW55dGhpbmc="},
|
||||||
|
)
|
||||||
|
|
||||||
|
print("JSON response: ", _json_data)
|
||||||
|
|
||||||
|
print("RESPONSE RECEIVED - {}".format(response.text))
|
||||||
|
|
||||||
|
# Assert the response
|
||||||
|
assert response.status_code == expected_error_code
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pass_through_endpoint_anthropic(client):
|
async def test_pass_through_endpoint_anthropic(client):
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -178,6 +283,11 @@ async def test_pass_through_endpoint_anthropic(client):
|
||||||
|
|
||||||
# Initialize the pass-through endpoint
|
# Initialize the pass-through endpoint
|
||||||
await initialize_pass_through_endpoints(pass_through_endpoints)
|
await initialize_pass_through_endpoints(pass_through_endpoints)
|
||||||
|
general_settings: Optional[dict] = (
|
||||||
|
getattr(litellm.proxy.proxy_server, "general_settings", {}) or {}
|
||||||
|
)
|
||||||
|
general_settings.update({"pass_through_endpoints": pass_through_endpoints})
|
||||||
|
setattr(litellm.proxy.proxy_server, "general_settings", general_settings)
|
||||||
|
|
||||||
_json_data = {
|
_json_data = {
|
||||||
"model": "gpt-3.5-turbo",
|
"model": "gpt-3.5-turbo",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue