Merge branch 'main' into litellm_admin_ui_view_all_keys

This commit is contained in:
Krish Dholakia 2024-02-06 14:34:57 -08:00 committed by GitHub
commit e36566a212
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 96 additions and 30 deletions

View file

@ -531,6 +531,9 @@ class RedisSemanticCache(BaseCache):
return None return None
pass pass
async def _index_info(self):
return await self.index.ainfo()
class S3Cache(BaseCache): class S3Cache(BaseCache):
def __init__( def __init__(

View file

@ -255,6 +255,7 @@ class LangFuseLogger:
if key in [ if key in [
"user_api_key", "user_api_key",
"user_api_key_user_id", "user_api_key_user_id",
"semantic-similarity",
]: ]:
tags.append(f"{key}:{value}") tags.append(f"{key}:{value}")
if "cache_hit" in kwargs: if "cache_hit" in kwargs:

View file

@ -10,7 +10,6 @@
import os, openai, sys, json, inspect, uuid, datetime, threading import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union from typing import Any, Literal, Union
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx

View file

@ -78,7 +78,9 @@ litellm_settings:
type: "redis-semantic" type: "redis-semantic"
similarity_threshold: 0.8 similarity_threshold: 0.8
redis_semantic_cache_embedding_model: azure-embedding-model redis_semantic_cache_embedding_model: azure-embedding-model
# cache: True upperbound_key_generate_params:
max_budget: 100
duration: "30d"
# setting callback class # setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]

View file

@ -1759,7 +1759,33 @@ async def async_data_generator(response, user_api_key_dict):
done_message = "[DONE]" done_message = "[DONE]"
yield f"data: {done_message}\n\n" yield f"data: {done_message}\n\n"
except Exception as e: except Exception as e:
yield f"data: {str(e)}\n\n" traceback.print_exc()
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
)
verbose_proxy_logger.debug(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if user_debug:
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
def select_data_generator(response, user_api_key_dict): def select_data_generator(response, user_api_key_dict):
@ -1767,7 +1793,7 @@ def select_data_generator(response, user_api_key_dict):
# since boto3 - sagemaker does not support async calls, we should use a sync data_generator # since boto3 - sagemaker does not support async calls, we should use a sync data_generator
if hasattr( if hasattr(
response, "custom_llm_provider" response, "custom_llm_provider"
) and response.custom_llm_provider in ["sagemaker", "together_ai"]: ) and response.custom_llm_provider in ["sagemaker"]:
return data_generator( return data_generator(
response=response, response=response,
) )
@ -2256,7 +2282,6 @@ async def chat_completion(
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict response=response, user_api_key_dict=user_api_key_dict
) )
return StreamingResponse( return StreamingResponse(
selected_data_generator, selected_data_generator,
media_type="text/event-stream", media_type="text/event-stream",
@ -4103,10 +4128,20 @@ async def health_readiness():
cache_type = None cache_type = None
if litellm.cache is not None: if litellm.cache is not None:
from litellm.caching import RedisSemanticCache
cache_type = litellm.cache.type cache_type = litellm.cache.type
if isinstance(litellm.cache.cache, RedisSemanticCache):
# ping the cache
try:
index_info = await litellm.cache.cache._index_info()
except Exception as e:
index_info = "index does not exist - error: " + str(e)
cache_type = {"type": cache_type, "index_info": index_info}
if prisma_client is not None: # if db passed in, check if it's connected if prisma_client is not None: # if db passed in, check if it's connected
if prisma_client.db.is_connected() == True: await prisma_client.health_check() # test the db connection
response_object = {"db": "connected"} response_object = {"db": "connected"}
return { return {

View file

@ -472,8 +472,6 @@ class PrismaClient:
reset_at: Optional[datetime] = None, reset_at: Optional[datetime] = None,
): ):
try: try:
print_verbose("PrismaClient: get_data")
response: Any = None response: Any = None
if token is not None or (table_name is not None and table_name == "key"): if token is not None or (table_name is not None and table_name == "key"):
# check if plain text or hash # check if plain text or hash
@ -896,6 +894,21 @@ class PrismaClient:
) )
raise e raise e
async def health_check(self):
"""
Health check endpoint for the prisma client
"""
sql_query = """
SELECT 1
FROM "LiteLLM_VerificationToken"
LIMIT 1
"""
# Execute the raw query
# The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query)
return response
class DBClient: class DBClient:
""" """

View file

@ -169,6 +169,8 @@ def map_finish_reason(
return "stop" return "stop"
elif finish_reason == "SAFETY": # vertex ai elif finish_reason == "SAFETY": # vertex ai
return "content_filter" return "content_filter"
elif finish_reason == "STOP": # vertex ai
return "stop"
return finish_reason return finish_reason
@ -1305,7 +1307,7 @@ class Logging:
) )
if callback == "langfuse": if callback == "langfuse":
global langFuseLogger global langFuseLogger
verbose_logger.debug("reaches langfuse for logging!") verbose_logger.debug("reaches langfuse for success logging!")
kwargs = {} kwargs = {}
for k, v in self.model_call_details.items(): for k, v in self.model_call_details.items():
if ( if (
@ -6716,7 +6718,13 @@ def exception_type(
message=f"VertexAIException - {error_str}", message=f"VertexAIException - {error_str}",
model=model, model=model,
llm_provider="vertex_ai", llm_provider="vertex_ai",
response=original_exception.response, response=httpx.Response(
status_code=429,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
),
) )
elif ( elif (
"429 Quota exceeded" in error_str "429 Quota exceeded" in error_str
@ -8351,13 +8359,20 @@ class CustomStreamWrapper:
completion_obj["content"] = chunk.text completion_obj["content"] = chunk.text
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
try: try:
# print(chunk) if hasattr(chunk, "candidates") == True:
if hasattr(chunk, "text"): try:
# vertexAI chunks return
# MultiCandidateTextGenerationResponse(text=' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', _prediction_response=Prediction(predictions=[{'candidates': [{'content': ' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', 'author': '1'}], 'citationMetadata': [{'citations': None}], 'safetyAttributes': [{'blocked': False, 'scores': None, 'categories': None}]}], deployed_model_id='', model_version_id=None, model_resource_name=None, explanations=None), is_blocked=False, safety_attributes={}, candidates=[ ```python
# This Python code says "Hi" 100 times.
# Create])
completion_obj["content"] = chunk.text completion_obj["content"] = chunk.text
if hasattr(chunk.candidates[0], "finish_reason"):
model_response.choices[
0
].finish_reason = map_finish_reason(
chunk.candidates[0].finish_reason.name
)
except:
if chunk.candidates[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
)
else: else:
completion_obj["content"] = str(chunk) completion_obj["content"] = str(chunk)
except StopIteration as e: except StopIteration as e:
@ -8646,7 +8661,6 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "ollama_chat" or self.custom_llm_provider == "ollama_chat"
or self.custom_llm_provider == "vertex_ai" or self.custom_llm_provider == "vertex_ai"
): ):
print_verbose(f"INSIDE ASYNC STREAMING!!!")
print_verbose( print_verbose(
f"value of async completion stream: {self.completion_stream}" f"value of async completion stream: {self.completion_stream}"
) )

View file

@ -8,8 +8,7 @@ pyyaml>=6.0.1 # server dep
uvicorn==0.22.0 # server dep uvicorn==0.22.0 # server dep
gunicorn==21.2.0 # server dep gunicorn==21.2.0 # server dep
boto3==1.28.58 # aws bedrock/sagemaker calls boto3==1.28.58 # aws bedrock/sagemaker calls
redis==4.6.0 # caching redis==5.0.0 # caching
redisvl==0.0.7 # semantic caching
numpy==1.24.3 # semantic caching numpy==1.24.3 # semantic caching
prisma==0.11.0 # for db prisma==0.11.0 # for db
mangum==0.17.0 # for aws lambda functions mangum==0.17.0 # for aws lambda functions