LiteLLM Minor Fixes & Improvements (09/19/2024) (#5793)

* fix(model_prices_and_context_window.json): add cost tracking for more vertex llama3.1 model

8b and 70b models

* fix(proxy/utils.py): handle data being none on pre-call hooks

* fix(proxy/): create views on initial proxy startup

fixes base case, where user starts proxy for first time

 Fixes https://github.com/BerriAI/litellm/issues/5756

* build(config.yml): fix vertex version for test

* feat(ui/): support enabling/disabling slack alerting

Allows admin to turn on/off slack alerting through ui

* feat(rerank/main.py): support langfuse logging

* fix(proxy/utils.py): fix linting errors

* fix(langfuse.py): log clean metadata

* test(tests): replace deprecated openai model
This commit is contained in:
Krish Dholakia 2024-09-20 08:19:52 -07:00 committed by GitHub
parent 696fc387d2
commit 3933fba41f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 645 additions and 94 deletions

View file

@ -558,6 +558,7 @@ jobs:
pip install "anyio==3.7.1" pip install "anyio==3.7.1"
pip install "asyncio==3.4.3" pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1" pip install "PyGithub==1.59.1"
pip install "google-cloud-aiplatform==1.59.0"
- run: - run:
name: Build Docker image name: Build Docker image
command: docker build -t my-app:latest -f Dockerfile.database . command: docker build -t my-app:latest -f Dockerfile.database .

View file

@ -51,22 +51,25 @@ async def check_view_exists():
print("LiteLLM_VerificationTokenView Created!") # noqa print("LiteLLM_VerificationTokenView Created!") # noqa
sql_query = """ try:
CREATE MATERIALIZED VIEW IF NOT EXISTS "MonthlyGlobalSpend" AS await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
print("MonthlyGlobalSpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
SELECT SELECT
DATE_TRUNC('day', "startTime") AS date, DATE("startTime") AS date,
SUM("spend") AS spend SUM("spend") AS spend
FROM FROM
"LiteLLM_SpendLogs" "LiteLLM_SpendLogs"
WHERE WHERE
"startTime" >= CURRENT_DATE - INTERVAL '30 days' "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY GROUP BY
DATE_TRUNC('day', "startTime"); DATE("startTime");
""" """
# Execute the queries await db.execute_raw(query=sql_query)
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpend Created!") # noqa print("MonthlyGlobalSpend Created!") # noqa
try: try:
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""") await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")

View file

@ -10,6 +10,7 @@ from pydantic import BaseModel
import litellm import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
from litellm.secret_managers.main import str_to_bool
class LangFuseLogger: class LangFuseLogger:
@ -66,6 +67,11 @@ class LangFuseLogger:
project_id = None project_id = None
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None: if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
upstream_langfuse_debug = (
str_to_bool(self.upstream_langfuse_debug)
if self.upstream_langfuse_debug is not None
else None
)
self.upstream_langfuse_secret_key = os.getenv( self.upstream_langfuse_secret_key = os.getenv(
"UPSTREAM_LANGFUSE_SECRET_KEY" "UPSTREAM_LANGFUSE_SECRET_KEY"
) )
@ -80,7 +86,11 @@ class LangFuseLogger:
secret_key=self.upstream_langfuse_secret_key, secret_key=self.upstream_langfuse_secret_key,
host=self.upstream_langfuse_host, host=self.upstream_langfuse_host,
release=self.upstream_langfuse_release, release=self.upstream_langfuse_release,
debug=self.upstream_langfuse_debug, debug=(
upstream_langfuse_debug
if upstream_langfuse_debug is not None
else False
),
) )
else: else:
self.upstream_langfuse = None self.upstream_langfuse = None
@ -175,6 +185,7 @@ class LangFuseLogger:
pass pass
# end of processing langfuse ######################## # end of processing langfuse ########################
if ( if (
level == "ERROR" level == "ERROR"
and status_message is not None and status_message is not None
@ -208,6 +219,11 @@ class LangFuseLogger:
): ):
input = prompt input = prompt
output = response_obj["text"] output = response_obj["text"]
elif response_obj is not None and isinstance(
response_obj, litellm.RerankResponse
):
input = prompt
output = response_obj.results
elif ( elif (
kwargs.get("call_type") is not None kwargs.get("call_type") is not None
and kwargs.get("call_type") == "pass_through_endpoint" and kwargs.get("call_type") == "pass_through_endpoint"
@ -283,14 +299,14 @@ class LangFuseLogger:
input, input,
response_obj, response_obj,
): ):
from langfuse.model import CreateGeneration, CreateTrace from langfuse.model import CreateGeneration, CreateTrace # type: ignore
verbose_logger.warning( verbose_logger.warning(
"Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1" "Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
) )
trace = self.Langfuse.trace( trace = self.Langfuse.trace( # type: ignore
CreateTrace( CreateTrace( # type: ignore
name=metadata.get("generation_name", "litellm-completion"), name=metadata.get("generation_name", "litellm-completion"),
input=input, input=input,
output=output, output=output,
@ -336,6 +352,7 @@ class LangFuseLogger:
try: try:
tags = [] tags = []
try: try:
optional_params.pop("metadata")
metadata = copy.deepcopy( metadata = copy.deepcopy(
metadata metadata
) # Avoid modifying the original metadata ) # Avoid modifying the original metadata
@ -361,7 +378,7 @@ class LangFuseLogger:
langfuse.version.__version__ langfuse.version.__version__
) >= Version("2.7.3") ) >= Version("2.7.3")
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ") print_verbose("Langfuse Layer Logging - logging to langfuse v2 ")
if supports_tags: if supports_tags:
metadata_tags = metadata.pop("tags", []) metadata_tags = metadata.pop("tags", [])
@ -519,11 +536,11 @@ class LangFuseLogger:
if key.lower() not in ["authorization", "cookie", "referer"]: if key.lower() not in ["authorization", "cookie", "referer"]:
clean_headers[key] = value clean_headers[key] = value
clean_metadata["request"] = { # clean_metadata["request"] = {
"method": method, # "method": method,
"url": url, # "url": url,
"headers": clean_headers, # "headers": clean_headers,
} # }
trace = self.Langfuse.trace(**trace_params) trace = self.Langfuse.trace(**trace_params)
# Log provider specific information as a span # Log provider specific information as a span
@ -531,13 +548,19 @@ class LangFuseLogger:
generation_id = None generation_id = None
usage = None usage = None
if response_obj is not None and response_obj.get("id", None) is not None: if response_obj is not None:
generation_id = litellm.utils.get_logging_id(start_time, response_obj) if response_obj.get("id", None) is not None:
usage = { generation_id = litellm.utils.get_logging_id(
"prompt_tokens": response_obj.usage.prompt_tokens, start_time, response_obj
"completion_tokens": response_obj.usage.completion_tokens, )
"total_cost": cost if supports_costs else None, _usage_obj = getattr(response_obj, "usage", None)
}
if _usage_obj:
usage = {
"prompt_tokens": _usage_obj.prompt_tokens,
"completion_tokens": _usage_obj.completion_tokens,
"total_cost": cost if supports_costs else None,
}
generation_name = clean_metadata.pop("generation_name", None) generation_name = clean_metadata.pop("generation_name", None)
if generation_name is None: if generation_name is None:
# if `generation_name` is None, use sensible default values # if `generation_name` is None, use sensible default values

View file

@ -66,9 +66,6 @@ class CohereRerank(BaseLLM):
request_data_dict = request_data.dict(exclude_none=True) request_data_dict = request_data.dict(exclude_none=True)
if _is_async:
return self.async_rerank(request_data_dict=request_data_dict, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
## LOGGING ## LOGGING
litellm_logging_obj.pre_call( litellm_logging_obj.pre_call(
input=request_data_dict, input=request_data_dict,
@ -79,6 +76,10 @@ class CohereRerank(BaseLLM):
"headers": headers, "headers": headers,
}, },
) )
if _is_async:
return self.async_rerank(request_data_dict=request_data_dict, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
client = _get_httpx_client() client = _get_httpx_client()
response = client.post( response = client.post(
api_base, api_base,

View file

@ -175,7 +175,7 @@ class VertexAIPartnerModels(BaseLLM):
client=client, client=client,
timeout=timeout, timeout=timeout,
encoding=encoding, encoding=encoding,
custom_llm_provider="vertex_ai_beta", custom_llm_provider="vertex_ai",
) )
except Exception as e: except Exception as e:

View file

@ -2350,6 +2350,26 @@
"mode": "chat", "mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
}, },
"vertex_ai/meta/llama3-70b-instruct-maas": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "vertex_ai-llama_models",
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/meta/llama3-8b-instruct-maas": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "vertex_ai-llama_models",
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/mistral-large@latest": { "vertex_ai/mistral-large@latest": {
"max_tokens": 8191, "max_tokens": 8191,
"max_input_tokens": 128000, "max_input_tokens": 128000,

View file

@ -19,11 +19,11 @@ model_list:
- model_name: o1-preview - model_name: o1-preview
litellm_params: litellm_params:
model: o1-preview model: o1-preview
- model_name: rerank-english-v3.0
litellm_params:
model: cohere/rerank-english-v3.0
api_key: os.environ/COHERE_API_KEY
litellm_settings: litellm_settings:
cache: true success_callback: ["langfuse"]
# cache_params:
# type: "redis"
# service_name: "mymaster"
# sentinel_nodes:
# - ["localhost", 26379]

View file

@ -0,0 +1,232 @@
from typing import TYPE_CHECKING, Any
from litellm import verbose_logger
if TYPE_CHECKING:
from prisma import Prisma
_db = Prisma
else:
_db = Any
async def create_missing_views(db: _db):
"""
--------------------------------------------------
NOTE: Copy of `litellm/db_scripts/create_views.py`.
--------------------------------------------------
Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db.
LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth
MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month
If the view doesn't exist, one will be created.
"""
try:
# Try to select one row from the view
await db.query_raw("""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""")
print("LiteLLM_VerificationTokenView Exists!") # noqa
except Exception as e:
# If an error occurs, the view does not exist, so create it
await db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
t.spend AS team_spend,
t.max_budget AS team_max_budget,
t.tpm_limit AS team_tpm_limit,
t.rpm_limit AS team_rpm_limit
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
"""
)
print("LiteLLM_VerificationTokenView Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
print("MonthlyGlobalSpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
SELECT
DATE("startTime") AS date,
SUM("spend") AS spend
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE("startTime");
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
print("Last30dKeysBySpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS
SELECT
L."api_key",
V."key_alias",
V."key_name",
SUM(L."spend") AS total_spend
FROM
"LiteLLM_SpendLogs" L
LEFT JOIN
"LiteLLM_VerificationToken" V
ON
L."api_key" = V."token"
WHERE
L."startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
L."api_key", V."key_alias", V."key_name"
ORDER BY
total_spend DESC;
"""
await db.execute_raw(query=sql_query)
print("Last30dKeysBySpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""")
print("Last30dModelsBySpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS
SELECT
"model",
SUM("spend") AS total_spend
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
AND "model" != ''
GROUP BY
"model"
ORDER BY
total_spend DESC;
"""
await db.execute_raw(query=sql_query)
print("Last30dModelsBySpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""")
print("MonthlyGlobalSpendPerKey Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS
SELECT
DATE("startTime") AS date,
SUM("spend") AS spend,
api_key as api_key
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE("startTime"),
api_key;
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpendPerKey Created!") # noqa
try:
await db.query_raw(
"""SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1"""
)
print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS
SELECT
DATE("startTime") AS date,
SUM("spend") AS spend,
api_key as api_key,
"user" as "user"
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE("startTime"),
"user",
api_key;
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM DailyTagSpend LIMIT 1""")
print("DailyTagSpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW DailyTagSpend AS
SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag,
DATE(s."startTime") AS spend_date,
COUNT(*) AS log_count,
SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs" s
GROUP BY individual_request_tag, DATE(s."startTime");
"""
await db.execute_raw(query=sql_query)
print("DailyTagSpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""")
print("Last30dTopEndUsersSpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE VIEW "Last30dTopEndUsersSpend" AS
SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs"
WHERE end_user <> '' AND end_user <> user
AND "startTime" >= CURRENT_DATE - INTERVAL '30 days'
GROUP BY end_user
ORDER BY total_spend DESC
LIMIT 100;
"""
await db.execute_raw(query=sql_query)
print("Last30dTopEndUsersSpend Created!") # noqa
return
async def should_create_missing_views(db: _db) -> bool:
"""
Run only on first time startup.
If SpendLogs table already has values, then don't create views on startup.
"""
sql_query = """
SELECT reltuples::BIGINT
FROM pg_class
WHERE oid = '"LiteLLM_SpendLogs"'::regclass;
"""
result = await db.query_raw(query=sql_query)
verbose_logger.debug("Estimated Row count of LiteLLM_SpendLogs = {}".format(result))
if (
result
and isinstance(result, list)
and len(result) > 0
and isinstance(result[0], dict)
and "reltuples" in result[0]
and result[0]["reltuples"]
and (result[0]["reltuples"] == 0 or result[0]["reltuples"] == -1)
):
verbose_logger.debug("Should create views")
return True
return False

View file

@ -34,6 +34,7 @@ from litellm.proxy._types import (
UserAPIKeyAuth, UserAPIKeyAuth,
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.secret_managers.main import get_secret_str
from .streaming_handler import chunk_processor from .streaming_handler import chunk_processor
from .success_handler import PassThroughEndpointLogging from .success_handler import PassThroughEndpointLogging
@ -72,11 +73,11 @@ async def set_env_variables_in_header(custom_headers: dict):
if isinstance( if isinstance(
_langfuse_public_key, str _langfuse_public_key, str
) and _langfuse_public_key.startswith("os.environ/"): ) and _langfuse_public_key.startswith("os.environ/"):
_langfuse_public_key = litellm.get_secret(_langfuse_public_key) _langfuse_public_key = get_secret_str(_langfuse_public_key)
if isinstance( if isinstance(
_langfuse_secret_key, str _langfuse_secret_key, str
) and _langfuse_secret_key.startswith("os.environ/"): ) and _langfuse_secret_key.startswith("os.environ/"):
_langfuse_secret_key = litellm.get_secret(_langfuse_secret_key) _langfuse_secret_key = get_secret_str(_langfuse_secret_key)
headers["Authorization"] = "Basic " + b64encode( headers["Authorization"] = "Basic " + b64encode(
f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8") f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8")
).decode("ascii") ).decode("ascii")
@ -95,9 +96,10 @@ async def set_env_variables_in_header(custom_headers: dict):
"pass through endpoint - getting secret for variable name: %s", "pass through endpoint - getting secret for variable name: %s",
_variable_name, _variable_name,
) )
_secret_value = litellm.get_secret(_variable_name) _secret_value = get_secret_str(_variable_name)
new_value = value.replace(_variable_name, _secret_value) if _secret_value is not None:
headers[key] = new_value new_value = value.replace(_variable_name, _secret_value)
headers[key] = new_value
return headers return headers
@ -349,7 +351,7 @@ async def pass_through_request(
### CALL HOOKS ### - modify incoming data / reject request before calling the model ### CALL HOOKS ### - modify incoming data / reject request before calling the model
_parsed_body = await proxy_logging_obj.pre_call_hook( _parsed_body = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
data=_parsed_body or {}, data=_parsed_body,
call_type="pass_through_endpoint", call_type="pass_through_endpoint",
) )
@ -576,7 +578,7 @@ def create_pass_through_route(
adapter_id = str(uuid.uuid4()) adapter_id = str(uuid.uuid4())
litellm.adapters = [{"id": adapter_id, "adapter": adapter}] litellm.adapters = [{"id": adapter_id, "adapter": adapter}]
async def endpoint_func( async def endpoint_func( # type: ignore
request: Request, request: Request,
fastapi_response: Response, fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
@ -647,9 +649,9 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"adding pass through endpoint: %s, dependencies: %s", _path, _dependencies "adding pass through endpoint: %s, dependencies: %s", _path, _dependencies
) )
app.add_api_route( app.add_api_route( # type: ignore
path=_path, path=_path,
endpoint=create_pass_through_route( endpoint=create_pass_through_route( # type: ignore
_path, _target, _custom_headers, _forward_headers, _dependencies _path, _target, _custom_headers, _forward_headers, _dependencies
), ),
methods=["GET", "POST", "PUT", "DELETE", "PATCH"], methods=["GET", "POST", "PUT", "DELETE", "PATCH"],

View file

@ -2229,9 +2229,12 @@ class ProxyConfig:
and _general_settings.get("alerting", None) is not None and _general_settings.get("alerting", None) is not None
and isinstance(_general_settings["alerting"], list) and isinstance(_general_settings["alerting"], list)
): ):
for alert in _general_settings["alerting"]: verbose_proxy_logger.debug(
if alert not in general_settings["alerting"]: "Overriding Default 'alerting' values with db 'alerting' values."
general_settings["alerting"].append(alert) )
general_settings["alerting"] = _general_settings[
"alerting"
] # override yaml values with db
proxy_logging_obj.alerting = general_settings["alerting"] proxy_logging_obj.alerting = general_settings["alerting"]
proxy_logging_obj.slack_alerting_instance.alerting = general_settings[ proxy_logging_obj.slack_alerting_instance.alerting = general_settings[
"alerting" "alerting"
@ -7779,10 +7782,13 @@ async def alerting_settings(
if db_general_settings is not None and db_general_settings.param_value is not None: if db_general_settings is not None and db_general_settings.param_value is not None:
db_general_settings_dict = dict(db_general_settings.param_value) db_general_settings_dict = dict(db_general_settings.param_value)
alerting_args_dict: dict = db_general_settings_dict.get("alerting_args", {}) # type: ignore alerting_args_dict: dict = db_general_settings_dict.get("alerting_args", {}) # type: ignore
alerting_values: Optional[list] = db_general_settings_dict.get("alerting") # type: ignore
else: else:
alerting_args_dict = {} alerting_args_dict = {}
alerting_values = None
allowed_args = { allowed_args = {
"slack_alerting": {"type": "Boolean"},
"daily_report_frequency": {"type": "Integer"}, "daily_report_frequency": {"type": "Integer"},
"report_check_interval": {"type": "Integer"}, "report_check_interval": {"type": "Integer"},
"budget_alert_ttl": {"type": "Integer"}, "budget_alert_ttl": {"type": "Integer"},
@ -7798,6 +7804,25 @@ async def alerting_settings(
return_val = [] return_val = []
is_slack_enabled = False
if general_settings.get("alerting") and isinstance(
general_settings["alerting"], list
):
if "slack" in general_settings["alerting"]:
is_slack_enabled = True
_response_obj = ConfigList(
field_name="slack_alerting",
field_type=allowed_args["slack_alerting"]["type"],
field_description="Enable slack alerting for monitoring proxy in production: llm outages, budgets, spend tracking failures.",
field_value=is_slack_enabled,
stored_in_db=True if alerting_values is not None else False,
field_default_value=None,
premium_field=False,
)
return_val.append(_response_obj)
for field_name, field_info in SlackAlertingArgs.model_fields.items(): for field_name, field_info in SlackAlertingArgs.model_fields.items():
if field_name in allowed_args: if field_name in allowed_args:

View file

@ -1,21 +1,52 @@
import vertexai # import datetime
from vertexai.preview.generative_models import GenerativeModel
LITE_LLM_ENDPOINT = "http://localhost:4000" # import vertexai
# from vertexai.generative_models import Part
# from vertexai.preview import caching
# from vertexai.preview.generative_models import GenerativeModel
vertexai.init( # LITE_LLM_ENDPOINT = "http://localhost:4000"
project="adroit-crow-413218",
location="us-central1",
api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai",
api_transport="rest",
)
model = GenerativeModel(model_name="gemini-1.5-flash-001") # vertexai.init(
response = model.generate_content( # project="adroit-crow-413218",
"hi tell me a joke and a very long story", stream=True # location="us-central1",
) # api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai",
# api_transport="rest",
# )
print("response", response) # # model = GenerativeModel(model_name="gemini-1.5-flash-001")
# # response = model.generate_content(
# # "hi tell me a joke and a very long story", stream=True
# # )
for chunk in response: # # print("response", response)
print(chunk)
# # for chunk in response:
# # print(chunk)
# system_instruction = """
# You are an expert researcher. You always stick to the facts in the sources provided, and never make up new facts.
# Now look at these research papers, and answer the following questions.
# """
# contents = [
# Part.from_uri(
# "gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf",
# mime_type="application/pdf",
# ),
# Part.from_uri(
# "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf",
# mime_type="application/pdf",
# ),
# ]
# cached_content = caching.CachedContent.create(
# model_name="gemini-1.5-pro-001",
# system_instruction=system_instruction,
# contents=contents,
# ttl=datetime.timedelta(minutes=60),
# # display_name="example-cache",
# )
# print(cached_content.name)

View file

@ -14,7 +14,7 @@ from datetime import datetime, timedelta
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, overload
import backoff import backoff
import httpx import httpx
@ -51,6 +51,10 @@ from litellm.proxy._types import (
SpendLogsPayload, SpendLogsPayload,
UserAPIKeyAuth, UserAPIKeyAuth,
) )
from litellm.proxy.db.create_views import (
create_missing_views,
should_create_missing_views,
)
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.parallel_request_limiter import ( from litellm.proxy.hooks.parallel_request_limiter import (
@ -365,6 +369,25 @@ class ProxyLogging:
return data return data
# The actual implementation of the function # The actual implementation of the function
@overload
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: None,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> None:
pass
@overload
async def pre_call_hook( async def pre_call_hook(
self, self,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
@ -380,6 +403,23 @@ class ProxyLogging:
"rerank", "rerank",
], ],
) -> dict: ) -> dict:
pass
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: Optional[dict],
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Optional[dict]:
""" """
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
@ -394,6 +434,9 @@ class ProxyLogging:
self.slack_alerting_instance.response_taking_too_long(request_data=data) self.slack_alerting_instance.response_taking_too_long(request_data=data)
) )
if data is None:
return None
try: try:
for callback in litellm.callbacks: for callback in litellm.callbacks:
_callback = None _callback = None
@ -418,7 +461,7 @@ class ProxyLogging:
response = await _callback.async_pre_call_hook( response = await _callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"], cache=self.call_details["user_api_key_cache"],
data=data, data=data, # type: ignore
call_type=call_type, call_type=call_type,
) )
if response is not None: if response is not None:
@ -434,7 +477,7 @@ class ProxyLogging:
response = await _callback.async_pre_call_hook( response = await _callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"], cache=self.call_details["user_api_key_cache"],
data=data, data=data, # type: ignore
call_type=call_type, call_type=call_type,
) )
if response is not None: if response is not None:
@ -1021,20 +1064,24 @@ class PrismaClient:
"LiteLLM_VerificationTokenView Created in DB!" "LiteLLM_VerificationTokenView Created in DB!"
) )
else: else:
# don't block execution if these views are missing should_create_views = await should_create_missing_views(db=self.db)
# Convert lists to sets for efficient difference calculation if should_create_views:
ret_view_names_set = ( await create_missing_views(db=self.db)
set(ret[0]["view_names"]) if ret[0]["view_names"] else set() else:
) # don't block execution if these views are missing
expected_views_set = set(expected_views) # Convert lists to sets for efficient difference calculation
# Find missing views ret_view_names_set = (
missing_views = expected_views_set - ret_view_names_set set(ret[0]["view_names"]) if ret[0]["view_names"] else set()
)
verbose_proxy_logger.warning( expected_views_set = set(expected_views)
"\n\n\033[93mNot all views exist in db, needed for UI 'Usage' tab. Missing={}.\nRun 'create_views.py' from https://github.com/BerriAI/litellm/tree/main/db_scripts to create missing views.\033[0m\n".format( # Find missing views
missing_views missing_views = expected_views_set - ret_view_names_set
verbose_proxy_logger.warning(
"\n\n\033[93mNot all views exist in db, needed for UI 'Usage' tab. Missing={}.\nRun 'create_views.py' from https://github.com/BerriAI/litellm/tree/main/db_scripts to create missing views.\033[0m\n".format(
missing_views
)
) )
)
except Exception as e: except Exception as e:
raise raise

View file

@ -103,10 +103,20 @@ def rerank(
) )
) )
model_parameters = [
"top_n",
"rank_fields",
"return_documents",
"max_chunks_per_doc",
]
model_params_dict = {}
for k, v in optional_params.model_fields.items():
if k in model_parameters:
model_params_dict[k] = v
litellm_logging_obj.update_environment_variables( litellm_logging_obj.update_environment_variables(
model=model, model=model,
user=user, user=user,
optional_params=optional_params.model_dump(), optional_params=model_params_dict,
litellm_params={ litellm_params={
"litellm_call_id": litellm_call_id, "litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request, "proxy_server_request": proxy_server_request,
@ -114,6 +124,7 @@ def rerank(
"metadata": metadata, "metadata": metadata,
"preset_cache_key": None, "preset_cache_key": None,
"stream_response": {}, "stream_response": {},
**optional_params.model_dump(exclude_unset=True),
}, },
custom_llm_provider=_custom_llm_provider, custom_llm_provider=_custom_llm_provider,
) )

View file

@ -3532,7 +3532,8 @@ class Router:
elif isinstance(id, int): elif isinstance(id, int):
id = str(id) id = str(id)
total_tokens = completion_response["usage"].get("total_tokens", 0) _usage_obj = completion_response.get("usage")
total_tokens = _usage_obj.get("total_tokens", 0) if _usage_obj else 0
# ------------ # ------------
# Setup values # Setup values

View file

@ -50,6 +50,20 @@ def str_to_bool(value: str) -> Optional[bool]:
return None return None
def get_secret_str(
secret_name: str,
default_value: Optional[Union[str, bool]] = None,
) -> Optional[str]:
"""
Guarantees response from 'get_secret' is either string or none. Used for fixing linting errors.
"""
value = get_secret(secret_name=secret_name, default_value=default_value)
if value is not None and not isinstance(value, str):
return None
return value
def get_secret( def get_secret(
secret_name: str, secret_name: str,
default_value: Optional[Union[str, bool]] = None, default_value: Optional[Union[str, bool]] = None,

View file

@ -1268,3 +1268,41 @@ def test_completion_cost_fireworks_ai():
print(resp) print(resp)
cost = completion_cost(completion_response=resp) cost = completion_cost(completion_response=resp)
def test_completion_cost_vertex_llama3():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
from litellm.utils import Choices, Message, ModelResponse, Usage
response = ModelResponse(
id="2024-09-19|14:52:01.823070-07|3.10.13.64|-333502972",
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="My name is Litellm Bot, and I'm here to help you with any questions or tasks you may have. As for the weather, I'd be happy to provide you with the current conditions and forecast for your location. However, I'm a large language model, I don't have real-time access to your location, so I'll need you to tell me where you are or provide me with a specific location you're interested in knowing the weather for.\\n\\nOnce you provide me with that information, I can give you the current weather conditions, including temperature, humidity, wind speed, and more, as well as a forecast for the next few days. Just let me know how I can assist you!",
role="assistant",
tool_calls=None,
function_call=None,
),
)
],
created=1726782721,
model="vertex_ai/meta/llama3-405b-instruct-maas",
object="chat.completion",
system_fingerprint="",
usage=Usage(
completion_tokens=152,
prompt_tokens=27,
total_tokens=179,
completion_tokens_details=None,
),
)
model = "vertex_ai/meta/llama3-8b-instruct-maas"
cost = completion_cost(model=model, completion_response=response)
assert cost == 0

View file

@ -118,6 +118,8 @@ class CallTypes(Enum):
transcription = "transcription" transcription = "transcription"
aspeech = "aspeech" aspeech = "aspeech"
speech = "speech" speech = "speech"
rerank = "rerank"
arerank = "arerank"
class PassthroughCallTypes(Enum): class PassthroughCallTypes(Enum):

View file

@ -550,6 +550,10 @@ def function_setup(
or call_type == CallTypes.text_completion.value or call_type == CallTypes.text_completion.value
): ):
messages = args[0] if len(args) > 0 else kwargs["prompt"] messages = args[0] if len(args) > 0 else kwargs["prompt"]
elif (
call_type == CallTypes.rerank.value or call_type == CallTypes.arerank.value
):
messages = kwargs.get("query")
elif ( elif (
call_type == CallTypes.atranscription.value call_type == CallTypes.atranscription.value
or call_type == CallTypes.transcription.value or call_type == CallTypes.transcription.value

View file

@ -2350,6 +2350,26 @@
"mode": "chat", "mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
}, },
"vertex_ai/meta/llama3-70b-instruct-maas": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "vertex_ai-llama_models",
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/meta/llama3-8b-instruct-maas": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "vertex_ai-llama_models",
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/mistral-large@latest": { "vertex_ai/mistral-large@latest": {
"max_tokens": 8191, "max_tokens": 8191,
"max_input_tokens": 128000, "max_input_tokens": 128000,

View file

@ -48,6 +48,7 @@ def load_vertex_ai_credentials():
service_account_key_data["private_key_id"] = private_key_id service_account_key_data["private_key_id"] = private_key_id
service_account_key_data["private_key"] = private_key service_account_key_data["private_key"] = private_key
# print(f"service_account_key_data: {service_account_key_data}")
# Create a temporary file # Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary files # Write the updated content to the temporary files
@ -151,3 +152,46 @@ async def test_basic_vertex_ai_pass_through_streaming_with_spendlog():
) )
pass pass
@pytest.mark.asyncio
async def test_vertex_ai_pass_through_endpoint_context_caching():
import vertexai
from vertexai.generative_models import Part
from vertexai.preview import caching
import datetime
load_vertex_ai_credentials()
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai",
api_transport="rest",
)
system_instruction = """
You are an expert researcher. You always stick to the facts in the sources provided, and never make up new facts.
Now look at these research papers, and answer the following questions.
"""
contents = [
Part.from_uri(
"gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf",
mime_type="application/pdf",
),
Part.from_uri(
"gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf",
mime_type="application/pdf",
),
]
cached_content = caching.CachedContent.create(
model_name="gemini-1.5-pro-001",
system_instruction=system_instruction,
contents=contents,
ttl=datetime.timedelta(minutes=60),
# display_name="example-cache",
)
print(cached_content.name)

