fix(main.py): pass user_id + encoding_format for logging + to openai/azure

This commit is contained in:
Krrish Dholakia 2023-12-12 15:44:04 -08:00
parent 35fa176c97
commit 8b07a6c046
10 changed files with 82 additions and 26 deletions

View file

@ -26,9 +26,12 @@ class MyCustomHandler(CustomLogger):
self.stream_collected_response = None # type: ignore
self.sync_stream_collected_response = None # type: ignore
self.user = None # type: ignore
self.data_sent_to_api: dict = {}
def log_pre_api_call(self, model, messages, kwargs):
print(f"Pre-API Call")
self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {})
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
print(f"Post-API Call")
@ -49,6 +52,7 @@ class MyCustomHandler(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async success")
print(f"received kwargs user: {kwargs['user']}")
self.async_success = True
if kwargs.get("model") == "text-embedding-ada-002":
self.async_success_embedding = True
@ -57,6 +61,7 @@ class MyCustomHandler(CustomLogger):
if kwargs.get("stream") == True:
self.stream_collected_response = response_obj
self.async_completion_kwargs = kwargs
self.user = kwargs.get("user", None)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
print(f"On Async Failure")
@ -95,14 +100,6 @@ def test_async_chat_openai_stream():
print(complete_streaming_response)
asyncio.run(call_gpt())
complete_streaming_response = complete_streaming_response.strip("'")
print(f"complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content']}")
print(f"type of complete_streaming_response_in_callback: {type(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
print(f"hidden char complete_streaming_response_in_callback: {repr(tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'])}")
print(f"encoding complete_streaming_response_in_callback: {tmp_function.complete_streaming_response_in_callback['choices'][0]['message']['content'].encode('utf-8')}")
print(f"complete_streaming_response: {complete_streaming_response}")
print(f"type(complete_streaming_response): {type(complete_streaming_response)}")
print(f"hidden char complete_streaming_response): {repr(complete_streaming_response)}")
print(f"encoding complete_streaming_response): {repr(complete_streaming_response).encode('utf-8')}")
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"]
response2 = complete_streaming_response
assert [ord(c) for c in response1] == [ord(c) for c in response2]
@ -110,7 +107,7 @@ def test_async_chat_openai_stream():
except Exception as e:
print(e)
pytest.fail(f"An error occurred - {str(e)}")
test_async_chat_openai_stream()
# test_async_chat_openai_stream()
def test_completion_azure_stream_moderation_failure():
try:
@ -290,9 +287,29 @@ async def test_async_custom_handler_embedding():
assert len(str(customHandler_embedding.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
asyncio.run(test_async_custom_handler_embedding())
from litellm import Cache
# asyncio.run(test_async_custom_handler_embedding())
@pytest.mark.asyncio
async def test_async_custom_handler_embedding_optional_param():
"""
Tests if the openai optional params for embedding - user + encoding_format,
are logged
"""
customHandler_optional_params = MyCustomHandler()
litellm.callbacks = [customHandler_optional_params]
response = await litellm.aembedding(
model="text-embedding-ada-002",
input = ["hello world"],
user = "John"
)
await asyncio.sleep(1) # success callback is async
assert customHandler_optional_params.user == "John"
assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"]
# asyncio.run(test_async_custom_handler_embedding_optional_param())
def test_redis_cache_completion_stream():
from litellm import Cache
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set
import random
try:
@ -325,4 +342,4 @@ def test_redis_cache_completion_stream():
print(e)
litellm.success_callback = []
raise e
test_redis_cache_completion_stream()
# test_redis_cache_completion_stream()