forked from phoenix/litellm-mirror
fix(utils.py): bug fixes
This commit is contained in:
parent
252c8415c6
commit
bfa26dd5b3
2 changed files with 88 additions and 9 deletions
|
@ -547,6 +547,7 @@ async def test_async_chat_bedrock_stream():
|
||||||
|
|
||||||
# Text Completion
|
# Text Completion
|
||||||
|
|
||||||
|
|
||||||
## Test OpenAI text completion + Async
|
## Test OpenAI text completion + Async
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_text_completion_openai_stream():
|
async def test_async_text_completion_openai_stream():
|
||||||
|
@ -585,6 +586,7 @@ async def test_async_text_completion_openai_stream():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred: {str(e)}")
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# EMBEDDING
|
# EMBEDDING
|
||||||
## Test OpenAI + Async
|
## Test OpenAI + Async
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -758,6 +760,7 @@ async def test_async_embedding_azure_caching():
|
||||||
)
|
)
|
||||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||||
print(customHandler_caching.states)
|
print(customHandler_caching.states)
|
||||||
|
print(customHandler_caching.errors)
|
||||||
assert len(customHandler_caching.errors) == 0
|
assert len(customHandler_caching.errors) == 0
|
||||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||||
|
|
||||||
|
|
|
@ -2026,12 +2026,17 @@ def client(original_function):
|
||||||
)
|
)
|
||||||
# if caching is false or cache["no-cache"]==True, don't run this
|
# if caching is false or cache["no-cache"]==True, don't run this
|
||||||
if (
|
if (
|
||||||
|
(
|
||||||
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
(kwargs.get("caching", None) is None and litellm.cache is not None)
|
||||||
or kwargs.get("caching", False) == True
|
or kwargs.get("caching", False) == True
|
||||||
or (
|
or (
|
||||||
kwargs.get("cache", None) is not None
|
kwargs.get("cache", None) is not None
|
||||||
and kwargs.get("cache", {}).get("no-cache", False) != True
|
and kwargs.get("cache", {}).get("no-cache", False) != True
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
and kwargs.get("aembedding", False) != True
|
||||||
|
and kwargs.get("acompletion", False) != True
|
||||||
|
and kwargs.get("aimg_generation", False) != True
|
||||||
): # allow users to control returning cached responses from the completion function
|
): # allow users to control returning cached responses from the completion function
|
||||||
# checking cache
|
# checking cache
|
||||||
print_verbose(f"INSIDE CHECKING CACHE")
|
print_verbose(f"INSIDE CHECKING CACHE")
|
||||||
|
@ -2329,13 +2334,78 @@ def client(original_function):
|
||||||
|
|
||||||
if len(non_null_list) > 0:
|
if len(non_null_list) > 0:
|
||||||
final_embedding_cached_response = EmbeddingResponse(
|
final_embedding_cached_response = EmbeddingResponse(
|
||||||
model=kwargs.get("model"), data=[]
|
model=kwargs.get("model"),
|
||||||
|
data=[None] * len(original_kwargs_input),
|
||||||
)
|
)
|
||||||
|
|
||||||
for val in non_null_list:
|
for val in non_null_list:
|
||||||
idx, cr = val # (idx, cr) tuple
|
idx, cr = val # (idx, cr) tuple
|
||||||
if cr is not None:
|
if cr is not None:
|
||||||
final_embedding_cached_response.data[idx] = val
|
final_embedding_cached_response.data[idx] = val
|
||||||
|
|
||||||
|
if len(remaining_list) == 0:
|
||||||
|
# LOG SUCCESS
|
||||||
|
cache_hit = True
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
(
|
||||||
|
model,
|
||||||
|
custom_llm_provider,
|
||||||
|
dynamic_api_key,
|
||||||
|
api_base,
|
||||||
|
) = litellm.get_llm_provider(
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=kwargs.get(
|
||||||
|
"custom_llm_provider", None
|
||||||
|
),
|
||||||
|
api_base=kwargs.get("api_base", None),
|
||||||
|
api_key=kwargs.get("api_key", None),
|
||||||
|
)
|
||||||
|
print_verbose(
|
||||||
|
f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}"
|
||||||
|
)
|
||||||
|
logging_obj.update_environment_variables(
|
||||||
|
model=model,
|
||||||
|
user=kwargs.get("user", None),
|
||||||
|
optional_params={},
|
||||||
|
litellm_params={
|
||||||
|
"logger_fn": kwargs.get("logger_fn", None),
|
||||||
|
"acompletion": True,
|
||||||
|
"metadata": kwargs.get("metadata", {}),
|
||||||
|
"model_info": kwargs.get("model_info", {}),
|
||||||
|
"proxy_server_request": kwargs.get(
|
||||||
|
"proxy_server_request", None
|
||||||
|
),
|
||||||
|
"preset_cache_key": kwargs.get(
|
||||||
|
"preset_cache_key", None
|
||||||
|
),
|
||||||
|
"stream_response": kwargs.get(
|
||||||
|
"stream_response", {}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
input=kwargs.get("messages", ""),
|
||||||
|
api_key=kwargs.get("api_key", None),
|
||||||
|
original_response=str(final_embedding_cached_response),
|
||||||
|
additional_args=None,
|
||||||
|
stream=kwargs.get("stream", False),
|
||||||
|
)
|
||||||
|
asyncio.create_task(
|
||||||
|
logging_obj.async_success_handler(
|
||||||
|
final_embedding_cached_response,
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
cache_hit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
threading.Thread(
|
||||||
|
target=logging_obj.success_handler,
|
||||||
|
args=(
|
||||||
|
final_embedding_cached_response,
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
cache_hit,
|
||||||
|
),
|
||||||
|
).start()
|
||||||
|
return final_embedding_cached_response
|
||||||
# MODEL CALL
|
# MODEL CALL
|
||||||
result = await original_function(*args, **kwargs)
|
result = await original_function(*args, **kwargs)
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
|
@ -2371,6 +2441,7 @@ def client(original_function):
|
||||||
embedding_kwargs = copy.deepcopy(kwargs)
|
embedding_kwargs = copy.deepcopy(kwargs)
|
||||||
for idx, i in enumerate(kwargs["input"]):
|
for idx, i in enumerate(kwargs["input"]):
|
||||||
embedding_response = result.data[idx]
|
embedding_response = result.data[idx]
|
||||||
|
embedding_kwargs["input"] = i
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
litellm.cache._async_add_cache(
|
litellm.cache._async_add_cache(
|
||||||
embedding_response, *args, **embedding_kwargs
|
embedding_response, *args, **embedding_kwargs
|
||||||
|
@ -5971,7 +6042,12 @@ def exception_type(
|
||||||
message=f"BedrockException - {original_exception.message}",
|
message=f"BedrockException - {original_exception.message}",
|
||||||
llm_provider="bedrock",
|
llm_provider="bedrock",
|
||||||
model=model,
|
model=model,
|
||||||
response=original_exception.response,
|
response=httpx.Response(
|
||||||
|
status_code=500,
|
||||||
|
request=httpx.Request(
|
||||||
|
method="POST", url="https://api.openai.com/v1/"
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif original_exception.status_code == 401:
|
elif original_exception.status_code == 401:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue