(Bug fix) missing model_group field in logs for aspeech call types (#7392)

* fix use _update_kwargs_before_fallbacks

* test assert standard_logging_object includes model_group

* test_datadog_non_serializable_messages

* update test
This commit is contained in:
Ishaan Jaff 2024-12-27 17:00:39 -08:00 committed by GitHub
parent 79c783e83f
commit 5e8c64f128
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 24 additions and 17 deletions

View file

@ -801,9 +801,7 @@ class Router:
kwargs["stream"] = stream kwargs["stream"] = stream
kwargs["original_function"] = self._acompletion kwargs["original_function"] = self._acompletion
self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
request_priority = kwargs.get("priority") or self.default_priority request_priority = kwargs.get("priority") or self.default_priority
start_time = time.time() start_time = time.time()
if request_priority is not None and isinstance(request_priority, int): if request_priority is not None and isinstance(request_priority, int):
response = await self.schedule_acompletion(**kwargs) response = await self.schedule_acompletion(**kwargs)
@ -1422,7 +1420,7 @@ class Router:
kwargs["prompt"] = prompt kwargs["prompt"] = prompt
kwargs["original_function"] = self._aimage_generation kwargs["original_function"] = self._aimage_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model}) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
return response return response
@ -1660,13 +1658,7 @@ class Router:
messages=[{"role": "user", "content": "prompt"}], messages=[{"role": "user", "content": "prompt"}],
specific_deployment=kwargs.pop("specific_deployment", None), specific_deployment=kwargs.pop("specific_deployment", None),
) )
kwargs.setdefault("metadata", {}).update( self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
{
"deployment": deployment["litellm_params"]["model"],
"model_info": deployment.get("model_info", {}),
}
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
data["model"] data["model"]
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
@ -1777,7 +1769,7 @@ class Router:
messages = [{"role": "user", "content": "dummy-text"}] messages = [{"role": "user", "content": "dummy-text"}]
try: try:
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model}) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
# pick the one that is available (lowest TPM/RPM) # pick the one that is available (lowest TPM/RPM)
deployment = await self.async_get_available_deployment( deployment = await self.async_get_available_deployment(
@ -2215,7 +2207,7 @@ class Router:
kwargs["model"] = model kwargs["model"] = model
kwargs["original_function"] = self._acreate_file kwargs["original_function"] = self._acreate_file
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model}) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
return response return response
@ -2320,7 +2312,7 @@ class Router:
kwargs["model"] = model kwargs["model"] = model
kwargs["original_function"] = self._acreate_batch kwargs["original_function"] = self._acreate_batch
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries) kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model}) self._update_kwargs_before_fallbacks(model=model, kwargs=kwargs)
response = await self.async_function_with_fallbacks(**kwargs) response = await self.async_function_with_fallbacks(**kwargs)
return response return response

View file

@ -1,6 +1,8 @@
import sys import sys
import os import os
import json
import traceback import traceback
from typing import Optional
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import Request from fastapi import Request
from datetime import datetime from datetime import datetime
@ -9,6 +11,7 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
from litellm import Router, CustomLogger from litellm import Router, CustomLogger
from litellm.types.utils import StandardLoggingPayload
# Get the current directory of the file being run # Get the current directory of the file being run
pwd = os.path.dirname(os.path.realpath(__file__)) pwd = os.path.dirname(os.path.realpath(__file__))
@ -76,19 +79,20 @@ class MyCustomHandler(CustomLogger):
print("logging a transcript kwargs: ", kwargs) print("logging a transcript kwargs: ", kwargs)
print("openai client=", kwargs.get("client")) print("openai client=", kwargs.get("client"))
self.openai_client = kwargs.get("client") self.openai_client = kwargs.get("client")
self.standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
except Exception: except Exception:
pass pass
proxy_handler_instance = MyCustomHandler()
# Set litellm.callbacks = [proxy_handler_instance] on the proxy # Set litellm.callbacks = [proxy_handler_instance] on the proxy
# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy # need to set litellm.callbacks = [proxy_handler_instance] # on the proxy
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.flaky(retries=6, delay=10) @pytest.mark.flaky(retries=6, delay=10)
async def test_transcription_on_router(): async def test_transcription_on_router():
proxy_handler_instance = MyCustomHandler()
litellm.set_verbose = True litellm.set_verbose = True
litellm.callbacks = [proxy_handler_instance] litellm.callbacks = [proxy_handler_instance]
print("\n Testing async transcription on router\n") print("\n Testing async transcription on router\n")
@ -150,7 +154,9 @@ async def test_transcription_on_router():
@pytest.mark.parametrize("mode", ["iterator"]) # "file", @pytest.mark.parametrize("mode", ["iterator"]) # "file",
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_audio_speech_router(mode): async def test_audio_speech_router(mode):
litellm.set_verbose = True
test_logger = MyCustomHandler()
litellm.callbacks = [test_logger]
from litellm import Router from litellm import Router
client = Router( client = Router(
@ -178,10 +184,19 @@ async def test_audio_speech_router(mode):
optional_params={}, optional_params={},
) )
await asyncio.sleep(3)
from litellm.llms.openai.openai import HttpxBinaryResponseContent from litellm.llms.openai.openai import HttpxBinaryResponseContent
assert isinstance(response, HttpxBinaryResponseContent) assert isinstance(response, HttpxBinaryResponseContent)
assert test_logger.standard_logging_object is not None
print(
"standard_logging_object=",
json.dumps(test_logger.standard_logging_object, indent=4),
)
assert test_logger.standard_logging_object["model_group"] == "tts"
@pytest.mark.asyncio() @pytest.mark.asyncio()
async def test_rerank_endpoint(model_list): async def test_rerank_endpoint(model_list):