mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
feat(utils.py): enable post call rules for streaming
This commit is contained in:
parent
1b39362e49
commit
fd6f64a4ae
3 changed files with 76 additions and 4 deletions
|
@ -63,6 +63,22 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def async_post_call_streaming_hook(
|
||||||
|
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns streaming chunk before their returned to user
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_post_call_success_hook(
|
||||||
|
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns llm response before it's returned to user
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
|
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
|
||||||
|
|
||||||
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
||||||
|
|
|
@ -58,6 +58,18 @@ def my_post_call_rule(input: str):
|
||||||
return {"decision": True}
|
return {"decision": True}
|
||||||
|
|
||||||
|
|
||||||
|
def my_post_call_rule_2(input: str):
|
||||||
|
input = input.lower()
|
||||||
|
print(f"input: {input}")
|
||||||
|
print(f"INSIDE MY POST CALL RULE, len(input) - {len(input)}")
|
||||||
|
if len(input) < 200 and len(input) > 0:
|
||||||
|
return {
|
||||||
|
"decision": False,
|
||||||
|
"message": "This violates LiteLLM Proxy Rules. Response too short",
|
||||||
|
}
|
||||||
|
return {"decision": True}
|
||||||
|
|
||||||
|
|
||||||
# test_pre_call_rule()
|
# test_pre_call_rule()
|
||||||
# Test 2: Post-call rule
|
# Test 2: Post-call rule
|
||||||
# commenting out of ci/cd since llm's have variable output which was causing our pipeline to fail erratically.
|
# commenting out of ci/cd since llm's have variable output which was causing our pipeline to fail erratically.
|
||||||
|
@ -94,3 +106,24 @@ def test_post_call_rule():
|
||||||
|
|
||||||
|
|
||||||
# test_post_call_rule()
|
# test_post_call_rule()
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_call_rule_streaming():
|
||||||
|
try:
|
||||||
|
litellm.pre_call_rules = []
|
||||||
|
litellm.post_call_rules = [my_post_call_rule_2]
|
||||||
|
### completion
|
||||||
|
response = completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "say sorry"}],
|
||||||
|
max_tokens=2,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
for chunk in response:
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
|
pytest.fail(f"Completion call should have been failed. ")
|
||||||
|
except Exception as e:
|
||||||
|
print("Got exception", e)
|
||||||
|
print(type(e))
|
||||||
|
print(vars(e))
|
||||||
|
assert e.message == "This violates LiteLLM Proxy Rules. Response too short"
|
||||||
|
|
|
@ -7692,6 +7692,7 @@ class CustomStreamWrapper:
|
||||||
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
||||||
self.holding_chunk = ""
|
self.holding_chunk = ""
|
||||||
self.complete_response = ""
|
self.complete_response = ""
|
||||||
|
self.response_uptil_now = ""
|
||||||
_model_info = (
|
_model_info = (
|
||||||
self.logging_obj.model_call_details.get("litellm_params", {}).get(
|
self.logging_obj.model_call_details.get("litellm_params", {}).get(
|
||||||
"model_info", {}
|
"model_info", {}
|
||||||
|
@ -7703,6 +7704,7 @@ class CustomStreamWrapper:
|
||||||
} # returned as x-litellm-model-id response header in proxy
|
} # returned as x-litellm-model-id response header in proxy
|
||||||
self.response_id = None
|
self.response_id = None
|
||||||
self.logging_loop = None
|
self.logging_loop = None
|
||||||
|
self.rules = Rules()
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -8659,7 +8661,7 @@ class CustomStreamWrapper:
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
if chunk is not None and chunk != b"":
|
if chunk is not None and chunk != b"":
|
||||||
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||||
response = self.chunk_creator(chunk=chunk)
|
response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk)
|
||||||
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")
|
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")
|
||||||
if response is None:
|
if response is None:
|
||||||
continue
|
continue
|
||||||
|
@ -8667,7 +8669,12 @@ class CustomStreamWrapper:
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.run_success_logging_in_thread, args=(response,)
|
target=self.run_success_logging_in_thread, args=(response,)
|
||||||
).start() # log response
|
).start() # log response
|
||||||
|
self.response_uptil_now += (
|
||||||
|
response.choices[0].delta.get("content", "") or ""
|
||||||
|
)
|
||||||
|
self.rules.post_call_rules(
|
||||||
|
input=self.response_uptil_now, model=self.model
|
||||||
|
)
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
return response
|
return response
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
|
@ -8703,7 +8710,9 @@ class CustomStreamWrapper:
|
||||||
# chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks.
|
# chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks.
|
||||||
# __anext__ also calls async_success_handler, which does logging
|
# __anext__ also calls async_success_handler, which does logging
|
||||||
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
|
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||||
processed_chunk = self.chunk_creator(chunk=chunk)
|
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
|
||||||
|
chunk=chunk
|
||||||
|
)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"PROCESSED ASYNC CHUNK POST CHUNK CREATOR: {processed_chunk}"
|
f"PROCESSED ASYNC CHUNK POST CHUNK CREATOR: {processed_chunk}"
|
||||||
)
|
)
|
||||||
|
@ -8718,6 +8727,12 @@ class CustomStreamWrapper:
|
||||||
processed_chunk,
|
processed_chunk,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.response_uptil_now += (
|
||||||
|
processed_chunk.choices[0].delta.get("content", "") or ""
|
||||||
|
)
|
||||||
|
self.rules.post_call_rules(
|
||||||
|
input=self.response_uptil_now, model=self.model
|
||||||
|
)
|
||||||
return processed_chunk
|
return processed_chunk
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
else: # temporary patch for non-aiohttp async calls
|
else: # temporary patch for non-aiohttp async calls
|
||||||
|
@ -8731,7 +8746,9 @@ class CustomStreamWrapper:
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
if chunk is not None and chunk != b"":
|
if chunk is not None and chunk != b"":
|
||||||
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||||
processed_chunk = self.chunk_creator(chunk=chunk)
|
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
|
||||||
|
chunk=chunk
|
||||||
|
)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
|
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
|
||||||
)
|
)
|
||||||
|
@ -8748,6 +8765,12 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.response_uptil_now += processed_chunk.choices[0].delta.get(
|
||||||
|
"content", ""
|
||||||
|
)
|
||||||
|
self.rules.post_call_rules(
|
||||||
|
input=self.response_uptil_now, model=self.model
|
||||||
|
)
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
return processed_chunk
|
return processed_chunk
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue