diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 316e48aed..3bd0f40e6 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -63,6 +63,22 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac ): pass + async def async_post_call_streaming_hook( + self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth + ): + """ + Returns streaming chunk before their returned to user + """ + pass + + async def async_post_call_success_hook( + self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth + ): + """ + Returns llm response before it's returned to user + """ + pass + #### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function def log_input_event(self, model, messages, kwargs, print_verbose, callback_func): diff --git a/litellm/tests/test_rules.py b/litellm/tests/test_rules.py index 0c1b573b8..7e7a984a9 100644 --- a/litellm/tests/test_rules.py +++ b/litellm/tests/test_rules.py @@ -58,6 +58,18 @@ def my_post_call_rule(input: str): return {"decision": True} +def my_post_call_rule_2(input: str): + input = input.lower() + print(f"input: {input}") + print(f"INSIDE MY POST CALL RULE, len(input) - {len(input)}") + if len(input) < 200 and len(input) > 0: + return { + "decision": False, + "message": "This violates LiteLLM Proxy Rules. Response too short", + } + return {"decision": True} + + # test_pre_call_rule() # Test 2: Post-call rule # commenting out of ci/cd since llm's have variable output which was causing our pipeline to fail erratically. @@ -94,3 +106,24 @@ def test_post_call_rule(): # test_post_call_rule() + + +def test_post_call_rule_streaming(): + try: + litellm.pre_call_rules = [] + litellm.post_call_rules = [my_post_call_rule_2] + ### completion + response = completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "say sorry"}], + max_tokens=2, + stream=True, + ) + for chunk in response: + print(f"chunk: {chunk}") + pytest.fail(f"Completion call should have been failed. ") + except Exception as e: + print("Got exception", e) + print(type(e)) + print(vars(e)) + assert e.message == "This violates LiteLLM Proxy Rules. Response too short" diff --git a/litellm/utils.py b/litellm/utils.py index ac26e0864..916f0c793 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7692,6 +7692,7 @@ class CustomStreamWrapper: self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "", ""] self.holding_chunk = "" self.complete_response = "" + self.response_uptil_now = "" _model_info = ( self.logging_obj.model_call_details.get("litellm_params", {}).get( "model_info", {} @@ -7703,6 +7704,7 @@ class CustomStreamWrapper: } # returned as x-litellm-model-id response header in proxy self.response_id = None self.logging_loop = None + self.rules = Rules() def __iter__(self): return self @@ -8659,7 +8661,7 @@ class CustomStreamWrapper: chunk = next(self.completion_stream) if chunk is not None and chunk != b"": print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") - response = self.chunk_creator(chunk=chunk) + response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk) print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}") if response is None: continue @@ -8667,7 +8669,12 @@ class CustomStreamWrapper: threading.Thread( target=self.run_success_logging_in_thread, args=(response,) ).start() # log response - + self.response_uptil_now += ( + response.choices[0].delta.get("content", "") or "" + ) + self.rules.post_call_rules( + input=self.response_uptil_now, model=self.model + ) # RETURN RESULT return response except StopIteration: @@ -8703,7 +8710,9 @@ class CustomStreamWrapper: # chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks. # __anext__ also calls async_success_handler, which does logging print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}") - processed_chunk = self.chunk_creator(chunk=chunk) + processed_chunk: Optional[ModelResponse] = self.chunk_creator( + chunk=chunk + ) print_verbose( f"PROCESSED ASYNC CHUNK POST CHUNK CREATOR: {processed_chunk}" ) @@ -8718,6 +8727,12 @@ class CustomStreamWrapper: processed_chunk, ) ) + self.response_uptil_now += ( + processed_chunk.choices[0].delta.get("content", "") or "" + ) + self.rules.post_call_rules( + input=self.response_uptil_now, model=self.model + ) return processed_chunk raise StopAsyncIteration else: # temporary patch for non-aiohttp async calls @@ -8731,7 +8746,9 @@ class CustomStreamWrapper: chunk = next(self.completion_stream) if chunk is not None and chunk != b"": print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}") - processed_chunk = self.chunk_creator(chunk=chunk) + processed_chunk: Optional[ModelResponse] = self.chunk_creator( + chunk=chunk + ) print_verbose( f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}" ) @@ -8748,6 +8765,12 @@ class CustomStreamWrapper: ) ) + self.response_uptil_now += processed_chunk.choices[0].delta.get( + "content", "" + ) + self.rules.post_call_rules( + input=self.response_uptil_now, model=self.model + ) # RETURN RESULT return processed_chunk except StopAsyncIteration: