diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index f760e9fcef..2677d12c57 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -196,8 +196,19 @@ class AzureChatCompletion(BaseLLM): else: azure_client = client response = azure_client.chat.completions.create(**data) # type: ignore - response.model = "azure/" + str(response.model) - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) + stringified_response = response.model_dump_json() + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=stringified_response, + additional_args={ + "headers": headers, + "api_version": api_version, + "api_base": api_base, + }, + ) + return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response) except AzureOpenAIError as e: exception_mapping_worked = True raise e diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 29934d1304..9c16708e20 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -219,13 +219,14 @@ class OpenAIChatCompletion(BaseLLM): else: openai_client = client response = openai_client.chat.completions.create(**data) # type: ignore + stringified_response = response.model_dump_json() logging_obj.post_call( - input=None, + input=messages, api_key=api_key, - original_response=response, + original_response=stringified_response, additional_args={"complete_input_dict": data}, ) - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) + return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response) except Exception as e: if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e): # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility diff --git a/litellm/main.py b/litellm/main.py index d4b7991162..e92e5e7463 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -319,7 +319,6 @@ def completion( ######### unpacking kwargs ##################### args = locals() api_base = kwargs.get('api_base', None) - return_async = kwargs.get('return_async', False) mock_response = kwargs.get('mock_response', None) force_timeout= kwargs.get('force_timeout', 600) ## deprecated logger_fn = kwargs.get('logger_fn', None) @@ -351,7 +350,7 @@ def completion( client = kwargs.get("client", None) ######## end of unpacking kwargs ########### openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"] - litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"] + litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"] default_params = openai_params + litellm_params non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider if mock_response: @@ -449,7 +448,6 @@ def completion( # For logging - save the values of the litellm-specific params passed in litellm_params = get_litellm_params( acompletion=acompletion, - return_async=return_async, api_key=api_key, force_timeout=force_timeout, logger_fn=logger_fn, @@ -526,17 +524,18 @@ def completion( client=client # pass AsyncAzureOpenAI, AzureOpenAI client ) - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_version": api_version, - "api_base": api_base, - }, - ) + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={ + "headers": headers, + "api_version": api_version, + "api_base": api_base, + }, + ) elif ( model in litellm.open_ai_chat_completion_models or custom_llm_provider == "custom_openai" @@ -606,13 +605,14 @@ def completion( ) raise e - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={"headers": headers}, - ) + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={"headers": headers}, + ) elif ( custom_llm_provider == "text-completion-openai" or "ft:babbage-002" in model @@ -1787,7 +1787,7 @@ def embedding( proxy_server_request = kwargs.get("proxy_server_request", None) aembedding = kwargs.pop("aembedding", None) openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"] - litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"] + litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"] default_params = openai_params + litellm_params non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider optional_params = {} diff --git a/litellm/proxy/custom_callbacks.py b/litellm/proxy/custom_callbacks.py index c049163441..c30368ebb4 100644 --- a/litellm/proxy/custom_callbacks.py +++ b/litellm/proxy/custom_callbacks.py @@ -36,9 +36,12 @@ class MyCustomHandler(CustomLogger): def log_success_event(self, kwargs, response_obj, start_time, end_time): print_verbose("On Success!") + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): print_verbose(f"On Async Success!") + response_cost = litellm.completion_cost(completion_response=response_obj) + assert response_cost > 0.0 return async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6bab8dc8d8..2368994cbc 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -262,41 +262,43 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid: raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys") - if prisma_client: - ## check for cache hit (In-Memory Cache) - valid_token = user_api_key_cache.get_cache(key=api_key) - print(f"valid_token from cache: {valid_token}") - if valid_token is None: - ## check db - valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow()) - user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) - elif valid_token is not None: - print(f"API Key Cache Hit!") - if valid_token: - litellm.model_alias_map = valid_token.aliases - config = valid_token.config - if config != {}: - model_list = config.get("model_list", []) - llm_model_list = model_list - print("\n new llm router model list", llm_model_list) - if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called - api_key = valid_token.token - valid_token_dict = _get_pydantic_json_dict(valid_token) - valid_token_dict.pop("token", None) - return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) - else: - data = await request.json() - model = data.get("model", None) - if model in litellm.model_alias_map: - model = litellm.model_alias_map[model] - if model and model not in valid_token.models: - raise Exception(f"Token not allowed to access model") + if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error + raise Exception("No connected db.") + + ## check for cache hit (In-Memory Cache) + valid_token = user_api_key_cache.get_cache(key=api_key) + print(f"valid_token from cache: {valid_token}") + if valid_token is None: + ## check db + valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow()) + user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) + elif valid_token is not None: + print(f"API Key Cache Hit!") + if valid_token: + litellm.model_alias_map = valid_token.aliases + config = valid_token.config + if config != {}: + model_list = config.get("model_list", []) + llm_model_list = model_list + print("\n new llm router model list", llm_model_list) + if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called api_key = valid_token.token valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict.pop("token", None) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) else: - raise Exception(f"Invalid token") + data = await request.json() + model = data.get("model", None) + if model in litellm.model_alias_map: + model = litellm.model_alias_map[model] + if model and model not in valid_token.models: + raise Exception(f"Token not allowed to access model") + api_key = valid_token.token + valid_token_dict = _get_pydantic_json_dict(valid_token) + valid_token_dict.pop("token", None) + return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) + else: + raise Exception(f"Invalid token") except Exception as e: print(f"An exception occurred - {traceback.format_exc()}") if isinstance(e, HTTPException): @@ -380,25 +382,14 @@ async def track_cost_callback( if "complete_streaming_response" in kwargs: # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost completion_response=kwargs["complete_streaming_response"] - input_text = kwargs["messages"] - output_text = completion_response["choices"][0]["message"]["content"] - response_cost = litellm.completion_cost( - model = kwargs["model"], - messages = input_text, - completion=output_text - ) + response_cost = litellm.completion_cost(completion_response=completion_response) print("streaming response_cost", response_cost) user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}") if user_api_key and prisma_client: await update_prisma_database(token=user_api_key, response_cost=response_cost) elif kwargs["stream"] == False: # for non streaming responses - input_text = kwargs.get("messages", "") - print(f"type of input_text: {type(input_text)}") - if isinstance(input_text, list): - response_cost = litellm.completion_cost(completion_response=completion_response, messages=input_text) - elif isinstance(input_text, str): - response_cost = litellm.completion_cost(completion_response=completion_response, prompt=input_text) + response_cost = litellm.completion_cost(completion_response=completion_response) print(f"received completion response: {completion_response}") print(f"regular response_cost: {response_cost}") diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index e972eff4df..00797691a3 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -104,7 +104,7 @@ class ProxyLogging: 2. /embeddings """ # check if max parallel requests set - if user_api_key_dict.max_parallel_requests is not None: + if user_api_key_dict is not None and user_api_key_dict.max_parallel_requests is not None: ## decrement call count if call failed if (hasattr(original_exception, "status_code") and original_exception.status_code == 429 diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py new file mode 100644 index 0000000000..428d2e4d88 --- /dev/null +++ b/litellm/tests/test_custom_callback_input.py @@ -0,0 +1,386 @@ +### What this tests #### +## This test asserts the type of data passed into each method of the custom callback handler +import sys, os, time, inspect, asyncio, traceback +from datetime import datetime +import pytest +sys.path.insert(0, os.path.abspath('../..')) +from typing import Optional +from litellm import completion, embedding +import litellm +from litellm.integrations.custom_logger import CustomLogger + +# Test Scenarios (test across completion, streaming, embedding) +## 1: Pre-API-Call +## 2: Post-API-Call +## 3: On LiteLLM Call success +## 4: On LiteLLM Call failure + +# Test models +## 1. OpenAI +## 2. Azure OpenAI +## 3. Non-OpenAI/Azure - e.g. Bedrock + +# Test interfaces +## 1. litellm.completion() + litellm.embeddings() +## 2. router.completion() + router.embeddings() +## 3. proxy.completions + proxy.embeddings + +class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class + """ + The set of expected inputs to a custom handler for a + """ + # Class variables or attributes + def __init__(self): + self.errors = [] + + def log_pre_api_call(self, model, messages, kwargs): + try: + ## MODEL + assert isinstance(model, str) + ## MESSAGES + assert isinstance(messages, list) and isinstance(messages[0], dict) + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + except Exception as e: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + def log_post_api_call(self, kwargs, response_obj, start_time, end_time): + try: + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert end_time == None + ## RESPONSE OBJECT + assert response_obj == None + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict) + assert isinstance(kwargs['api_key'], str) + assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response']) + assert isinstance(kwargs['additional_args'], Optional[dict]) + assert isinstance(kwargs['log_event_type'], str) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): + try: + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert isinstance(end_time, datetime) + ## RESPONSE OBJECT + assert isinstance(response_obj, litellm.ModelResponse) + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict) + assert isinstance(kwargs['api_key'], str) + assert inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) + assert isinstance(kwargs['additional_args'], Optional[dict]) + assert isinstance(kwargs['log_event_type'], str) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert isinstance(end_time, datetime) + ## RESPONSE OBJECT + assert isinstance(response_obj, litellm.ModelResponse) + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict) + assert isinstance(kwargs['api_key'], str) + assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) + assert isinstance(kwargs['additional_args'], Optional[dict]) + assert isinstance(kwargs['log_event_type'], str) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert isinstance(end_time, datetime) + ## RESPONSE OBJECT + assert response_obj == None + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict) + assert isinstance(kwargs['api_key'], str) + assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None + assert isinstance(kwargs['additional_args'], Optional[dict]) + assert isinstance(kwargs['log_event_type'], str) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + async def async_log_pre_api_call(self, model, messages, kwargs): + try: + ## MODEL + assert isinstance(model, str) + ## MESSAGES + assert isinstance(messages, list) and isinstance(messages[0], dict) + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + except Exception as e: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + try: + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert isinstance(end_time, datetime) + ## RESPONSE OBJECT + assert isinstance(response_obj, litellm.ModelResponse) + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict) + assert isinstance(kwargs['api_key'], str) + assert isinstance(kwargs['original_response'], str) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) + assert isinstance(kwargs['additional_args'], Optional[dict]) + assert isinstance(kwargs['log_event_type'], str) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + try: + ## START TIME + assert isinstance(start_time, datetime) + ## END TIME + assert isinstance(end_time, datetime) + ## RESPONSE OBJECT + assert response_obj == None + ## KWARGS + assert isinstance(kwargs['model'], str) + assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['optional_params'], dict) + assert isinstance(kwargs['litellm_params'], dict) + assert isinstance(kwargs['start_time'], Optional[datetime]) + assert isinstance(kwargs['stream'], bool) + assert isinstance(kwargs['user'], Optional[str]) + assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict) + assert isinstance(kwargs['api_key'], str) + assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) + assert isinstance(kwargs['additional_args'], Optional[dict]) + assert isinstance(kwargs['log_event_type'], str) + except: + print(f"Assertion Error: {traceback.format_exc()}") + self.errors.append(traceback.format_exc()) + +## Test OpenAI + sync +def test_chat_openai_stream(): + try: + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + response = litellm.completion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm sync openai" + }]) + ## test streaming + response = litellm.completion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm openai" + }], + stream=True) + for chunk in response: + continue + ## test failure callback + try: + response = litellm.completion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm openai" + }], + api_key="my-bad-key", + stream=True) + for chunk in response: + continue + except: + pass + time.sleep(1) + print(f"customHandler.errors: {customHandler.errors}") + assert len(customHandler.errors) == 0 + litellm.callbacks = [] + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") + +# test_chat_openai_stream() + +## Test OpenAI + Async +@pytest.mark.asyncio +async def test_async_chat_openai_stream(): + try: + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + response = await litellm.acompletion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm openai" + }]) + ## test streaming + response = await litellm.acompletion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm openai" + }], + stream=True) + async for chunk in response: + continue + ## test failure callback + try: + response = await litellm.acompletion(model="gpt-3.5-turbo", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm openai" + }], + api_key="my-bad-key", + stream=True) + async for chunk in response: + continue + except: + pass + time.sleep(1) + print(f"customHandler.errors: {customHandler.errors}") + assert len(customHandler.errors) == 0 + litellm.callbacks = [] + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") + +# asyncio.run(test_async_chat_openai_stream()) + +## Test Azure + sync +def test_chat_azure_stream(): + try: + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + response = litellm.completion(model="azure/chatgpt-v-2", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm sync azure" + }]) + # test streaming + response = litellm.completion(model="azure/chatgpt-v-2", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm sync azure" + }], + stream=True) + for chunk in response: + continue + # test failure callback + try: + response = litellm.completion(model="azure/chatgpt-v-2", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm sync azure" + }], + api_key="my-bad-key", + stream=True) + for chunk in response: + continue + except: + pass + time.sleep(1) + print(f"customHandler.errors: {customHandler.errors}") + assert len(customHandler.errors) == 0 + litellm.callbacks = [] + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") + +# test_chat_azure_stream() + +## Test OpenAI + Async +@pytest.mark.asyncio +async def test_async_chat_azure_stream(): + try: + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + response = await litellm.acompletion(model="azure/chatgpt-v-2", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm async azure" + }]) + ## test streaming + response = await litellm.acompletion(model="azure/chatgpt-v-2", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm async azure" + }], + stream=True) + async for chunk in response: + continue + ## test failure callback + try: + response = await litellm.acompletion(model="azure/chatgpt-v-2", + messages=[{ + "role": "user", + "content": "Hi 👋 - i'm async azure" + }], + api_key="my-bad-key", + stream=True) + async for chunk in response: + continue + except: + pass + time.sleep(1) + print(f"customHandler.errors: {customHandler.errors}") + assert len(customHandler.errors) == 0 + litellm.callbacks = [] + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") + +# asyncio.run(test_async_chat_azure_stream()) \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 5494b00c7a..0c48c83a87 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -801,9 +801,6 @@ class Logging: end_time = datetime.datetime.now() self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["end_time"] = end_time - - if isinstance(result, OpenAIObject): - result = result.model_dump() if litellm.max_budget and self.stream: time_diff = (end_time - start_time).total_seconds() @@ -857,9 +854,6 @@ class Logging: call_type = self.call_type, stream = self.stream, ) - if callback == "api_manager": - print_verbose("reaches api manager for updating model cost") - litellm.apiManager.update_cost(completion_obj=result, user=self.user) if callback == "promptlayer": print_verbose("reaches promptlayer for logging!") promptLayerLogger.log_event( @@ -994,7 +988,7 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) - if isinstance(callback, CustomLogger): # custom logger class + if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks print_verbose(f"success callbacks: Running Custom Logger Class") if self.stream and complete_streaming_response is None: callback.log_stream_event( @@ -1044,7 +1038,6 @@ class Logging: Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ print_verbose(f"Async success callbacks: {litellm._async_success_callback}") - ## BUILD COMPLETE STREAMED RESPONSE complete_streaming_response = None if self.stream: @@ -1081,6 +1074,13 @@ class Logging: start_time=start_time, end_time=end_time, ) + else: + await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function + kwargs=self.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time + ) else: await callback.async_log_success_event( kwargs=self.model_call_details, @@ -1103,24 +1103,29 @@ class Logging: f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" ) + def _failure_handler_helper_fn(self, exception, traceback_exception, start_time=None, end_time=None): + if start_time is None: + start_time = self.start_time + if end_time is None: + end_time = datetime.datetime.now() + + # on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions + if not hasattr(self, "model_call_details"): + self.model_call_details = {} + + self.model_call_details["log_event_type"] = "failed_api_call" + self.model_call_details["exception"] = exception + self.model_call_details["traceback_exception"] = traceback_exception + self.model_call_details["end_time"] = end_time + self.model_call_details.setdefault("original_response", None) + return start_time, end_time + def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None): print_verbose( f"Logging Details LiteLLM-Failure Call" ) try: - if start_time is None: - start_time = self.start_time - if end_time is None: - end_time = datetime.datetime.now() - - # on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions - if not hasattr(self, "model_call_details"): - self.model_call_details = {} - - self.model_call_details["log_event_type"] = "failed_api_call" - self.model_call_details["exception"] = exception - self.model_call_details["traceback_exception"] = traceback_exception - self.model_call_details["end_time"] = end_time + start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time) result = None # result sent to all loggers, init this to None incase it's not created for callback in litellm.failure_callback: try: @@ -1212,16 +1217,8 @@ class Logging: """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ - # on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions - if not hasattr(self, "model_call_details"): - self.model_call_details = {} - - self.model_call_details["log_event_type"] = "failed_api_call" - self.model_call_details["exception"] = exception - self.model_call_details["traceback_exception"] = traceback_exception - self.model_call_details["end_time"] = end_time - result = {} # result sent to all loggers, init this to None incase it's not created - + start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time) + result = None # result sent to all loggers, init this to None incase it's not created for callback in litellm._async_failure_callback: try: if isinstance(callback, CustomLogger): # custom logger class @@ -2060,7 +2057,6 @@ def register_model(model_cost: Union[str, dict]): return model_cost def get_litellm_params( - return_async=False, api_key=None, force_timeout=600, azure=False, @@ -2082,7 +2078,6 @@ def get_litellm_params( ): litellm_params = { "acompletion": acompletion, - "return_async": return_async, "api_key": api_key, "force_timeout": force_timeout, "logger_fn": logger_fn, @@ -5094,9 +5089,6 @@ class CustomStreamWrapper: self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "", ""] self.holding_chunk = "" self.complete_response = "" - if self.logging_obj: - # Log the type of the received item - self.logging_obj.post_call(str(type(completion_stream))) def __iter__(self): return self @@ -5121,10 +5113,6 @@ class CustomStreamWrapper: except Exception as e: raise e - def logging(self, text): - if self.logging_obj: - self.logging_obj.post_call(text) - def check_special_tokens(self, chunk: str, finish_reason: Optional[str]): hold = False if finish_reason: @@ -5638,16 +5626,12 @@ class CustomStreamWrapper: completion_obj["role"] = "assistant" self.sent_first_chunk = True model_response.choices[0].delta = Delta(**completion_obj) - # LOGGING - threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() print_verbose(f"model_response: {model_response}") return model_response else: return elif model_response.choices[0].finish_reason: model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai - # LOGGING - threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() return model_response elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints # enter this branch when no content has been passed in response @@ -5668,8 +5652,6 @@ class CustomStreamWrapper: if self.sent_first_chunk == False: model_response.choices[0].delta["role"] = "assistant" self.sent_first_chunk = True - # LOGGING - threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() # log response return model_response else: return @@ -5678,8 +5660,6 @@ class CustomStreamWrapper: except Exception as e: traceback_exception = traceback.format_exc() e.message = str(e) - # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated - threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start() raise exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e) ## needs to handle the empty string case (even starting chunk can be an empty string) @@ -5692,12 +5672,17 @@ class CustomStreamWrapper: chunk = next(self.completion_stream) if chunk is not None and chunk != b'': response = self.chunk_creator(chunk=chunk) - if response is not None: - return response + if response is None: + continue + ## LOGGING + threading.Thread(target=self.logging_obj.success_handler, args=(response,)).start() # log response + return response except StopIteration: raise # Re-raise StopIteration except Exception as e: - # Handle other exceptions if needed + traceback_exception = traceback.format_exc() + # LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated + threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start() raise e @@ -5728,7 +5713,9 @@ class CustomStreamWrapper: asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,)) return processed_chunk except Exception as e: + traceback_exception = traceback.format_exc() # Handle any exceptions that might occur during streaming + asyncio.create_task(self.logging_obj.async_failure_handler(e, traceback_exception)) raise StopAsyncIteration class TextCompletionStreamWrapper: