Merge branch 'main' into litellm_region_based_routing

This commit is contained in:
Krish Dholakia 2024-05-08 22:19:51 -07:00 committed by GitHub
commit 8ad979cdfe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
85 changed files with 793 additions and 448 deletions

View file

@ -30,7 +30,7 @@ sys.path.insert(
try:
import fastapi
import backoff
import yaml
import yaml # type: ignore
import orjson
import logging
from apscheduler.schedulers.asyncio import AsyncIOScheduler
@ -3731,6 +3731,7 @@ async def chat_completion(
"x-litellm-model-id": model_id,
"x-litellm-cache-key": cache_key,
"x-litellm-model-api-base": api_base,
"x-litellm-version": version,
}
selected_data_generator = select_data_generator(
response=response,
@ -3746,6 +3747,7 @@ async def chat_completion(
fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
@ -3902,14 +3904,10 @@ async def completion(
},
)
if hasattr(response, "_hidden_params"):
model_id = response._hidden_params.get("model_id", None) or ""
original_response = (
response._hidden_params.get("original_response", None) or ""
)
else:
model_id = ""
original_response = ""
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
verbose_proxy_logger.debug("final response: %s", response)
if (
@ -3917,6 +3915,9 @@ async def completion(
): # use generate_responses to stream responses
custom_headers = {
"x-litellm-model-id": model_id,
"x-litellm-cache-key": cache_key,
"x-litellm-model-api-base": api_base,
"x-litellm-version": version,
}
selected_data_generator = select_data_generator(
response=response,
@ -3931,6 +3932,10 @@ async def completion(
)
fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
@ -3970,6 +3975,7 @@ async def completion(
) # azure compatible endpoint
async def embeddings(
request: Request,
fastapi_response: Response,
model: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
@ -4116,6 +4122,17 @@ async def embeddings(
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
@ -4154,6 +4171,7 @@ async def embeddings(
)
async def image_generation(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
@ -4273,6 +4291,17 @@ async def image_generation(
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
@ -4309,6 +4338,7 @@ async def image_generation(
)
async def audio_transcriptions(
request: Request,
fastapi_response: Response,
file: UploadFile = File(...),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
@ -4453,6 +4483,18 @@ async def audio_transcriptions(
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
@ -4492,6 +4534,7 @@ async def audio_transcriptions(
)
async def moderations(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
@ -4616,6 +4659,17 @@ async def moderations(
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers["x-litellm-model-id"] = model_id
fastapi_response.headers["x-litellm-cache-key"] = cache_key
fastapi_response.headers["x-litellm-model-api-base"] = api_base
fastapi_response.headers["x-litellm-version"] = version
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
@ -5821,35 +5875,38 @@ async def global_spend_end_users(data: Optional[GlobalEndUsersSpend] = None):
if prisma_client is None:
raise HTTPException(status_code=500, detail={"error": "No db connected"})
if data is None:
sql_query = f"""SELECT * FROM "Last30dTopEndUsersSpend";"""
"""
Gets the top 100 end-users for a given api key
"""
startTime = None
endTime = None
selected_api_key = None
if data is not None:
startTime = data.startTime
endTime = data.endTime
selected_api_key = data.api_key
response = await prisma_client.db.query_raw(query=sql_query)
else:
"""
Gets the top 100 end-users for a given api key
"""
current_date = datetime.now()
past_date = current_date - timedelta(days=30)
response = await prisma_client.db.litellm_spendlogs.group_by( # type: ignore
by=["end_user"],
where={
"AND": [{"startTime": {"gte": past_date}}, {"api_key": data.api_key}] # type: ignore
},
sum={"spend": True},
order={"_sum": {"spend": "desc"}}, # type: ignore
take=100,
count=True,
)
if response is not None and isinstance(response, list):
new_response = []
for r in response:
new_r = r
new_r["total_spend"] = r["_sum"]["spend"]
new_r["total_count"] = r["_count"]["_all"]
new_r.pop("_sum")
new_r.pop("_count")
new_response.append(new_r)
startTime = startTime or datetime.now() - timedelta(days=30)
endTime = endTime or datetime.now()
sql_query = """
SELECT end_user, COUNT(*) AS total_count, SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs"
WHERE "startTime" >= $1::timestamp
AND "startTime" < $2::timestamp
AND (
CASE
WHEN $3::TEXT IS NULL THEN TRUE
ELSE api_key = $3
END
)
GROUP BY end_user
ORDER BY total_spend DESC
LIMIT 100
"""
response = await prisma_client.db.query_raw(
sql_query, startTime, endTime, selected_api_key
)
return response