Compare commits

...
Sign in to create a new pull request.

6 commits

Author SHA1 Message Date
Krrish Dholakia
03a5cc364f test: run flaky tests first 2024-11-27 23:32:58 +05:30
Krrish Dholakia
fc30b20c6e fix: fix linting error 2024-11-27 22:32:12 +05:30
Krrish Dholakia
74d59d74d4 fix(user_api_key_auth.py): use model from query param 2024-11-27 17:52:04 +05:30
Krrish Dholakia
9ec6ebaeeb fix(user_api_key_auth.py): add auth check for websocket endpoint
Fixes https://github.com/BerriAI/litellm/issues/6926
2024-11-27 17:45:59 +05:30
Krrish Dholakia
037171b98b fix(converse/transformation.py): support bedrock apac cross region inference
Fixes https://github.com/BerriAI/litellm/issues/6905
2024-11-27 16:18:40 +05:30
Krrish Dholakia
bbf31346ca fix(http_parsing_utils.py): remove ast.literal_eval() from http utils
Security fix - https://huntr.com/bounties/96a32812-213c-4819-ba4e-36143d35e95b?token=bf414bbd77f8b346556e
64ab2dd9301ea44339910877ea50401c76f977e36cdd78272f5fb4ca852a88a7e832828aae1192df98680544ee24aa98f3cf6980d8
bab641a66b7ccbc02c0e7d4ddba2db4dbe7318889dc0098d8db2d639f345f574159814627bb084563bad472e2f990f825bff0878a9
e281e72c88b4bc5884d637d186c0d67c9987c57c3f0caf395aff07b89ad2b7220d1dd7d1b427fd2260b5f01090efce5250f8b56ea2
c0ec19916c24b23825d85ce119911275944c840a1340d69e23ca6a462da610
2024-11-27 13:54:59 +05:30
15 changed files with 210 additions and 52 deletions

View file

@ -458,7 +458,7 @@ class AmazonConverseConfig:
"""
Abbreviations of regions AWS Bedrock supports for cross region inference
"""
return ["us", "eu"]
return ["us", "eu", "apac"]
def _get_base_model(self, model: str) -> str:
"""

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -11,7 +11,11 @@ model_list:
model: vertex_ai/claude-3-5-sonnet-v2
vertex_ai_project: "adroit-crow-413218"
vertex_ai_location: "us-east5"
- model_name: openai-gpt-4o-realtime-audio
litellm_params:
model: openai/gpt-4o-realtime-preview-2024-10-01
api_key: os.environ/OPENAI_API_KEY
router_settings:
routing_strategy: usage-based-routing-v2
#redis_url: "os.environ/REDIS_URL"

View file

@ -28,6 +28,8 @@ from fastapi import (
Request,
Response,
UploadFile,
WebSocket,
WebSocketDisconnect,
status,
)
from fastapi.middleware.cors import CORSMiddleware
@ -195,6 +197,52 @@ def _is_allowed_route(
)
async def user_api_key_auth_websocket(websocket: WebSocket):
# Accept the WebSocket connection
request = Request(scope={"type": "http"})
request._url = websocket.url
query_params = websocket.query_params
model = query_params.get("model")
async def return_body():
return_string = f'{{"model": "{model}"}}'
# return string as bytes
return return_string.encode()
request.body = return_body # type: ignore
# Extract the Authorization header
authorization = websocket.headers.get("authorization")
# If no Authorization header, try the api-key header
if not authorization:
api_key = websocket.headers.get("api-key")
if not api_key:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
raise HTTPException(status_code=403, detail="No API key provided")
else:
# Extract the API key from the Bearer token
if not authorization.startswith("Bearer "):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
raise HTTPException(
status_code=403, detail="Invalid Authorization header format"
)
api_key = authorization[len("Bearer ") :].strip()
# Call user_api_key_auth with the extracted API key
# Note: You'll need to modify this to work with WebSocket context if needed
try:
return await user_api_key_auth(request=request, api_key=f"Bearer {api_key}")
except Exception as e:
verbose_proxy_logger.exception(e)
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
raise HTTPException(status_code=403, detail=str(e))
async def user_api_key_auth( # noqa: PLR0915
request: Request,
api_key: str = fastapi.Security(api_key_header),

View file

@ -1,6 +1,6 @@
import ast
import json
from typing import List, Optional
from typing import Dict, List, Optional
from fastapi import Request, UploadFile, status
@ -8,31 +8,43 @@ from litellm._logging import verbose_proxy_logger
from litellm.types.router import Deployment
async def _read_request_body(request: Optional[Request]) -> dict:
async def _read_request_body(request: Optional[Request]) -> Dict:
"""
Asynchronous function to read the request body and parse it as JSON or literal data.
Safely read the request body and parse it as JSON.
Parameters:
- request: The request object to read the body from
Returns:
- dict: Parsed request data as a dictionary
- dict: Parsed request data as a dictionary or an empty dictionary if parsing fails
"""
try:
request_data: dict = {}
if request is None:
return request_data
return {}
# Read the request body
body = await request.body()
if body == b"" or body is None:
return request_data
# Return empty dict if body is empty or None
if not body:
return {}
# Decode the body to a string
body_str = body.decode()
try:
request_data = ast.literal_eval(body_str)
except Exception:
request_data = json.loads(body_str)
return request_data
except Exception:
# Attempt JSON parsing (safe for untrusted input)
return json.loads(body_str)
except json.JSONDecodeError:
# Log detailed information for debugging
verbose_proxy_logger.exception("Invalid JSON payload received.")
return {}
except Exception as e:
# Catch unexpected errors to avoid crashes
verbose_proxy_logger.exception(
"Unexpected error reading request body - {}".format(e)
)
return {}

View file

@ -134,7 +134,10 @@ from litellm.proxy.auth.model_checks import (
get_key_models,
get_team_models,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.auth.user_api_key_auth import (
user_api_key_auth,
user_api_key_auth_websocket,
)
## Import All Misc routes here ##
from litellm.proxy.caching_routes import router as caching_router
@ -4394,7 +4397,11 @@ from litellm import _arealtime
@app.websocket("/v1/realtime")
async def websocket_endpoint(websocket: WebSocket, model: str):
async def websocket_endpoint(
websocket: WebSocket,
model: str,
user_api_key_dict=Depends(user_api_key_auth_websocket),
):
import websockets
await websocket.accept()

View file

@ -1,29 +0,0 @@
import pytest
import litellm
def test_mlflow_logging():
litellm.success_callback = ["mlflow"]
litellm.failure_callback = ["mlflow"]
litellm.completion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "what llm are u"}],
max_tokens=10,
temperature=0.2,
user="test-user",
)
@pytest.mark.asyncio()
async def test_async_mlflow_logging():
litellm.success_callback = ["mlflow"]
litellm.failure_callback = ["mlflow"]
await litellm.acompletion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "hi test from local arize"}],
mock_response="hello",
temperature=0.1,
user="OTEL_USER",
)

View file

@ -1243,6 +1243,19 @@ def test_bedrock_cross_region_inference(model):
)
@pytest.mark.parametrize(
"model, expected_base_model",
[
(
"apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
),
],
)
def test_bedrock_get_base_model(model, expected_base_model):
assert litellm.AmazonConverseConfig()._get_base_model(model) == expected_base_model
from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt

View file

@ -0,0 +1,79 @@
import pytest
from fastapi import Request
from fastapi.testclient import TestClient
from starlette.datastructures import Headers
from starlette.requests import HTTPConnection
import os
import sys
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
@pytest.mark.asyncio
async def test_read_request_body_valid_json():
"""Test the function with a valid JSON payload."""
class MockRequest:
async def body(self):
return b'{"key": "value"}'
request = MockRequest()
result = await _read_request_body(request)
assert result == {"key": "value"}
@pytest.mark.asyncio
async def test_read_request_body_empty_body():
"""Test the function with an empty body."""
class MockRequest:
async def body(self):
return b""
request = MockRequest()
result = await _read_request_body(request)
assert result == {}
@pytest.mark.asyncio
async def test_read_request_body_invalid_json():
"""Test the function with an invalid JSON payload."""
class MockRequest:
async def body(self):
return b'{"key": value}' # Missing quotes around `value`
request = MockRequest()
result = await _read_request_body(request)
assert result == {} # Should return an empty dict on failure
@pytest.mark.asyncio
async def test_read_request_body_large_payload():
"""Test the function with a very large payload."""
large_payload = '{"key":' + '"a"' * 10**6 + "}" # Large payload
class MockRequest:
async def body(self):
return large_payload.encode()
request = MockRequest()
result = await _read_request_body(request)
assert result == {} # Large payloads could trigger errors, so validate behavior
@pytest.mark.asyncio
async def test_read_request_body_unexpected_error():
"""Test the function when an unexpected error occurs."""
class MockRequest:
async def body(self):
raise ValueError("Unexpected error")
request = MockRequest()
result = await _read_request_body(request)
assert result == {} # Ensure fallback behavior

View file

@ -536,7 +536,7 @@ def test_init_clients_azure_command_r_plus():
@pytest.mark.asyncio
async def test_text_completion_with_organization():
async def test_aaaaatext_completion_with_organization():
try:
print("Testing Text OpenAI with organization")
model_list = [

View file

@ -415,3 +415,18 @@ def test_allowed_route_inside_route(
)
== expected_result
)
def test_read_request_body():
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from fastapi import Request
payload = "()" * 1000000
request = Request(scope={"type": "http"})
async def return_body():
return payload
request.body = return_body
result = _read_request_body(request)
assert result is not None

View file

@ -215,7 +215,7 @@ async def test_rerank_endpoint(model_list):
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_text_completion_endpoint(model_list, sync_mode):
async def test_aaaaatext_completion_endpoint(model_list, sync_mode):
router = Router(model_list=model_list)
if sync_mode:

View file

@ -24,6 +24,7 @@ import {
Icon,
BarChart,
TextInput,
Textarea,
} from "@tremor/react";
import { Select as Select3, SelectItem, MultiSelect, MultiSelectItem } from "@tremor/react";
import {
@ -40,6 +41,7 @@ import {
} from "antd";
import { CopyToClipboard } from "react-copy-to-clipboard";
import TextArea from "antd/es/input/TextArea";
const { Option } = Select;
const isLocal = process.env.NODE_ENV === "development";
@ -438,6 +440,16 @@ const ViewKeyTable: React.FC<ViewKeyTableProps> = ({
>
<InputNumber step={1} precision={1} width={200} />
</Form.Item>
<Form.Item
label="Metadata"
name="metadata"
initialValue={token.metadata}
>
<TextArea
value={String(token.metadata)}
rows={10}
/>
</Form.Item>
</>
<div style={{ textAlign: "right", marginTop: "10px" }}>
<Button2 htmlType="submit">Edit Key</Button2>