fix(utils.py): bug fixes

This commit is contained in:
Krrish Dholakia 2024-01-11 18:14:22 +05:30
parent 252c8415c6
commit bfa26dd5b3
2 changed files with 88 additions and 9 deletions

View file

@ -545,8 +545,9 @@ async def test_async_chat_bedrock_stream():
# asyncio.run(test_async_chat_bedrock_stream())
# Text Completion
# Text Completion
## Test OpenAI text completion + Async
@pytest.mark.asyncio
async def test_async_text_completion_openai_stream():
@ -585,6 +586,7 @@ async def test_async_text_completion_openai_stream():
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# EMBEDDING
## Test OpenAI + Async
@pytest.mark.asyncio
@ -758,6 +760,7 @@ async def test_async_embedding_azure_caching():
)
await asyncio.sleep(1) # success callbacks are done in parallel
print(customHandler_caching.states)
print(customHandler_caching.errors)
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success

View file

@ -2026,12 +2026,17 @@ def client(original_function):
)
# if caching is false or cache["no-cache"]==True, don't run this
if (
(kwargs.get("caching", None) is None and litellm.cache is not None)
or kwargs.get("caching", False) == True
or (
kwargs.get("cache", None) is not None
and kwargs.get("cache", {}).get("no-cache", False) != True
(
(kwargs.get("caching", None) is None and litellm.cache is not None)
or kwargs.get("caching", False) == True
or (
kwargs.get("cache", None) is not None
and kwargs.get("cache", {}).get("no-cache", False) != True
)
)
and kwargs.get("aembedding", False) != True
and kwargs.get("acompletion", False) != True
and kwargs.get("aimg_generation", False) != True
): # allow users to control returning cached responses from the completion function
# checking cache
print_verbose(f"INSIDE CHECKING CACHE")
@ -2329,13 +2334,78 @@ def client(original_function):
if len(non_null_list) > 0:
final_embedding_cached_response = EmbeddingResponse(
model=kwargs.get("model"), data=[]
model=kwargs.get("model"),
data=[None] * len(original_kwargs_input),
)
for val in non_null_list:
idx, cr = val # (idx, cr) tuple
if cr is not None:
final_embedding_cached_response.data[idx] = val
if len(remaining_list) == 0:
# LOG SUCCESS
cache_hit = True
end_time = datetime.datetime.now()
(
model,
custom_llm_provider,
dynamic_api_key,
api_base,
) = litellm.get_llm_provider(
model=model,
custom_llm_provider=kwargs.get(
"custom_llm_provider", None
),
api_base=kwargs.get("api_base", None),
api_key=kwargs.get("api_key", None),
)
print_verbose(
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
)
logging_obj.update_environment_variables(
model=model,
user=kwargs.get("user", None),
optional_params={},
litellm_params={
"logger_fn": kwargs.get("logger_fn", None),
"acompletion": True,
"metadata": kwargs.get("metadata", {}),
"model_info": kwargs.get("model_info", {}),
"proxy_server_request": kwargs.get(
"proxy_server_request", None
),
"preset_cache_key": kwargs.get(
"preset_cache_key", None
),
"stream_response": kwargs.get(
"stream_response", {}
),
},
input=kwargs.get("messages", ""),
api_key=kwargs.get("api_key", None),
original_response=str(final_embedding_cached_response),
additional_args=None,
stream=kwargs.get("stream", False),
)
asyncio.create_task(
logging_obj.async_success_handler(
final_embedding_cached_response,
start_time,
end_time,
cache_hit,
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(
final_embedding_cached_response,
start_time,
end_time,
cache_hit,
),
).start()
return final_embedding_cached_response
# MODEL CALL
result = await original_function(*args, **kwargs)
end_time = datetime.datetime.now()
@ -2371,6 +2441,7 @@ def client(original_function):
embedding_kwargs = copy.deepcopy(kwargs)
for idx, i in enumerate(kwargs["input"]):
embedding_response = result.data[idx]
embedding_kwargs["input"] = i
asyncio.create_task(
litellm.cache._async_add_cache(
embedding_response, *args, **embedding_kwargs
@ -5971,7 +6042,12 @@ def exception_type(
message=f"BedrockException - {original_exception.message}",
llm_provider="bedrock",
model=model,
response=original_exception.response,
response=httpx.Response(
status_code=500,
request=httpx.Request(
method="POST", url="https://api.openai.com/v1/"
),
),
)
elif original_exception.status_code == 401:
exception_mapping_worked = True