forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (11/27/2024) (#6943)
* fix(http_parsing_utils.py): remove `ast.literal_eval()` from http utils Security fix - https://huntr.com/bounties/96a32812-213c-4819-ba4e-36143d35e95b?token=bf414bbd77f8b346556e 64ab2dd9301ea44339910877ea50401c76f977e36cdd78272f5fb4ca852a88a7e832828aae1192df98680544ee24aa98f3cf6980d8 bab641a66b7ccbc02c0e7d4ddba2db4dbe7318889dc0098d8db2d639f345f574159814627bb084563bad472e2f990f825bff0878a9 e281e72c88b4bc5884d637d186c0d67c9987c57c3f0caf395aff07b89ad2b7220d1dd7d1b427fd2260b5f01090efce5250f8b56ea2 c0ec19916c24b23825d85ce119911275944c840a1340d69e23ca6a462da610 * fix(converse/transformation.py): support bedrock apac cross region inference Fixes https://github.com/BerriAI/litellm/issues/6905 * fix(user_api_key_auth.py): add auth check for websocket endpoint Fixes https://github.com/BerriAI/litellm/issues/6926 * fix(user_api_key_auth.py): use `model` from query param * fix: fix linting error * test: run flaky tests first
This commit is contained in:
parent
2d2931a215
commit
21156ff5d0
12 changed files with 210 additions and 49 deletions
|
@ -458,7 +458,7 @@ class AmazonConverseConfig:
|
||||||
"""
|
"""
|
||||||
Abbreviations of regions AWS Bedrock supports for cross region inference
|
Abbreviations of regions AWS Bedrock supports for cross region inference
|
||||||
"""
|
"""
|
||||||
return ["us", "eu"]
|
return ["us", "eu", "apac"]
|
||||||
|
|
||||||
def _get_base_model(self, model: str) -> str:
|
def _get_base_model(self, model: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -11,7 +11,11 @@ model_list:
|
||||||
model: vertex_ai/claude-3-5-sonnet-v2
|
model: vertex_ai/claude-3-5-sonnet-v2
|
||||||
vertex_ai_project: "adroit-crow-413218"
|
vertex_ai_project: "adroit-crow-413218"
|
||||||
vertex_ai_location: "us-east5"
|
vertex_ai_location: "us-east5"
|
||||||
|
- model_name: openai-gpt-4o-realtime-audio
|
||||||
|
litellm_params:
|
||||||
|
model: openai/gpt-4o-realtime-preview-2024-10-01
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
router_settings:
|
router_settings:
|
||||||
routing_strategy: usage-based-routing-v2
|
routing_strategy: usage-based-routing-v2
|
||||||
#redis_url: "os.environ/REDIS_URL"
|
#redis_url: "os.environ/REDIS_URL"
|
||||||
|
|
|
@ -28,6 +28,8 @@ from fastapi import (
|
||||||
Request,
|
Request,
|
||||||
Response,
|
Response,
|
||||||
UploadFile,
|
UploadFile,
|
||||||
|
WebSocket,
|
||||||
|
WebSocketDisconnect,
|
||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
@ -195,6 +197,52 @@ def _is_allowed_route(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def user_api_key_auth_websocket(websocket: WebSocket):
|
||||||
|
# Accept the WebSocket connection
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = websocket.url
|
||||||
|
|
||||||
|
query_params = websocket.query_params
|
||||||
|
|
||||||
|
model = query_params.get("model")
|
||||||
|
|
||||||
|
async def return_body():
|
||||||
|
return_string = f'{{"model": "{model}"}}'
|
||||||
|
# return string as bytes
|
||||||
|
return return_string.encode()
|
||||||
|
|
||||||
|
request.body = return_body # type: ignore
|
||||||
|
|
||||||
|
# Extract the Authorization header
|
||||||
|
authorization = websocket.headers.get("authorization")
|
||||||
|
|
||||||
|
# If no Authorization header, try the api-key header
|
||||||
|
if not authorization:
|
||||||
|
api_key = websocket.headers.get("api-key")
|
||||||
|
if not api_key:
|
||||||
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||||
|
raise HTTPException(status_code=403, detail="No API key provided")
|
||||||
|
else:
|
||||||
|
# Extract the API key from the Bearer token
|
||||||
|
if not authorization.startswith("Bearer "):
|
||||||
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail="Invalid Authorization header format"
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = authorization[len("Bearer ") :].strip()
|
||||||
|
|
||||||
|
# Call user_api_key_auth with the extracted API key
|
||||||
|
# Note: You'll need to modify this to work with WebSocket context if needed
|
||||||
|
try:
|
||||||
|
return await user_api_key_auth(request=request, api_key=f"Bearer {api_key}")
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(e)
|
||||||
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||||
|
raise HTTPException(status_code=403, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
async def user_api_key_auth( # noqa: PLR0915
|
async def user_api_key_auth( # noqa: PLR0915
|
||||||
request: Request,
|
request: Request,
|
||||||
api_key: str = fastapi.Security(api_key_header),
|
api_key: str = fastapi.Security(api_key_header),
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import ast
|
import ast
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from fastapi import Request, UploadFile, status
|
from fastapi import Request, UploadFile, status
|
||||||
|
|
||||||
|
@ -8,31 +8,43 @@ from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.types.router import Deployment
|
from litellm.types.router import Deployment
|
||||||
|
|
||||||
|
|
||||||
async def _read_request_body(request: Optional[Request]) -> dict:
|
async def _read_request_body(request: Optional[Request]) -> Dict:
|
||||||
"""
|
"""
|
||||||
Asynchronous function to read the request body and parse it as JSON or literal data.
|
Safely read the request body and parse it as JSON.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- request: The request object to read the body from
|
- request: The request object to read the body from
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- dict: Parsed request data as a dictionary
|
- dict: Parsed request data as a dictionary or an empty dictionary if parsing fails
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
request_data: dict = {}
|
|
||||||
if request is None:
|
if request is None:
|
||||||
return request_data
|
return {}
|
||||||
|
|
||||||
|
# Read the request body
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
|
||||||
if body == b"" or body is None:
|
# Return empty dict if body is empty or None
|
||||||
return request_data
|
if not body:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Decode the body to a string
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
try:
|
|
||||||
request_data = ast.literal_eval(body_str)
|
# Attempt JSON parsing (safe for untrusted input)
|
||||||
except Exception:
|
return json.loads(body_str)
|
||||||
request_data = json.loads(body_str)
|
|
||||||
return request_data
|
except json.JSONDecodeError:
|
||||||
except Exception:
|
# Log detailed information for debugging
|
||||||
|
verbose_proxy_logger.exception("Invalid JSON payload received.")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Catch unexpected errors to avoid crashes
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
"Unexpected error reading request body - {}".format(e)
|
||||||
|
)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -134,7 +134,10 @@ from litellm.proxy.auth.model_checks import (
|
||||||
get_key_models,
|
get_key_models,
|
||||||
get_team_models,
|
get_team_models,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import (
|
||||||
|
user_api_key_auth,
|
||||||
|
user_api_key_auth_websocket,
|
||||||
|
)
|
||||||
|
|
||||||
## Import All Misc routes here ##
|
## Import All Misc routes here ##
|
||||||
from litellm.proxy.caching_routes import router as caching_router
|
from litellm.proxy.caching_routes import router as caching_router
|
||||||
|
@ -4339,7 +4342,11 @@ from litellm import _arealtime
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/v1/realtime")
|
@app.websocket("/v1/realtime")
|
||||||
async def websocket_endpoint(websocket: WebSocket, model: str):
|
async def websocket_endpoint(
|
||||||
|
websocket: WebSocket,
|
||||||
|
model: str,
|
||||||
|
user_api_key_dict=Depends(user_api_key_auth_websocket),
|
||||||
|
):
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|
|
@ -1,29 +0,0 @@
|
||||||
import pytest
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
|
|
||||||
|
|
||||||
def test_mlflow_logging():
|
|
||||||
litellm.success_callback = ["mlflow"]
|
|
||||||
litellm.failure_callback = ["mlflow"]
|
|
||||||
|
|
||||||
litellm.completion(
|
|
||||||
model="gpt-4o-mini",
|
|
||||||
messages=[{"role": "user", "content": "what llm are u"}],
|
|
||||||
max_tokens=10,
|
|
||||||
temperature=0.2,
|
|
||||||
user="test-user",
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio()
|
|
||||||
async def test_async_mlflow_logging():
|
|
||||||
litellm.success_callback = ["mlflow"]
|
|
||||||
litellm.failure_callback = ["mlflow"]
|
|
||||||
|
|
||||||
await litellm.acompletion(
|
|
||||||
model="gpt-4o-mini",
|
|
||||||
messages=[{"role": "user", "content": "hi test from local arize"}],
|
|
||||||
mock_response="hello",
|
|
||||||
temperature=0.1,
|
|
||||||
user="OTEL_USER",
|
|
||||||
)
|
|
|
@ -1243,6 +1243,19 @@ def test_bedrock_cross_region_inference(model):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model, expected_base_model",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_bedrock_get_base_model(model, expected_base_model):
|
||||||
|
assert litellm.AmazonConverseConfig()._get_base_model(model) == expected_base_model
|
||||||
|
|
||||||
|
|
||||||
from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt
|
from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt
|
||||||
|
|
||||||
|
|
||||||
|
|
79
tests/local_testing/test_http_parsing_utils.py
Normal file
79
tests/local_testing/test_http_parsing_utils.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
import pytest
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from starlette.datastructures import Headers
|
||||||
|
from starlette.requests import HTTPConnection
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
|
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_request_body_valid_json():
|
||||||
|
"""Test the function with a valid JSON payload."""
|
||||||
|
|
||||||
|
class MockRequest:
|
||||||
|
async def body(self):
|
||||||
|
return b'{"key": "value"}'
|
||||||
|
|
||||||
|
request = MockRequest()
|
||||||
|
result = await _read_request_body(request)
|
||||||
|
assert result == {"key": "value"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_request_body_empty_body():
|
||||||
|
"""Test the function with an empty body."""
|
||||||
|
|
||||||
|
class MockRequest:
|
||||||
|
async def body(self):
|
||||||
|
return b""
|
||||||
|
|
||||||
|
request = MockRequest()
|
||||||
|
result = await _read_request_body(request)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_request_body_invalid_json():
|
||||||
|
"""Test the function with an invalid JSON payload."""
|
||||||
|
|
||||||
|
class MockRequest:
|
||||||
|
async def body(self):
|
||||||
|
return b'{"key": value}' # Missing quotes around `value`
|
||||||
|
|
||||||
|
request = MockRequest()
|
||||||
|
result = await _read_request_body(request)
|
||||||
|
assert result == {} # Should return an empty dict on failure
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_request_body_large_payload():
|
||||||
|
"""Test the function with a very large payload."""
|
||||||
|
large_payload = '{"key":' + '"a"' * 10**6 + "}" # Large payload
|
||||||
|
|
||||||
|
class MockRequest:
|
||||||
|
async def body(self):
|
||||||
|
return large_payload.encode()
|
||||||
|
|
||||||
|
request = MockRequest()
|
||||||
|
result = await _read_request_body(request)
|
||||||
|
assert result == {} # Large payloads could trigger errors, so validate behavior
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_request_body_unexpected_error():
|
||||||
|
"""Test the function when an unexpected error occurs."""
|
||||||
|
|
||||||
|
class MockRequest:
|
||||||
|
async def body(self):
|
||||||
|
raise ValueError("Unexpected error")
|
||||||
|
|
||||||
|
request = MockRequest()
|
||||||
|
result = await _read_request_body(request)
|
||||||
|
assert result == {} # Ensure fallback behavior
|
|
@ -536,7 +536,7 @@ def test_init_clients_azure_command_r_plus():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_text_completion_with_organization():
|
async def test_aaaaatext_completion_with_organization():
|
||||||
try:
|
try:
|
||||||
print("Testing Text OpenAI with organization")
|
print("Testing Text OpenAI with organization")
|
||||||
model_list = [
|
model_list = [
|
||||||
|
|
|
@ -415,3 +415,18 @@ def test_allowed_route_inside_route(
|
||||||
)
|
)
|
||||||
== expected_result
|
== expected_result
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_request_body():
|
||||||
|
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
payload = "()" * 1000000
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
|
||||||
|
async def return_body():
|
||||||
|
return payload
|
||||||
|
|
||||||
|
request.body = return_body
|
||||||
|
result = _read_request_body(request)
|
||||||
|
assert result is not None
|
||||||
|
|
|
@ -215,7 +215,7 @@ async def test_rerank_endpoint(model_list):
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_text_completion_endpoint(model_list, sync_mode):
|
async def test_aaaaatext_completion_endpoint(model_list, sync_mode):
|
||||||
router = Router(model_list=model_list)
|
router = Router(model_list=model_list)
|
||||||
|
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
|
|
|
@ -24,6 +24,7 @@ import {
|
||||||
Icon,
|
Icon,
|
||||||
BarChart,
|
BarChart,
|
||||||
TextInput,
|
TextInput,
|
||||||
|
Textarea,
|
||||||
} from "@tremor/react";
|
} from "@tremor/react";
|
||||||
import { Select as Select3, SelectItem, MultiSelect, MultiSelectItem } from "@tremor/react";
|
import { Select as Select3, SelectItem, MultiSelect, MultiSelectItem } from "@tremor/react";
|
||||||
import {
|
import {
|
||||||
|
@ -40,6 +41,7 @@ import {
|
||||||
} from "antd";
|
} from "antd";
|
||||||
|
|
||||||
import { CopyToClipboard } from "react-copy-to-clipboard";
|
import { CopyToClipboard } from "react-copy-to-clipboard";
|
||||||
|
import TextArea from "antd/es/input/TextArea";
|
||||||
|
|
||||||
const { Option } = Select;
|
const { Option } = Select;
|
||||||
const isLocal = process.env.NODE_ENV === "development";
|
const isLocal = process.env.NODE_ENV === "development";
|
||||||
|
@ -438,6 +440,16 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
|
||||||
>
|
>
|
||||||
<InputNumber step={1} precision={1} width={200} />
|
<InputNumber step={1} precision={1} width={200} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
label="Metadata"
|
||||||
|
name="metadata"
|
||||||
|
initialValue={token.metadata}
|
||||||
|
>
|
||||||
|
<TextArea
|
||||||
|
value={String(token.metadata)}
|
||||||
|
rows={10}
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
</>
|
</>
|
||||||
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
<div style={{ textAlign: "right", marginTop: "10px" }}>
|
||||||
<Button2 htmlType="submit">Edit Key</Button2>
|
<Button2 htmlType="submit">Edit Key</Button2>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue