Merge pull request #1618 from BerriAI/litellm_sagemaker_cost_tracking_fixes

fix(utils.py): fix sagemaker cost tracking for streaming
This commit is contained in:
Krish Dholakia 2024-01-25 19:01:57 -08:00 committed by GitHub
commit 612f74a426
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 282 additions and 79 deletions

View file

@ -147,6 +147,9 @@ jobs:
-e AZURE_API_KEY=$AZURE_API_KEY \ -e AZURE_API_KEY=$AZURE_API_KEY \
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \ -e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \ -e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
-e AWS_REGION_NAME=$AWS_REGION_NAME \
--name my-app \ --name my-app \
-v $(pwd)/proxy_server_config.yaml:/app/config.yaml \ -v $(pwd)/proxy_server_config.yaml:/app/config.yaml \
my-app:latest \ my-app:latest \

View file

@ -34,22 +34,35 @@ class TokenIterator:
self.byte_iterator = iter(stream) self.byte_iterator = iter(stream)
self.buffer = io.BytesIO() self.buffer = io.BytesIO()
self.read_pos = 0 self.read_pos = 0
self.end_of_data = False
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
while True: try:
self.buffer.seek(self.read_pos) while True:
line = self.buffer.readline() self.buffer.seek(self.read_pos)
if line and line[-1] == ord("\n"): line = self.buffer.readline()
self.read_pos += len(line) + 1 if line and line[-1] == ord("\n"):
full_line = line[:-1].decode("utf-8") response_obj = {"text": "", "is_finished": False}
line_data = json.loads(full_line.lstrip("data:").rstrip("/n")) self.read_pos += len(line) + 1
return line_data["token"]["text"] full_line = line[:-1].decode("utf-8")
chunk = next(self.byte_iterator) line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
self.buffer.seek(0, io.SEEK_END) if line_data.get("generated_text", None) is not None:
self.buffer.write(chunk["PayloadPart"]["Bytes"]) self.end_of_data = True
response_obj["is_finished"] = True
response_obj["text"] = line_data["token"]["text"]
return response_obj
chunk = next(self.byte_iterator)
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
except StopIteration as e:
if self.end_of_data == True:
raise e # Re-raise StopIteration
else:
self.end_of_data = True
return "data: [DONE]"
class SagemakerConfig: class SagemakerConfig:

View file

