feat(utils.py): enable post call rules for streaming

This commit is contained in:
Krrish Dholakia 2024-02-12 22:08:04 -08:00
parent f3a950705c
commit 7600c8f41d
3 changed files with 76 additions and 4 deletions

View file

@ -7692,6 +7692,7 @@ class CustomStreamWrapper:
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
self.holding_chunk = ""
self.complete_response = ""
self.response_uptil_now = ""
_model_info = (
self.logging_obj.model_call_details.get("litellm_params", {}).get(
"model_info", {}
@ -7703,6 +7704,7 @@ class CustomStreamWrapper:
} # returned as x-litellm-model-id response header in proxy
self.response_id = None
self.logging_loop = None
self.rules = Rules()
def __iter__(self):
return self
@ -8659,7 +8661,7 @@ class CustomStreamWrapper:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
response = self.chunk_creator(chunk=chunk)
response: Optional[ModelResponse] = self.chunk_creator(chunk=chunk)
print_verbose(f"PROCESSED CHUNK POST CHUNK CREATOR: {response}")
if response is None:
continue
@ -8667,7 +8669,12 @@ class CustomStreamWrapper:
threading.Thread(
target=self.run_success_logging_in_thread, args=(response,)
).start() # log response
self.response_uptil_now += (
response.choices[0].delta.get("content", "") or ""
)
self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model
)
# RETURN RESULT
return response
except StopIteration:
@ -8703,7 +8710,9 @@ class CustomStreamWrapper:
# chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks.
# __anext__ also calls async_success_handler, which does logging
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk = self.chunk_creator(chunk=chunk)
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
chunk=chunk
)
print_verbose(
f"PROCESSED ASYNC CHUNK POST CHUNK CREATOR: {processed_chunk}"
)
@ -8718,6 +8727,12 @@ class CustomStreamWrapper:
processed_chunk,
)
)
self.response_uptil_now += (
processed_chunk.choices[0].delta.get("content", "") or ""
)
self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model
)
return processed_chunk
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls
@ -8731,7 +8746,9 @@ class CustomStreamWrapper:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk = self.chunk_creator(chunk=chunk)
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
chunk=chunk
)
print_verbose(
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
)
@ -8748,6 +8765,12 @@ class CustomStreamWrapper:
)
)
self.response_uptil_now += processed_chunk.choices[0].delta.get(
"content", ""
)
self.rules.post_call_rules(
input=self.response_uptil_now, model=self.model
)
# RETURN RESULT
return processed_chunk
except StopAsyncIteration: