LiteLLM Minor Fixes & Improvements (09/19/2024) (#5793)

* fix(model_prices_and_context_window.json): add cost tracking for more vertex llama3.1 model

8b and 70b models

* fix(proxy/utils.py): handle data being none on pre-call hooks

* fix(proxy/): create views on initial proxy startup

fixes base case, where user starts proxy for first time

 Fixes https://github.com/BerriAI/litellm/issues/5756

* build(config.yml): fix vertex version for test

* feat(ui/): support enabling/disabling slack alerting

Allows admin to turn on/off slack alerting through ui

* feat(rerank/main.py): support langfuse logging

* fix(proxy/utils.py): fix linting errors

* fix(langfuse.py): log clean metadata

* test(tests): replace deprecated openai model
This commit is contained in:
Krish Dholakia 2024-09-20 08:19:52 -07:00 committed by GitHub
parent 696fc387d2
commit 3933fba41f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 645 additions and 94 deletions

View file

@ -558,6 +558,7 @@ jobs:
pip install "anyio==3.7.1"
pip install "asyncio==3.4.3"
pip install "PyGithub==1.59.1"
pip install "google-cloud-aiplatform==1.59.0"
- run:
name: Build Docker image
command: docker build -t my-app:latest -f Dockerfile.database .

View file

@ -51,22 +51,25 @@ async def check_view_exists():
print("LiteLLM_VerificationTokenView Created!") # noqa
sql_query = """
CREATE MATERIALIZED VIEW IF NOT EXISTS "MonthlyGlobalSpend" AS
try:
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
print("MonthlyGlobalSpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
SELECT
DATE_TRUNC('day', "startTime") AS date,
SUM("spend") AS spend
DATE("startTime") AS date,
SUM("spend") AS spend
FROM
"LiteLLM_SpendLogs"
"LiteLLM_SpendLogs"
WHERE
"startTime" >= CURRENT_DATE - INTERVAL '30 days'
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE_TRUNC('day', "startTime");
"""
# Execute the queries
await db.execute_raw(query=sql_query)
DATE("startTime");
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpend Created!") # noqa
print("MonthlyGlobalSpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")

View file

@ -10,6 +10,7 @@ from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
from litellm.secret_managers.main import str_to_bool
class LangFuseLogger:
@ -66,6 +67,11 @@ class LangFuseLogger:
project_id = None
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
upstream_langfuse_debug = (
str_to_bool(self.upstream_langfuse_debug)
if self.upstream_langfuse_debug is not None
else None
)
self.upstream_langfuse_secret_key = os.getenv(
"UPSTREAM_LANGFUSE_SECRET_KEY"
)
@ -80,7 +86,11 @@ class LangFuseLogger:
secret_key=self.upstream_langfuse_secret_key,
host=self.upstream_langfuse_host,
release=self.upstream_langfuse_release,
debug=self.upstream_langfuse_debug,
debug=(
upstream_langfuse_debug
if upstream_langfuse_debug is not None
else False
),
)
else:
self.upstream_langfuse = None
@ -175,6 +185,7 @@ class LangFuseLogger:
pass
# end of processing langfuse ########################
if (
level == "ERROR"
and status_message is not None
@ -208,6 +219,11 @@ class LangFuseLogger:
):
input = prompt
output = response_obj["text"]
elif response_obj is not None and isinstance(
response_obj, litellm.RerankResponse
):
input = prompt
output = response_obj.results
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "pass_through_endpoint"
@ -283,14 +299,14 @@ class LangFuseLogger:
input,
response_obj,
):
from langfuse.model import CreateGeneration, CreateTrace
from langfuse.model import CreateGeneration, CreateTrace # type: ignore
verbose_logger.warning(
"Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1"
)
trace = self.Langfuse.trace(
CreateTrace(
trace = self.Langfuse.trace( # type: ignore
CreateTrace( # type: ignore
name=metadata.get("generation_name", "litellm-completion"),
input=input,
output=output,
@ -336,6 +352,7 @@ class LangFuseLogger:
try:
tags = []
try:
optional_params.pop("metadata")
metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
@ -361,7 +378,7 @@ class LangFuseLogger:
langfuse.version.__version__
) >= Version("2.7.3")
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
print_verbose("Langfuse Layer Logging - logging to langfuse v2 ")
if supports_tags:
metadata_tags = metadata.pop("tags", [])
@ -519,11 +536,11 @@ class LangFuseLogger:
if key.lower() not in ["authorization", "cookie", "referer"]:
clean_headers[key] = value
clean_metadata["request"] = {
"method": method,
"url": url,
"headers": clean_headers,
}
# clean_metadata["request"] = {
# "method": method,
# "url": url,
# "headers": clean_headers,
# }
trace = self.Langfuse.trace(**trace_params)
# Log provider specific information as a span
@ -531,13 +548,19 @@ class LangFuseLogger:
generation_id = None
usage = None
if response_obj is not None and response_obj.get("id", None) is not None:
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
usage = {
"prompt_tokens": response_obj.usage.prompt_tokens,
"completion_tokens": response_obj.usage.completion_tokens,
"total_cost": cost if supports_costs else None,
}
if response_obj is not None:
if response_obj.get("id", None) is not None:
generation_id = litellm.utils.get_logging_id(
start_time, response_obj
)
_usage_obj = getattr(response_obj, "usage", None)
if _usage_obj:
usage = {
"prompt_tokens": _usage_obj.prompt_tokens,
"completion_tokens": _usage_obj.completion_tokens,
"total_cost": cost if supports_costs else None,
}
generation_name = clean_metadata.pop("generation_name", None)
if generation_name is None:
# if `generation_name` is None, use sensible default values

View file

@ -66,9 +66,6 @@ class CohereRerank(BaseLLM):
request_data_dict = request_data.dict(exclude_none=True)
if _is_async:
return self.async_rerank(request_data_dict=request_data_dict, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
## LOGGING
litellm_logging_obj.pre_call(
input=request_data_dict,
@ -79,6 +76,10 @@ class CohereRerank(BaseLLM):
"headers": headers,
},
)
if _is_async:
return self.async_rerank(request_data_dict=request_data_dict, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method
client = _get_httpx_client()
response = client.post(
api_base,

View file

@ -175,7 +175,7 @@ class VertexAIPartnerModels(BaseLLM):
client=client,
timeout=timeout,
encoding=encoding,
custom_llm_provider="vertex_ai_beta",
custom_llm_provider="vertex_ai",
)
except Exception as e:

View file

@ -2350,6 +2350,26 @@
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/meta/llama3-70b-instruct-maas": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "vertex_ai-llama_models",
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/meta/llama3-8b-instruct-maas": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "vertex_ai-llama_models",
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/mistral-large@latest": {
"max_tokens": 8191,
"max_input_tokens": 128000,

View file

@ -19,11 +19,11 @@ model_list:
- model_name: o1-preview
litellm_params:
model: o1-preview
- model_name: rerank-english-v3.0
litellm_params:
model: cohere/rerank-english-v3.0
api_key: os.environ/COHERE_API_KEY
litellm_settings:
cache: true
# cache_params:
# type: "redis"
# service_name: "mymaster"
# sentinel_nodes:
# - ["localhost", 26379]
success_callback: ["langfuse"]

View file

@ -0,0 +1,232 @@
from typing import TYPE_CHECKING, Any
from litellm import verbose_logger
if TYPE_CHECKING:
from prisma import Prisma
_db = Prisma
else:
_db = Any
async def create_missing_views(db: _db):
"""
--------------------------------------------------
NOTE: Copy of `litellm/db_scripts/create_views.py`.
--------------------------------------------------
Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db.
LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth
MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month
If the view doesn't exist, one will be created.
"""
try:
# Try to select one row from the view
await db.query_raw("""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""")
print("LiteLLM_VerificationTokenView Exists!") # noqa
except Exception as e:
# If an error occurs, the view does not exist, so create it
await db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
t.spend AS team_spend,
t.max_budget AS team_max_budget,
t.tpm_limit AS team_tpm_limit,
t.rpm_limit AS team_rpm_limit
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
"""
)
print("LiteLLM_VerificationTokenView Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
print("MonthlyGlobalSpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
SELECT
DATE("startTime") AS date,
SUM("spend") AS spend
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE("startTime");
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
print("Last30dKeysBySpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS
SELECT
L."api_key",
V."key_alias",
V."key_name",
SUM(L."spend") AS total_spend
FROM
"LiteLLM_SpendLogs" L
LEFT JOIN
"LiteLLM_VerificationToken" V
ON
L."api_key" = V."token"
WHERE
L."startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
L."api_key", V."key_alias", V."key_name"
ORDER BY
total_spend DESC;
"""
await db.execute_raw(query=sql_query)
print("Last30dKeysBySpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""")
print("Last30dModelsBySpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS
SELECT
"model",
SUM("spend") AS total_spend
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
AND "model" != ''
GROUP BY
"model"
ORDER BY
total_spend DESC;
"""
await db.execute_raw(query=sql_query)
print("Last30dModelsBySpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""")
print("MonthlyGlobalSpendPerKey Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS
SELECT
DATE("startTime") AS date,
SUM("spend") AS spend,
api_key as api_key
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE("startTime"),
api_key;
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpendPerKey Created!") # noqa
try:
await db.query_raw(
"""SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1"""
)
print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS
SELECT
DATE("startTime") AS date,
SUM("spend") AS spend,
api_key as api_key,
"user" as "user"
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE("startTime"),
"user",
api_key;
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM DailyTagSpend LIMIT 1""")
print("DailyTagSpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW DailyTagSpend AS
SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag,
DATE(s."startTime") AS spend_date,
COUNT(*) AS log_count,
SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs" s
GROUP BY individual_request_tag, DATE(s."startTime");
"""
await db.execute_raw(query=sql_query)
print("DailyTagSpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""")
print("Last30dTopEndUsersSpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE VIEW "Last30dTopEndUsersSpend" AS
SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs"
WHERE end_user <> '' AND end_user <> user
AND "startTime" >= CURRENT_DATE - INTERVAL '30 days'
GROUP BY end_user
ORDER BY total_spend DESC
LIMIT 100;
"""
await db.execute_raw(query=sql_query)
print("Last30dTopEndUsersSpend Created!") # noqa
return
async def should_create_missing_views(db: _db) -> bool:
"""
Run only on first time startup.
If SpendLogs table already has values, then don't create views on startup.
"""
sql_query = """
SELECT reltuples::BIGINT
FROM pg_class
WHERE oid = '"LiteLLM_SpendLogs"'::regclass;
"""
result = await db.query_raw(query=sql_query)
verbose_logger.debug("Estimated Row count of LiteLLM_SpendLogs = {}".format(result))
if (
result
and isinstance(result, list)
and len(result) > 0
and isinstance(result[0], dict)
and "reltuples" in result[0]
and result[0]["reltuples"]
and (result[0]["reltuples"] == 0 or result[0]["reltuples"] == -1)
):
verbose_logger.debug("Should create views")
return True
return False

View file

@ -34,6 +34,7 @@ from litellm.proxy._types import (
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.secret_managers.main import get_secret_str
from .streaming_handler import chunk_processor
from .success_handler import PassThroughEndpointLogging
@ -72,11 +73,11 @@ async def set_env_variables_in_header(custom_headers: dict):
if isinstance(
_langfuse_public_key, str
) and _langfuse_public_key.startswith("os.environ/"):
_langfuse_public_key = litellm.get_secret(_langfuse_public_key)
_langfuse_public_key = get_secret_str(_langfuse_public_key)
if isinstance(
_langfuse_secret_key, str
) and _langfuse_secret_key.startswith("os.environ/"):
_langfuse_secret_key = litellm.get_secret(_langfuse_secret_key)
_langfuse_secret_key = get_secret_str(_langfuse_secret_key)
headers["Authorization"] = "Basic " + b64encode(
f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8")
).decode("ascii")
@ -95,9 +96,10 @@ async def set_env_variables_in_header(custom_headers: dict):
"pass through endpoint - getting secret for variable name: %s",
_variable_name,
)
_secret_value = litellm.get_secret(_variable_name)
new_value = value.replace(_variable_name, _secret_value)
headers[key] = new_value
_secret_value = get_secret_str(_variable_name)
if _secret_value is not None:
new_value = value.replace(_variable_name, _secret_value)
headers[key] = new_value
return headers
@ -349,7 +351,7 @@ async def pass_through_request(
### CALL HOOKS ### - modify incoming data / reject request before calling the model
_parsed_body = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict,
data=_parsed_body or {},
data=_parsed_body,
call_type="pass_through_endpoint",
)
@ -576,7 +578,7 @@ def create_pass_through_route(
adapter_id = str(uuid.uuid4())
litellm.adapters = [{"id": adapter_id, "adapter": adapter}]
async def endpoint_func(
async def endpoint_func( # type: ignore
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
@ -647,9 +649,9 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list):
verbose_proxy_logger.debug(
"adding pass through endpoint: %s, dependencies: %s", _path, _dependencies
)
app.add_api_route(
app.add_api_route( # type: ignore
path=_path,
endpoint=create_pass_through_route(
endpoint=create_pass_through_route( # type: ignore
_path, _target, _custom_headers, _forward_headers, _dependencies
),
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],

View file

@ -2229,9 +2229,12 @@ class ProxyConfig:
and _general_settings.get("alerting", None) is not None
and isinstance(_general_settings["alerting"], list)
):
for alert in _general_settings["alerting"]:
if alert not in general_settings["alerting"]:
general_settings["alerting"].append(alert)
verbose_proxy_logger.debug(
"Overriding Default 'alerting' values with db 'alerting' values."
)
general_settings["alerting"] = _general_settings[
"alerting"
] # override yaml values with db
proxy_logging_obj.alerting = general_settings["alerting"]
proxy_logging_obj.slack_alerting_instance.alerting = general_settings[
"alerting"
@ -7779,10 +7782,13 @@ async def alerting_settings(
if db_general_settings is not None and db_general_settings.param_value is not None:
db_general_settings_dict = dict(db_general_settings.param_value)
alerting_args_dict: dict = db_general_settings_dict.get("alerting_args", {}) # type: ignore
alerting_values: Optional[list] = db_general_settings_dict.get("alerting") # type: ignore
else:
alerting_args_dict = {}
alerting_values = None
allowed_args = {
"slack_alerting": {"type": "Boolean"},
"daily_report_frequency": {"type": "Integer"},
"report_check_interval": {"type": "Integer"},
"budget_alert_ttl": {"type": "Integer"},
@ -7798,6 +7804,25 @@ async def alerting_settings(
return_val = []
is_slack_enabled = False
if general_settings.get("alerting") and isinstance(
general_settings["alerting"], list
):
if "slack" in general_settings["alerting"]:
is_slack_enabled = True
_response_obj = ConfigList(
field_name="slack_alerting",
field_type=allowed_args["slack_alerting"]["type"],
field_description="Enable slack alerting for monitoring proxy in production: llm outages, budgets, spend tracking failures.",
field_value=is_slack_enabled,
stored_in_db=True if alerting_values is not None else False,
field_default_value=None,
premium_field=False,
)
return_val.append(_response_obj)
for field_name, field_info in SlackAlertingArgs.model_fields.items():
if field_name in allowed_args:

View file

@ -1,21 +1,52 @@
import vertexai
from vertexai.preview.generative_models import GenerativeModel
# import datetime
LITE_LLM_ENDPOINT = "http://localhost:4000"
# import vertexai
# from vertexai.generative_models import Part
# from vertexai.preview import caching
# from vertexai.preview.generative_models import GenerativeModel
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai",
api_transport="rest",
)
# LITE_LLM_ENDPOINT = "http://localhost:4000"
model = GenerativeModel(model_name="gemini-1.5-flash-001")
response = model.generate_content(
"hi tell me a joke and a very long story", stream=True
)
# vertexai.init(
# project="adroit-crow-413218",
# location="us-central1",
# api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai",
# api_transport="rest",
# )
print("response", response)
# # model = GenerativeModel(model_name="gemini-1.5-flash-001")
# # response = model.generate_content(
# # "hi tell me a joke and a very long story", stream=True
# # )
for chunk in response:
print(chunk)
# # print("response", response)
# # for chunk in response:
# # print(chunk)
# system_instruction = """
# You are an expert researcher. You always stick to the facts in the sources provided, and never make up new facts.
# Now look at these research papers, and answer the following questions.
# """
# contents = [
# Part.from_uri(
# "gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf",
# mime_type="application/pdf",
# ),
# Part.from_uri(
# "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf",
# mime_type="application/pdf",
# ),
# ]
# cached_content = caching.CachedContent.create(
# model_name="gemini-1.5-pro-001",
# system_instruction=system_instruction,
# contents=contents,
# ttl=datetime.timedelta(minutes=60),
# # display_name="example-cache",
# )
# print(cached_content.name)

View file

@ -14,7 +14,7 @@ from datetime import datetime, timedelta
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from functools import wraps
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, overload
import backoff
import httpx
@ -51,6 +51,10 @@ from litellm.proxy._types import (
SpendLogsPayload,
UserAPIKeyAuth,
)
from litellm.proxy.db.create_views import (
create_missing_views,
should_create_missing_views,
)
from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck
from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter
from litellm.proxy.hooks.parallel_request_limiter import (
@ -365,6 +369,25 @@ class ProxyLogging:
return data
# The actual implementation of the function
@overload
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: None,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> None:
pass
@overload
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
@ -380,6 +403,23 @@ class ProxyLogging:
"rerank",
],
) -> dict:
pass
async def pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
data: Optional[dict],
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Optional[dict]:
"""
Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body.
@ -394,6 +434,9 @@ class ProxyLogging:
self.slack_alerting_instance.response_taking_too_long(request_data=data)
)
if data is None:
return None
try:
for callback in litellm.callbacks:
_callback = None
@ -418,7 +461,7 @@ class ProxyLogging:
response = await _callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"],
data=data,
data=data, # type: ignore
call_type=call_type,
)
if response is not None:
@ -434,7 +477,7 @@ class ProxyLogging:
response = await _callback.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=self.call_details["user_api_key_cache"],
data=data,
data=data, # type: ignore
call_type=call_type,
)
if response is not None:
@ -1021,20 +1064,24 @@ class PrismaClient:
"LiteLLM_VerificationTokenView Created in DB!"
)
else:
# don't block execution if these views are missing
# Convert lists to sets for efficient difference calculation
ret_view_names_set = (
set(ret[0]["view_names"]) if ret[0]["view_names"] else set()
)
expected_views_set = set(expected_views)
# Find missing views
missing_views = expected_views_set - ret_view_names_set
verbose_proxy_logger.warning(
"\n\n\033[93mNot all views exist in db, needed for UI 'Usage' tab. Missing={}.\nRun 'create_views.py' from https://github.com/BerriAI/litellm/tree/main/db_scripts to create missing views.\033[0m\n".format(
missing_views
should_create_views = await should_create_missing_views(db=self.db)
if should_create_views:
await create_missing_views(db=self.db)
else:
# don't block execution if these views are missing
# Convert lists to sets for efficient difference calculation
ret_view_names_set = (
set(ret[0]["view_names"]) if ret[0]["view_names"] else set()
)
expected_views_set = set(expected_views)
# Find missing views
missing_views = expected_views_set - ret_view_names_set
verbose_proxy_logger.warning(
"\n\n\033[93mNot all views exist in db, needed for UI 'Usage' tab. Missing={}.\nRun 'create_views.py' from https://github.com/BerriAI/litellm/tree/main/db_scripts to create missing views.\033[0m\n".format(
missing_views
)
)
)
except Exception as e:
raise

View file

@ -103,10 +103,20 @@ def rerank(
)
)
model_parameters = [
"top_n",
"rank_fields",
"return_documents",
"max_chunks_per_doc",
]
model_params_dict = {}
for k, v in optional_params.model_fields.items():
if k in model_parameters:
model_params_dict[k] = v
litellm_logging_obj.update_environment_variables(
model=model,
user=user,
optional_params=optional_params.model_dump(),
optional_params=model_params_dict,
litellm_params={
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,
@ -114,6 +124,7 @@ def rerank(
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
**optional_params.model_dump(exclude_unset=True),
},
custom_llm_provider=_custom_llm_provider,
)

View file

@ -3532,7 +3532,8 @@ class Router:
elif isinstance(id, int):
id = str(id)
total_tokens = completion_response["usage"].get("total_tokens", 0)
_usage_obj = completion_response.get("usage")
total_tokens = _usage_obj.get("total_tokens", 0) if _usage_obj else 0
# ------------
# Setup values

View file

@ -50,6 +50,20 @@ def str_to_bool(value: str) -> Optional[bool]:
return None
def get_secret_str(
secret_name: str,
default_value: Optional[Union[str, bool]] = None,
) -> Optional[str]:
"""
Guarantees response from 'get_secret' is either string or none. Used for fixing linting errors.
"""
value = get_secret(secret_name=secret_name, default_value=default_value)
if value is not None and not isinstance(value, str):
return None
return value
def get_secret(
secret_name: str,
default_value: Optional[Union[str, bool]] = None,

View file

@ -1268,3 +1268,41 @@ def test_completion_cost_fireworks_ai():
print(resp)
cost = completion_cost(completion_response=resp)
def test_completion_cost_vertex_llama3():
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")
from litellm.utils import Choices, Message, ModelResponse, Usage
response = ModelResponse(
id="2024-09-19|14:52:01.823070-07|3.10.13.64|-333502972",
choices=[
Choices(
finish_reason="stop",
index=0,
message=Message(
content="My name is Litellm Bot, and I'm here to help you with any questions or tasks you may have. As for the weather, I'd be happy to provide you with the current conditions and forecast for your location. However, I'm a large language model, I don't have real-time access to your location, so I'll need you to tell me where you are or provide me with a specific location you're interested in knowing the weather for.\\n\\nOnce you provide me with that information, I can give you the current weather conditions, including temperature, humidity, wind speed, and more, as well as a forecast for the next few days. Just let me know how I can assist you!",
role="assistant",
tool_calls=None,
function_call=None,
),
)
],
created=1726782721,
model="vertex_ai/meta/llama3-405b-instruct-maas",
object="chat.completion",
system_fingerprint="",
usage=Usage(
completion_tokens=152,
prompt_tokens=27,
total_tokens=179,
completion_tokens_details=None,
),
)
model = "vertex_ai/meta/llama3-8b-instruct-maas"
cost = completion_cost(model=model, completion_response=response)
assert cost == 0

View file

@ -118,6 +118,8 @@ class CallTypes(Enum):
transcription = "transcription"
aspeech = "aspeech"
speech = "speech"
rerank = "rerank"
arerank = "arerank"
class PassthroughCallTypes(Enum):

View file

@ -550,6 +550,10 @@ def function_setup(
or call_type == CallTypes.text_completion.value
):
messages = args[0] if len(args) > 0 else kwargs["prompt"]
elif (
call_type == CallTypes.rerank.value or call_type == CallTypes.arerank.value
):
messages = kwargs.get("query")
elif (
call_type == CallTypes.atranscription.value
or call_type == CallTypes.transcription.value

View file

@ -2350,6 +2350,26 @@
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/meta/llama3-70b-instruct-maas": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "vertex_ai-llama_models",
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/meta/llama3-8b-instruct-maas": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "vertex_ai-llama_models",
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/mistral-large@latest": {
"max_tokens": 8191,
"max_input_tokens": 128000,

View file

@ -48,6 +48,7 @@ def load_vertex_ai_credentials():
service_account_key_data["private_key_id"] = private_key_id
service_account_key_data["private_key"] = private_key
# print(f"service_account_key_data: {service_account_key_data}")
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary files
@ -151,3 +152,46 @@ async def test_basic_vertex_ai_pass_through_streaming_with_spendlog():
)
pass
@pytest.mark.asyncio
async def test_vertex_ai_pass_through_endpoint_context_caching():
import vertexai
from vertexai.generative_models import Part
from vertexai.preview import caching
import datetime
load_vertex_ai_credentials()
vertexai.init(
project="adroit-crow-413218",
location="us-central1",
api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai",
api_transport="rest",
)
system_instruction = """
You are an expert researcher. You always stick to the facts in the sources provided, and never make up new facts.
Now look at these research papers, and answer the following questions.
"""
contents = [
Part.from_uri(
"gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf",
mime_type="application/pdf",
),
Part.from_uri(
"gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf",
mime_type="application/pdf",
),
]
cached_content = caching.CachedContent.create(
model_name="gemini-1.5-pro-001",
system_instruction=system_instruction,
contents=contents,
ttl=datetime.timedelta(minutes=60),
# display_name="example-cache",
)
print(cached_content.name)

View file

@ -41,7 +41,6 @@ const AlertingSettings: React.FC<AlertingSettingsProps> = ({
alertingSettingsItem[]
>([]);
console.log("INSIDE ALERTING SETTINGS");
useEffect(() => {
// get values
if (!accessToken) {
@ -59,6 +58,8 @@ const AlertingSettings: React.FC<AlertingSettingsProps> = ({
? { ...setting, field_value: newValue }
: setting
);
console.log(`updatedSettings: ${JSON.stringify(updatedSettings)}`)
setAlertingSettings(updatedSettings);
};
@ -67,6 +68,7 @@ const AlertingSettings: React.FC<AlertingSettingsProps> = ({
return;
}
console.log(`formValues: ${formValues}`)
let fieldValue = formValues;
if (fieldValue == null || fieldValue == undefined) {
@ -74,14 +76,25 @@ const AlertingSettings: React.FC<AlertingSettingsProps> = ({
}
const initialFormValues: Record<string, any> = {};
alertingSettings.forEach((setting) => {
initialFormValues[setting.field_name] = setting.field_value;
});
// Merge initialFormValues with actual formValues
const mergedFormValues = { ...formValues, ...initialFormValues };
console.log(`mergedFormValues: ${JSON.stringify(mergedFormValues)}`)
const { slack_alerting, ...alertingArgs } = mergedFormValues;
console.log(`slack_alerting: ${slack_alerting}, alertingArgs: ${JSON.stringify(alertingArgs)}`)
try {
updateConfigFieldSetting(accessToken, "alerting_args", mergedFormValues);
updateConfigFieldSetting(accessToken, "alerting_args", alertingArgs);
if (typeof slack_alerting === "boolean") {
if (slack_alerting == true) {
updateConfigFieldSetting(accessToken, "alerting", ["slack"]);
} else {
updateConfigFieldSetting(accessToken, "alerting", []);
}
}
// update value in state
message.success("Wait 10s for proxy to update.");
} catch (error) {
@ -107,7 +120,6 @@ const AlertingSettings: React.FC<AlertingSettingsProps> = ({
}
: setting
);
console.log("INSIDE HANDLE RESET FIELD");
setAlertingSettings(updatedSettings);
} catch (error) {
// do something

View file

@ -1,7 +1,7 @@
import React from "react";
import { Form, Input, InputNumber, Row, Col, Button as Button2 } from "antd";
import { TrashIcon, CheckCircleIcon } from "@heroicons/react/outline";
import { Button, Badge, Icon, Text, TableRow, TableCell } from "@tremor/react";
import { Button, Badge, Icon, Text, TableRow, TableCell, Switch } from "@tremor/react";
import Paragraph from "antd/es/typography/Paragraph";
interface AlertingSetting {
field_name: string;
@ -30,10 +30,15 @@ const DynamicForm: React.FC<DynamicFormProps> = ({
const [form] = Form.useForm();
const onFinish = () => {
console.log(`INSIDE ONFINISH`)
const formData = form.getFieldsValue();
const isEmpty = Object.values(formData).some(
(value) => value === "" || value === null || value === undefined
);
const isEmpty = Object.entries(formData).every(([key, value]) => {
if (typeof value === 'boolean') {
return false; // Boolean values are never considered empty
}
return value === '' || value === null || value === undefined;
});
console.log(`formData: ${JSON.stringify(formData)}, isEmpty: ${isEmpty}`)
if (!isEmpty) {
handleSubmit(formData);
} else {
@ -68,6 +73,11 @@ const DynamicForm: React.FC<DynamicFormProps> = ({
value={value.field_value}
onChange={(e) => handleInputChange(value.field_name, e)}
/>
) : value.field_type === "Boolean" ? (
<Switch
checked={value.field_value}
onChange={(checked) => handleInputChange(value.field_name, checked)}
/>
) : (
<Input
value={value.field_value}
@ -86,7 +96,7 @@ const DynamicForm: React.FC<DynamicFormProps> = ({
</TableCell>
)
) : (
<Form.Item name={value.field_name} className="mb-0">
<Form.Item name={value.field_name} className="mb-0" valuePropName={value.field_type === "Boolean" ? "checked" : "value"}>
<TableCell>
{value.field_type === "Integer" ? (
<InputNumber
@ -95,7 +105,17 @@ const DynamicForm: React.FC<DynamicFormProps> = ({
onChange={(e) => handleInputChange(value.field_name, e)}
className="p-0"
/>
) : (
) : value.field_type === "Boolean" ? (
<Switch
checked={value.field_value}
onChange={(checked) => {
handleInputChange(value.field_name, checked);
form.setFieldsValue({ [value.field_name]: checked });
}}
/>
) :(
<Input
value={value.field_value}
onChange={(e) => handleInputChange(value.field_name, e)}