diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 0fd174440..5fa6c0457 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -528,16 +528,19 @@ async def pass_through_request( # noqa: PLR0915 response_body: Optional[dict] = get_response_body(response) passthrough_logging_payload["response_body"] = response_body end_time = datetime.now() - await pass_through_endpoint_logging.pass_through_async_success_handler( - httpx_response=response, - response_body=response_body, - url_route=str(url), - result="", - start_time=start_time, - end_time=end_time, - logging_obj=logging_obj, - cache_hit=False, - **kwargs, + + asyncio.create_task( + pass_through_endpoint_logging.pass_through_async_success_handler( + httpx_response=response, + response_body=response_body, + url_route=str(url), + result="", + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + cache_hit=False, + **kwargs, + ) ) return Response( diff --git a/litellm/proxy/pass_through_endpoints/streaming_handler.py b/litellm/proxy/pass_through_endpoints/streaming_handler.py index dc6aae3af..15bc7127e 100644 --- a/litellm/proxy/pass_through_endpoints/streaming_handler.py +++ b/litellm/proxy/pass_through_endpoints/streaming_handler.py @@ -58,22 +58,24 @@ class PassThroughStreamingHandler: # After all chunks are processed, handle post-processing end_time = datetime.now() - await PassThroughStreamingHandler._route_streaming_logging_to_handler( - litellm_logging_obj=litellm_logging_obj, - passthrough_success_handler_obj=passthrough_success_handler_obj, - url_route=url_route, - request_body=request_body or {}, - endpoint_type=endpoint_type, - start_time=start_time, - raw_bytes=raw_bytes, - end_time=end_time, + asyncio.create_task( + PassThroughStreamingHandler.handle_logging_collected_stream_response( + litellm_logging_obj=litellm_logging_obj, + passthrough_success_handler_obj=passthrough_success_handler_obj, + url_route=url_route, + request_body=request_body or {}, + endpoint_type=endpoint_type, + start_time=start_time, + raw_bytes=raw_bytes, + end_time=end_time, + ) ) except Exception as e: verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}") raise @staticmethod - async def _route_streaming_logging_to_handler( + async def handle_logging_collected_stream_response( litellm_logging_obj: LiteLLMLoggingObj, passthrough_success_handler_obj: PassThroughEndpointLogging, url_route: str, @@ -108,9 +110,9 @@ class PassThroughStreamingHandler: all_chunks=all_chunks, end_time=end_time, ) - standard_logging_response_object = anthropic_passthrough_logging_handler_result[ - "result" - ] + standard_logging_response_object = ( + anthropic_passthrough_logging_handler_result["result"] + ) kwargs = anthropic_passthrough_logging_handler_result["kwargs"] elif endpoint_type == EndpointType.VERTEX_AI: vertex_passthrough_logging_handler_result = ( @@ -125,9 +127,9 @@ class PassThroughStreamingHandler: end_time=end_time, ) ) - standard_logging_response_object = vertex_passthrough_logging_handler_result[ - "result" - ] + standard_logging_response_object = ( + vertex_passthrough_logging_handler_result["result"] + ) kwargs = vertex_passthrough_logging_handler_result["kwargs"] if standard_logging_response_object is None: @@ -168,4 +170,4 @@ class PassThroughStreamingHandler: # Split by newlines and filter out empty lines lines = [line.strip() for line in combined_str.split("\n") if line.strip()] - return lines \ No newline at end of file + return lines diff --git a/tests/local_testing/test_router_provider_budgets.py b/tests/local_testing/test_router_provider_budgets.py index 46b9ee29e..541592ddb 100644 --- a/tests/local_testing/test_router_provider_budgets.py +++ b/tests/local_testing/test_router_provider_budgets.py @@ -142,24 +142,8 @@ async def test_provider_budgets_e2e_test_expect_to_fail(): assert "Exceeded budget for provider" in str(exc_info.value) -def test_get_ttl_seconds(): - """ - Test the get_ttl_seconds helper method" - - """ - provider_budget = ProviderBudgetLimiting( - router_cache=DualCache(), provider_budget_config={} - ) - - assert provider_budget.get_ttl_seconds("1d") == 86400 # 1 day in seconds - assert provider_budget.get_ttl_seconds("7d") == 604800 # 7 days in seconds - assert provider_budget.get_ttl_seconds("30d") == 2592000 # 30 days in seconds - - with pytest.raises(ValueError, match="Unsupported time period format"): - provider_budget.get_ttl_seconds("1h") - - -def test_get_llm_provider_for_deployment(): +@pytest.mark.asyncio +async def test_get_llm_provider_for_deployment(): """ Test the _get_llm_provider_for_deployment helper method @@ -189,7 +173,8 @@ def test_get_llm_provider_for_deployment(): assert provider_budget._get_llm_provider_for_deployment(unknown_deployment) is None -def test_get_budget_config_for_provider(): +@pytest.mark.asyncio +async def test_get_budget_config_for_provider(): """ Test the _get_budget_config_for_provider helper method