diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index 6c08758dd..23ee97a47 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -458,7 +458,7 @@ class AmazonConverseConfig: """ Abbreviations of regions AWS Bedrock supports for cross region inference """ - return ["us", "eu"] + return ["us", "eu", "apac"] def _get_base_model(self, model: str) -> str: """ diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 86ece3788..03d66351d 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -11,7 +11,11 @@ model_list: model: vertex_ai/claude-3-5-sonnet-v2 vertex_ai_project: "adroit-crow-413218" vertex_ai_location: "us-east5" - + - model_name: openai-gpt-4o-realtime-audio + litellm_params: + model: openai/gpt-4o-realtime-preview-2024-10-01 + api_key: os.environ/OPENAI_API_KEY + router_settings: routing_strategy: usage-based-routing-v2 #redis_url: "os.environ/REDIS_URL" diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 32f0c95db..c292a7dc3 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -28,6 +28,8 @@ from fastapi import ( Request, Response, UploadFile, + WebSocket, + WebSocketDisconnect, status, ) from fastapi.middleware.cors import CORSMiddleware @@ -195,6 +197,52 @@ def _is_allowed_route( ) +async def user_api_key_auth_websocket(websocket: WebSocket): + # Accept the WebSocket connection + + request = Request(scope={"type": "http"}) + request._url = websocket.url + + query_params = websocket.query_params + + model = query_params.get("model") + + async def return_body(): + return_string = f'{{"model": "{model}"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body # type: ignore + + # Extract the Authorization header + authorization = websocket.headers.get("authorization") + + # If no Authorization header, try the api-key header + if not authorization: + api_key = websocket.headers.get("api-key") + if not api_key: + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + raise HTTPException(status_code=403, detail="No API key provided") + else: + # Extract the API key from the Bearer token + if not authorization.startswith("Bearer "): + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + raise HTTPException( + status_code=403, detail="Invalid Authorization header format" + ) + + api_key = authorization[len("Bearer ") :].strip() + + # Call user_api_key_auth with the extracted API key + # Note: You'll need to modify this to work with WebSocket context if needed + try: + return await user_api_key_auth(request=request, api_key=f"Bearer {api_key}") + except Exception as e: + verbose_proxy_logger.exception(e) + await websocket.close(code=status.WS_1008_POLICY_VIOLATION) + raise HTTPException(status_code=403, detail=str(e)) + + async def user_api_key_auth( # noqa: PLR0915 request: Request, api_key: str = fastapi.Security(api_key_header), diff --git a/litellm/proxy/common_utils/http_parsing_utils.py b/litellm/proxy/common_utils/http_parsing_utils.py index 0a8dd86eb..deb259895 100644 --- a/litellm/proxy/common_utils/http_parsing_utils.py +++ b/litellm/proxy/common_utils/http_parsing_utils.py @@ -1,6 +1,6 @@ import ast import json -from typing import List, Optional +from typing import Dict, List, Optional from fastapi import Request, UploadFile, status @@ -8,31 +8,43 @@ from litellm._logging import verbose_proxy_logger from litellm.types.router import Deployment -async def _read_request_body(request: Optional[Request]) -> dict: +async def _read_request_body(request: Optional[Request]) -> Dict: """ - Asynchronous function to read the request body and parse it as JSON or literal data. + Safely read the request body and parse it as JSON. Parameters: - request: The request object to read the body from Returns: - - dict: Parsed request data as a dictionary + - dict: Parsed request data as a dictionary or an empty dictionary if parsing fails """ try: - request_data: dict = {} if request is None: - return request_data + return {} + + # Read the request body body = await request.body() - if body == b"" or body is None: - return request_data + # Return empty dict if body is empty or None + if not body: + return {} + + # Decode the body to a string body_str = body.decode() - try: - request_data = ast.literal_eval(body_str) - except Exception: - request_data = json.loads(body_str) - return request_data - except Exception: + + # Attempt JSON parsing (safe for untrusted input) + return json.loads(body_str) + + except json.JSONDecodeError: + # Log detailed information for debugging + verbose_proxy_logger.exception("Invalid JSON payload received.") + return {} + + except Exception as e: + # Catch unexpected errors to avoid crashes + verbose_proxy_logger.exception( + "Unexpected error reading request body - {}".format(e) + ) return {} diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index afb83aa37..3f0425809 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -134,7 +134,10 @@ from litellm.proxy.auth.model_checks import ( get_key_models, get_team_models, ) -from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.auth.user_api_key_auth import ( + user_api_key_auth, + user_api_key_auth_websocket, +) ## Import All Misc routes here ## from litellm.proxy.caching_routes import router as caching_router @@ -4339,7 +4342,11 @@ from litellm import _arealtime @app.websocket("/v1/realtime") -async def websocket_endpoint(websocket: WebSocket, model: str): +async def websocket_endpoint( + websocket: WebSocket, + model: str, + user_api_key_dict=Depends(user_api_key_auth_websocket), +): import websockets await websocket.accept() diff --git a/litellm/tests/test_mlflow.py b/litellm/tests/test_mlflow.py deleted file mode 100644 index ec23875ea..000000000 --- a/litellm/tests/test_mlflow.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest - -import litellm - - -def test_mlflow_logging(): - litellm.success_callback = ["mlflow"] - litellm.failure_callback = ["mlflow"] - - litellm.completion( - model="gpt-4o-mini", - messages=[{"role": "user", "content": "what llm are u"}], - max_tokens=10, - temperature=0.2, - user="test-user", - ) - -@pytest.mark.asyncio() -async def test_async_mlflow_logging(): - litellm.success_callback = ["mlflow"] - litellm.failure_callback = ["mlflow"] - - await litellm.acompletion( - model="gpt-4o-mini", - messages=[{"role": "user", "content": "hi test from local arize"}], - mock_response="hello", - temperature=0.1, - user="OTEL_USER", - ) diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index 35a9fc276..e1bd7a9ab 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -1243,6 +1243,19 @@ def test_bedrock_cross_region_inference(model): ) +@pytest.mark.parametrize( + "model, expected_base_model", + [ + ( + "apac.anthropic.claude-3-5-sonnet-20240620-v1:0", + "anthropic.claude-3-5-sonnet-20240620-v1:0", + ), + ], +) +def test_bedrock_get_base_model(model, expected_base_model): + assert litellm.AmazonConverseConfig()._get_base_model(model) == expected_base_model + + from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt diff --git a/tests/local_testing/test_http_parsing_utils.py b/tests/local_testing/test_http_parsing_utils.py new file mode 100644 index 000000000..2c6956c79 --- /dev/null +++ b/tests/local_testing/test_http_parsing_utils.py @@ -0,0 +1,79 @@ +import pytest +from fastapi import Request +from fastapi.testclient import TestClient +from starlette.datastructures import Headers +from starlette.requests import HTTPConnection +import os +import sys + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +from litellm.proxy.common_utils.http_parsing_utils import _read_request_body + + +@pytest.mark.asyncio +async def test_read_request_body_valid_json(): + """Test the function with a valid JSON payload.""" + + class MockRequest: + async def body(self): + return b'{"key": "value"}' + + request = MockRequest() + result = await _read_request_body(request) + assert result == {"key": "value"} + + +@pytest.mark.asyncio +async def test_read_request_body_empty_body(): + """Test the function with an empty body.""" + + class MockRequest: + async def body(self): + return b"" + + request = MockRequest() + result = await _read_request_body(request) + assert result == {} + + +@pytest.mark.asyncio +async def test_read_request_body_invalid_json(): + """Test the function with an invalid JSON payload.""" + + class MockRequest: + async def body(self): + return b'{"key": value}' # Missing quotes around `value` + + request = MockRequest() + result = await _read_request_body(request) + assert result == {} # Should return an empty dict on failure + + +@pytest.mark.asyncio +async def test_read_request_body_large_payload(): + """Test the function with a very large payload.""" + large_payload = '{"key":' + '"a"' * 10**6 + "}" # Large payload + + class MockRequest: + async def body(self): + return large_payload.encode() + + request = MockRequest() + result = await _read_request_body(request) + assert result == {} # Large payloads could trigger errors, so validate behavior + + +@pytest.mark.asyncio +async def test_read_request_body_unexpected_error(): + """Test the function when an unexpected error occurs.""" + + class MockRequest: + async def body(self): + raise ValueError("Unexpected error") + + request = MockRequest() + result = await _read_request_body(request) + assert result == {} # Ensure fallback behavior diff --git a/tests/local_testing/test_router_init.py b/tests/local_testing/test_router_init.py index 3733af252..9b4e12f12 100644 --- a/tests/local_testing/test_router_init.py +++ b/tests/local_testing/test_router_init.py @@ -536,7 +536,7 @@ def test_init_clients_azure_command_r_plus(): @pytest.mark.asyncio -async def test_text_completion_with_organization(): +async def test_aaaaatext_completion_with_organization(): try: print("Testing Text OpenAI with organization") model_list = [ diff --git a/tests/local_testing/test_user_api_key_auth.py b/tests/local_testing/test_user_api_key_auth.py index 1a129489c..167809da1 100644 --- a/tests/local_testing/test_user_api_key_auth.py +++ b/tests/local_testing/test_user_api_key_auth.py @@ -415,3 +415,18 @@ def test_allowed_route_inside_route( ) == expected_result ) + + +def test_read_request_body(): + from litellm.proxy.common_utils.http_parsing_utils import _read_request_body + from fastapi import Request + + payload = "()" * 1000000 + request = Request(scope={"type": "http"}) + + async def return_body(): + return payload + + request.body = return_body + result = _read_request_body(request) + assert result is not None diff --git a/tests/router_unit_tests/test_router_endpoints.py b/tests/router_unit_tests/test_router_endpoints.py index 4c9fc8f35..19949ddba 100644 --- a/tests/router_unit_tests/test_router_endpoints.py +++ b/tests/router_unit_tests/test_router_endpoints.py @@ -215,7 +215,7 @@ async def test_rerank_endpoint(model_list): @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio -async def test_text_completion_endpoint(model_list, sync_mode): +async def test_aaaaatext_completion_endpoint(model_list, sync_mode): router = Router(model_list=model_list) if sync_mode: diff --git a/ui/litellm-dashboard/src/components/view_key_table.tsx b/ui/litellm-dashboard/src/components/view_key_table.tsx index b657ed47c..3ef50bc60 100644 --- a/ui/litellm-dashboard/src/components/view_key_table.tsx +++ b/ui/litellm-dashboard/src/components/view_key_table.tsx @@ -24,6 +24,7 @@ import { Icon, BarChart, TextInput, + Textarea, } from "@tremor/react"; import { Select as Select3, SelectItem, MultiSelect, MultiSelectItem } from "@tremor/react"; import { @@ -40,6 +41,7 @@ import { } from "antd"; import { CopyToClipboard } from "react-copy-to-clipboard"; +import TextArea from "antd/es/input/TextArea"; const { Option } = Select; const isLocal = process.env.NODE_ENV === "development"; @@ -438,6 +440,16 @@ const ViewKeyTable: React.FC = ({ > + +