forked from phoenix/litellm-mirror
fix(utils.py): fix sagemaker async logging for sync streaming
https://github.com/BerriAI/litellm/issues/1592
This commit is contained in:
parent
39d5407e67
commit
09ec6d6458
10 changed files with 247 additions and 64 deletions
|
@ -147,6 +147,9 @@ jobs:
|
||||||
-e AZURE_API_KEY=$AZURE_API_KEY \
|
-e AZURE_API_KEY=$AZURE_API_KEY \
|
||||||
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
|
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
|
||||||
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
|
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
|
||||||
|
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
|
||||||
|
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
|
||||||
|
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||||
--name my-app \
|
--name my-app \
|
||||||
-v $(pwd)/proxy_server_config.yaml:/app/config.yaml \
|
-v $(pwd)/proxy_server_config.yaml:/app/config.yaml \
|
||||||
my-app:latest \
|
my-app:latest \
|
||||||
|
|
|
@ -34,22 +34,35 @@ class TokenIterator:
|
||||||
self.byte_iterator = iter(stream)
|
self.byte_iterator = iter(stream)
|
||||||
self.buffer = io.BytesIO()
|
self.buffer = io.BytesIO()
|
||||||
self.read_pos = 0
|
self.read_pos = 0
|
||||||
|
self.end_of_data = False
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
while True:
|
try:
|
||||||
self.buffer.seek(self.read_pos)
|
while True:
|
||||||
line = self.buffer.readline()
|
self.buffer.seek(self.read_pos)
|
||||||
if line and line[-1] == ord("\n"):
|
line = self.buffer.readline()
|
||||||
self.read_pos += len(line) + 1
|
if line and line[-1] == ord("\n"):
|
||||||
full_line = line[:-1].decode("utf-8")
|
response_obj = {"text": "", "is_finished": False}
|
||||||
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
|
self.read_pos += len(line) + 1
|
||||||
return line_data["token"]["text"]
|
full_line = line[:-1].decode("utf-8")
|
||||||
chunk = next(self.byte_iterator)
|
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
|
||||||
self.buffer.seek(0, io.SEEK_END)
|
if line_data.get("generated_text", None) is not None:
|
||||||
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
self.end_of_data = True
|
||||||
|
response_obj["is_finished"] = True
|
||||||
|
response_obj["text"] = line_data["token"]["text"]
|
||||||
|
return response_obj
|
||||||
|
chunk = next(self.byte_iterator)
|
||||||
|
self.buffer.seek(0, io.SEEK_END)
|
||||||
|
self.buffer.write(chunk["PayloadPart"]["Bytes"])
|
||||||
|
except StopIteration as e:
|
||||||
|
if self.end_of_data == True:
|
||||||
|
raise e # Re-raise StopIteration
|
||||||
|
else:
|
||||||
|
self.end_of_data = True
|
||||||
|
return "data: [DONE]"
|
||||||
|
|
||||||
|
|
||||||
class SagemakerConfig:
|
class SagemakerConfig:
|
||||||
|
|
|
@ -1514,11 +1514,6 @@ def completion(
|
||||||
if (
|
if (
|
||||||
"stream" in optional_params and optional_params["stream"] == True
|
"stream" in optional_params and optional_params["stream"] == True
|
||||||
): ## [BETA]
|
): ## [BETA]
|
||||||
# sagemaker does not support streaming as of now so we're faking streaming:
|
|
||||||
# https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611
|
|
||||||
# "SageMaker is currently not supporting streaming responses."
|
|
||||||
|
|
||||||
# fake streaming for sagemaker
|
|
||||||
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
||||||
from .llms.sagemaker import TokenIterator
|
from .llms.sagemaker import TokenIterator
|
||||||
|
|
||||||
|
@ -1529,6 +1524,12 @@ def completion(
|
||||||
custom_llm_provider="sagemaker",
|
custom_llm_provider="sagemaker",
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
)
|
)
|
||||||
|
## LOGGING
|
||||||
|
logging.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=None,
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
|
|
|
@ -690,6 +690,9 @@ async def update_database(
|
||||||
existing_spend_obj = await custom_db_client.get_data(
|
existing_spend_obj = await custom_db_client.get_data(
|
||||||
key=id, table_name="user"
|
key=id, table_name="user"
|
||||||
)
|
)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"Updating existing_spend_obj: {existing_spend_obj}"
|
||||||
|
)
|
||||||
if existing_spend_obj is None:
|
if existing_spend_obj is None:
|
||||||
existing_spend = 0
|
existing_spend = 0
|
||||||
existing_spend_obj = LiteLLM_UserTable(
|
existing_spend_obj = LiteLLM_UserTable(
|
||||||
|
|
|
@ -409,7 +409,9 @@ class PrismaClient:
|
||||||
hashed_token = token
|
hashed_token = token
|
||||||
if token.startswith("sk-"):
|
if token.startswith("sk-"):
|
||||||
hashed_token = self.hash_token(token=token)
|
hashed_token = self.hash_token(token=token)
|
||||||
print_verbose("PrismaClient: find_unique")
|
verbose_proxy_logger.debug(
|
||||||
|
f"PrismaClient: find_unique for token: {hashed_token}"
|
||||||
|
)
|
||||||
if query_type == "find_unique":
|
if query_type == "find_unique":
|
||||||
response = await self.db.litellm_verificationtoken.find_unique(
|
response = await self.db.litellm_verificationtoken.find_unique(
|
||||||
where={"token": hashed_token}
|
where={"token": hashed_token}
|
||||||
|
@ -716,7 +718,6 @@ class PrismaClient:
|
||||||
Batch write update queries
|
Batch write update queries
|
||||||
"""
|
"""
|
||||||
batcher = self.db.batch_()
|
batcher = self.db.batch_()
|
||||||
verbose_proxy_logger.debug(f"data list for user table: {data_list}")
|
|
||||||
for idx, user in enumerate(data_list):
|
for idx, user in enumerate(data_list):
|
||||||
try:
|
try:
|
||||||
data_json = self.jsonify_object(data=user.model_dump())
|
data_json = self.jsonify_object(data=user.model_dump())
|
||||||
|
|
|
@ -556,6 +556,47 @@ async def test_async_chat_bedrock_stream():
|
||||||
|
|
||||||
# asyncio.run(test_async_chat_bedrock_stream())
|
# asyncio.run(test_async_chat_bedrock_stream())
|
||||||
|
|
||||||
|
|
||||||
|
## Test Sagemaker + Async
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_sagemaker_stream():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||||
|
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
|
||||||
|
)
|
||||||
|
# test streaming
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||||
|
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
async for chunk in response:
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
|
continue
|
||||||
|
## test failure callback
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||||
|
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
|
||||||
|
aws_region_name="my-bad-key",
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# Text Completion
|
# Text Completion
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -872,41 +872,53 @@ async def test_sagemaker_streaming_async():
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
complete_response = ""
|
complete_response = ""
|
||||||
|
has_finish_reason = False
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
idx = 0
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
complete_response += chunk.choices[0].delta.content or ""
|
# print
|
||||||
print(f"complete_response: {complete_response}")
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
assert len(complete_response) > 0
|
has_finish_reason = finished
|
||||||
|
complete_response += chunk
|
||||||
|
if finished:
|
||||||
|
break
|
||||||
|
idx += 1
|
||||||
|
if has_finish_reason is False:
|
||||||
|
raise Exception("finish reason not set for last chunk")
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
print(f"completion_response: {complete_response}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# def test_completion_sagemaker_stream():
|
def test_completion_sagemaker_stream():
|
||||||
# try:
|
try:
|
||||||
# response = completion(
|
response = completion(
|
||||||
# model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
|
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
|
||||||
# messages=messages,
|
messages=messages,
|
||||||
# temperature=0.2,
|
temperature=0.2,
|
||||||
# max_tokens=80,
|
max_tokens=80,
|
||||||
# stream=True,
|
stream=True,
|
||||||
# )
|
)
|
||||||
# complete_response = ""
|
complete_response = ""
|
||||||
# has_finish_reason = False
|
has_finish_reason = False
|
||||||
# # Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
# for idx, chunk in enumerate(response):
|
for idx, chunk in enumerate(response):
|
||||||
# chunk, finished = streaming_format_tests(idx, chunk)
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
# has_finish_reason = finished
|
has_finish_reason = finished
|
||||||
# if finished:
|
if finished:
|
||||||
# break
|
break
|
||||||
# complete_response += chunk
|
complete_response += chunk
|
||||||
# if has_finish_reason is False:
|
if has_finish_reason is False:
|
||||||
# raise Exception("finish reason not set for last chunk")
|
raise Exception("finish reason not set for last chunk")
|
||||||
# if complete_response.strip() == "":
|
if complete_response.strip() == "":
|
||||||
# raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
# except InvalidRequestError as e:
|
except Exception as e:
|
||||||
# pass
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# except Exception as e:
|
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
# test_completion_sagemaker_stream()
|
# test_completion_sagemaker_stream()
|
||||||
|
|
||||||
|
|
|
@ -1417,7 +1417,9 @@ class Logging:
|
||||||
"""
|
"""
|
||||||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||||
"""
|
"""
|
||||||
print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
|
verbose_logger.debug(
|
||||||
|
f"Async success callbacks: {litellm._async_success_callback}"
|
||||||
|
)
|
||||||
start_time, end_time, result = self._success_handler_helper_fn(
|
start_time, end_time, result = self._success_handler_helper_fn(
|
||||||
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
|
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
|
||||||
)
|
)
|
||||||
|
@ -1426,7 +1428,7 @@ class Logging:
|
||||||
if self.stream:
|
if self.stream:
|
||||||
if result.choices[0].finish_reason is not None: # if it's the last chunk
|
if result.choices[0].finish_reason is not None: # if it's the last chunk
|
||||||
self.streaming_chunks.append(result)
|
self.streaming_chunks.append(result)
|
||||||
# print_verbose(f"final set of received chunks: {self.streaming_chunks}")
|
# verbose_logger.debug(f"final set of received chunks: {self.streaming_chunks}")
|
||||||
try:
|
try:
|
||||||
complete_streaming_response = litellm.stream_chunk_builder(
|
complete_streaming_response = litellm.stream_chunk_builder(
|
||||||
self.streaming_chunks,
|
self.streaming_chunks,
|
||||||
|
@ -1435,14 +1437,16 @@ class Logging:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(
|
verbose_logger.debug(
|
||||||
f"Error occurred building stream chunk: {traceback.format_exc()}"
|
f"Error occurred building stream chunk: {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
complete_streaming_response = None
|
complete_streaming_response = None
|
||||||
else:
|
else:
|
||||||
self.streaming_chunks.append(result)
|
self.streaming_chunks.append(result)
|
||||||
if complete_streaming_response is not None:
|
if complete_streaming_response is not None:
|
||||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
verbose_logger.debug(
|
||||||
|
"Async success callbacks: Got a complete streaming response"
|
||||||
|
)
|
||||||
self.model_call_details[
|
self.model_call_details[
|
||||||
"complete_streaming_response"
|
"complete_streaming_response"
|
||||||
] = complete_streaming_response
|
] = complete_streaming_response
|
||||||
|
@ -7682,6 +7686,27 @@ class CustomStreamWrapper:
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
def handle_sagemaker_stream(self, chunk):
|
||||||
|
if "data: [DONE]" in chunk:
|
||||||
|
text = ""
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = "stop"
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"is_finished": is_finished,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
elif isinstance(chunk, dict):
|
||||||
|
if chunk["is_finished"] == True:
|
||||||
|
finish_reason = "stop"
|
||||||
|
else:
|
||||||
|
finish_reason = ""
|
||||||
|
return {
|
||||||
|
"text": chunk["text"],
|
||||||
|
"is_finished": chunk["is_finished"],
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
|
||||||
def chunk_creator(self, chunk):
|
def chunk_creator(self, chunk):
|
||||||
model_response = ModelResponse(stream=True, model=self.model)
|
model_response = ModelResponse(stream=True, model=self.model)
|
||||||
if self.response_id is not None:
|
if self.response_id is not None:
|
||||||
|
@ -7807,8 +7832,14 @@ class CustomStreamWrapper:
|
||||||
]
|
]
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
elif self.custom_llm_provider == "sagemaker":
|
elif self.custom_llm_provider == "sagemaker":
|
||||||
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
verbose_logger.debug(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
|
||||||
completion_obj["content"] = chunk
|
response_obj = self.handle_sagemaker_stream(chunk)
|
||||||
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
if response_obj["is_finished"]:
|
||||||
|
model_response.choices[0].finish_reason = response_obj[
|
||||||
|
"finish_reason"
|
||||||
|
]
|
||||||
|
self.sent_last_chunk = True
|
||||||
elif self.custom_llm_provider == "petals":
|
elif self.custom_llm_provider == "petals":
|
||||||
if len(self.completion_stream) == 0:
|
if len(self.completion_stream) == 0:
|
||||||
if self.sent_last_chunk:
|
if self.sent_last_chunk:
|
||||||
|
@ -7984,6 +8015,19 @@ class CustomStreamWrapper:
|
||||||
original_exception=e,
|
original_exception=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def run_success_logging_in_thread(self, processed_chunk):
|
||||||
|
# Create an event loop for the new thread
|
||||||
|
## ASYNC LOGGING
|
||||||
|
# Run the asynchronous function in the new thread's event loop
|
||||||
|
asyncio.run(
|
||||||
|
self.logging_obj.async_success_handler(
|
||||||
|
processed_chunk,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
## SYNC LOGGING
|
||||||
|
self.logging_obj.success_handler(processed_chunk)
|
||||||
|
|
||||||
## needs to handle the empty string case (even starting chunk can be an empty string)
|
## needs to handle the empty string case (even starting chunk can be an empty string)
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
try:
|
try:
|
||||||
|
@ -8002,8 +8046,9 @@ class CustomStreamWrapper:
|
||||||
continue
|
continue
|
||||||
## LOGGING
|
## LOGGING
|
||||||
threading.Thread(
|
threading.Thread(
|
||||||
target=self.logging_obj.success_handler, args=(response,)
|
target=self.run_success_logging_in_thread, args=(response,)
|
||||||
).start() # log response
|
).start() # log response
|
||||||
|
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
return response
|
return response
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
|
@ -8059,13 +8104,34 @@ class CustomStreamWrapper:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
else: # temporary patch for non-aiohttp async calls
|
else: # temporary patch for non-aiohttp async calls
|
||||||
# example - boto3 bedrock llms
|
# example - boto3 bedrock llms
|
||||||
processed_chunk = next(self)
|
while True:
|
||||||
asyncio.create_task(
|
if isinstance(self.completion_stream, str) or isinstance(
|
||||||
self.logging_obj.async_success_handler(
|
self.completion_stream, bytes
|
||||||
processed_chunk,
|
):
|
||||||
)
|
chunk = self.completion_stream
|
||||||
)
|
else:
|
||||||
return processed_chunk
|
chunk = next(self.completion_stream)
|
||||||
|
if chunk is not None and chunk != b"":
|
||||||
|
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||||
|
processed_chunk = self.chunk_creator(chunk=chunk)
|
||||||
|
print_verbose(
|
||||||
|
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
|
||||||
|
)
|
||||||
|
if processed_chunk is None:
|
||||||
|
continue
|
||||||
|
## LOGGING
|
||||||
|
threading.Thread(
|
||||||
|
target=self.logging_obj.success_handler,
|
||||||
|
args=(processed_chunk,),
|
||||||
|
).start() # log processed_chunk
|
||||||
|
asyncio.create_task(
|
||||||
|
self.logging_obj.async_success_handler(
|
||||||
|
processed_chunk,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# RETURN RESULT
|
||||||
|
return processed_chunk
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
raise
|
raise
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
|
|
|
@ -11,6 +11,10 @@ model_list:
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
api_version: "2023-05-15"
|
api_version: "2023-05-15"
|
||||||
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||||
|
- model_name: sagemaker-completion-model
|
||||||
|
litellm_params:
|
||||||
|
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
|
||||||
|
input_cost_per_second: 0.000420
|
||||||
- model_name: gpt-4
|
- model_name: gpt-4
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/gpt-turbo
|
model: azure/gpt-turbo
|
||||||
|
|
|
@ -13,17 +13,21 @@ sys.path.insert(
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
async def generate_key(session, i, budget=None, budget_duration=None):
|
async def generate_key(
|
||||||
|
session, i, budget=None, budget_duration=None, models=["azure-models", "gpt-4"]
|
||||||
|
):
|
||||||
url = "http://0.0.0.0:4000/key/generate"
|
url = "http://0.0.0.0:4000/key/generate"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
data = {
|
data = {
|
||||||
"models": ["azure-models", "gpt-4"],
|
"models": models,
|
||||||
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||||
"duration": None,
|
"duration": None,
|
||||||
"max_budget": budget,
|
"max_budget": budget,
|
||||||
"budget_duration": budget_duration,
|
"budget_duration": budget_duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print(f"data: {data}")
|
||||||
|
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
status = response.status
|
status = response.status
|
||||||
response_text = await response.text()
|
response_text = await response.text()
|
||||||
|
@ -293,7 +297,7 @@ async def test_key_info_spend_values():
|
||||||
rounded_response_cost = round(response_cost, 8)
|
rounded_response_cost = round(response_cost, 8)
|
||||||
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
|
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
|
||||||
assert rounded_response_cost == rounded_key_info_spend
|
assert rounded_response_cost == rounded_key_info_spend
|
||||||
## streaming
|
## streaming - azure
|
||||||
key_gen = await generate_key(session=session, i=0)
|
key_gen = await generate_key(session=session, i=0)
|
||||||
new_key = key_gen["key"]
|
new_key = key_gen["key"]
|
||||||
prompt_tokens, completion_tokens = await chat_completion_streaming(
|
prompt_tokens, completion_tokens = await chat_completion_streaming(
|
||||||
|
@ -318,6 +322,41 @@ async def test_key_info_spend_values():
|
||||||
assert rounded_response_cost == rounded_key_info_spend
|
assert rounded_response_cost == rounded_key_info_spend
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_key_info_spend_values_sagemaker():
|
||||||
|
"""
|
||||||
|
Tests the sync streaming loop to ensure spend is correctly calculated.
|
||||||
|
- create key
|
||||||
|
- make completion call
|
||||||
|
- assert cost is expected value
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
## streaming - sagemaker
|
||||||
|
key_gen = await generate_key(session=session, i=0, models=[])
|
||||||
|
new_key = key_gen["key"]
|
||||||
|
prompt_tokens, completion_tokens = await chat_completion_streaming(
|
||||||
|
session=session, key=new_key, model="sagemaker-completion-model"
|
||||||
|
)
|
||||||
|
# print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
|
||||||
|
# prompt_cost, completion_cost = litellm.cost_per_token(
|
||||||
|
# model="azure/gpt-35-turbo",
|
||||||
|
# prompt_tokens=prompt_tokens,
|
||||||
|
# completion_tokens=completion_tokens,
|
||||||
|
# )
|
||||||
|
# response_cost = prompt_cost + completion_cost
|
||||||
|
await asyncio.sleep(5) # allow db log to be updated
|
||||||
|
key_info = await get_key_info(
|
||||||
|
session=session, get_key=new_key, call_key=new_key
|
||||||
|
)
|
||||||
|
# print(
|
||||||
|
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
|
||||||
|
# )
|
||||||
|
# rounded_response_cost = round(response_cost, 8)
|
||||||
|
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
|
||||||
|
assert rounded_key_info_spend > 0
|
||||||
|
# assert rounded_response_cost == rounded_key_info_spend
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_key_with_budgets():
|
async def test_key_with_budgets():
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue