Merge branch 'main' into litellm_admin_ui_view_all_keys

This commit is contained in:
Krish Dholakia 2024-02-06 14:34:57 -08:00 committed by GitHub
commit e36566a212
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 96 additions and 30 deletions

View file

@ -531,6 +531,9 @@ class RedisSemanticCache(BaseCache):
return None return None
pass pass
async def _index_info(self):
return await self.index.ainfo()
class S3Cache(BaseCache): class S3Cache(BaseCache):
def __init__( def __init__(

View file

@ -255,6 +255,7 @@ class LangFuseLogger:
if key in [ if key in [
"user_api_key", "user_api_key",
"user_api_key_user_id", "user_api_key_user_id",
"semantic-similarity",
]: ]:
tags.append(f"{key}:{value}") tags.append(f"{key}:{value}")
if "cache_hit" in kwargs: if "cache_hit" in kwargs:

View file

@ -10,7 +10,6 @@
import os, openai, sys, json, inspect, uuid, datetime, threading import os, openai, sys, json, inspect, uuid, datetime, threading
from typing import Any, Literal, Union from typing import Any, Literal, Union
from functools import partial from functools import partial
import dotenv, traceback, random, asyncio, time, contextvars import dotenv, traceback, random, asyncio, time, contextvars
from copy import deepcopy from copy import deepcopy
import httpx import httpx

View file

@ -78,7 +78,9 @@ litellm_settings:
type: "redis-semantic" type: "redis-semantic"
similarity_threshold: 0.8 similarity_threshold: 0.8
redis_semantic_cache_embedding_model: azure-embedding-model redis_semantic_cache_embedding_model: azure-embedding-model
# cache: True upperbound_key_generate_params:
max_budget: 100
duration: "30d"
# setting callback class # setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance] # callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]

View file

@ -1759,7 +1759,33 @@ async def async_data_generator(response, user_api_key_dict):
done_message = "[DONE]" done_message = "[DONE]"
yield f"data: {done_message}\n\n" yield f"data: {done_message}\n\n"
except Exception as e: except Exception as e:
yield f"data: {str(e)}\n\n" traceback.print_exc()
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e
)
verbose_proxy_logger.debug(
f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
)
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if user_debug:
traceback.print_exc()
if isinstance(e, HTTPException):
raise e
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}\n\n{error_traceback}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
def select_data_generator(response, user_api_key_dict): def select_data_generator(response, user_api_key_dict):
@ -1767,7 +1793,7 @@ def select_data_generator(response, user_api_key_dict):
# since boto3 - sagemaker does not support async calls, we should use a sync data_generator # since boto3 - sagemaker does not support async calls, we should use a sync data_generator
if hasattr( if hasattr(
response, "custom_llm_provider" response, "custom_llm_provider"
) and response.custom_llm_provider in ["sagemaker", "together_ai"]: ) and response.custom_llm_provider in ["sagemaker"]:
return data_generator( return data_generator(
response=response, response=response,
) )
@ -2256,7 +2282,6 @@ async def chat_completion(
selected_data_generator = select_data_generator( selected_data_generator = select_data_generator(
response=response, user_api_key_dict=user_api_key_dict response=response, user_api_key_dict=user_api_key_dict
) )
return StreamingResponse( return StreamingResponse(
selected_data_generator, selected_data_generator,
media_type="text/event-stream", media_type="text/event-stream",
@ -4103,19 +4128,29 @@ async def health_readiness():
cache_type = None cache_type = None
if litellm.cache is not None: if litellm.cache is not None:
from litellm.caching import RedisSemanticCache
cache_type = litellm.cache.type cache_type = litellm.cache.type
if prisma_client is not None: # if db passed in, check if it's connected if isinstance(litellm.cache.cache, RedisSemanticCache):
if prisma_client.db.is_connected() == True: # ping the cache
response_object = {"db": "connected"} try:
index_info = await litellm.cache.cache._index_info()
except Exception as e:
index_info = "index does not exist - error: " + str(e)
cache_type = {"type": cache_type, "index_info": index_info}
return { if prisma_client is not None: # if db passed in, check if it's connected
"status": "healthy", await prisma_client.health_check() # test the db connection
"db": "connected", response_object = {"db": "connected"}
"cache": cache_type,
"litellm_version": version, return {
"success_callbacks": litellm.success_callback, "status": "healthy",
} "db": "connected",
"cache": cache_type,
"litellm_version": version,
"success_callbacks": litellm.success_callback,
}
else: else:
return { return {
"status": "healthy", "status": "healthy",

View file

@ -472,8 +472,6 @@ class PrismaClient:
reset_at: Optional[datetime] = None, reset_at: Optional[datetime] = None,
): ):
try: try:
print_verbose("PrismaClient: get_data")
response: Any = None response: Any = None
if token is not None or (table_name is not None and table_name == "key"): if token is not None or (table_name is not None and table_name == "key"):
# check if plain text or hash # check if plain text or hash
@ -896,6 +894,21 @@ class PrismaClient:
) )
raise e raise e
async def health_check(self):
"""
Health check endpoint for the prisma client
"""
sql_query = """
SELECT 1
FROM "LiteLLM_VerificationToken"
LIMIT 1
"""
# Execute the raw query
# The asterisk before `user_id_list` unpacks the list into separate arguments
response = await self.db.query_raw(sql_query)
return response
class DBClient: class DBClient:
""" """

View file

@ -41,7 +41,7 @@ def test_completion_custom_provider_model_name():
messages=messages, messages=messages,
logger_fn=logger_fn, logger_fn=logger_fn,
) )
# Add any assertions here to check the, response # Add any assertions here to check the,response
print(response) print(response)
print(response["choices"][0]["finish_reason"]) print(response["choices"][0]["finish_reason"])
except litellm.Timeout as e: except litellm.Timeout as e:

View file

@ -169,6 +169,8 @@ def map_finish_reason(
return "stop" return "stop"
elif finish_reason == "SAFETY": # vertex ai elif finish_reason == "SAFETY": # vertex ai
return "content_filter" return "content_filter"
elif finish_reason == "STOP": # vertex ai
return "stop"
return finish_reason return finish_reason
@ -1305,7 +1307,7 @@ class Logging:
) )
if callback == "langfuse": if callback == "langfuse":
global langFuseLogger global langFuseLogger
verbose_logger.debug("reaches langfuse for logging!") verbose_logger.debug("reaches langfuse for success logging!")
kwargs = {} kwargs = {}
for k, v in self.model_call_details.items(): for k, v in self.model_call_details.items():
if ( if (
@ -6716,7 +6718,13 @@ def exception_type(
message=f"VertexAIException - {error_str}", message=f"VertexAIException - {error_str}",
model=model, model=model,
llm_provider="vertex_ai", llm_provider="vertex_ai",
response=original_exception.response, response=httpx.Response(
status_code=429,
request=httpx.Request(
method="POST",
url=" https://cloud.google.com/vertex-ai/",
),
),
) )
elif ( elif (
"429 Quota exceeded" in error_str "429 Quota exceeded" in error_str
@ -8351,13 +8359,20 @@ class CustomStreamWrapper:
completion_obj["content"] = chunk.text completion_obj["content"] = chunk.text
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
try: try:
# print(chunk) if hasattr(chunk, "candidates") == True:
if hasattr(chunk, "text"): try:
# vertexAI chunks return completion_obj["content"] = chunk.text
# MultiCandidateTextGenerationResponse(text=' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', _prediction_response=Prediction(predictions=[{'candidates': [{'content': ' ```python\n# This Python code says "Hi" 100 times.\n\n# Create', 'author': '1'}], 'citationMetadata': [{'citations': None}], 'safetyAttributes': [{'blocked': False, 'scores': None, 'categories': None}]}], deployed_model_id='', model_version_id=None, model_resource_name=None, explanations=None), is_blocked=False, safety_attributes={}, candidates=[ ```python if hasattr(chunk.candidates[0], "finish_reason"):
# This Python code says "Hi" 100 times. model_response.choices[
# Create]) 0
completion_obj["content"] = chunk.text ].finish_reason = map_finish_reason(
chunk.candidates[0].finish_reason.name
)
except:
if chunk.candidates[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
)
else: else:
completion_obj["content"] = str(chunk) completion_obj["content"] = str(chunk)
except StopIteration as e: except StopIteration as e:
@ -8646,7 +8661,6 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "ollama_chat" or self.custom_llm_provider == "ollama_chat"
or self.custom_llm_provider == "vertex_ai" or self.custom_llm_provider == "vertex_ai"
): ):
print_verbose(f"INSIDE ASYNC STREAMING!!!")
print_verbose( print_verbose(
f"value of async completion stream: {self.completion_stream}" f"value of async completion stream: {self.completion_stream}"
) )

View file

@ -8,8 +8,7 @@ pyyaml>=6.0.1 # server dep
uvicorn==0.22.0 # server dep uvicorn==0.22.0 # server dep
gunicorn==21.2.0 # server dep gunicorn==21.2.0 # server dep
boto3==1.28.58 # aws bedrock/sagemaker calls boto3==1.28.58 # aws bedrock/sagemaker calls
redis==4.6.0 # caching redis==5.0.0 # caching
redisvl==0.0.7 # semantic caching
numpy==1.24.3 # semantic caching numpy==1.24.3 # semantic caching
prisma==0.11.0 # for db prisma==0.11.0 # for db
mangum==0.17.0 # for aws lambda functions mangum==0.17.0 # for aws lambda functions