forked from phoenix/litellm-mirror
fix(utils.py): bug fixes
This commit is contained in:
parent
252c8415c6
commit
bfa26dd5b3
2 changed files with 88 additions and 9 deletions
|
@ -545,8 +545,9 @@ async def test_async_chat_bedrock_stream():
|
|||
|
||||
# asyncio.run(test_async_chat_bedrock_stream())
|
||||
|
||||
# Text Completion
|
||||
|
||||
# Text Completion
|
||||
|
||||
|
||||
## Test OpenAI text completion + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_text_completion_openai_stream():
|
||||
|
@ -585,6 +586,7 @@ async def test_async_text_completion_openai_stream():
|
|||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
|
||||
# EMBEDDING
|
||||
## Test OpenAI + Async
|
||||
@pytest.mark.asyncio
|
||||
|
@ -758,6 +760,7 @@ async def test_async_embedding_azure_caching():
|
|||
)
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
print(customHandler_caching.states)
|
||||
print(customHandler_caching.errors)
|
||||
assert len(customHandler_caching.errors) == 0
|
||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||
|
||||
|
|
|
@ -2026,12 +2026,17 @@ def client(original_function):
|
|||
)
|
||||
# if caching is false or cache["no-cache"]==True, don't run this
|
||||
if (
|
||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||
or kwargs.get("caching", False) == True
|
||||
or (
|
||||
kwargs.get("cache", None) is not None
|
||||
and kwargs.get("cache", {}).get("no-cache", False) != True
|
||||
(
|
||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||
or kwargs.get("caching", False) == True
|
||||
or (
|
||||
kwargs.get("cache", None) is not None
|
||||
and kwargs.get("cache", {}).get("no-cache", False) != True
|
||||
)
|
||||
)
|
||||
and kwargs.get("aembedding", False) != True
|
||||
and kwargs.get("acompletion", False) != True
|
||||
and kwargs.get("aimg_generation", False) != True
|
||||
): # allow users to control returning cached responses from the completion function
|
||||
# checking cache
|
||||
print_verbose(f"INSIDE CHECKING CACHE")
|
||||
|
@ -2329,13 +2334,78 @@ def client(original_function):
|
|||
|
||||
if len(non_null_list) > 0:
|
||||
final_embedding_cached_response = EmbeddingResponse(
|
||||
model=kwargs.get("model"), data=[]
|
||||
model=kwargs.get("model"),
|
||||
data=[None] * len(original_kwargs_input),
|
||||
)
|
||||
|
||||
for val in non_null_list:
|
||||
idx, cr = val # (idx, cr) tuple
|
||||
if cr is not None:
|
||||
final_embedding_cached_response.data[idx] = val
|
||||
|
||||
if len(remaining_list) == 0:
|
||||
# LOG SUCCESS
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
(
|
||||
model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=kwargs.get(
|
||||
"custom_llm_provider", None
|
||||
),
|
||||
api_base=kwargs.get("api_base", None),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
)
|
||||
print_verbose(
|
||||
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
|
||||
)
|
||||
logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user=kwargs.get("user", None),
|
||||
optional_params={},
|
||||
litellm_params={
|
||||
"logger_fn": kwargs.get("logger_fn", None),
|
||||
"acompletion": True,
|
||||
"metadata": kwargs.get("metadata", {}),
|
||||
"model_info": kwargs.get("model_info", {}),
|
||||
"proxy_server_request": kwargs.get(
|
||||
"proxy_server_request", None
|
||||
),
|
||||
"preset_cache_key": kwargs.get(
|
||||
"preset_cache_key", None
|
||||
),
|
||||
"stream_response": kwargs.get(
|
||||
"stream_response", {}
|
||||
),
|
||||
},
|
||||
input=kwargs.get("messages", ""),
|
||||
api_key=kwargs.get("api_key", None),
|
||||
original_response=str(final_embedding_cached_response),
|
||||
additional_args=None,
|
||||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(
|
||||
final_embedding_cached_response,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(
|
||||
final_embedding_cached_response,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
),
|
||||
).start()
|
||||
return final_embedding_cached_response
|
||||
# MODEL CALL
|
||||
result = await original_function(*args, **kwargs)
|
||||
end_time = datetime.datetime.now()
|
||||
|
@ -2371,6 +2441,7 @@ def client(original_function):
|
|||
embedding_kwargs = copy.deepcopy(kwargs)
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
embedding_response = result.data[idx]
|
||||
embedding_kwargs["input"] = i
|
||||
asyncio.create_task(
|
||||
litellm.cache._async_add_cache(
|
||||
embedding_response, *args, **embedding_kwargs
|
||||
|
@ -5971,7 +6042,12 @@ def exception_type(
|
|||
message=f"BedrockException - {original_exception.message}",
|
||||
llm_provider="bedrock",
|
||||
model=model,
|
||||
response=original_exception.response,
|
||||
response=httpx.Response(
|
||||
status_code=500,
|
||||
request=httpx.Request(
|
||||
method="POST", url="https://api.openai.com/v1/"
|
||||
),
|
||||
),
|
||||
)
|
||||
elif original_exception.status_code == 401:
|
||||
exception_mapping_worked = True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue