make max budget work for openai streaming

This commit is contained in:
Krrish Dholakia 2023-09-14 16:22:49 -07:00
parent 519f29a4b8
commit f7e92bb0db
6 changed files with 38 additions and 16 deletions

View file

@ -180,6 +180,10 @@ class Logging:
# Log the exact input to the LLM API
print_verbose(f"Logging Details Pre-API Call for call id {self.litellm_call_id}")
try:
if start_time is None:
start_time = self.start_time
if end_time is None:
end_time = datetime.datetime.now()
# print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}")
self.model_call_details["input"] = input
self.model_call_details["api_key"] = api_key
@ -202,6 +206,11 @@ class Logging:
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
)
if litellm.max_budget and self.stream:
time_diff = (end_time - start_time).total_seconds()
float_diff = float(time_diff)
litellm._current_cost += litellm.completion_cost(model=self.model, prompt="".join(message["content"] for message in self.messages), completion="", total_time=float_diff)
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
for callback in litellm.input_callback:
try:
@ -314,6 +323,12 @@ class Logging:
if end_time is None:
end_time = datetime.datetime.now()
print_verbose(f"success callbacks: {litellm.success_callback}")
if litellm.max_budget and self.stream:
time_diff = (end_time - start_time).total_seconds()
float_diff = float(time_diff)
litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff)
for callback in litellm.success_callback:
try:
if callback == "lite_debugger":
@ -574,10 +589,6 @@ def client(original_function):
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
# [OPTIONAL] UPDATE BUDGET
if litellm.max_budget:
litellm._current_cost += litellm.completion_cost(completion_response=result)
# [OPTIONAL] Return LiteLLM call_id
if litellm.use_client == True:
result['litellm_call_id'] = litellm_call_id
@ -2383,7 +2394,6 @@ class CustomStreamWrapper:
def handle_cohere_chunk(self, chunk):
chunk = chunk.decode("utf-8")
print(f"cohere chunk: {chunk}")
data_json = json.loads(chunk)
try:
print(f"data json: {data_json}")
@ -2474,7 +2484,8 @@ class CustomStreamWrapper:
completion_obj["content"] = self.handle_cohere_chunk(chunk)
else: # openai chat/azure models
chunk = next(self.completion_stream)
return chunk # open ai returns finish_reason, we should just return the openai chunk
completion_obj["content"] = chunk["choices"][0]["delta"]["content"]
# return chunk # open ai returns finish_reason, we should just return the openai chunk
#completion_obj["content"] = self.handle_openai_chat_completion_chunk(chunk)
# LOGGING