forked from phoenix/litellm-mirror
make max budget work for openai streaming
This commit is contained in:
parent
519f29a4b8
commit
f7e92bb0db
6 changed files with 38 additions and 16 deletions
|
@ -269,6 +269,7 @@ from .exceptions import (
|
||||||
ServiceUnavailableError,
|
ServiceUnavailableError,
|
||||||
OpenAIError,
|
OpenAIError,
|
||||||
ContextWindowExceededError,
|
ContextWindowExceededError,
|
||||||
|
BudgetExceededError
|
||||||
|
|
||||||
)
|
)
|
||||||
from .budget_manager import BudgetManager
|
from .budget_manager import BudgetManager
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,6 +1,6 @@
|
||||||
#### What this tests ####
|
# #### What this tests ####
|
||||||
# This tests calling litellm.max_budget by making back-to-back gpt-4 calls
|
# # This tests calling litellm.max_budget by making back-to-back gpt-4 calls
|
||||||
# commenting out this test for circle ci, as it causes other tests to fail, since litellm.max_budget would impact other litellm imports
|
# # commenting out this test for circle ci, as it causes other tests to fail, since litellm.max_budget would impact other litellm imports
|
||||||
# import sys, os, json
|
# import sys, os, json
|
||||||
# import traceback
|
# import traceback
|
||||||
# import pytest
|
# import pytest
|
||||||
|
@ -9,13 +9,23 @@
|
||||||
# 0, os.path.abspath("../..")
|
# 0, os.path.abspath("../..")
|
||||||
# ) # Adds the parent directory to the system path
|
# ) # Adds the parent directory to the system path
|
||||||
# import litellm
|
# import litellm
|
||||||
# litellm.set_verbose = True
|
# # litellm.set_verbose = True
|
||||||
# from litellm import completion
|
# from litellm import completion, BudgetExceededError
|
||||||
|
|
||||||
|
# def test_max_budget():
|
||||||
|
# try:
|
||||||
# litellm.max_budget = 0.001 # sets a max budget of $0.001
|
# litellm.max_budget = 0.001 # sets a max budget of $0.001
|
||||||
|
|
||||||
# messages = [{"role": "user", "content": "Hey, how's it going"}]
|
# messages = [{"role": "user", "content": "Hey, how's it going"}]
|
||||||
# completion(model="gpt-4", messages=messages)
|
# response = completion(model="gpt-4", messages=messages, stream=True)
|
||||||
# completion(model="gpt-4", messages=messages)
|
# for chunk in response:
|
||||||
|
# continue
|
||||||
# print(litellm._current_cost)
|
# print(litellm._current_cost)
|
||||||
|
# completion(model="gpt-4", messages=messages, stream=True)
|
||||||
|
# litellm.max_budget = float('inf')
|
||||||
|
# except BudgetExceededError as e:
|
||||||
|
# pass
|
||||||
|
# except Exception as e:
|
||||||
|
# pytest.fail(f"An error occured: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -180,6 +180,10 @@ class Logging:
|
||||||
# Log the exact input to the LLM API
|
# Log the exact input to the LLM API
|
||||||
print_verbose(f"Logging Details Pre-API Call for call id {self.litellm_call_id}")
|
print_verbose(f"Logging Details Pre-API Call for call id {self.litellm_call_id}")
|
||||||
try:
|
try:
|
||||||
|
if start_time is None:
|
||||||
|
start_time = self.start_time
|
||||||
|
if end_time is None:
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
# print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}")
|
# print_verbose(f"logging pre call for model: {self.model} with call type: {self.call_type}")
|
||||||
self.model_call_details["input"] = input
|
self.model_call_details["input"] = input
|
||||||
self.model_call_details["api_key"] = api_key
|
self.model_call_details["api_key"] = api_key
|
||||||
|
@ -202,6 +206,11 @@ class Logging:
|
||||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
|
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if litellm.max_budget and self.stream:
|
||||||
|
time_diff = (end_time - start_time).total_seconds()
|
||||||
|
float_diff = float(time_diff)
|
||||||
|
litellm._current_cost += litellm.completion_cost(model=self.model, prompt="".join(message["content"] for message in self.messages), completion="", total_time=float_diff)
|
||||||
|
|
||||||
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
|
# Input Integration Logging -> If you want to log the fact that an attempt to call the model was made
|
||||||
for callback in litellm.input_callback:
|
for callback in litellm.input_callback:
|
||||||
try:
|
try:
|
||||||
|
@ -314,6 +323,12 @@ class Logging:
|
||||||
if end_time is None:
|
if end_time is None:
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
print_verbose(f"success callbacks: {litellm.success_callback}")
|
print_verbose(f"success callbacks: {litellm.success_callback}")
|
||||||
|
|
||||||
|
if litellm.max_budget and self.stream:
|
||||||
|
time_diff = (end_time - start_time).total_seconds()
|
||||||
|
float_diff = float(time_diff)
|
||||||
|
litellm._current_cost += litellm.completion_cost(model=self.model, prompt="", completion=result["content"], total_time=float_diff)
|
||||||
|
|
||||||
for callback in litellm.success_callback:
|
for callback in litellm.success_callback:
|
||||||
try:
|
try:
|
||||||
if callback == "lite_debugger":
|
if callback == "lite_debugger":
|
||||||
|
@ -574,10 +589,6 @@ def client(original_function):
|
||||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||||
litellm.cache.add_cache(result, *args, **kwargs)
|
litellm.cache.add_cache(result, *args, **kwargs)
|
||||||
|
|
||||||
# [OPTIONAL] UPDATE BUDGET
|
|
||||||
if litellm.max_budget:
|
|
||||||
litellm._current_cost += litellm.completion_cost(completion_response=result)
|
|
||||||
|
|
||||||
# [OPTIONAL] Return LiteLLM call_id
|
# [OPTIONAL] Return LiteLLM call_id
|
||||||
if litellm.use_client == True:
|
if litellm.use_client == True:
|
||||||
result['litellm_call_id'] = litellm_call_id
|
result['litellm_call_id'] = litellm_call_id
|
||||||
|
@ -2383,7 +2394,6 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
def handle_cohere_chunk(self, chunk):
|
def handle_cohere_chunk(self, chunk):
|
||||||
chunk = chunk.decode("utf-8")
|
chunk = chunk.decode("utf-8")
|
||||||
print(f"cohere chunk: {chunk}")
|
|
||||||
data_json = json.loads(chunk)
|
data_json = json.loads(chunk)
|
||||||
try:
|
try:
|
||||||
print(f"data json: {data_json}")
|
print(f"data json: {data_json}")
|
||||||
|
@ -2474,7 +2484,8 @@ class CustomStreamWrapper:
|
||||||
completion_obj["content"] = self.handle_cohere_chunk(chunk)
|
completion_obj["content"] = self.handle_cohere_chunk(chunk)
|
||||||
else: # openai chat/azure models
|
else: # openai chat/azure models
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
return chunk # open ai returns finish_reason, we should just return the openai chunk
|
completion_obj["content"] = chunk["choices"][0]["delta"]["content"]
|
||||||
|
# return chunk # open ai returns finish_reason, we should just return the openai chunk
|
||||||
|
|
||||||
#completion_obj["content"] = self.handle_openai_chat_completion_chunk(chunk)
|
#completion_obj["content"] = self.handle_openai_chat_completion_chunk(chunk)
|
||||||
# LOGGING
|
# LOGGING
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue