fix(utils.py): fix sagemaker async logging for sync streaming

https://github.com/BerriAI/litellm/issues/1592
This commit is contained in:
Krrish Dholakia 2024-01-25 12:49:45 -08:00
parent 39d5407e67
commit 09ec6d6458
10 changed files with 247 additions and 64 deletions

View file

@ -147,6 +147,9 @@ jobs:
-e AZURE_API_KEY=$AZURE_API_KEY \ -e AZURE_API_KEY=$AZURE_API_KEY \
-e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \ -e AZURE_FRANCE_API_KEY=$AZURE_FRANCE_API_KEY \
-e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \ -e AZURE_EUROPE_API_KEY=$AZURE_EUROPE_API_KEY \
-e AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID \
-e AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY \
-e AWS_REGION_NAME=$AWS_REGION_NAME \
--name my-app \ --name my-app \
-v $(pwd)/proxy_server_config.yaml:/app/config.yaml \ -v $(pwd)/proxy_server_config.yaml:/app/config.yaml \
my-app:latest \ my-app:latest \

View file

@ -34,22 +34,35 @@ class TokenIterator:
self.byte_iterator = iter(stream) self.byte_iterator = iter(stream)
self.buffer = io.BytesIO() self.buffer = io.BytesIO()
self.read_pos = 0 self.read_pos = 0
self.end_of_data = False
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
while True: try:
self.buffer.seek(self.read_pos) while True:
line = self.buffer.readline() self.buffer.seek(self.read_pos)
if line and line[-1] == ord("\n"): line = self.buffer.readline()
self.read_pos += len(line) + 1 if line and line[-1] == ord("\n"):
full_line = line[:-1].decode("utf-8") response_obj = {"text": "", "is_finished": False}
line_data = json.loads(full_line.lstrip("data:").rstrip("/n")) self.read_pos += len(line) + 1
return line_data["token"]["text"] full_line = line[:-1].decode("utf-8")
chunk = next(self.byte_iterator) line_data = json.loads(full_line.lstrip("data:").rstrip("/n"))
self.buffer.seek(0, io.SEEK_END) if line_data.get("generated_text", None) is not None:
self.buffer.write(chunk["PayloadPart"]["Bytes"]) self.end_of_data = True
response_obj["is_finished"] = True
response_obj["text"] = line_data["token"]["text"]
return response_obj
chunk = next(self.byte_iterator)
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
except StopIteration as e:
if self.end_of_data == True:
raise e # Re-raise StopIteration
else:
self.end_of_data = True
return "data: [DONE]"
class SagemakerConfig: class SagemakerConfig:

View file

@ -1514,11 +1514,6 @@ def completion(
if ( if (
"stream" in optional_params and optional_params["stream"] == True "stream" in optional_params and optional_params["stream"] == True
): ## [BETA] ): ## [BETA]
# sagemaker does not support streaming as of now so we're faking streaming:
# https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611
# "SageMaker is currently not supporting streaming responses."
# fake streaming for sagemaker
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
from .llms.sagemaker import TokenIterator from .llms.sagemaker import TokenIterator
@ -1529,6 +1524,12 @@ def completion(
custom_llm_provider="sagemaker", custom_llm_provider="sagemaker",
logging_obj=logging, logging_obj=logging,
) )
## LOGGING
logging.post_call(
input=messages,
api_key=None,
original_response=response,
)
return response return response
## RESPONSE OBJECT ## RESPONSE OBJECT

View file

@ -690,6 +690,9 @@ async def update_database(
existing_spend_obj = await custom_db_client.get_data( existing_spend_obj = await custom_db_client.get_data(
key=id, table_name="user" key=id, table_name="user"
) )
verbose_proxy_logger.debug(
f"Updating existing_spend_obj: {existing_spend_obj}"
)
if existing_spend_obj is None: if existing_spend_obj is None:
existing_spend = 0 existing_spend = 0
existing_spend_obj = LiteLLM_UserTable( existing_spend_obj = LiteLLM_UserTable(

View file

@ -409,7 +409,9 @@ class PrismaClient:
hashed_token = token hashed_token = token
if token.startswith("sk-"): if token.startswith("sk-"):
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
print_verbose("PrismaClient: find_unique") verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}"
)
if query_type == "find_unique": if query_type == "find_unique":
response = await self.db.litellm_verificationtoken.find_unique( response = await self.db.litellm_verificationtoken.find_unique(
where={"token": hashed_token} where={"token": hashed_token}
@ -716,7 +718,6 @@ class PrismaClient:
Batch write update queries Batch write update queries
""" """
batcher = self.db.batch_() batcher = self.db.batch_()
verbose_proxy_logger.debug(f"data list for user table: {data_list}")
for idx, user in enumerate(data_list): for idx, user in enumerate(data_list):
try: try:
data_json = self.jsonify_object(data=user.model_dump()) data_json = self.jsonify_object(data=user.model_dump())

View file

@ -556,6 +556,47 @@ async def test_async_chat_bedrock_stream():
# asyncio.run(test_async_chat_bedrock_stream()) # asyncio.run(test_async_chat_bedrock_stream())
## Test Sagemaker + Async
@pytest.mark.asyncio
async def test_async_chat_sagemaker_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
)
# test streaming
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
stream=True,
)
print(f"response: {response}")
async for chunk in response:
print(f"chunk: {chunk}")
continue
## test failure callback
try:
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
aws_region_name="my-bad-key",
stream=True,
)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# Text Completion # Text Completion

View file

@ -872,41 +872,53 @@ async def test_sagemaker_streaming_async():
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response)
complete_response = "" complete_response = ""
has_finish_reason = False
# Add any assertions here to check the response
idx = 0
async for chunk in response: async for chunk in response:
complete_response += chunk.choices[0].delta.content or "" # print
print(f"complete_response: {complete_response}") chunk, finished = streaming_format_tests(idx, chunk)
assert len(complete_response) > 0 has_finish_reason = finished
complete_response += chunk
if finished:
break
idx += 1
if has_finish_reason is False:
raise Exception("finish reason not set for last chunk")
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
# def test_completion_sagemaker_stream(): def test_completion_sagemaker_stream():
# try: try:
# response = completion( response = completion(
# model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b", model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
# messages=messages, messages=messages,
# temperature=0.2, temperature=0.2,
# max_tokens=80, max_tokens=80,
# stream=True, stream=True,
# ) )
# complete_response = "" complete_response = ""
# has_finish_reason = False has_finish_reason = False
# # Add any assertions here to check the response # Add any assertions here to check the response
# for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
# chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
# has_finish_reason = finished has_finish_reason = finished
# if finished: if finished:
# break break
# complete_response += chunk complete_response += chunk
# if has_finish_reason is False: if has_finish_reason is False:
# raise Exception("finish reason not set for last chunk") raise Exception("finish reason not set for last chunk")
# if complete_response.strip() == "": if complete_response.strip() == "":
# raise Exception("Empty response received") raise Exception("Empty response received")
# except InvalidRequestError as e: except Exception as e:
# pass pytest.fail(f"Error occurred: {e}")
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_sagemaker_stream() # test_completion_sagemaker_stream()

View file

@ -1417,7 +1417,9 @@ class Logging:
""" """
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
""" """
print_verbose(f"Async success callbacks: {litellm._async_success_callback}") verbose_logger.debug(
f"Async success callbacks: {litellm._async_success_callback}"
)
start_time, end_time, result = self._success_handler_helper_fn( start_time, end_time, result = self._success_handler_helper_fn(
start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit start_time=start_time, end_time=end_time, result=result, cache_hit=cache_hit
) )
@ -1426,7 +1428,7 @@ class Logging:
if self.stream: if self.stream:
if result.choices[0].finish_reason is not None: # if it's the last chunk if result.choices[0].finish_reason is not None: # if it's the last chunk
self.streaming_chunks.append(result) self.streaming_chunks.append(result)
# print_verbose(f"final set of received chunks: {self.streaming_chunks}") # verbose_logger.debug(f"final set of received chunks: {self.streaming_chunks}")
try: try:
complete_streaming_response = litellm.stream_chunk_builder( complete_streaming_response = litellm.stream_chunk_builder(
self.streaming_chunks, self.streaming_chunks,
@ -1435,14 +1437,16 @@ class Logging:
end_time=end_time, end_time=end_time,
) )
except Exception as e: except Exception as e:
print_verbose( verbose_logger.debug(
f"Error occurred building stream chunk: {traceback.format_exc()}" f"Error occurred building stream chunk: {traceback.format_exc()}"
) )
complete_streaming_response = None complete_streaming_response = None
else: else:
self.streaming_chunks.append(result) self.streaming_chunks.append(result)
if complete_streaming_response is not None: if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response") verbose_logger.debug(
"Async success callbacks: Got a complete streaming response"
)
self.model_call_details[ self.model_call_details[
"complete_streaming_response" "complete_streaming_response"
] = complete_streaming_response ] = complete_streaming_response
@ -7682,6 +7686,27 @@ class CustomStreamWrapper:
} }
return "" return ""
def handle_sagemaker_stream(self, chunk):
if "data: [DONE]" in chunk:
text = ""
is_finished = True
finish_reason = "stop"
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif isinstance(chunk, dict):
if chunk["is_finished"] == True:
finish_reason = "stop"
else:
finish_reason = ""
return {
"text": chunk["text"],
"is_finished": chunk["is_finished"],
"finish_reason": finish_reason,
}
def chunk_creator(self, chunk): def chunk_creator(self, chunk):
model_response = ModelResponse(stream=True, model=self.model) model_response = ModelResponse(stream=True, model=self.model)
if self.response_id is not None: if self.response_id is not None:
@ -7807,8 +7832,14 @@ class CustomStreamWrapper:
] ]
self.sent_last_chunk = True self.sent_last_chunk = True
elif self.custom_llm_provider == "sagemaker": elif self.custom_llm_provider == "sagemaker":
print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") verbose_logger.debug(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}")
completion_obj["content"] = chunk response_obj = self.handle_sagemaker_stream(chunk)
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
model_response.choices[0].finish_reason = response_obj[
"finish_reason"
]
self.sent_last_chunk = True
elif self.custom_llm_provider == "petals": elif self.custom_llm_provider == "petals":
if len(self.completion_stream) == 0: if len(self.completion_stream) == 0:
if self.sent_last_chunk: if self.sent_last_chunk:
@ -7984,6 +8015,19 @@ class CustomStreamWrapper:
original_exception=e, original_exception=e,
) )
def run_success_logging_in_thread(self, processed_chunk):
# Create an event loop for the new thread
## ASYNC LOGGING
# Run the asynchronous function in the new thread's event loop
asyncio.run(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
## SYNC LOGGING
self.logging_obj.success_handler(processed_chunk)
## needs to handle the empty string case (even starting chunk can be an empty string) ## needs to handle the empty string case (even starting chunk can be an empty string)
def __next__(self): def __next__(self):
try: try:
@ -8002,8 +8046,9 @@ class CustomStreamWrapper:
continue continue
## LOGGING ## LOGGING
threading.Thread( threading.Thread(
target=self.logging_obj.success_handler, args=(response,) target=self.run_success_logging_in_thread, args=(response,)
).start() # log response ).start() # log response
# RETURN RESULT # RETURN RESULT
return response return response
except StopIteration: except StopIteration:
@ -8059,13 +8104,34 @@ class CustomStreamWrapper:
raise StopAsyncIteration raise StopAsyncIteration
else: # temporary patch for non-aiohttp async calls else: # temporary patch for non-aiohttp async calls
# example - boto3 bedrock llms # example - boto3 bedrock llms
processed_chunk = next(self) while True:
asyncio.create_task( if isinstance(self.completion_stream, str) or isinstance(
self.logging_obj.async_success_handler( self.completion_stream, bytes
processed_chunk, ):
) chunk = self.completion_stream
) else:
return processed_chunk chunk = next(self.completion_stream)
if chunk is not None and chunk != b"":
print_verbose(f"PROCESSED CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk = self.chunk_creator(chunk=chunk)
print_verbose(
f"PROCESSED CHUNK POST CHUNK CREATOR: {processed_chunk}"
)
if processed_chunk is None:
continue
## LOGGING
threading.Thread(
target=self.logging_obj.success_handler,
args=(processed_chunk,),
).start() # log processed_chunk
asyncio.create_task(
self.logging_obj.async_success_handler(
processed_chunk,
)
)
# RETURN RESULT
return processed_chunk
except StopAsyncIteration: except StopAsyncIteration:
raise raise
except StopIteration: except StopIteration:

View file

@ -11,6 +11,10 @@ model_list:
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
api_version: "2023-05-15" api_version: "2023-05-15"
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
- model_name: sagemaker-completion-model
litellm_params:
model: sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4
input_cost_per_second: 0.000420
- model_name: gpt-4 - model_name: gpt-4
litellm_params: litellm_params:
model: azure/gpt-turbo model: azure/gpt-turbo

View file

@ -13,17 +13,21 @@ sys.path.insert(
import litellm import litellm
async def generate_key(session, i, budget=None, budget_duration=None): async def generate_key(
session, i, budget=None, budget_duration=None, models=["azure-models", "gpt-4"]
):
url = "http://0.0.0.0:4000/key/generate" url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = { data = {
"models": ["azure-models", "gpt-4"], "models": models,
"aliases": {"mistral-7b": "gpt-3.5-turbo"}, "aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None, "duration": None,
"max_budget": budget, "max_budget": budget,
"budget_duration": budget_duration, "budget_duration": budget_duration,
} }
print(f"data: {data}")
async with session.post(url, headers=headers, json=data) as response: async with session.post(url, headers=headers, json=data) as response:
status = response.status status = response.status
response_text = await response.text() response_text = await response.text()
@ -293,7 +297,7 @@ async def test_key_info_spend_values():
rounded_response_cost = round(response_cost, 8) rounded_response_cost = round(response_cost, 8)
rounded_key_info_spend = round(key_info["info"]["spend"], 8) rounded_key_info_spend = round(key_info["info"]["spend"], 8)
assert rounded_response_cost == rounded_key_info_spend assert rounded_response_cost == rounded_key_info_spend
## streaming ## streaming - azure
key_gen = await generate_key(session=session, i=0) key_gen = await generate_key(session=session, i=0)
new_key = key_gen["key"] new_key = key_gen["key"]
prompt_tokens, completion_tokens = await chat_completion_streaming( prompt_tokens, completion_tokens = await chat_completion_streaming(
@ -318,6 +322,41 @@ async def test_key_info_spend_values():
assert rounded_response_cost == rounded_key_info_spend assert rounded_response_cost == rounded_key_info_spend
@pytest.mark.asyncio
async def test_key_info_spend_values_sagemaker():
"""
Tests the sync streaming loop to ensure spend is correctly calculated.
- create key
- make completion call
- assert cost is expected value
"""
async with aiohttp.ClientSession() as session:
## streaming - sagemaker
key_gen = await generate_key(session=session, i=0, models=[])
new_key = key_gen["key"]
prompt_tokens, completion_tokens = await chat_completion_streaming(
session=session, key=new_key, model="sagemaker-completion-model"
)
# print(f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}")
# prompt_cost, completion_cost = litellm.cost_per_token(
# model="azure/gpt-35-turbo",
# prompt_tokens=prompt_tokens,
# completion_tokens=completion_tokens,
# )
# response_cost = prompt_cost + completion_cost
await asyncio.sleep(5) # allow db log to be updated
key_info = await get_key_info(
session=session, get_key=new_key, call_key=new_key
)
# print(
# f"response_cost: {response_cost}; key_info spend: {key_info['info']['spend']}"
# )
# rounded_response_cost = round(response_cost, 8)
rounded_key_info_spend = round(key_info["info"]["spend"], 8)
assert rounded_key_info_spend > 0
# assert rounded_response_cost == rounded_key_info_spend
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_key_with_budgets(): async def test_key_with_budgets():
""" """