Merge pull request #1618 from BerriAI/litellm_sagemaker_cost_tracking_fixes

fix(utils.py): fix sagemaker cost tracking for streaming
This commit is contained in:
Krish Dholakia 2024-01-25 19:01:57 -08:00 committed by GitHub
commit 612f74a426
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 282 additions and 79 deletions

View file

@ -147,6 +147,9 @@ jobs:
-e AZURE_API_KEY=$AZURE_API_KEY \
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
-e AWS_REGION_NAME=$AWS_REGION_NAME \
--name my-app \
-v $(pwd)/proxy_server_config.yaml:/app/config.yaml \
my-app:latest \

View file

@ -34,22 +34,35 @@ class TokenIterator:
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
self.end_of_data = False
def __iter__(self):
return self
def __next__(self):
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line) + 1
full_line = line[:-1].decode("utf-8")
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
return line_data["token"]["text"]
chunk = next(self.byte_iterator)
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
try:
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
response_obj = {"text": "", "is_finished": False}
self.read_pos += len(line) + 1
full_line = line[:-1].decode("utf-8")
line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
if line_data.get("generated_text", None) is not None:
self.end_of_data = True
response_obj["is_finished"] = True
response_obj["text"] = line_data["token"]["text"]
return response_obj
chunk = next(self.byte_iterator)
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
except StopIteration as e:
if self.end_of_data == True:
raise e # Re-raise StopIteration
else:
self.end_of_data = True
return "data: [DONE]"
class SagemakerConfig:

View file

@ -15,7 +15,7 @@ import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy
import httpx
import litellm
from ._logging import verbose_logger
from litellm import ( # type: ignore
client,
exception_type,
@ -274,14 +274,10 @@ async def acompletion(
else:
# Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) # type: ignore
# if kwargs.get("stream", False): # return an async generator
# return _async_streaming(
# response=response,
# model=model,
# custom_llm_provider=custom_llm_provider,
# args=args,
# )
# else:
if isinstance(response, CustomStreamWrapper):
response.set_logging_event_loop(
loop=loop
) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls)
return response
except Exception as e:
custom_llm_provider = custom_llm_provider or "openai"
@ -1520,11 +1516,6 @@ def completion(
if (
"stream" in optional_params and optional_params["stream"] == True
): ## [BETA]
# sagemaker does not support streaming as of now so we're faking streaming:
# https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611
# "SageMaker is currently not supporting streaming responses."
# fake streaming for sagemaker
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
from .llms.sagemaker import TokenIterator
@ -1535,6 +1526,12 @@ def completion(
custom_llm_provider="sagemaker",
logging_obj=logging,
)
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
)
return response
## RESPONSE OBJECT
@ -3348,6 +3345,16 @@ def stream_chunk_builder(
chunks: list, messages: Optional[list] = None, start_time=None, end_time=None
):
model_response = litellm.ModelResponse()
### SORT CHUNKS BASED ON CREATED ORDER ##
print_verbose("Goes into checking if chunk has hiddden created at param")
if chunks[0]._hidden_params.get("created_at", None):
print_verbose("Chunks have a created at hidden param")
# Sort chunks based on created_at in ascending order
chunks = sorted(
chunks, key=lambda x: x._hidden_params.get("created_at", float("inf"))
)
print_verbose("Chunks sorted")
# set hidden params from chunk to model_response
if model_response is not None and hasattr(model_response, "_hidden_params"):
model_response._hidden_params = chunks[0].get("_hidden_params", {})

View file

@ -718,6 +718,9 @@ async def update_database(
existing_spend_obj = await custom_db_client.get_data(
key=id, table_name="user"
)
verbose_proxy_logger.debug(
f"Updating existing_spend_obj: {existing_spend_obj}"
)
if existing_spend_obj is None:
existing_spend = 0
existing_spend_obj = LiteLLM_UserTable(

View file

@ -467,7 +467,9 @@ class PrismaClient:
hashed_token = token
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
print_verbose("PrismaClient: find_unique")
verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}"
)
if query_type == "find_unique":
response = await self.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token}
@ -774,7 +776,6 @@ class PrismaClient:
Batch write update queries
"""
batcher = self.db.batch_()
verbose_proxy_logger.debug(f"data list for user table: {data_list}")
for idx, user in enumerate(data_list):
try:
data_json = self.jsonify_object(data=user.model_dump())

View file

@ -556,6 +556,47 @@ async def test_async_chat_bedrock_stream():
# asyncio.run(test_async_chat_bedrock_stream())
## Test Sagemaker + Async
@pytest.mark.asyncio
async def test_async_chat_sagemaker_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
)
# test streaming
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
stream=True,
)
print(f"response: {response}")
async for chunk in response:
print(f"chunk: {chunk}")
continue
## test failure callback
try:
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
aws_region_name="my-bad-key",
stream=True,
)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# Text Completion

View file

@ -206,12 +206,12 @@ def test_azure_completion_stream():
# checks if the model response available in the async + stream callbacks is equal to the received response
customHandler2 = MyCustomHandler()
litellm.callbacks = [customHandler2]
litellm.set_verbose = False
litellm.set_verbose = True
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "write 1 sentence about litellm being amazing",
"content": f"write 1 sentence about litellm being amazing {time.time()}",
},
]
complete_streaming_response = ""

View file

@ -847,9 +847,13 @@ def test_sagemaker_weird_response():
logging_obj=logging_obj,
)
complete_response = ""
for chunk in response:
print(chunk)
complete_response += chunk["choices"][0]["delta"]["content"]
for idx, chunk in enumerate(response):
# print
chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
complete_response += chunk
if finished:
break
assert len(complete_response) > 0
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
@ -872,41 +876,53 @@ async def test_sagemaker_streaming_async():
)
# Add any assertions here to check the response
print(response)
complete_response = ""
has_finish_reason = False
# Add any assertions here to check the response
idx = 0
async for chunk in response:
complete_response += chunk.choices[0].delta.content or ""
print(f"complete_response: {complete_response}")
assert len(complete_response) > 0
# print
chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
complete_response += chunk
if finished:
break
idx += 1
if has_finish_reason is False:
raise Exception("finish reason not set for last chunk")
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# def test_completion_sagemaker_stream():
# try:
# response = completion(
# model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b",
# messages=messages,
# temperature=0.2,
# max_tokens=80,
# stream=True,
# )
# complete_response = ""
# has_finish_reason = False
# # Add any assertions here to check the response
# for idx, chunk in enumerate(response):
# chunk, finished = streaming_format_tests(idx, chunk)
# has_finish_reason = finished
# if finished:
# break
# complete_response += chunk
# if has_finish_reason is False:
# raise Exception("finish reason not set for last chunk")
# if complete_response.strip() == "":
# raise Exception("Empty response received")
# except InvalidRequestError as e:
# pass
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
def test_completion_sagemaker_stream():
try:
response = completion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=messages,
temperature=0.2,
max_tokens=80,
stream=True,
)
complete_response = ""
has_finish_reason = False
# Add any assertions here to check the response
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
if finished:
break
complete_response += chunk
if has_finish_reason is False:
raise Exception("finish reason not set for last chunk")
if complete_response.strip() == "":
raise Exception("Empty response received")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_sagemaker_stream()

View file

@ -1418,7 +1418,9 @@ class Logging:
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
verbose_logger.debug(
f"Async success callbacks: {litellm._async_success_callback}"
)
start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
)
@ -1427,7 +1429,7 @@ class Logging:
if self.stream:
if result.choices[0].finish_reason is not None: # if it's the last chunk
self.streaming_chunks.append(result)
# print_verbose(f"final set of received chunks: {self.streaming_chunks}")
# verbose_logger.debug(f"final set of received chunks: {self.streaming_chunks}")
try:
complete_streaming_response = litellm.stream_chunk_builder(
self.streaming_chunks,
@ -1436,14 +1438,16 @@ class Logging:
end_time=end_time,
)
except Exception as e:
print_verbose(
verbose_logger.debug(
f"Error occurred building stream chunk: {traceback.format_exc()}"
)
complete_streaming_response = None
else:
self.streaming_chunks.append(result)
if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
verbose_logger.debug(
"Async success callbacks: Got a complete streaming response"
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
@ -7152,6 +7156,7 @@ class CustomStreamWrapper:
"model_id": (_model_info.get("id", None))
} # returned as x-litellm-model-id response header in proxy
self.response_id = None
self.logging_loop = None
def __iter__(self):
return self
@ -7722,6 +7727,27 @@ class CustomStreamWrapper:
}
return ""
def handle_sagemaker_stream(self, chunk):
if "data: [DONE]" in chunk:
text = ""
is_finished = True
finish_reason = "stop"
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif isinstance(chunk, dict):
if chunk["is_finished"] == True:
finish_reason = "stop"
else:
finish_reason = ""
return {
"text": chunk["text"],
"is_finished": chunk["is_finished"],
"finish_reason": finish_reason,
}
def chunk_creator(self, chunk):
model_response = ModelResponse(stream=True, model=self.model)
if self.response_id is not None:
@ -7729,6 +7755,7 @@ class CustomStreamWrapper:
else:
self.response_id = model_response.id
model_response._hidden_params["custom_llm_provider"] = self.custom_llm_provider
model_response._hidden_params["created_at"] = time.time()
model_response.choices = [StreamingChoices()]
model_response.choices[0].finish_reason = None
response_obj = {}
@ -7847,8 +7874,14 @@ class CustomStreamWrapper:
]
self.sent_last_chunk = True
elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
completion_obj["content"] = chunk
verbose_logger.debug(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
response_obj = self.handle_sagemaker_stream(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.sent_last_chunk = True
elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0:
if self.sent_last_chunk:
@ -8024,6 +8057,27 @@ class CustomStreamWrapper:
original_exception=e,
)
def set_logging_event_loop(self, loop):
self.logging_loop = loop
async def your_async_function(self):
# Your asynchronous code here
return "Your asynchronous code is running"
def run_success_logging_in_thread(self, processed_chunk):
# Create an event loop for the new thread
## ASYNC LOGGING
if self.logging_loop is not None:
future = asyncio.run_coroutine_threadsafe(
self.logging_obj.async_success_handler(processed_chunk),
loop=self.logging_loop,
)
result = future.result()
else:
asyncio.run(self.logging_obj.async_success_handler(processed_chunk))
## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk)
## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self):
try:
@ -8042,8 +8096,9 @@ class CustomStreamWrapper:
continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler, args=(response,)
target=self.run_success_logging_in_thread, args=(response,)
).start() # log response
# RETURN RESULT
return response
except StopIteration:
@ -8099,13 +8154,34 @@ class CustomStreamWrapper:
raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls
# example - boto3 bedrock llms
processed_chunk = next(self)
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
return processed_chunk
while True:
if isinstance(self.completion_stream, str) or isinstance(
self.completion_stream, bytes
):
chunk = self.completion_stream
else:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk = self.chunk_creator(chunk=chunk)
print_verbose(
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
)
if processed_chunk is None:
continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk,),
).start() # log processed_chunk
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
# RETURN RESULT
return processed_chunk
except StopAsyncIteration:
raise
except StopIteration:

View file

@ -11,6 +11,10 @@ model_list:
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
- model_name: sagemaker-completion-model
litellm_params:
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
input_cost_per_second: 0.000420
- model_name: gpt-4
litellm_params:
model: azure/gpt-turbo

View file

@ -13,17 +13,21 @@ sys.path.insert(
import litellm
async def generate_key(session, i, budget=None, budget_duration=None):
async def generate_key(
session, i, budget=None, budget_duration=None, models=["azure-models", "gpt-4"]
):
url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"models": ["azure-models", "gpt-4"],
"models": models,
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None,
"max_budget": budget,
"budget_duration": budget_duration,
}
print(f"data: {data}")
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
@ -293,7 +297,7 @@ async def test_key_info_spend_values():
rounded_response_cost = round(response_cost, 8)
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
assert rounded_response_cost == rounded_key_info_spend
## streaming
## streaming - azure
key_gen = await generate_key(session=session, i=0)
new_key = key_gen["key"]
prompt_tokens, completion_tokens = await chat_completion_streaming(
@ -318,6 +322,41 @@ async def test_key_info_spend_values():
assert rounded_response_cost == rounded_key_info_spend
@pytest.mark.asyncio
async def test_key_info_spend_values_sagemaker():
"""
Tests the sync streaming loop to ensure spend is correctly calculated.
- create key
- make completion call
- assert cost is expected value
"""
async with aiohttp.ClientSession() as session:
## streaming - sagemaker
key_gen = await generate_key(session=session, i=0, models=[])
new_key = key_gen["key"]
prompt_tokens, completion_tokens = await chat_completion_streaming(
session=session, key=new_key, model="sagemaker-completion-model"
)
# print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
# prompt_cost, completion_cost = litellm.cost_per_token(
# model="azure/gpt-35-turbo",
# prompt_tokens=prompt_tokens,
# completion_tokens=completion_tokens,
# )
# response_cost = prompt_cost + completion_cost
await asyncio.sleep(5) # allow db log to be updated
key_info = await get_key_info(
session=session, get_key=new_key, call_key=new_key
)
# print(
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
# )
# rounded_response_cost = round(response_cost, 8)
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
assert rounded_key_info_spend > 0
# assert rounded_response_cost == rounded_key_info_spend
@pytest.mark.asyncio
async def test_key_with_budgets():
"""
@ -337,7 +376,7 @@ async def test_key_with_budgets():
print(f"hashed_token: {hashed_token}")
key_info = await get_key_info(session=session, get_key=key, call_key=key)
reset_at_init_value = key_info["info"]["budget_reset_at"]
await asyncio.sleep(15)
await asyncio.sleep(30)
key_info = await get_key_info(session=session, get_key=key, call_key=key)
reset_at_new_value = key_info["info"]["budget_reset_at"]
assert reset_at_init_value != reset_at_new_value