diff --git a/litellm/caching.py b/litellm/caching.py index 6bf53ea451..f996a58735 100644 --- a/litellm/caching.py +++ b/litellm/caching.py @@ -531,6 +531,9 @@ class RedisSemanticCache(BaseCache): return None pass + async def _index_info(self): + return await self.index.ainfo() + class S3Cache(BaseCache): def __init__( diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index 3c3e793dfb..3031868ec7 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -255,6 +255,7 @@ class LangFuseLogger: if key in [ "user_api_key", "user_api_key_user_id", + "semantic-similarity", ]: tags.append(f"{key}:{value}") if "cache_hit" in kwargs: diff --git a/litellm/main.py b/litellm/main.py index 384dadc32d..b18221607f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -10,7 +10,6 @@ import os, openai, sys, json, inspect, uuid, datetime, threading from typing import Any, Literal, Union from functools import partial - import dotenv, traceback, random, asyncio, time, contextvars from copy import deepcopy import httpx diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 326544f41e..a8144e9d48 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -78,7 +78,9 @@ litellm_settings: type: "redis-semantic" similarity_threshold: 0.8 redis_semantic_cache_embedding_model: azure-embedding-model - # cache: True + upperbound_key_generate_params: + max_budget: 100 + duration: "30d" # setting callback class # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 6f442f1ae3..3b8b5a3b32 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1759,7 +1759,33 @@ async def async_data_generator(response, user_api_key_dict): done_message = "[DONE]" yield f"data: {done_message}\n\n" except Exception as e: - yield f"data: {str(e)}\n\n" + traceback.print_exc() + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e + ) + verbose_proxy_logger.debug( + f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`" + ) + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if user_debug: + traceback.print_exc() + + if isinstance(e, HTTPException): + raise e + else: + error_traceback = traceback.format_exc() + error_msg = f"{str(e)}\n\n{error_traceback}" + + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) def select_data_generator(response, user_api_key_dict): @@ -1767,7 +1793,7 @@ def select_data_generator(response, user_api_key_dict): # since boto3 - sagemaker does not support async calls, we should use a sync data_generator if hasattr( response, "custom_llm_provider" - ) and response.custom_llm_provider in ["sagemaker", "together_ai"]: + ) and response.custom_llm_provider in ["sagemaker"]: return data_generator( response=response, ) @@ -2256,7 +2282,6 @@ async def chat_completion( selected_data_generator = select_data_generator( response=response, user_api_key_dict=user_api_key_dict ) - return StreamingResponse( selected_data_generator, media_type="text/event-stream", @@ -4103,19 +4128,29 @@ async def health_readiness(): cache_type = None if litellm.cache is not None: + from litellm.caching import RedisSemanticCache + cache_type = litellm.cache.type - if prisma_client is not None: # if db passed in, check if it's connected - if prisma_client.db.is_connected() == True: - response_object = {"db": "connected"} + if isinstance(litellm.cache.cache, RedisSemanticCache): + # ping the cache + try: + index_info = await litellm.cache.cache._index_info() + except Exception as e: + index_info = "index does not exist - error: " + str(e) + cache_type = {"type": cache_type, "index_info": index_info} - return { - "status": "healthy", - "db": "connected", - "cache": cache_type, - "litellm_version": version, - "success_callbacks": litellm.success_callback, - } + if prisma_client is not None: # if db passed in, check if it's connected + await prisma_client.health_check() # test the db connection + response_object = {"db": "connected"} + + return { + "status": "healthy", + "db": "connected", + "cache": cache_type, + "litellm_version": version, + "success_callbacks": litellm.success_callback, + } else: return { "status": "healthy", diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 62cbc6b4be..20b619b730 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -472,8 +472,6 @@ class PrismaClient: reset_at: Optional[datetime] = None, ): try: - print_verbose("PrismaClient: get_data") - response: Any = None if token is not None or (table_name is not None and table_name == "key"): # check if plain text or hash @@ -896,6 +894,21 @@ class PrismaClient: ) raise e + async def health_check(self): + """ + Health check endpoint for the prisma client + """ + sql_query = """ + SELECT 1 + FROM "LiteLLM_VerificationToken" + LIMIT 1 + """ + + # Execute the raw query + # The asterisk before `user_id_list` unpacks the list into separate arguments + response = await self.db.query_raw(sql_query) + return response + class DBClient: """ diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index b075e48190..80a4372a57 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -41,7 +41,7 @@ def test_completion_custom_provider_model_name(): messages=messages, logger_fn=logger_fn, ) - # Add any assertions here to check the, response + # Add any assertions here to check the,response print(response) print(response["choices"][0]["finish_reason"]) except litellm.Timeout as e: diff --git a/litellm/utils.py b/litellm/utils.py index c25572c03c..b37c68d655 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -169,6 +169,8 @@ def map_finish_reason( return "stop" elif finish_reason == "SAFETY": # vertex ai return "content_filter" + elif finish_reason == "STOP": # vertex ai + return "stop" return finish_reason @@ -1305,7 +1307,7 @@ class Logging: ) if callback == "langfuse": global langFuseLogger - verbose_logger.debug("reaches langfuse for logging!") + verbose_logger.debug("reaches langfuse for success logging!") kwargs = {} for k, v in self.model_call_details.items(): if ( @@ -6716,7 +6718,13 @@ def exception_type( message=f"VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", - response=original_exception.response, + response=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), ) elif ( "429 Quota exceeded" in error_str @@ -8351,13 +8359,20 @@ class CustomStreamWrapper: completion_obj["content"] = chunk.text elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): try: - # print(chunk) - if hasattr(chunk, "text"): - # vertexAI chunks return - # MultiCandidateTextGenerationResponse(text=' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', _prediction_response=Prediction(predictions=[{'candidates': [{'content': ' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', 'author': '1'}], 'citationMetadata': [{'citations': None}], 'safetyAttributes': [{'blocked': False, 'scores': None, 'categories': None}]}], deployed_model_id='', model_version_id=None, model_resource_name=None, explanations=None), is_blocked=False, safety_attributes={}, candidates=[ ```python - # This Python code says "Hi" 100 times. - # Create]) - completion_obj["content"] = chunk.text + if hasattr(chunk, "candidates") == True: + try: + completion_obj["content"] = chunk.text + if hasattr(chunk.candidates[0], "finish_reason"): + model_response.choices[ + 0 + ].finish_reason = map_finish_reason( + chunk.candidates[0].finish_reason.name + ) + except: + if chunk.candidates[0].finish_reason.name == "SAFETY": + raise Exception( + f"The response was blocked by VertexAI. {str(chunk)}" + ) else: completion_obj["content"] = str(chunk) except StopIteration as e: @@ -8646,7 +8661,6 @@ class CustomStreamWrapper: or self.custom_llm_provider == "ollama_chat" or self.custom_llm_provider == "vertex_ai" ): - print_verbose(f"INSIDE ASYNC STREAMING!!!") print_verbose( f"value of async completion stream: {self.completion_stream}" ) diff --git a/requirements.txt b/requirements.txt index 3ace5872ad..55c5f14568 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,8 +8,7 @@ pyyaml>=6.0.1 # server dep uvicorn==0.22.0 # server dep gunicorn==21.2.0 # server dep boto3==1.28.58 # aws bedrock/sagemaker calls -redis==4.6.0 # caching -redisvl==0.0.7 # semantic caching +redis==5.0.0 # caching numpy==1.24.3 # semantic caching prisma==0.11.0 # for db mangum==0.17.0 # for aws lambda functions