forked from phoenix/litellm-mirror
Merge pull request #2124 from BerriAI/litellm_streaming_caching_logging
fix(utils.py): support streaming cached response logging
This commit is contained in:
commit
bfdacf6d6b
5 changed files with 140 additions and 38 deletions
|
@ -124,7 +124,9 @@ class RedisCache(BaseCache):
|
||||||
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
self.redis_client.set(name=key, value=str(value), ex=ttl)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# NON blocking - notify users Redis is throwing an exception
|
# NON blocking - notify users Redis is throwing an exception
|
||||||
print_verbose("LiteLLM Caching: set() - Got exception from REDIS : ", e)
|
print_verbose(
|
||||||
|
f"LiteLLM Caching: set() - Got exception from REDIS : {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
async def async_set_cache(self, key, value, **kwargs):
|
async def async_set_cache(self, key, value, **kwargs):
|
||||||
_redis_client = self.init_async_client()
|
_redis_client = self.init_async_client()
|
||||||
|
|
|
@ -1986,6 +1986,8 @@ def test_completion_gemini():
|
||||||
response = completion(model=model_name, messages=messages)
|
response = completion(model=model_name, messages=messages)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
|
except litellm.APIError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -2015,6 +2017,8 @@ def test_completion_palm():
|
||||||
response = completion(model=model_name, messages=messages)
|
response = completion(model=model_name, messages=messages)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
|
except litellm.APIError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -2037,6 +2041,8 @@ def test_completion_palm_stream():
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
except litellm.APIError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
## This test asserts the type of data passed into each method of the custom callback handler
|
## This test asserts the type of data passed into each method of the custom callback handler
|
||||||
import sys, os, time, inspect, asyncio, traceback
|
import sys, os, time, inspect, asyncio, traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import pytest
|
import pytest, uuid
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
sys.path.insert(0, os.path.abspath("../.."))
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
|
@ -795,6 +795,53 @@ async def test_async_completion_azure_caching():
|
||||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_completion_azure_caching_streaming():
|
||||||
|
import copy
|
||||||
|
|
||||||
|
litellm.set_verbose = True
|
||||||
|
customHandler_caching = CompletionCustomHandler()
|
||||||
|
litellm.cache = Cache(
|
||||||
|
type="redis",
|
||||||
|
host=os.environ["REDIS_HOST"],
|
||||||
|
port=os.environ["REDIS_PORT"],
|
||||||
|
password=os.environ["REDIS_PASSWORD"],
|
||||||
|
)
|
||||||
|
litellm.callbacks = [customHandler_caching]
|
||||||
|
unique_time = uuid.uuid4()
|
||||||
|
response1 = await litellm.acompletion(
|
||||||
|
model="azure/chatgpt-v-2",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
|
||||||
|
],
|
||||||
|
caching=True,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
async for chunk in response1:
|
||||||
|
print(f"chunk in response1: {chunk}")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
initial_customhandler_caching_states = len(customHandler_caching.states)
|
||||||
|
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||||
|
response2 = await litellm.acompletion(
|
||||||
|
model="azure/chatgpt-v-2",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
|
||||||
|
],
|
||||||
|
caching=True,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
async for chunk in response2:
|
||||||
|
print(f"chunk in response2: {chunk}")
|
||||||
|
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||||
|
print(
|
||||||
|
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
|
||||||
|
)
|
||||||
|
assert len(customHandler_caching.errors) == 0
|
||||||
|
assert (
|
||||||
|
len(customHandler_caching.states) > initial_customhandler_caching_states
|
||||||
|
) # pre, post, streaming .., success, success
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_embedding_azure_caching():
|
async def test_async_embedding_azure_caching():
|
||||||
print("Testing custom callback input - Azure Caching")
|
print("Testing custom callback input - Azure Caching")
|
||||||
|
|
|
@ -392,6 +392,8 @@ def test_completion_palm_stream():
|
||||||
if complete_response.strip() == "":
|
if complete_response.strip() == "":
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
print(f"completion_response: {complete_response}")
|
print(f"completion_response: {complete_response}")
|
||||||
|
except litellm.APIError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -425,6 +427,8 @@ def test_completion_gemini_stream():
|
||||||
if complete_response.strip() == "":
|
if complete_response.strip() == "":
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
print(f"completion_response: {complete_response}")
|
print(f"completion_response: {complete_response}")
|
||||||
|
except litellm.APIError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@ -461,6 +465,8 @@ async def test_acompletion_gemini_stream():
|
||||||
print(f"completion_response: {complete_response}")
|
print(f"completion_response: {complete_response}")
|
||||||
if complete_response.strip() == "":
|
if complete_response.strip() == "":
|
||||||
raise Exception("Empty response received")
|
raise Exception("Empty response received")
|
||||||
|
except litellm.APIError as e:
|
||||||
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
113
litellm/utils.py
113
litellm/utils.py
|
@ -1411,7 +1411,7 @@ class Logging:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n"
|
f"success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n"
|
||||||
)
|
)
|
||||||
return
|
pass
|
||||||
else:
|
else:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
"success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache"
|
"success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache"
|
||||||
|
@ -1616,7 +1616,7 @@ class Logging:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n"
|
f"async success_callback: reaches cache for logging, there is no complete_streaming_response. Kwargs={kwargs}\n\n"
|
||||||
)
|
)
|
||||||
return
|
pass
|
||||||
else:
|
else:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
"async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache"
|
"async success_callback: reaches cache for logging, there is a complete_streaming_response. Adding to cache"
|
||||||
|
@ -1625,8 +1625,10 @@ class Logging:
|
||||||
# only add to cache once we have a complete streaming response
|
# only add to cache once we have a complete streaming response
|
||||||
litellm.cache.add_cache(result, **kwargs)
|
litellm.cache.add_cache(result, **kwargs)
|
||||||
if isinstance(callback, CustomLogger): # custom logger class
|
if isinstance(callback, CustomLogger): # custom logger class
|
||||||
print_verbose(f"Async success callbacks: {callback}")
|
print_verbose(
|
||||||
if self.stream:
|
f"Async success callbacks: {callback}; self.stream: {self.stream}; complete_streaming_response: {self.model_call_details.get('complete_streaming_response', None)}"
|
||||||
|
)
|
||||||
|
if self.stream == True:
|
||||||
if "complete_streaming_response" in self.model_call_details:
|
if "complete_streaming_response" in self.model_call_details:
|
||||||
await callback.async_log_success_event(
|
await callback.async_log_success_event(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
|
@ -2328,6 +2330,13 @@ def client(original_function):
|
||||||
model_response_object=ModelResponse(),
|
model_response_object=ModelResponse(),
|
||||||
stream=kwargs.get("stream", False),
|
stream=kwargs.get("stream", False),
|
||||||
)
|
)
|
||||||
|
if kwargs.get("stream", False) == True:
|
||||||
|
cached_result = CustomStreamWrapper(
|
||||||
|
completion_stream=cached_result,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
elif call_type == CallTypes.embedding.value and isinstance(
|
elif call_type == CallTypes.embedding.value and isinstance(
|
||||||
cached_result, dict
|
cached_result, dict
|
||||||
):
|
):
|
||||||
|
@ -2624,28 +2633,6 @@ def client(original_function):
|
||||||
cached_result, list
|
cached_result, list
|
||||||
):
|
):
|
||||||
print_verbose(f"Cache Hit!")
|
print_verbose(f"Cache Hit!")
|
||||||
call_type = original_function.__name__
|
|
||||||
if call_type == CallTypes.acompletion.value and isinstance(
|
|
||||||
cached_result, dict
|
|
||||||
):
|
|
||||||
if kwargs.get("stream", False) == True:
|
|
||||||
cached_result = convert_to_streaming_response_async(
|
|
||||||
response_object=cached_result,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cached_result = convert_to_model_response_object(
|
|
||||||
response_object=cached_result,
|
|
||||||
model_response_object=ModelResponse(),
|
|
||||||
)
|
|
||||||
elif call_type == CallTypes.aembedding.value and isinstance(
|
|
||||||
cached_result, dict
|
|
||||||
):
|
|
||||||
cached_result = convert_to_model_response_object(
|
|
||||||
response_object=cached_result,
|
|
||||||
model_response_object=EmbeddingResponse(),
|
|
||||||
response_type="embedding",
|
|
||||||
)
|
|
||||||
# LOG SUCCESS
|
|
||||||
cache_hit = True
|
cache_hit = True
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
(
|
(
|
||||||
|
@ -2685,15 +2672,44 @@ def client(original_function):
|
||||||
additional_args=None,
|
additional_args=None,
|
||||||
stream=kwargs.get("stream", False),
|
stream=kwargs.get("stream", False),
|
||||||
)
|
)
|
||||||
asyncio.create_task(
|
call_type = original_function.__name__
|
||||||
logging_obj.async_success_handler(
|
if call_type == CallTypes.acompletion.value and isinstance(
|
||||||
cached_result, start_time, end_time, cache_hit
|
cached_result, dict
|
||||||
|
):
|
||||||
|
if kwargs.get("stream", False) == True:
|
||||||
|
cached_result = convert_to_streaming_response_async(
|
||||||
|
response_object=cached_result,
|
||||||
|
)
|
||||||
|
cached_result = CustomStreamWrapper(
|
||||||
|
completion_stream=cached_result,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cached_result = convert_to_model_response_object(
|
||||||
|
response_object=cached_result,
|
||||||
|
model_response_object=ModelResponse(),
|
||||||
|
)
|
||||||
|
elif call_type == CallTypes.aembedding.value and isinstance(
|
||||||
|
cached_result, dict
|
||||||
|
):
|
||||||
|
cached_result = convert_to_model_response_object(
|
||||||
|
response_object=cached_result,
|
||||||
|
model_response_object=EmbeddingResponse(),
|
||||||
|
response_type="embedding",
|
||||||
)
|
)
|
||||||
)
|
if kwargs.get("stream", False) == False:
|
||||||
threading.Thread(
|
# LOG SUCCESS
|
||||||
target=logging_obj.success_handler,
|
asyncio.create_task(
|
||||||
args=(cached_result, start_time, end_time, cache_hit),
|
logging_obj.async_success_handler(
|
||||||
).start()
|
cached_result, start_time, end_time, cache_hit
|
||||||
|
)
|
||||||
|
)
|
||||||
|
threading.Thread(
|
||||||
|
target=logging_obj.success_handler,
|
||||||
|
args=(cached_result, start_time, end_time, cache_hit),
|
||||||
|
).start()
|
||||||
return cached_result
|
return cached_result
|
||||||
elif (
|
elif (
|
||||||
call_type == CallTypes.aembedding.value
|
call_type == CallTypes.aembedding.value
|
||||||
|
@ -4296,7 +4312,9 @@ def get_optional_params(
|
||||||
parameters=tool["function"].get("parameters", {}),
|
parameters=tool["function"].get("parameters", {}),
|
||||||
)
|
)
|
||||||
gtool_func_declarations.append(gtool_func_declaration)
|
gtool_func_declarations.append(gtool_func_declaration)
|
||||||
optional_params["tools"] = [generative_models.Tool(function_declarations=gtool_func_declarations)]
|
optional_params["tools"] = [
|
||||||
|
generative_models.Tool(function_declarations=gtool_func_declarations)
|
||||||
|
]
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
## check if unsupported param passed in
|
## check if unsupported param passed in
|
||||||
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||||
|
@ -6795,7 +6813,7 @@ def exception_type(
|
||||||
llm_provider="vertex_ai",
|
llm_provider="vertex_ai",
|
||||||
request=original_exception.request,
|
request=original_exception.request,
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "palm":
|
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
||||||
if "503 Getting metadata" in error_str:
|
if "503 Getting metadata" in error_str:
|
||||||
# auth errors look like this
|
# auth errors look like this
|
||||||
# 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate.
|
# 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate.
|
||||||
|
@ -6814,6 +6832,15 @@ def exception_type(
|
||||||
llm_provider="palm",
|
llm_provider="palm",
|
||||||
response=original_exception.response,
|
response=original_exception.response,
|
||||||
)
|
)
|
||||||
|
if "500 An internal error has occurred." in error_str:
|
||||||
|
exception_mapping_worked = True
|
||||||
|
raise APIError(
|
||||||
|
status_code=original_exception.status_code,
|
||||||
|
message=f"PalmException - {original_exception.message}",
|
||||||
|
llm_provider="palm",
|
||||||
|
model=model,
|
||||||
|
request=original_exception.request,
|
||||||
|
)
|
||||||
if hasattr(original_exception, "status_code"):
|
if hasattr(original_exception, "status_code"):
|
||||||
if original_exception.status_code == 400:
|
if original_exception.status_code == 400:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
|
@ -8524,6 +8551,19 @@ class CustomStreamWrapper:
|
||||||
]
|
]
|
||||||
elif self.custom_llm_provider == "text-completion-openai":
|
elif self.custom_llm_provider == "text-completion-openai":
|
||||||
response_obj = self.handle_openai_text_completion_chunk(chunk)
|
response_obj = self.handle_openai_text_completion_chunk(chunk)
|
||||||
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
|
if response_obj["is_finished"]:
|
||||||
|
model_response.choices[0].finish_reason = response_obj[
|
||||||
|
"finish_reason"
|
||||||
|
]
|
||||||
|
elif self.custom_llm_provider == "cached_response":
|
||||||
|
response_obj = {
|
||||||
|
"text": chunk.choices[0].delta.content,
|
||||||
|
"is_finished": True,
|
||||||
|
"finish_reason": chunk.choices[0].finish_reason,
|
||||||
|
}
|
||||||
|
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
|
@ -8732,6 +8772,7 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "vertex_ai"
|
or self.custom_llm_provider == "vertex_ai"
|
||||||
or self.custom_llm_provider == "sagemaker"
|
or self.custom_llm_provider == "sagemaker"
|
||||||
or self.custom_llm_provider == "gemini"
|
or self.custom_llm_provider == "gemini"
|
||||||
|
or self.custom_llm_provider == "cached_response"
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
):
|
):
|
||||||
async for chunk in self.completion_stream:
|
async for chunk in self.completion_stream:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue