diff --git a/litellm/router.py b/litellm/router.py index 7aa2528504..065905503e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -801,9 +801,7 @@ class Router: kwargs["stream"] = stream kwargs["original_function"] = self._acompletion self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) - request_priority = kwargs.get("priority") or self.default_priority - start_time = time.time() if request_priority is not None and isinstance(request_priority, int): response = await self.schedule_acompletion(**kwargs) @@ -1422,7 +1420,7 @@ class Router: kwargs["prompt"] = prompt kwargs["original_function"] = self._aimage_generation kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) response = await self.async_function_with_fallbacks(**kwargs) return response @@ -1660,13 +1658,7 @@ class Router: messages=[{"role": "user", "content": "prompt"}], specific_deployment=kwargs.pop("specific_deployment", None), ) - kwargs.setdefault("metadata", {}).update( - { - "deployment": deployment["litellm_params"]["model"], - "model_info": deployment.get("model_info", {}), - } - ) - kwargs["model_info"] = deployment.get("model_info", {}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) data = deployment["litellm_params"].copy() data["model"] for k, v in self.default_litellm_params.items(): @@ -1777,7 +1769,7 @@ class Router: messages = [{"role": "user", "content": "dummy-text"}] try: kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) # pick the one that is available (lowest TPM/RPM) deployment = await self.async_get_available_deployment( @@ -2215,7 +2207,7 @@ class Router: kwargs["model"] = model kwargs["original_function"] = self._acreate_file kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) response = await self.async_function_with_fallbacks(**kwargs) return response @@ -2320,7 +2312,7 @@ class Router: kwargs["model"] = model kwargs["original_function"] = self._acreate_batch kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) - kwargs.setdefault("metadata", {}).update({"model_group": model}) + self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) response = await self.async_function_with_fallbacks(**kwargs) return response diff --git a/tests/router_unit_tests/test_router_endpoints.py b/tests/router_unit_tests/test_router_endpoints.py index 0d13b9c610..e876d37662 100644 --- a/tests/router_unit_tests/test_router_endpoints.py +++ b/tests/router_unit_tests/test_router_endpoints.py @@ -1,6 +1,8 @@ import sys import os +import json import traceback +from typing import Optional from dotenv import load_dotenv from fastapi import Request from datetime import datetime @@ -9,6 +11,7 @@ sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path from litellm import Router, CustomLogger +from litellm.types.utils import StandardLoggingPayload # Get the current directory of the file being run pwd = os.path.dirname(os.path.realpath(__file__)) @@ -76,19 +79,20 @@ class MyCustomHandler(CustomLogger): print("logging a transcript kwargs: ", kwargs) print("openai client=", kwargs.get("client")) self.openai_client = kwargs.get("client") + self.standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object" + ) except Exception: pass -proxy_handler_instance = MyCustomHandler() - - # Set litellm.callbacks = [proxy_handler_instance] on the proxy # need to set litellm.callbacks = [proxy_handler_instance] # on the proxy @pytest.mark.asyncio @pytest.mark.flaky(retries=6, delay=10) async def test_transcription_on_router(): + proxy_handler_instance = MyCustomHandler() litellm.set_verbose = True litellm.callbacks = [proxy_handler_instance] print("\n Testing async transcription on router\n") @@ -150,7 +154,9 @@ async def test_transcription_on_router(): @pytest.mark.parametrize("mode", ["iterator"]) # "file", @pytest.mark.asyncio async def test_audio_speech_router(mode): - + litellm.set_verbose = True + test_logger = MyCustomHandler() + litellm.callbacks = [test_logger] from litellm import Router client = Router( @@ -178,10 +184,19 @@ async def test_audio_speech_router(mode): optional_params={}, ) + await asyncio.sleep(3) + from litellm.llms.openai.openai import HttpxBinaryResponseContent assert isinstance(response, HttpxBinaryResponseContent) + assert test_logger.standard_logging_object is not None + print( + "standard_logging_object=", + json.dumps(test_logger.standard_logging_object, indent=4), + ) + assert test_logger.standard_logging_object["model_group"] == "tts" + @pytest.mark.asyncio() async def test_rerank_endpoint(model_list):