fix(utils.py): bug fixes

This commit is contained in:
Krrish Dholakia 2024-01-11 18:14:22 +05:30
parent 252c8415c6
commit bfa26dd5b3
2 changed files with 88 additions and 9 deletions

View file

@ -547,6 +547,7 @@ async def test_async_chat_bedrock_stream():
# Text Completion # Text Completion
## Test OpenAI text completion + Async ## Test OpenAI text completion + Async
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_text_completion_openai_stream(): async def test_async_text_completion_openai_stream():
@ -585,6 +586,7 @@ async def test_async_text_completion_openai_stream():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
# EMBEDDING # EMBEDDING
## Test OpenAI + Async ## Test OpenAI + Async
@pytest.mark.asyncio @pytest.mark.asyncio
@ -758,6 +760,7 @@ async def test_async_embedding_azure_caching():
) )
await asyncio.sleep(1) # success callbacks are done in parallel await asyncio.sleep(1) # success callbacks are done in parallel
print(customHandler_caching.states) print(customHandler_caching.states)
print(customHandler_caching.errors)
assert len(customHandler_caching.errors) == 0 assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success assert len(customHandler_caching.states) == 4 # pre, post, success, success

View file

@ -2026,12 +2026,17 @@ def client(original_function):
) )
# if caching is false or cache["no-cache"]==True, don't run this # if caching is false or cache["no-cache"]==True, don't run this
if ( if (
(kwargs.get("caching", None) is None and litellm.cache is not None) (
or kwargs.get("caching", False) == True (kwargs.get("caching", None) is None and litellm.cache is not None)
or ( or kwargs.get("caching", False) == True
kwargs.get("cache", None) is not None or (
and kwargs.get("cache", {}).get("no-cache", False) != True kwargs.get("cache", None) is not None
and kwargs.get("cache", {}).get("no-cache", False) != True
)
) )
and kwargs.get("aembedding", False) != True
and kwargs.get("acompletion", False) != True
and kwargs.get("aimg_generation", False) != True
): # allow users to control returning cached responses from the completion function ): # allow users to control returning cached responses from the completion function
# checking cache # checking cache
print_verbose(f"INSIDE CHECKING CACHE") print_verbose(f"INSIDE CHECKING CACHE")
@ -2329,13 +2334,78 @@ def client(original_function):
if len(non_null_list) > 0: if len(non_null_list) > 0:
final_embedding_cached_response = EmbeddingResponse( final_embedding_cached_response = EmbeddingResponse(
model=kwargs.get("model"), data=[] model=kwargs.get("model"),
data=[None] * len(original_kwargs_input),
) )
for val in non_null_list: for val in non_null_list:
idx, cr = val # (idx, cr) tuple idx, cr = val # (idx, cr) tuple
if cr is not None: if cr is not None:
final_embedding_cached_response.data[idx] = val final_embedding_cached_response.data[idx] = val
if len(remaining_list) == 0:
# LOG SUCCESS
cache_hit = True
end_time = datetime.datetime.now()
(
model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model,
custom_llm_provider=kwargs.get(
"custom_llm_provider", None
),
api_base=kwargs.get("api_base", None),
api_key=kwargs.get("api_key", None),
)
print_verbose(
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
)
logging_obj.update_environment_variables(
model=model,
user=kwargs.get("user", None),
optional_params={},
litellm_params={
"logger_fn": kwargs.get("logger_fn", None),
"acompletion": True,
"metadata": kwargs.get("metadata", {}),
"model_info": kwargs.get("model_info", {}),
"proxy_server_request": kwargs.get(
"proxy_server_request", None
),
"preset_cache_key": kwargs.get(
"preset_cache_key", None
),
"stream_response": kwargs.get(
"stream_response", {}
),
},
input=kwargs.get("messages", ""),
api_key=kwargs.get("api_key", None),
original_response=str(final_embedding_cached_response),
additional_args=None,
stream=kwargs.get("stream", False),
)
asyncio.create_task(
logging_obj.async_success_handler(
final_embedding_cached_response,
start_time,
end_time,
cache_hit,
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(
final_embedding_cached_response,
start_time,
end_time,
cache_hit,
),
).start()
return final_embedding_cached_response
# MODEL CALL # MODEL CALL
result = await original_function(*args, **kwargs) result = await original_function(*args, **kwargs)
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
@ -2371,6 +2441,7 @@ def client(original_function):
embedding_kwargs = copy.deepcopy(kwargs) embedding_kwargs = copy.deepcopy(kwargs)
for idx, i in enumerate(kwargs["input"]): for idx, i in enumerate(kwargs["input"]):
embedding_response = result.data[idx] embedding_response = result.data[idx]
embedding_kwargs["input"] = i
asyncio.create_task( asyncio.create_task(
litellm.cache._async_add_cache( litellm.cache._async_add_cache(
embedding_response, *args, **embedding_kwargs embedding_response, *args, **embedding_kwargs
@ -5971,7 +6042,12 @@ def exception_type(
message=f"BedrockException - {original_exception.message}", message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock", llm_provider="bedrock",
model=model, model=model,
response=original_exception.response, response=httpx.Response(
status_code=500,
request=httpx.Request(
method="POST", url="https://api.openai.com/v1/"
),
),
) )
elif original_exception.status_code == 401: elif original_exception.status_code == 401:
exception_mapping_worked = True exception_mapping_worked = True