diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 60050fbeb2..2ea3c18ea8 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -108,7 +108,13 @@ class ProxyBaseLLMRequestProcessing: user_api_key_dict: UserAPIKeyAuth, proxy_logging_obj: ProxyLogging, proxy_config: ProxyConfig, - route_type: Literal["acompletion", "aresponses", "_arealtime"], + route_type: Literal[ + "acompletion", + "aresponses", + "_arealtime", + "aget_responses", + "adelete_responses", + ], version: Optional[str] = None, user_model: Optional[str] = None, user_temperature: Optional[float] = None, @@ -178,7 +184,13 @@ class ProxyBaseLLMRequestProcessing: request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth, - route_type: Literal["acompletion", "aresponses", "_arealtime"], + route_type: Literal[ + "acompletion", + "aresponses", + "_arealtime", + "aget_responses", + "adelete_responses", + ], proxy_logging_obj: ProxyLogging, general_settings: dict, proxy_config: ProxyConfig, diff --git a/litellm/proxy/common_utils/http_parsing_utils.py b/litellm/proxy/common_utils/http_parsing_utils.py index ca4b5a0588..948233ad3e 100644 --- a/litellm/proxy/common_utils/http_parsing_utils.py +++ b/litellm/proxy/common_utils/http_parsing_utils.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import orjson from fastapi import Request, UploadFile, status @@ -147,10 +147,10 @@ def check_file_size_under_limit( if llm_router is not None and request_data["model"] in router_model_names: try: - deployment: Optional[ - Deployment - ] = llm_router.get_deployment_by_model_group_name( - model_group_name=request_data["model"] + deployment: Optional[Deployment] = ( + llm_router.get_deployment_by_model_group_name( + model_group_name=request_data["model"] + ) ) if ( deployment @@ -185,3 +185,23 @@ def check_file_size_under_limit( ) return True + + +async def get_form_data(request: Request) -> Dict[str, Any]: + """ + Read form data from request + + Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]` + """ + form = await request.form() + form_data = dict(form) + parsed_form_data: dict[str, Any] = {} + for key, value in form_data.items(): + + # OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]` + if key.endswith("[]"): + clean_key = key[:-2] + parsed_form_data.setdefault(clean_key, []).append(value) + else: + parsed_form_data[key] = value + return parsed_form_data diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 63d0a3ffe2..75c49211e7 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,16 +1,8 @@ model_list: - - model_name: azure-computer-use-preview + - model_name: openai/* litellm_params: - model: azure/computer-use-preview - api_key: mock-api-key - api_version: mock-api-version - api_base: https://mock-endpoint.openai.azure.com - - model_name: azure-computer-use-preview - litellm_params: - model: azure/computer-use-preview-2 - api_key: mock-api-key-2 - api_version: mock-api-version-2 - api_base: https://mock-endpoint-2.openai.azure.com + model: openai/* + api_key: os.environ/OPENAI_API_KEY router_settings: optional_pre_call_checks: ["responses_api_deployment_check"] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index fd32a62ee4..8c5b88be4d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -179,6 +179,7 @@ from litellm.proxy.common_utils.html_forms.ui_login import html_form from litellm.proxy.common_utils.http_parsing_utils import ( _read_request_body, check_file_size_under_limit, + get_form_data, ) from litellm.proxy.common_utils.load_config_utils import ( get_config_file_contents_from_gcs, @@ -804,9 +805,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter( dual_cache=user_api_key_cache ) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) -redis_usage_cache: Optional[ - RedisCache -] = None # redis cache used for tracking spend, tpm/rpm limits +redis_usage_cache: Optional[RedisCache] = ( + None # redis cache used for tracking spend, tpm/rpm limits +) user_custom_auth = None user_custom_key_generate = None user_custom_sso = None @@ -1132,9 +1133,9 @@ async def update_cache( # noqa: PLR0915 _id = "team_id:{}".format(team_id) try: # Fetch the existing cost for the given user - existing_spend_obj: Optional[ - LiteLLM_TeamTable - ] = await user_api_key_cache.async_get_cache(key=_id) + existing_spend_obj: Optional[LiteLLM_TeamTable] = ( + await user_api_key_cache.async_get_cache(key=_id) + ) if existing_spend_obj is None: # do nothing if team not in api key cache return @@ -2806,9 +2807,9 @@ async def initialize( # noqa: PLR0915 user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base if api_version: - os.environ[ - "AZURE_API_VERSION" - ] = api_version # set this for azure - litellm can read this from the env + os.environ["AZURE_API_VERSION"] = ( + api_version # set this for azure - litellm can read this from the env + ) if max_tokens: # model-specific param dynamic_config[user_model]["max_tokens"] = max_tokens if temperature: # model-specific param @@ -4120,7 +4121,7 @@ async def audio_transcriptions( data: Dict = {} try: # Use orjson to parse JSON data, orjson speeds up requests significantly - form_data = await request.form() + form_data = await get_form_data(request) data = {key: value for key, value in form_data.items() if key != "file"} # Include original request and headers in the data @@ -7758,9 +7759,9 @@ async def get_config_list( hasattr(sub_field_info, "description") and sub_field_info.description is not None ): - nested_fields[ - idx - ].field_description = sub_field_info.description + nested_fields[idx].field_description = ( + sub_field_info.description + ) idx += 1 _stored_in_db = None diff --git a/litellm/proxy/response_api_endpoints/endpoints.py b/litellm/proxy/response_api_endpoints/endpoints.py index f9ddf306a7..bef5b4b807 100644 --- a/litellm/proxy/response_api_endpoints/endpoints.py +++ b/litellm/proxy/response_api_endpoints/endpoints.py @@ -106,8 +106,50 @@ async def get_response( -H "Authorization: Bearer sk-1234" ``` """ - # TODO: Implement response retrieval logic - pass + from litellm.proxy.proxy_server import ( + _read_request_body, + general_settings, + llm_router, + proxy_config, + proxy_logging_obj, + select_data_generator, + user_api_base, + user_max_tokens, + user_model, + user_request_timeout, + user_temperature, + version, + ) + + data = await _read_request_body(request=request) + data["response_id"] = response_id + processor = ProxyBaseLLMRequestProcessing(data=data) + try: + return await processor.base_process_llm_request( + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + route_type="aget_responses", + proxy_logging_obj=proxy_logging_obj, + llm_router=llm_router, + general_settings=general_settings, + proxy_config=proxy_config, + select_data_generator=select_data_generator, + model=None, + user_model=user_model, + user_temperature=user_temperature, + user_request_timeout=user_request_timeout, + user_max_tokens=user_max_tokens, + user_api_base=user_api_base, + version=version, + ) + except Exception as e: + raise await processor._handle_llm_api_exception( + e=e, + user_api_key_dict=user_api_key_dict, + proxy_logging_obj=proxy_logging_obj, + version=version, + ) @router.delete( @@ -136,8 +178,50 @@ async def delete_response( -H "Authorization: Bearer sk-1234" ``` """ - # TODO: Implement response deletion logic - pass + from litellm.proxy.proxy_server import ( + _read_request_body, + general_settings, + llm_router, + proxy_config, + proxy_logging_obj, + select_data_generator, + user_api_base, + user_max_tokens, + user_model, + user_request_timeout, + user_temperature, + version, + ) + + data = await _read_request_body(request=request) + data["response_id"] = response_id + processor = ProxyBaseLLMRequestProcessing(data=data) + try: + return await processor.base_process_llm_request( + request=request, + fastapi_response=fastapi_response, + user_api_key_dict=user_api_key_dict, + route_type="adelete_responses", + proxy_logging_obj=proxy_logging_obj, + llm_router=llm_router, + general_settings=general_settings, + proxy_config=proxy_config, + select_data_generator=select_data_generator, + model=None, + user_model=user_model, + user_temperature=user_temperature, + user_request_timeout=user_request_timeout, + user_max_tokens=user_max_tokens, + user_api_base=user_api_base, + version=version, + ) + except Exception as e: + raise await processor._handle_llm_api_exception( + e=e, + user_api_key_dict=user_api_key_dict, + proxy_logging_obj=proxy_logging_obj, + version=version, + ) @router.get( diff --git a/litellm/proxy/route_llm_request.py b/litellm/proxy/route_llm_request.py index ac9332b219..ef5edaff99 100644 --- a/litellm/proxy/route_llm_request.py +++ b/litellm/proxy/route_llm_request.py @@ -47,6 +47,8 @@ async def route_request( "amoderation", "arerank", "aresponses", + "aget_responses", + "adelete_responses", "_arealtime", # private function for realtime API ], ): diff --git a/litellm/responses/utils.py b/litellm/responses/utils.py index 9fa455de71..cf51c40b2a 100644 --- a/litellm/responses/utils.py +++ b/litellm/responses/utils.py @@ -176,6 +176,16 @@ class ResponsesAPIRequestUtils: response_id=response_id, ) + @staticmethod + def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]: + """Get the model_id from the response_id""" + if response_id is None: + return None + decoded_response_id = ( + ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id) + ) + return decoded_response_id.get("model_id") or None + class ResponseAPILoggingUtils: @staticmethod diff --git a/litellm/router.py b/litellm/router.py index dba886b856..e7f98fab48 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -739,6 +739,12 @@ class Router: litellm.afile_content, call_type="afile_content" ) self.responses = self.factory_function(litellm.responses, call_type="responses") + self.aget_responses = self.factory_function( + litellm.aget_responses, call_type="aget_responses" + ) + self.adelete_responses = self.factory_function( + litellm.adelete_responses, call_type="adelete_responses" + ) def validate_fallbacks(self, fallback_param: Optional[List]): """ @@ -3081,6 +3087,8 @@ class Router: "anthropic_messages", "aresponses", "responses", + "aget_responses", + "adelete_responses", "afile_delete", "afile_content", ] = "assistants", @@ -3135,6 +3143,11 @@ class Router: original_function=original_function, **kwargs, ) + elif call_type in ("aget_responses", "adelete_responses"): + return await self._init_responses_api_endpoints( + original_function=original_function, + **kwargs, + ) elif call_type in ("afile_delete", "afile_content"): return await self._ageneric_api_call_with_fallbacks( original_function=original_function, @@ -3145,6 +3158,28 @@ class Router: return async_wrapper + async def _init_responses_api_endpoints( + self, + original_function: Callable, + **kwargs, + ): + """ + Initialize the Responses API endpoints on the router. + + GET, DELETE Responses API Requests encode the model_id in the response_id, this function decodes the response_id and sets the model to the model_id. + """ + from litellm.responses.utils import ResponsesAPIRequestUtils + + model_id = ResponsesAPIRequestUtils.get_model_id_from_response_id( + kwargs.get("response_id") + ) + if model_id is not None: + kwargs["model"] = model_id + return await self._ageneric_api_call_with_fallbacks( + original_function=original_function, + **kwargs, + ) + async def _pass_through_assistants_endpoint_factory( self, original_function: Callable, diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 55052761c7..760a7c7842 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -7058,6 +7058,17 @@ "source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models", "supports_tool_choice": true }, + "command-a-03-2025": { + "max_tokens": 8000, + "max_input_tokens": 256000, + "max_output_tokens": 8000, + "input_cost_per_token": 0.0000025, + "output_cost_per_token": 0.00001, + "litellm_provider": "cohere_chat", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": true + }, "command-r": { "max_tokens": 4096, "max_input_tokens": 128000, diff --git a/tests/litellm/proxy/common_utils/test_http_parsing_utils.py b/tests/litellm/proxy/common_utils/test_http_parsing_utils.py index 38624422c6..b98a1f2657 100644 --- a/tests/litellm/proxy/common_utils/test_http_parsing_utils.py +++ b/tests/litellm/proxy/common_utils/test_http_parsing_utils.py @@ -18,6 +18,7 @@ from litellm.proxy.common_utils.http_parsing_utils import ( _read_request_body, _safe_get_request_parsed_body, _safe_set_request_parsed_body, + get_form_data, ) @@ -147,3 +148,53 @@ async def test_circular_reference_handling(): assert ( "proxy_server_request" not in result2 ) # This will pass, showing the cache pollution + + +@pytest.mark.asyncio +async def test_get_form_data(): + """ + Test that get_form_data correctly handles form data with array notation. + Tests audio transcription parameters as a specific example. + """ + # Create a mock request with transcription form data + mock_request = MagicMock() + + # Create mock form data with array notation for timestamp_granularities + mock_form_data = { + "file": "file_object", # In a real request this would be an UploadFile + "model": "gpt-4o-transcribe", + "include[]": "logprobs", # Array notation + "language": "en", + "prompt": "Transcribe this audio file", + "response_format": "json", + "stream": "false", + "temperature": "0.2", + "timestamp_granularities[]": "word", # First array item + "timestamp_granularities[]": "segment", # Second array item (would overwrite in dict, but handled by the function) + } + + # Mock the form method to return the test data + mock_request.form = AsyncMock(return_value=mock_form_data) + + # Call the function being tested + result = await get_form_data(mock_request) + + # Verify regular form fields are preserved + assert result["file"] == "file_object" + assert result["model"] == "gpt-4o-transcribe" + assert result["language"] == "en" + assert result["prompt"] == "Transcribe this audio file" + assert result["response_format"] == "json" + assert result["stream"] == "false" + assert result["temperature"] == "0.2" + + # Verify array fields are correctly parsed + assert "include" in result + assert isinstance(result["include"], list) + assert "logprobs" in result["include"] + + assert "timestamp_granularities" in result + assert isinstance(result["timestamp_granularities"], list) + # Note: In a real MultiDict, both values would be present + # But in our mock dictionary the second value overwrites the first + assert "segment" in result["timestamp_granularities"] diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index 230781c636..6a2cacd20a 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -947,6 +947,8 @@ class BaseLLMChatTest(ABC): second_response.choices[0].message.content is not None or second_response.choices[0].message.tool_calls is not None ) + except litellm.ServiceUnavailableError: + pytest.skip("Model is overloaded") except litellm.InternalServerError: pytest.skip("Model is overloaded") except litellm.RateLimitError: diff --git a/tests/openai_endpoints_tests/test_e2e_openai_responses_api.py b/tests/openai_endpoints_tests/test_e2e_openai_responses_api.py index 1dde8ebae6..f87444141e 100644 --- a/tests/openai_endpoints_tests/test_e2e_openai_responses_api.py +++ b/tests/openai_endpoints_tests/test_e2e_openai_responses_api.py @@ -73,15 +73,31 @@ def validate_stream_chunk(chunk): def test_basic_response(): client = get_test_client() response = client.responses.create( - model="gpt-4o", input="just respond with the word 'ping'" + model="gpt-4.0", input="just respond with the word 'ping'" ) print("basic response=", response) + # get the response + response = client.responses.retrieve(response.id) + print("GET response=", response) + + + # delete the response + delete_response = client.responses.delete(response.id) + print("DELETE response=", delete_response) + + # try getting the response again, we should not get it back + get_response = client.responses.retrieve(response.id) + print("GET response after delete=", get_response) + + with pytest.raises(Exception): + get_response = client.responses.retrieve(response.id) + def test_streaming_response(): client = get_test_client() stream = client.responses.create( - model="gpt-4o", input="just respond with the word 'ping'", stream=True + model="gpt-4.0", input="just respond with the word 'ping'", stream=True ) collected_chunks = [] @@ -104,5 +120,5 @@ def test_bad_request_bad_param_error(): with pytest.raises(BadRequestError): # Trigger error with invalid model name client.responses.create( - model="gpt-4o", input="This should fail", temperature=2000 + model="gpt-4.0", input="This should fail", temperature=2000 ) diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index 782f0d8fbb..b17f0c0a5e 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -1157,3 +1157,14 @@ def test_cached_get_model_group_info(model_list): # Verify the cache info shows hits cache_info = router._cached_get_model_group_info.cache_info() assert cache_info.hits > 0 # Should have at least one cache hit + + +def test_init_responses_api_endpoints(model_list): + """Test if the '_init_responses_api_endpoints' function is working correctly""" + from typing import Callable + router = Router(model_list=model_list) + + assert router.aget_responses is not None + assert isinstance(router.aget_responses, Callable) + assert router.adelete_responses is not None + assert isinstance(router.adelete_responses, Callable) diff --git a/ui/litellm-dashboard/src/app/page.tsx b/ui/litellm-dashboard/src/app/page.tsx index 1fea83d054..2a3b7422ce 100644 --- a/ui/litellm-dashboard/src/app/page.tsx +++ b/ui/litellm-dashboard/src/app/page.tsx @@ -151,7 +151,7 @@ export default function CreateKeyPage() { if (redirectToLogin) { window.location.href = (proxyBaseUrl || "") + "/sso/key/generate" } - }, [token, authLoading]) + }, [redirectToLogin]) useEffect(() => { if (!token) { @@ -223,7 +223,7 @@ export default function CreateKeyPage() { } }, [accessToken, userID, userRole]); - if (authLoading) { + if (authLoading || redirectToLogin) { return } diff --git a/ui/litellm-dashboard/src/components/all_keys_table.tsx b/ui/litellm-dashboard/src/components/all_keys_table.tsx index 403b131c65..6ecd802b65 100644 --- a/ui/litellm-dashboard/src/components/all_keys_table.tsx +++ b/ui/litellm-dashboard/src/components/all_keys_table.tsx @@ -450,6 +450,8 @@ export function AllKeysTable({ columns={columns.filter(col => col.id !== 'expander') as any} data={filteredKeys as any} isLoading={isLoading} + loadingMessage="🚅 Loading keys..." + noDataMessage="No keys found" getRowCanExpand={() => false} renderSubComponent={() => <>} /> diff --git a/ui/litellm-dashboard/src/components/mcp_tools/index.tsx b/ui/litellm-dashboard/src/components/mcp_tools/index.tsx index ae3d4fac62..c403b49668 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/index.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/index.tsx @@ -26,6 +26,8 @@ function DataTableWrapper({ isLoading={isLoading} renderSubComponent={renderSubComponent} getRowCanExpand={getRowCanExpand} + loadingMessage="🚅 Loading tools..." + noDataMessage="No tools found" /> ); } diff --git a/ui/litellm-dashboard/src/components/view_logs/table.tsx b/ui/litellm-dashboard/src/components/view_logs/table.tsx index ed728de074..51131f0dd7 100644 --- a/ui/litellm-dashboard/src/components/view_logs/table.tsx +++ b/ui/litellm-dashboard/src/components/view_logs/table.tsx @@ -26,6 +26,8 @@ interface DataTableProps { expandedRequestId?: string | null; onRowExpand?: (requestId: string | null) => void; setSelectedKeyIdInfoView?: (keyId: string | null) => void; + loadingMessage?: string; + noDataMessage?: string; } export function DataTable({ @@ -36,6 +38,8 @@ export function DataTable({ isLoading = false, expandedRequestId, onRowExpand, + loadingMessage = "🚅 Loading logs...", + noDataMessage = "No logs found", }: DataTableProps) { const table = useReactTable({ data, @@ -114,7 +118,7 @@ export function DataTable({
-

🚅 Loading logs...

+

{loadingMessage}

@@ -147,7 +151,7 @@ export function DataTable({ :
-

No logs found

+

{noDataMessage}