@ -15,7 +15,7 @@ import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx
import litellm import litellm
from ._logging import verbose_logger
from litellm import ( # type: ignore from litellm import ( # type: ignore
client, client,
exception_type, exception_type,
@ -274,14 +274,10 @@ async def acompletion(
else: else:
# Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) # type: ignore response = await loop.run_in_executor(None, func_with_context) # type: ignore
# if kwargs.get("stream", False): # return an async generator if isinstance(response, CustomStreamWrapper):
# return _async_streaming( response.set_logging_event_loop(
# response=response, loop=loop
# model=model, ) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
# custom_llm_provider=custom_llm_provider,
# args=args,
# )
# else:
return response return response
except Exception as e: except Exception as e:
custom_llm_provider = custom_llm_provider or "openai" custom_llm_provider = custom_llm_provider or "openai"
@ -1520,11 +1516,6 @@ def completion(
if ( if (
"stream" in optional_params and optional_params["stream"] == True "stream" in optional_params and optional_params["stream"] == True
): ## [BETA] ): ## [BETA]
# sagemaker does not support streaming as of now so we're faking streaming:
# https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611
# "SageMaker is currently not supporting streaming responses."
# fake streaming for sagemaker
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
from .llms.sagemaker import TokenIterator from .llms.sagemaker import TokenIterator
@ -1535,6 +1526,12 @@ def completion(
custom_llm_provider="sagemaker", custom_llm_provider="sagemaker",
logging_obj=logging, logging_obj=logging,
) )
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
)
return response return response
## RESPONSE OBJECT ## RESPONSE OBJECT
@ -3348,6 +3345,16 @@ def stream_chunk_builder(
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
): ):
model_response = litellm.ModelResponse() model_response = litellm.ModelResponse()
### SORT CHUNKS BASED ON CREATED ORDER ##
print_verbose("Goes into checking if chunk has hiddden created at param")
if chunks[0]._hidden_params.get("created_at", None):
print_verbose("Chunks have a created at hidden param")
# Sort chunks based on created_at in ascending order
chunks = sorted(
chunks, key=lambda x: x._hidden_params.get("created_at", float("inf"))
)
print_verbose("Chunks sorted")
# set hidden params from chunk to model_response # set hidden params from chunk to model_response
if model_response is not None and hasattr(model_response, "_hidden_params"): if model_response is not None and hasattr(model_response, "_hidden_params"):
model_response._hidden_params = chunks[0].get("_hidden_params", {}) model_response._hidden_params = chunks[0].get("_hidden_params", {})

View file

@ -718,6 +718,9 @@ async def update_database(
existing_spend_obj = await custom_db_client.get_data( existing_spend_obj = await custom_db_client.get_data(
key=id, table_name="user" key=id, table_name="user"
) )
verbose_proxy_logger.debug(
f"Updating existing_spend_obj: {existing_spend_obj}"
)
if existing_spend_obj is None: if existing_spend_obj is None:
existing_spend = 0 existing_spend = 0
existing_spend_obj = LiteLLM_UserTable( existing_spend_obj = LiteLLM_UserTable(

View file

@ -467,7 +467,9 @@ class PrismaClient:
hashed_token = token hashed_token = token
if token.startswith("sk-"): if token.startswith("sk-"):
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
print_verbose("PrismaClient: find_unique") verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}"
)
if query_type == "find_unique": if query_type == "find_unique":
response = await self.db.litellm_verificationtoken.find_unique( response = await self.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token} where={"token": hashed_token}
@ -774,7 +776,6 @@ class PrismaClient:
Batch write update queries Batch write update queries
""" """
batcher = self.db.batch_() batcher = self.db.batch_()
verbose_proxy_logger.debug(f"data list for user table: {data_list}")
for idx, user in enumerate(data_list): for idx, user in enumerate(data_list):
try: try:
data_json = self.jsonify_object(data=user.model_dump()) data_json = self.jsonify_object(data=user.model_dump())

View file

@ -556,6 +556,47 @@ async def test_async_chat_bedrock_stream():
# asyncio.run(test_async_chat_bedrock_stream()) # asyncio.run(test_async_chat_bedrock_stream())
## Test Sagemaker + Async
@pytest.mark.asyncio
async def test_async_chat_sagemaker_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
)
# test streaming
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
stream=True,
)
print(f"response: {response}")
async for chunk in response:
print(f"chunk: {chunk}")
continue
## test failure callback
try:
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
aws_region_name="my-bad-key",
stream=True,
)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# Text Completion # Text Completion

View file

@ -206,12 +206,12 @@ def test_azure_completion_stream():
# checks if the model response available in the async + stream callbacks is equal to the received response # checks if the model response available in the async + stream callbacks is equal to the received response
customHandler2 = MyCustomHandler() customHandler2 = MyCustomHandler()
litellm.callbacks = [customHandler2] litellm.callbacks = [customHandler2]
litellm.set_verbose = False litellm.set_verbose = True
messages = [ messages = [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{ {
"role": "user", "role": "user",
"content": "write 1 sentence about litellm being amazing", "content": f"write 1 sentence about litellm being amazing {time.time()}",
}, },
] ]
complete_streaming_response = "" complete_streaming_response = ""

View file

@ -847,9 +847,13 @@ def test_sagemaker_weird_response():
logging_obj=logging_obj, logging_obj=logging_obj,
) )
complete_response = "" complete_response = ""
for chunk in response: for idx, chunk in enumerate(response):
print(chunk) # print
complete_response += chunk["choices"][0]["delta"]["content"] chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
complete_response += chunk
if finished:
break
assert len(complete_response) > 0 assert len(complete_response) > 0
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
@ -872,41 +876,53 @@ async def test_sagemaker_streaming_async():
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response)
complete_response = "" complete_response = ""
has_finish_reason = False
# Add any assertions here to check the response
idx = 0
async for chunk in response: async for chunk in response:
complete_response += chunk.choices[0].delta.content or "" # print
print(f"complete_response: {complete_response}") chunk, finished = streaming_format_tests(idx, chunk)
assert len(complete_response) > 0 has_finish_reason = finished
complete_response += chunk
if finished:
break
idx += 1
if has_finish_reason is False:
raise Exception("finish reason not set for last chunk")
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# def test_completion_sagemaker_stream(): def test_completion_sagemaker_stream():
# try: try:
# response = completion( response = completion(
# model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b", model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
# messages=messages, messages=messages,
# temperature=0.2, temperature=0.2,
# max_tokens=80, max_tokens=80,
# stream=True, stream=True,
# ) )
# complete_response = "" complete_response = ""
# has_finish_reason = False has_finish_reason = False
# # Add any assertions here to check the response # Add any assertions here to check the response
# for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
# chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
# has_finish_reason = finished has_finish_reason = finished
# if finished: if finished:
# break break
# complete_response += chunk complete_response += chunk
# if has_finish_reason is False: if has_finish_reason is False:
# raise Exception("finish reason not set for last chunk") raise Exception("finish reason not set for last chunk")
# if complete_response.strip() == "": if complete_response.strip() == "":
# raise Exception("Empty response received") raise Exception("Empty response received")
# except InvalidRequestError as e: except Exception as e:
# pass pytest.fail(f"Error occurred: {e}")
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_sagemaker_stream() # test_completion_sagemaker_stream()

View file

@ -1418,7 +1418,9 @@ class Logging:
""" """
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
""" """
print_verbose(f"Async success callbacks: {litellm._async_success_callback}") verbose_logger.debug(
f"Async success callbacks: {litellm._async_success_callback}"
)
start_time, end_time, result = self._success_handler_helper_fn( start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
) )
@ -1427,7 +1429,7 @@ class Logging:
if self.stream: if self.stream:
if result.choices[0].finish_reason is not None: # if it's the last chunk if result.choices[0].finish_reason is not None: # if it's the last chunk
self.streaming_chunks.append(result) self.streaming_chunks.append(result)
# print_verbose(f"final set of received chunks: {self.streaming_chunks}") # verbose_logger.debug(f"final set of received chunks: {self.streaming_chunks}")
try: try:
complete_streaming_response = litellm.stream_chunk_builder( complete_streaming_response = litellm.stream_chunk_builder(
self.streaming_chunks, self.streaming_chunks,
@ -1436,14 +1438,16 @@ class Logging:
end_time=end_time, end_time=end_time,
) )
except Exception as e: except Exception as e:
print_verbose( verbose_logger.debug(
f"Error occurred building stream chunk: {traceback.format_exc()}" f"Error occurred building stream chunk: {traceback.format_exc()}"
) )
complete_streaming_response = None complete_streaming_response = None
else: else:
self.streaming_chunks.append(result) self.streaming_chunks.append(result)
if complete_streaming_response is not None: if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response") verbose_logger.debug(
"Async success callbacks: Got a complete streaming response"
)
self.model_call_details[ self.model_call_details[
"complete_streaming_response" "complete_streaming_response"
] = complete_streaming_response ] = complete_streaming_response
@ -7152,6 +7156,7 @@ class CustomStreamWrapper:
"model_id": (_model_info.get("id", None)) "model_id": (_model_info.get("id", None))
} # returned as x-litellm-model-id response header in proxy } # returned as x-litellm-model-id response header in proxy
self.response_id = None self.response_id = None
self.logging_loop = None
def __iter__(self): def __iter__(self):
return self return self
@ -7722,6 +7727,27 @@ class CustomStreamWrapper:
} }
return "" return ""
def handle_sagemaker_stream(self, chunk):
if "data: [DONE]" in chunk:
text = ""
is_finished = True
finish_reason = "stop"
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif isinstance(chunk, dict):
if chunk["is_finished"] == True:
finish_reason = "stop"
else:
finish_reason = ""
return {
"text": chunk["text"],
"is_finished": chunk["is_finished"],
"finish_reason": finish_reason,
}
def chunk_creator(self, chunk): def chunk_creator(self, chunk):
model_response = ModelResponse(stream=True, model=self.model) model_response = ModelResponse(stream=True, model=self.model)
if self.response_id is not None: if self.response_id is not None:
@ -7729,6 +7755,7 @@ class CustomStreamWrapper:
else: else:
self.response_id = model_response.id self.response_id = model_response.id
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
model_response._hidden_params["created_at"] = time.time()
model_response.choices = [StreamingChoices()] model_response.choices = [StreamingChoices()]
model_response.choices[0].finish_reason = None model_response.choices[0].finish_reason = None
response_obj = {} response_obj = {}
@ -7847,8 +7874,14 @@ class CustomStreamWrapper:
] ]
self.sent_last_chunk = True self.sent_last_chunk = True
elif self.custom_llm_provider == "sagemaker": elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") verbose_logger.debug(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
completion_obj["content"] = chunk response_obj = self.handle_sagemaker_stream(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.sent_last_chunk = True
elif self.custom_llm_provider == "petals": elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0: if len(self.completion_stream) == 0:
if self.sent_last_chunk: if self.sent_last_chunk:
@ -8024,6 +8057,27 @@ class CustomStreamWrapper:
original_exception=e, original_exception=e,
) )
def set_logging_event_loop(self, loop):
self.logging_loop = loop
async def your_async_function(self):
# Your asynchronous code here
return "Your asynchronous code is running"
def run_success_logging_in_thread(self, processed_chunk):
# Create an event loop for the new thread
## ASYNC LOGGING
if self.logging_loop is not None:
future = asyncio.run_coroutine_threadsafe(
self.logging_obj.async_success_handler(processed_chunk),
loop=self.logging_loop,
)
result = future.result()
else:
asyncio.run(self.logging_obj.async_success_handler(processed_chunk))
## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk)
## needs to handle the empty string case (even starting chunk can be an empty string) ## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self): def __next__(self):
try: try:
@ -8042,8 +8096,9 @@ class CustomStreamWrapper:
continue continue
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, args=(response,) target=self.run_success_logging_in_thread, args=(response,)
).start() # log response ).start() # log response
# RETURN RESULT # RETURN RESULT
return response return response
except StopIteration: except StopIteration:
@ -8099,13 +8154,34 @@ class CustomStreamWrapper:
raise StopAsyncIteration raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls else: # temporary patch for non-aiohttp async calls
# example - boto3 bedrock llms # example - boto3 bedrock llms
processed_chunk = next(self) while True:
asyncio.create_task( if isinstance(self.completion_stream, str) or isinstance(
self.logging_obj.async_success_handler( self.completion_stream, bytes
processed_chunk, ):
) chunk = self.completion_stream
) else:
return processed_chunk chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk = self.chunk_creator(chunk=chunk)
print_verbose(
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
)
if processed_chunk is None:
continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk,),
).start() # log processed_chunk
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
# RETURN RESULT
return processed_chunk
except StopAsyncIteration: except StopAsyncIteration:
raise raise
except StopIteration: except StopIteration:

View file

@ -11,6 +11,10 @@ model_list:
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15" api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
- model_name: sagemaker-completion-model
litellm_params:
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
input_cost_per_second: 0.000420
- model_name: gpt-4 - model_name: gpt-4
litellm_params: litellm_params:
model: azure/gpt-turbo model: azure/gpt-turbo

View file

@ -13,17 +13,21 @@ sys.path.insert(
import litellm import litellm
async def generate_key(session, i, budget=None, budget_duration=None): async def generate_key(
session, i, budget=None, budget_duration=None, models=["azure-models", "gpt-4"]
):
url = "http://0.0.0.0:4000/key/generate" url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = { data = {
"models": ["azure-models", "gpt-4"], "models": models,
"aliases": {"mistral-7b": "gpt-3.5-turbo"}, "aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None, "duration": None,
"max_budget": budget, "max_budget": budget,
"budget_duration": budget_duration, "budget_duration": budget_duration,
} }
print(f"data: {data}")
async with session.post(url, headers=headers, json=data) as response: async with session.post(url, headers=headers, json=data) as response:
status = response.status status = response.status
response_text = await response.text() response_text = await response.text()
@ -293,7 +297,7 @@ async def test_key_info_spend_values():
rounded_response_cost = round(response_cost, 8) rounded_response_cost = round(response_cost, 8)
rounded_key_info_spend = round(key_info["info"]["spend"], 8) rounded_key_info_spend = round(key_info["info"]["spend"], 8)
assert rounded_response_cost == rounded_key_info_spend assert rounded_response_cost == rounded_key_info_spend
## streaming ## streaming - azure
key_gen = await generate_key(session=session, i=0) key_gen = await generate_key(session=session, i=0)
new_key = key_gen["key"] new_key = key_gen["key"]
prompt_tokens, completion_tokens = await chat_completion_streaming( prompt_tokens, completion_tokens = await chat_completion_streaming(
@ -318,6 +322,41 @@ async def test_key_info_spend_values():
assert rounded_response_cost == rounded_key_info_spend assert rounded_response_cost == rounded_key_info_spend
@pytest.mark.asyncio
async def test_key_info_spend_values_sagemaker():
"""
Tests the sync streaming loop to ensure spend is correctly calculated.
- create key
- make completion call
- assert cost is expected value
"""
async with aiohttp.ClientSession() as session:
## streaming - sagemaker
key_gen = await generate_key(session=session, i=0, models=[])
new_key = key_gen["key"]
prompt_tokens, completion_tokens = await chat_completion_streaming(
session=session, key=new_key, model="sagemaker-completion-model"
)
# print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
# prompt_cost, completion_cost = litellm.cost_per_token(
# model="azure/gpt-35-turbo",
# prompt_tokens=prompt_tokens,
# completion_tokens=completion_tokens,
# )
# response_cost = prompt_cost + completion_cost
await asyncio.sleep(5) # allow db log to be updated
key_info = await get_key_info(
session=session, get_key=new_key, call_key=new_key
)
# print(
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
# )
# rounded_response_cost = round(response_cost, 8)
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
assert rounded_key_info_spend > 0
# assert rounded_response_cost == rounded_key_info_spend
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_key_with_budgets(): async def test_key_with_budgets():
""" """
@ -337,7 +376,7 @@ async def test_key_with_budgets():
print(f"hashed_token: {hashed_token}") print(f"hashed_token: {hashed_token}")
key_info = await get_key_info(session=session, get_key=key, call_key=key) key_info = await get_key_info(session=session, get_key=key, call_key=key)
reset_at_init_value = key_info["info"]["budget_reset_at"] reset_at_init_value = key_info["info"]["budget_reset_at"]
await asyncio.sleep(15) await asyncio.sleep(30)
key_info = await get_key_info(session=session, get_key=key, call_key=key) key_info = await get_key_info(session=session, get_key=key, call_key=key)
reset_at_new_value = key_info["info"]["budget_reset_at"] reset_at_new_value = key_info["info"]["budget_reset_at"]
assert reset_at_init_value != reset_at_new_value assert reset_at_init_value != reset_at_new_value