View file

@ -41,7 +41,6 @@ const AlertingSettings: React.FC<AlertingSettingsProps> = ({
alertingSettingsItem[] alertingSettingsItem[]
>([]); >([]);
console.log("INSIDE ALERTING SETTINGS");
useEffect(() => { useEffect(() => {
// get values // get values
if (!accessToken) { if (!accessToken) {
@ -59,6 +58,8 @@ const AlertingSettings: React.FC<AlertingSettingsProps> = ({
? { ...setting, field_value: newValue } ? { ...setting, field_value: newValue }
: setting : setting
); );
console.log(`updatedSettings: ${JSON.stringify(updatedSettings)}`)
setAlertingSettings(updatedSettings); setAlertingSettings(updatedSettings);
}; };
@ -67,6 +68,7 @@ const AlertingSettings: React.FC<AlertingSettingsProps> = ({
return; return;
} }
console.log(`formValues: ${formValues}`)
let fieldValue = formValues; let fieldValue = formValues;
if (fieldValue == null || fieldValue == undefined) { if (fieldValue == null || fieldValue == undefined) {
@ -74,14 +76,25 @@ const AlertingSettings: React.FC<AlertingSettingsProps> = ({
} }
const initialFormValues: Record<string, any> = {}; const initialFormValues: Record<string, any> = {};
alertingSettings.forEach((setting) => { alertingSettings.forEach((setting) => {
initialFormValues[setting.field_name] = setting.field_value; initialFormValues[setting.field_name] = setting.field_value;
}); });
// Merge initialFormValues with actual formValues // Merge initialFormValues with actual formValues
const mergedFormValues = { ...formValues, ...initialFormValues }; const mergedFormValues = { ...formValues, ...initialFormValues };
console.log(`mergedFormValues: ${JSON.stringify(mergedFormValues)}`)
const { slack_alerting, ...alertingArgs } = mergedFormValues;
console.log(`slack_alerting: ${slack_alerting}, alertingArgs: ${JSON.stringify(alertingArgs)}`)
try { try {
updateConfigFieldSetting(accessToken, "alerting_args", mergedFormValues); updateConfigFieldSetting(accessToken, "alerting_args", alertingArgs);
if (typeof slack_alerting === "boolean") {
if (slack_alerting == true) {
updateConfigFieldSetting(accessToken, "alerting", ["slack"]);
} else {
updateConfigFieldSetting(accessToken, "alerting", []);
}
}
// update value in state // update value in state
message.success("Wait 10s for proxy to update."); message.success("Wait 10s for proxy to update.");
} catch (error) { } catch (error) {
@ -107,7 +120,6 @@ const AlertingSettings: React.FC<AlertingSettingsProps> = ({
} }
: setting : setting
); );
console.log("INSIDE HANDLE RESET FIELD");
setAlertingSettings(updatedSettings); setAlertingSettings(updatedSettings);
} catch (error) { } catch (error) {
// do something // do something

View file

@ -1,7 +1,7 @@
import React from "react"; import React from "react";
import { Form, Input, InputNumber, Row, Col, Button as Button2 } from "antd"; import { Form, Input, InputNumber, Row, Col, Button as Button2 } from "antd";
import { TrashIcon, CheckCircleIcon } from "@heroicons/react/outline"; import { TrashIcon, CheckCircleIcon } from "@heroicons/react/outline";
import { Button, Badge, Icon, Text, TableRow, TableCell } from "@tremor/react"; import { Button, Badge, Icon, Text, TableRow, TableCell, Switch } from "@tremor/react";
import Paragraph from "antd/es/typography/Paragraph"; import Paragraph from "antd/es/typography/Paragraph";
interface AlertingSetting { interface AlertingSetting {
field_name: string; field_name: string;
@ -30,10 +30,15 @@ const DynamicForm: React.FC<DynamicFormProps> = ({
const [form] = Form.useForm(); const [form] = Form.useForm();
const onFinish = () => { const onFinish = () => {
console.log(`INSIDE ONFINISH`)
const formData = form.getFieldsValue(); const formData = form.getFieldsValue();
const isEmpty = Object.values(formData).some( const isEmpty = Object.entries(formData).every(([key, value]) => {
(value) => value === "" || value === null || value === undefined if (typeof value === 'boolean') {
); return false; // Boolean values are never considered empty
}
return value === '' || value === null || value === undefined;
});
console.log(`formData: ${JSON.stringify(formData)}, isEmpty: ${isEmpty}`)
if (!isEmpty) { if (!isEmpty) {
handleSubmit(formData); handleSubmit(formData);
} else { } else {
@ -68,6 +73,11 @@ const DynamicForm: React.FC<DynamicFormProps> = ({
value={value.field_value} value={value.field_value}
onChange={(e) => handleInputChange(value.field_name, e)} onChange={(e) => handleInputChange(value.field_name, e)}
/> />
) : value.field_type === "Boolean" ? (
<Switch
checked={value.field_value}
onChange={(checked) => handleInputChange(value.field_name, checked)}
/>
) : ( ) : (
<Input <Input
value={value.field_value} value={value.field_value}
@ -86,7 +96,7 @@ const DynamicForm: React.FC<DynamicFormProps> = ({
</TableCell> </TableCell>
) )
) : ( ) : (
<Form.Item name={value.field_name} className="mb-0"> <Form.Item name={value.field_name} className="mb-0" valuePropName={value.field_type === "Boolean" ? "checked" : "value"}>
<TableCell> <TableCell>
{value.field_type === "Integer" ? ( {value.field_type === "Integer" ? (
<InputNumber <InputNumber
@ -95,7 +105,17 @@ const DynamicForm: React.FC<DynamicFormProps> = ({
onChange={(e) => handleInputChange(value.field_name, e)} onChange={(e) => handleInputChange(value.field_name, e)}
className="p-0" className="p-0"
/> />
) : ( ) : value.field_type === "Boolean" ? (
<Switch
checked={value.field_value}
onChange={(checked) => {
handleInputChange(value.field_name, checked);
form.setFieldsValue({ [value.field_name]: checked });
}}
/>
) :(
<Input <Input
value={value.field_value} value={value.field_value}
onChange={(e) => handleInputChange(value.field_name, e)} onChange={(e) => handleInputChange(value.field_name, e)}