forked from phoenix/litellm-mirror
Merge pull request #2124 from BerriAI/litellm_streaming_caching_logging
fix(utils.py): support streaming cached response logging
This commit is contained in:
commit
bfdacf6d6b
5 changed files with 140 additions and 38 deletions
|
@ -124,7 +124,9 @@ class RedisCache(BaseCache):
|
|||
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Redis is throwing an exception
|
||||
print_verbose("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
||||
print_verbose(
|
||||
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
|
||||
)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
_redis_client = self.init_async_client()
|
||||
|
|
|
@ -1986,6 +1986,8 @@ def test_completion_gemini():
|
|||
response = completion(model=model_name, messages=messages)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -2015,6 +2017,8 @@ def test_completion_palm():
|
|||
response = completion(model=model_name, messages=messages)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -2037,6 +2041,8 @@ def test_completion_palm_stream():
|
|||
# Add any assertions here to check the response
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
## This test asserts the type of data passed into each method of the custom callback handler
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
import pytest, uuid
|
||||
from pydantic import BaseModel
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
@ -795,6 +795,53 @@ async def test_async_completion_azure_caching():
|
|||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_completion_azure_caching_streaming():
|
||||
import copy
|
||||
|
||||
litellm.set_verbose = True
|
||||
customHandler_caching = CompletionCustomHandler()
|
||||
litellm.cache = Cache(
|
||||
type="redis",
|
||||
host=os.environ["REDIS_HOST"],
|
||||
port=os.environ["REDIS_PORT"],
|
||||
password=os.environ["REDIS_PASSWORD"],
|
||||
)
|
||||
litellm.callbacks = [customHandler_caching]
|
||||
unique_time = uuid.uuid4()
|
||||
response1 = await litellm.acompletion(
|
||||
model="azure/chatgpt-v-2",
|
||||
messages=[
|
||||
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
|
||||
],
|
||||
caching=True,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response1:
|
||||
print(f"chunk in response1: {chunk}")
|
||||
await asyncio.sleep(1)
|
||||
initial_customhandler_caching_states = len(customHandler_caching.states)
|
||||
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||
response2 = await litellm.acompletion(
|
||||
model="azure/chatgpt-v-2",
|
||||
messages=[
|
||||
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
|
||||
],
|
||||
caching=True,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response2:
|
||||
print(f"chunk in response2: {chunk}")
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
print(
|
||||
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
|
||||
)
|
||||
assert len(customHandler_caching.errors) == 0
|
||||
assert (
|
||||
len(customHandler_caching.states) > initial_customhandler_caching_states
|
||||
) # pre, post, streaming .., success, success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_azure_caching():
|
||||
print("Testing custom callback input - Azure Caching")
|
||||
|
|
|
@ -392,6 +392,8 @@ def test_completion_palm_stream():
|
|||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}")
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -425,6 +427,8 @@ def test_completion_gemini_stream():
|
|||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}")
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
@ -461,6 +465,8 @@ async def test_acompletion_gemini_stream():
|
|||
print(f"completion_response: {complete_response}")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
except litellm.APIError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
|
113
litellm/utils.py
113
litellm/utils.py
|
@ -1411,7 +1411,7 @@ class Logging:
|
|||
print_verbose(
|
||||
f"success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n"
|
||||
)
|
||||
return
|
||||
pass
|
||||
else:
|
||||
print_verbose(
|
||||
"success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache"
|
||||
|
@ -1616,7 +1616,7 @@ class Logging:
|
|||
print_verbose(
|
||||
f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n"
|
||||
)
|
||||
return
|
||||
pass
|
||||
else:
|
||||
print_verbose(
|
||||
"async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache"
|
||||
|
@ -1625,8 +1625,10 @@ class Logging:
|
|||
# only add to cache once we have a complete streaming response
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
print_verbose(f"Async success callbacks: {callback}")
|
||||
if self.stream:
|
||||
print_verbose(
|
||||
f"Async success callbacks: {callback}; self.stream: {self.stream}; complete_streaming_response: {self.model_call_details.get('complete_streaming_response', None)}"
|
||||
)
|
||||
if self.stream == True:
|
||||
if "complete_streaming_response" in self.model_call_details:
|
||||
await callback.async_log_success_event(
|
||||
kwargs=self.model_call_details,
|
||||
|
@ -2328,6 +2330,13 @@ def client(original_function):
|
|||
model_response_object=ModelResponse(),
|
||||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
if kwargs.get("stream", False) == True:
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
elif call_type == CallTypes.embedding.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
|
@ -2624,28 +2633,6 @@ def client(original_function):
|
|||
cached_result, list
|
||||
):
|
||||
print_verbose(f"Cache Hit!")
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
if kwargs.get("stream", False) == True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
else:
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=ModelResponse(),
|
||||
)
|
||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
# LOG SUCCESS
|
||||
cache_hit = True
|
||||
end_time = datetime.datetime.now()
|
||||
(
|
||||
|
@ -2685,15 +2672,44 @@ def client(original_function):
|
|||
additional_args=None,
|
||||
stream=kwargs.get("stream", False),
|
||||
)
|
||||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(
|
||||
cached_result, start_time, end_time, cache_hit
|
||||
call_type = original_function.__name__
|
||||
if call_type == CallTypes.acompletion.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
if kwargs.get("stream", False) == True:
|
||||
cached_result = convert_to_streaming_response_async(
|
||||
response_object=cached_result,
|
||||
)
|
||||
cached_result = CustomStreamWrapper(
|
||||
completion_stream=cached_result,
|
||||
model=model,
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=ModelResponse(),
|
||||
)
|
||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||
cached_result, dict
|
||||
):
|
||||
cached_result = convert_to_model_response_object(
|
||||
response_object=cached_result,
|
||||
model_response_object=EmbeddingResponse(),
|
||||
response_type="embedding",
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
if kwargs.get("stream", False) == False:
|
||||
# LOG SUCCESS
|
||||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(
|
||||
cached_result, start_time, end_time, cache_hit
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
return cached_result
|
||||
elif (
|
||||
call_type == CallTypes.aembedding.value
|
||||
|
@ -4296,7 +4312,9 @@ def get_optional_params(
|
|||
parameters=tool["function"].get("parameters", {}),
|
||||
)
|
||||
gtool_func_declarations.append(gtool_func_declaration)
|
||||
optional_params["tools"] = [generative_models.Tool(function_declarations=gtool_func_declarations)]
|
||||
optional_params["tools"] = [
|
||||
generative_models.Tool(function_declarations=gtool_func_declarations)
|
||||
]
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
## check if unsupported param passed in
|
||||
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||
|
@ -6795,7 +6813,7 @@ def exception_type(
|
|||
llm_provider="vertex_ai",
|
||||
request=original_exception.request,
|
||||
)
|
||||
elif custom_llm_provider == "palm":
|
||||
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
||||
if "503 Getting metadata" in error_str:
|
||||
# auth errors look like this
|
||||
# 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate.
|
||||
|
@ -6814,6 +6832,15 @@ def exception_type(
|
|||
llm_provider="palm",
|
||||
response=original_exception.response,
|
||||
)
|
||||
if "500 An internal error has occurred." in error_str:
|
||||
exception_mapping_worked = True
|
||||
raise APIError(
|
||||
status_code=original_exception.status_code,
|
||||
message=f"PalmException - {original_exception.message}",
|
||||
llm_provider="palm",
|
||||
model=model,
|
||||
request=original_exception.request,
|
||||
)
|
||||
if hasattr(original_exception, "status_code"):
|
||||
if original_exception.status_code == 400:
|
||||
exception_mapping_worked = True
|
||||
|
@ -8524,6 +8551,19 @@ class CustomStreamWrapper:
|
|||
]
|
||||
elif self.custom_llm_provider == "text-completion-openai":
|
||||
response_obj = self.handle_openai_text_completion_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj["is_finished"]:
|
||||
model_response.choices[0].finish_reason = response_obj[
|
||||
"finish_reason"
|
||||
]
|
||||
elif self.custom_llm_provider == "cached_response":
|
||||
response_obj = {
|
||||
"text": chunk.choices[0].delta.content,
|
||||
"is_finished": True,
|
||||
"finish_reason": chunk.choices[0].finish_reason,
|
||||
}
|
||||
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj["is_finished"]:
|
||||
|
@ -8732,6 +8772,7 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "vertex_ai"
|
||||
or self.custom_llm_provider == "sagemaker"
|
||||
or self.custom_llm_provider == "gemini"
|
||||
or self.custom_llm_provider == "cached_response"
|
||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||
):
|
||||
async for chunk in self.completion_stream:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue