Merge branch 'main' into litellm_embedding_caching_updates

This commit is contained in:
Krish Dholakia 2024-02-03 18:08:47 -08:00 committed by GitHub
commit 9ab59045a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
236 changed files with 24483 additions and 2014 deletions

View file

@ -74,6 +74,7 @@ class CompletionCustomHandler(
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
print(f"kwargs: {kwargs}")
self.states.append("post_api_call")
## START TIME
assert isinstance(start_time, datetime)
@ -149,7 +150,14 @@ class CompletionCustomHandler(
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
assert isinstance(
response_obj,
(
litellm.ModelResponse,
litellm.EmbeddingResponse,
litellm.ImageResponse,
),
)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
@ -170,12 +178,14 @@ class CompletionCustomHandler(
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert isinstance(kwargs["response_cost"], (float, type(None)))
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
print(f"kwargs: {kwargs}")
self.states.append("sync_failure")
## START TIME
assert isinstance(start_time, datetime)
@ -262,6 +272,7 @@ class CompletionCustomHandler(
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
assert isinstance(kwargs["response_cost"], (float, type(None)))
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
@ -545,6 +556,46 @@ async def test_async_chat_bedrock_stream():
# asyncio.run(test_async_chat_bedrock_stream())
## Test Sagemaker + Async
@pytest.mark.asyncio
async def test_async_chat_sagemaker_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
)
# test streaming
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
stream=True,
)
print(f"response: {response}")
async for chunk in response:
print(f"chunk: {chunk}")
continue
## test failure callback
try:
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
aws_region_name="my-bad-key",
stream=True,
)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# Text Completion
@ -766,6 +817,54 @@ async def test_async_embedding_azure_caching():
assert len(customHandler_caching.states) == 4 # pre, post, success, success
# asyncio.run(
# test_async_embedding_azure_caching()
# )
# Image Generation
## Test OpenAI + Sync
def test_image_generation_openai():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
# litellm.callbacks = [customHandler_success]
# litellm.set_verbose = True
# response = litellm.image_generation(
# prompt="A cute baby sea otter", model="dall-e-3"
# )
# print(f"response: {response}")
# assert len(response.data) > 0
# print(f"customHandler_success.errors: {customHandler_success.errors}")
# print(f"customHandler_success.states: {customHandler_success.states}")
# assert len(customHandler_success.errors) == 0
# assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = litellm.image_generation(
prompt="A cute baby sea otter",
model="dall-e-2",
api_key="my-bad-api-key",
)
except:
pass
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # OpenAI randomly raises these errors - skip when they occur
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
test_image_generation_openai()
## Test OpenAI + Async
## Test Azure + Sync
## Test Azure + Async