diff --git a/litellm/main.py b/litellm/main.py index 5ecfcf1db..ddeff3414 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1835,7 +1835,7 @@ def embedding( try: response = None logging = litellm_logging_obj - logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding}) + logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding, "preset_cache_key": None, "stream_response": {}}) if azure == True or custom_llm_provider == "azure": # azure configs api_type = get_secret("AZURE_API_TYPE") or "azure" diff --git a/litellm/router.py b/litellm/router.py index 1721a381b..c01c7d42e 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -313,8 +313,8 @@ class Router: **kwargs) -> Union[List[float], None]: # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None)) - kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) - kwargs["model_info"] = deployment.get("model_info", {}) + kwargs.setdefault("model_info", {}) + kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks data = deployment["litellm_params"].copy() for k, v in self.default_litellm_params.items(): if k not in data: # prioritize model-specific params > default router params @@ -339,7 +339,7 @@ class Router: **kwargs) -> Union[List[float], None]: # pick the one that is available (lowest TPM/RPM) deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None)) - kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) + kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) data = deployment["litellm_params"].copy() kwargs["model_info"] = deployment.get("model_info", {}) for k, v in self.default_litellm_params.items(): diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 865ed4040..fa484dea0 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -22,8 +22,7 @@ from litellm.integrations.custom_logger import CustomLogger # Test interfaces ## 1. litellm.completion() + litellm.embeddings() -## 2. router.completion() + router.embeddings() -## 3. proxy.completions + proxy.embeddings +## refer to test_custom_callback_input_router.py for the router + proxy tests class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class """ @@ -577,4 +576,4 @@ async def test_async_embedding_bedrock(): except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") -asyncio.run(test_async_embedding_bedrock()) \ No newline at end of file +# asyncio.run(test_async_embedding_bedrock()) \ No newline at end of file diff --git a/litellm/tests/test_custom_callback_router.py b/litellm/tests/test_custom_callback_router.py new file mode 100644 index 000000000..38683bc2d --- /dev/null +++ b/litellm/tests/test_custom_callback_router.py @@ -0,0 +1,398 @@ +### What this tests #### +## This test asserts the type of data passed into each method of the custom callback handler +import sys, os, time, inspect, asyncio, traceback +from datetime import datetime +import pytest +sys.path.insert(0, os.path.abspath('../..')) +from typing import Optional, Literal, List +from litellm import Router +import litellm +from litellm.integrations.custom_logger import CustomLogger + +# Test Scenarios (test across completion, streaming, embedding) +## 1: Pre-API-Call +## 2: Post-API-Call +## 3: On LiteLLM Call success +## 4: On LiteLLM Call failure +## fallbacks +## retries + +# Test cases +## 1. Simple Azure OpenAI acompletion + streaming call +## 2. Simple Azure OpenAI aembedding call +## 3. Azure OpenAI acompletion + streaming call with retries +## 4. Azure OpenAI aembedding call with retries +## 5. Azure OpenAI acompletion + streaming call with fallbacks +## 6. Azure OpenAI aembedding call with fallbacks + +# Test interfaces +## 1. router.completion() + router.embeddings() +## 2. proxy.completions + proxy.embeddings + +class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class + """ + The set of expected inputs to a custom handler for a + """ + # Class variables or attributes + def __init__(self): + self.errors = [] + self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = [] + + def log_pre_api_call(self, model, messages, kwargs): + try: + print(f'received kwargs in pre-input: {kwargs}') + self.states.append("sync_pre_api_call") + ## MODEL + assert isinstance(model, str) + ## MESSAGES + assert isinstance(messages, list) + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + ### ROUTER-SPECIFIC KWARGS + assert isinstance(kwargs["litellm_params"]["metadata"], dict) + assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str) + assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str) + assert isinstance(kwargs["litellm_params"]["model_info"], dict) + assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str) + assert isinstance(kwargs["litellm_params"]["proxy_server_request"], Optional[str]) + assert isinstance(kwargs["litellm_params"]["preset_cache_key"], Optional[str]) + assert isinstance(kwargs["litellm_params"]["stream_response"], dict) + except Exception as e: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + try: + self.states.append("post_api_call") + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert end_time == None + ## RESPONSE OBJECT + assert response_obj == None + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + assert isinstance(kwargs['input'], (list, dict, str)) + assert isinstance(kwargs['api_key'], Optional[str]) + assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response']) + assert isinstance(kwargs['additional_args'], Optional[dict]) + assert isinstance(kwargs['log_event_type'], str) + ### ROUTER-SPECIFIC KWARGS + assert isinstance(kwargs["litellm_params"]["metadata"], dict) + assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str) + assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str) + assert isinstance(kwargs["litellm_params"]["model_info"], dict) + assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str) + assert isinstance(kwargs["litellm_params"]["proxy_server_request"], Optional[str]) + assert isinstance(kwargs["litellm_params"]["preset_cache_key"], Optional[str]) + assert isinstance(kwargs["litellm_params"]["stream_response"], dict) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): + try: + self.states.append("async_stream") + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert isinstance(end_time, datetime) + ## RESPONSE OBJECT + assert isinstance(response_obj, litellm.ModelResponse) + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) + assert isinstance(kwargs['api_key'], Optional[str]) + assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) + assert isinstance(kwargs['additional_args'], Optional[dict]) + assert isinstance(kwargs['log_event_type'], str) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + self.states.append("sync_success") + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert isinstance(end_time, datetime) + ## RESPONSE OBJECT + assert isinstance(response_obj, litellm.ModelResponse) + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) + assert isinstance(kwargs['api_key'], Optional[str]) + assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) + assert isinstance(kwargs['additional_args'], Optional[dict]) + assert isinstance(kwargs['log_event_type'], str) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + self.states.append("sync_failure") + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert isinstance(end_time, datetime) + ## RESPONSE OBJECT + assert response_obj == None + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) + assert isinstance(kwargs['api_key'], Optional[str]) + assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None + assert isinstance(kwargs['additional_args'], Optional[dict]) + assert isinstance(kwargs['log_event_type'], str) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + async def async_log_pre_api_call(self, model, messages, kwargs): + try: + """ + No-op. + Not implemented yet. + """ + pass + except Exception as e: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + self.states.append("async_success") + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert isinstance(end_time, datetime) + ## RESPONSE OBJECT + assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)) + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + assert isinstance(kwargs['input'], (list, dict, str)) + assert isinstance(kwargs['api_key'], Optional[str]) + assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) + assert isinstance(kwargs['additional_args'], Optional[dict]) + assert isinstance(kwargs['log_event_type'], str) + ### ROUTER-SPECIFIC KWARGS + assert isinstance(kwargs["litellm_params"]["metadata"], dict) + assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str) + assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str) + assert isinstance(kwargs["litellm_params"]["model_info"], dict) + assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str) + assert isinstance(kwargs["litellm_params"]["proxy_server_request"], Optional[str]) + assert isinstance(kwargs["litellm_params"]["preset_cache_key"], Optional[str]) + assert isinstance(kwargs["litellm_params"]["stream_response"], dict) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + print(f"received original response: {kwargs['original_response']}") + self.states.append("async_failure") + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert isinstance(end_time, datetime) + ## RESPONSE OBJECT + assert response_obj == None + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + assert isinstance(kwargs['input'], (list, str, dict)) + assert isinstance(kwargs['api_key'], Optional[str]) + assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None + assert isinstance(kwargs['additional_args'], Optional[dict]) + assert isinstance(kwargs['log_event_type'], str) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + +# Simple Azure OpenAI call +## COMPLETION +@pytest.mark.asyncio +async def test_async_chat_azure(): + try: + customHandler = CompletionCustomHandler() + customHandler_streaming = CompletionCustomHandler() + customHandler_failure = CompletionCustomHandler() + litellm.callbacks = [customHandler] + model_list = [ + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE") + }, + "tpm": 240000, + "rpm": 1800 + }, + ] + router = Router(model_list=model_list) # type: ignore + response = await router.acompletion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm openai" + }]) + await asyncio.sleep(2) + assert len(customHandler.errors) == 0 + assert len(customHandler.states) == 3 # pre, post, success + # streaming + litellm.callbacks = [customHandler_streaming] + router2 = Router(model_list=model_list) # type: ignore + response = await router2.acompletion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm openai" + }], + stream=True) + async for chunk in response: + continue + await asyncio.sleep(1) + print(f"customHandler.states: {customHandler_streaming.states}") + assert len(customHandler_streaming.errors) == 0 + assert len(customHandler_streaming.states) >= 4 # pre, post, stream (multiple times), success + # failure + model_list = [ + { + "model_name": "gpt-3.5-turbo", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/chatgpt-v-2", + "api_key": "my-bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE") + }, + "tpm": 240000, + "rpm": 1800 + }, + ] + litellm.callbacks = [customHandler_failure] + router3 = Router(model_list=model_list) # type: ignore + try: + response = await router3.acompletion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm openai" + }]) + print(f"response in router3 acompletion: {response}") + except: + pass + await asyncio.sleep(1) + print(f"customHandler.states: {customHandler_failure.states}") + assert len(customHandler_failure.errors) == 0 + assert len(customHandler_failure.states) == 3 # pre, post, failure + assert "async_failure" in customHandler_failure.states + except Exception as e: + print(f"Assertion Error: {traceback.format_exc()}") + pytest.fail(f"An exception occurred - {str(e)}") +# asyncio.run(test_async_chat_azure()) +## EMBEDDING +async def test_async_embedding_azure(): + try: + customHandler = CompletionCustomHandler() + customHandler_failure = CompletionCustomHandler() + litellm.callbacks = [customHandler] + model_list = [ + { + "model_name": "azure-embedding-model", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/azure-embedding-model", + "api_key": os.getenv("AZURE_API_KEY"), + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE") + }, + "tpm": 240000, + "rpm": 1800 + }, + ] + router = Router(model_list=model_list) # type: ignore + response = await router.aembedding(model="azure-embedding-model", + input=["hello from litellm!"]) + await asyncio.sleep(2) + assert len(customHandler.errors) == 0 + assert len(customHandler.states) == 3 # pre, post, success + # failure + model_list = [ + { + "model_name": "azure-embedding-model", # openai model name + "litellm_params": { # params for litellm completion/embedding call + "model": "azure/azure-embedding-model", + "api_key": "my-bad-key", + "api_version": os.getenv("AZURE_API_VERSION"), + "api_base": os.getenv("AZURE_API_BASE") + }, + "tpm": 240000, + "rpm": 1800 + }, + ] + litellm.callbacks = [customHandler_failure] + router3 = Router(model_list=model_list) # type: ignore + try: + response = await router3.aembedding(model="azure-embedding-model", + input=["hello from litellm!"]) + print(f"response in router3 aembedding: {response}") + except: + pass + await asyncio.sleep(1) + print(f"customHandler.states: {customHandler_failure.states}") + assert len(customHandler_failure.errors) == 0 + assert len(customHandler_failure.states) == 3 # pre, post, failure + assert "async_failure" in customHandler_failure.states + except Exception as e: + print(f"Assertion Error: {traceback.format_exc()}") + pytest.fail(f"An exception occurred - {str(e)}") +asyncio.run(test_async_embedding_azure()) +# Azure OpenAI call w/ Retries +## COMPLETION +## EMBEDDING + +# Azure OpenAI call w/ fallbacks +## COMPLETION +## EMBEDDING