diff --git a/.circleci/config.yml b/.circleci/config.yml index acd3a8058..6bddd80f1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -558,6 +558,7 @@ jobs: pip install "anyio==3.7.1" pip install "asyncio==3.4.3" pip install "PyGithub==1.59.1" + pip install "google-cloud-aiplatform==1.59.0" - run: name: Build Docker image command: docker build -t my-app:latest -f Dockerfile.database . diff --git a/db_scripts/create_views.py b/db_scripts/create_views.py index cbf00605f..1444ab96d 100644 --- a/db_scripts/create_views.py +++ b/db_scripts/create_views.py @@ -51,22 +51,25 @@ async def check_view_exists(): print("LiteLLM_VerificationTokenView Created!") # noqa - sql_query = """ - CREATE MATERIALIZED VIEW IF NOT EXISTS "MonthlyGlobalSpend" AS + try: + await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""") + print("MonthlyGlobalSpend Exists!") # noqa + except Exception as e: + sql_query = """ + CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS SELECT - DATE_TRUNC('day', "startTime") AS date, - SUM("spend") AS spend + DATE("startTime") AS date, + SUM("spend") AS spend FROM - "LiteLLM_SpendLogs" + "LiteLLM_SpendLogs" WHERE - "startTime" >= CURRENT_DATE - INTERVAL '30 days' + "startTime" >= (CURRENT_DATE - INTERVAL '30 days') GROUP BY - DATE_TRUNC('day', "startTime"); - """ - # Execute the queries - await db.execute_raw(query=sql_query) + DATE("startTime"); + """ + await db.execute_raw(query=sql_query) - print("MonthlyGlobalSpend Created!") # noqa + print("MonthlyGlobalSpend Created!") # noqa try: await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""") diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index e04230e7e..8c6879424 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -10,6 +10,7 @@ from pydantic import BaseModel import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info +from litellm.secret_managers.main import str_to_bool class LangFuseLogger: @@ -66,6 +67,11 @@ class LangFuseLogger: project_id = None if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None: + upstream_langfuse_debug = ( + str_to_bool(self.upstream_langfuse_debug) + if self.upstream_langfuse_debug is not None + else None + ) self.upstream_langfuse_secret_key = os.getenv( "UPSTREAM_LANGFUSE_SECRET_KEY" ) @@ -80,7 +86,11 @@ class LangFuseLogger: secret_key=self.upstream_langfuse_secret_key, host=self.upstream_langfuse_host, release=self.upstream_langfuse_release, - debug=self.upstream_langfuse_debug, + debug=( + upstream_langfuse_debug + if upstream_langfuse_debug is not None + else False + ), ) else: self.upstream_langfuse = None @@ -175,6 +185,7 @@ class LangFuseLogger: pass # end of processing langfuse ######################## + if ( level == "ERROR" and status_message is not None @@ -208,6 +219,11 @@ class LangFuseLogger: ): input = prompt output = response_obj["text"] + elif response_obj is not None and isinstance( + response_obj, litellm.RerankResponse + ): + input = prompt + output = response_obj.results elif ( kwargs.get("call_type") is not None and kwargs.get("call_type") == "pass_through_endpoint" @@ -283,14 +299,14 @@ class LangFuseLogger: input, response_obj, ): - from langfuse.model import CreateGeneration, CreateTrace + from langfuse.model import CreateGeneration, CreateTrace # type: ignore verbose_logger.warning( "Please upgrade langfuse to v2.0.0 or higher: https://github.com/langfuse/langfuse-python/releases/tag/v2.0.1" ) - trace = self.Langfuse.trace( - CreateTrace( + trace = self.Langfuse.trace( # type: ignore + CreateTrace( # type: ignore name=metadata.get("generation_name", "litellm-completion"), input=input, output=output, @@ -336,6 +352,7 @@ class LangFuseLogger: try: tags = [] try: + optional_params.pop("metadata") metadata = copy.deepcopy( metadata ) # Avoid modifying the original metadata @@ -361,7 +378,7 @@ class LangFuseLogger: langfuse.version.__version__ ) >= Version("2.7.3") - print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ") + print_verbose("Langfuse Layer Logging - logging to langfuse v2 ") if supports_tags: metadata_tags = metadata.pop("tags", []) @@ -519,11 +536,11 @@ class LangFuseLogger: if key.lower() not in ["authorization", "cookie", "referer"]: clean_headers[key] = value - clean_metadata["request"] = { - "method": method, - "url": url, - "headers": clean_headers, - } + # clean_metadata["request"] = { + # "method": method, + # "url": url, + # "headers": clean_headers, + # } trace = self.Langfuse.trace(**trace_params) # Log provider specific information as a span @@ -531,13 +548,19 @@ class LangFuseLogger: generation_id = None usage = None - if response_obj is not None and response_obj.get("id", None) is not None: - generation_id = litellm.utils.get_logging_id(start_time, response_obj) - usage = { - "prompt_tokens": response_obj.usage.prompt_tokens, - "completion_tokens": response_obj.usage.completion_tokens, - "total_cost": cost if supports_costs else None, - } + if response_obj is not None: + if response_obj.get("id", None) is not None: + generation_id = litellm.utils.get_logging_id( + start_time, response_obj + ) + _usage_obj = getattr(response_obj, "usage", None) + + if _usage_obj: + usage = { + "prompt_tokens": _usage_obj.prompt_tokens, + "completion_tokens": _usage_obj.completion_tokens, + "total_cost": cost if supports_costs else None, + } generation_name = clean_metadata.pop("generation_name", None) if generation_name is None: # if `generation_name` is None, use sensible default values diff --git a/litellm/llms/cohere/rerank.py b/litellm/llms/cohere/rerank.py index 069cf3968..5332be00c 100644 --- a/litellm/llms/cohere/rerank.py +++ b/litellm/llms/cohere/rerank.py @@ -66,9 +66,6 @@ class CohereRerank(BaseLLM): request_data_dict = request_data.dict(exclude_none=True) - if _is_async: - return self.async_rerank(request_data_dict=request_data_dict, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method - ## LOGGING litellm_logging_obj.pre_call( input=request_data_dict, @@ -79,6 +76,10 @@ class CohereRerank(BaseLLM): "headers": headers, }, ) + + if _is_async: + return self.async_rerank(request_data_dict=request_data_dict, api_key=api_key, api_base=api_base, headers=headers) # type: ignore # Call async method + client = _get_httpx_client() response = client.post( api_base, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py index e26a95e8b..da54f6e1b 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py @@ -175,7 +175,7 @@ class VertexAIPartnerModels(BaseLLM): client=client, timeout=timeout, encoding=encoding, - custom_llm_provider="vertex_ai_beta", + custom_llm_provider="vertex_ai", ) except Exception as e: diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 2ec96c7b2..d83aeaaf7 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -2350,6 +2350,26 @@ "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" }, + "vertex_ai/meta/llama3-70b-instruct-maas": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 32000, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "vertex_ai-llama_models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" + }, + "vertex_ai/meta/llama3-8b-instruct-maas": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 32000, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "vertex_ai-llama_models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" + }, "vertex_ai/mistral-large@latest": { "max_tokens": 8191, "max_input_tokens": 128000, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 5773f9f51..1b811fe23 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -19,11 +19,11 @@ model_list: - model_name: o1-preview litellm_params: model: o1-preview - + - model_name: rerank-english-v3.0 + litellm_params: + model: cohere/rerank-english-v3.0 + api_key: os.environ/COHERE_API_KEY + + litellm_settings: - cache: true - # cache_params: - # type: "redis" - # service_name: "mymaster" - # sentinel_nodes: - # - ["localhost", 26379] \ No newline at end of file + success_callback: ["langfuse"] \ No newline at end of file diff --git a/litellm/proxy/db/create_views.py b/litellm/proxy/db/create_views.py new file mode 100644 index 000000000..a83587478 --- /dev/null +++ b/litellm/proxy/db/create_views.py @@ -0,0 +1,232 @@ +from typing import TYPE_CHECKING, Any + +from litellm import verbose_logger + +if TYPE_CHECKING: + from prisma import Prisma + + _db = Prisma +else: + _db = Any + + +async def create_missing_views(db: _db): + """ + -------------------------------------------------- + NOTE: Copy of `litellm/db_scripts/create_views.py`. + -------------------------------------------------- + Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db. + + LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth + + MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month + + If the view doesn't exist, one will be created. + """ + try: + # Try to select one row from the view + await db.query_raw("""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""") + print("LiteLLM_VerificationTokenView Exists!") # noqa + except Exception as e: + # If an error occurs, the view does not exist, so create it + await db.execute_raw( + """ + CREATE VIEW "LiteLLM_VerificationTokenView" AS + SELECT + v.*, + t.spend AS team_spend, + t.max_budget AS team_max_budget, + t.tpm_limit AS team_tpm_limit, + t.rpm_limit AS team_rpm_limit + FROM "LiteLLM_VerificationToken" v + LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; + """ + ) + + print("LiteLLM_VerificationTokenView Created!") # noqa + + try: + await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""") + print("MonthlyGlobalSpend Exists!") # noqa + except Exception as e: + sql_query = """ + CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS + SELECT + DATE("startTime") AS date, + SUM("spend") AS spend + FROM + "LiteLLM_SpendLogs" + WHERE + "startTime" >= (CURRENT_DATE - INTERVAL '30 days') + GROUP BY + DATE("startTime"); + """ + await db.execute_raw(query=sql_query) + + print("MonthlyGlobalSpend Created!") # noqa + + try: + await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""") + print("Last30dKeysBySpend Exists!") # noqa + except Exception as e: + sql_query = """ + CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS + SELECT + L."api_key", + V."key_alias", + V."key_name", + SUM(L."spend") AS total_spend + FROM + "LiteLLM_SpendLogs" L + LEFT JOIN + "LiteLLM_VerificationToken" V + ON + L."api_key" = V."token" + WHERE + L."startTime" >= (CURRENT_DATE - INTERVAL '30 days') + GROUP BY + L."api_key", V."key_alias", V."key_name" + ORDER BY + total_spend DESC; + """ + await db.execute_raw(query=sql_query) + + print("Last30dKeysBySpend Created!") # noqa + + try: + await db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""") + print("Last30dModelsBySpend Exists!") # noqa + except Exception as e: + sql_query = """ + CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS + SELECT + "model", + SUM("spend") AS total_spend + FROM + "LiteLLM_SpendLogs" + WHERE + "startTime" >= (CURRENT_DATE - INTERVAL '30 days') + AND "model" != '' + GROUP BY + "model" + ORDER BY + total_spend DESC; + """ + await db.execute_raw(query=sql_query) + + print("Last30dModelsBySpend Created!") # noqa + try: + await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""") + print("MonthlyGlobalSpendPerKey Exists!") # noqa + except Exception as e: + sql_query = """ + CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS + SELECT + DATE("startTime") AS date, + SUM("spend") AS spend, + api_key as api_key + FROM + "LiteLLM_SpendLogs" + WHERE + "startTime" >= (CURRENT_DATE - INTERVAL '30 days') + GROUP BY + DATE("startTime"), + api_key; + """ + await db.execute_raw(query=sql_query) + + print("MonthlyGlobalSpendPerKey Created!") # noqa + try: + await db.query_raw( + """SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1""" + ) + print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa + except Exception as e: + sql_query = """ + CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS + SELECT + DATE("startTime") AS date, + SUM("spend") AS spend, + api_key as api_key, + "user" as "user" + FROM + "LiteLLM_SpendLogs" + WHERE + "startTime" >= (CURRENT_DATE - INTERVAL '30 days') + GROUP BY + DATE("startTime"), + "user", + api_key; + """ + await db.execute_raw(query=sql_query) + + print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa + + try: + await db.query_raw("""SELECT 1 FROM DailyTagSpend LIMIT 1""") + print("DailyTagSpend Exists!") # noqa + except Exception as e: + sql_query = """ + CREATE OR REPLACE VIEW DailyTagSpend AS + SELECT + jsonb_array_elements_text(request_tags) AS individual_request_tag, + DATE(s."startTime") AS spend_date, + COUNT(*) AS log_count, + SUM(spend) AS total_spend + FROM "LiteLLM_SpendLogs" s + GROUP BY individual_request_tag, DATE(s."startTime"); + """ + await db.execute_raw(query=sql_query) + + print("DailyTagSpend Created!") # noqa + + try: + await db.query_raw("""SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""") + print("Last30dTopEndUsersSpend Exists!") # noqa + except Exception as e: + sql_query = """ + CREATE VIEW "Last30dTopEndUsersSpend" AS + SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend + FROM "LiteLLM_SpendLogs" + WHERE end_user <> '' AND end_user <> user + AND "startTime" >= CURRENT_DATE - INTERVAL '30 days' + GROUP BY end_user + ORDER BY total_spend DESC + LIMIT 100; + """ + await db.execute_raw(query=sql_query) + + print("Last30dTopEndUsersSpend Created!") # noqa + + return + + +async def should_create_missing_views(db: _db) -> bool: + """ + Run only on first time startup. + + If SpendLogs table already has values, then don't create views on startup. + """ + + sql_query = """ + SELECT reltuples::BIGINT + FROM pg_class + WHERE oid = '"LiteLLM_SpendLogs"'::regclass; + """ + + result = await db.query_raw(query=sql_query) + + verbose_logger.debug("Estimated Row count of LiteLLM_SpendLogs = {}".format(result)) + if ( + result + and isinstance(result, list) + and len(result) > 0 + and isinstance(result[0], dict) + and "reltuples" in result[0] + and result[0]["reltuples"] + and (result[0]["reltuples"] == 0 or result[0]["reltuples"] == -1) + ): + verbose_logger.debug("Should create views") + return True + + return False diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 388f91ed2..510bec43e 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -34,6 +34,7 @@ from litellm.proxy._types import ( UserAPIKeyAuth, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.secret_managers.main import get_secret_str from .streaming_handler import chunk_processor from .success_handler import PassThroughEndpointLogging @@ -72,11 +73,11 @@ async def set_env_variables_in_header(custom_headers: dict): if isinstance( _langfuse_public_key, str ) and _langfuse_public_key.startswith("os.environ/"): - _langfuse_public_key = litellm.get_secret(_langfuse_public_key) + _langfuse_public_key = get_secret_str(_langfuse_public_key) if isinstance( _langfuse_secret_key, str ) and _langfuse_secret_key.startswith("os.environ/"): - _langfuse_secret_key = litellm.get_secret(_langfuse_secret_key) + _langfuse_secret_key = get_secret_str(_langfuse_secret_key) headers["Authorization"] = "Basic " + b64encode( f"{_langfuse_public_key}:{_langfuse_secret_key}".encode("utf-8") ).decode("ascii") @@ -95,9 +96,10 @@ async def set_env_variables_in_header(custom_headers: dict): "pass through endpoint - getting secret for variable name: %s", _variable_name, ) - _secret_value = litellm.get_secret(_variable_name) - new_value = value.replace(_variable_name, _secret_value) - headers[key] = new_value + _secret_value = get_secret_str(_variable_name) + if _secret_value is not None: + new_value = value.replace(_variable_name, _secret_value) + headers[key] = new_value return headers @@ -349,7 +351,7 @@ async def pass_through_request( ### CALL HOOKS ### - modify incoming data / reject request before calling the model _parsed_body = await proxy_logging_obj.pre_call_hook( user_api_key_dict=user_api_key_dict, - data=_parsed_body or {}, + data=_parsed_body, call_type="pass_through_endpoint", ) @@ -576,7 +578,7 @@ def create_pass_through_route( adapter_id = str(uuid.uuid4()) litellm.adapters = [{"id": adapter_id, "adapter": adapter}] - async def endpoint_func( + async def endpoint_func( # type: ignore request: Request, fastapi_response: Response, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), @@ -647,9 +649,9 @@ async def initialize_pass_through_endpoints(pass_through_endpoints: list): verbose_proxy_logger.debug( "adding pass through endpoint: %s, dependencies: %s", _path, _dependencies ) - app.add_api_route( + app.add_api_route( # type: ignore path=_path, - endpoint=create_pass_through_route( + endpoint=create_pass_through_route( # type: ignore _path, _target, _custom_headers, _forward_headers, _dependencies ), methods=["GET", "POST", "PUT", "DELETE", "PATCH"], diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8c6bc618c..2cb76eb92 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2229,9 +2229,12 @@ class ProxyConfig: and _general_settings.get("alerting", None) is not None and isinstance(_general_settings["alerting"], list) ): - for alert in _general_settings["alerting"]: - if alert not in general_settings["alerting"]: - general_settings["alerting"].append(alert) + verbose_proxy_logger.debug( + "Overriding Default 'alerting' values with db 'alerting' values." + ) + general_settings["alerting"] = _general_settings[ + "alerting" + ] # override yaml values with db proxy_logging_obj.alerting = general_settings["alerting"] proxy_logging_obj.slack_alerting_instance.alerting = general_settings[ "alerting" @@ -7779,10 +7782,13 @@ async def alerting_settings( if db_general_settings is not None and db_general_settings.param_value is not None: db_general_settings_dict = dict(db_general_settings.param_value) alerting_args_dict: dict = db_general_settings_dict.get("alerting_args", {}) # type: ignore + alerting_values: Optional[list] = db_general_settings_dict.get("alerting") # type: ignore else: alerting_args_dict = {} + alerting_values = None allowed_args = { + "slack_alerting": {"type": "Boolean"}, "daily_report_frequency": {"type": "Integer"}, "report_check_interval": {"type": "Integer"}, "budget_alert_ttl": {"type": "Integer"}, @@ -7798,6 +7804,25 @@ async def alerting_settings( return_val = [] + is_slack_enabled = False + + if general_settings.get("alerting") and isinstance( + general_settings["alerting"], list + ): + if "slack" in general_settings["alerting"]: + is_slack_enabled = True + + _response_obj = ConfigList( + field_name="slack_alerting", + field_type=allowed_args["slack_alerting"]["type"], + field_description="Enable slack alerting for monitoring proxy in production: llm outages, budgets, spend tracking failures.", + field_value=is_slack_enabled, + stored_in_db=True if alerting_values is not None else False, + field_default_value=None, + premium_field=False, + ) + return_val.append(_response_obj) + for field_name, field_info in SlackAlertingArgs.model_fields.items(): if field_name in allowed_args: diff --git a/litellm/proxy/tests/test_vertex_sdk_forward_headers.py b/litellm/proxy/tests/test_vertex_sdk_forward_headers.py index 7aa87905a..a5e59d5a1 100644 --- a/litellm/proxy/tests/test_vertex_sdk_forward_headers.py +++ b/litellm/proxy/tests/test_vertex_sdk_forward_headers.py @@ -1,21 +1,52 @@ -import vertexai -from vertexai.preview.generative_models import GenerativeModel +# import datetime -LITE_LLM_ENDPOINT = "http://localhost:4000" +# import vertexai +# from vertexai.generative_models import Part +# from vertexai.preview import caching +# from vertexai.preview.generative_models import GenerativeModel -vertexai.init( - project="adroit-crow-413218", - location="us-central1", - api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai", - api_transport="rest", -) +# LITE_LLM_ENDPOINT = "http://localhost:4000" -model = GenerativeModel(model_name="gemini-1.5-flash-001") -response = model.generate_content( - "hi tell me a joke and a very long story", stream=True -) +# vertexai.init( +# project="adroit-crow-413218", +# location="us-central1", +# api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai", +# api_transport="rest", +# ) -print("response", response) +# # model = GenerativeModel(model_name="gemini-1.5-flash-001") +# # response = model.generate_content( +# # "hi tell me a joke and a very long story", stream=True +# # ) -for chunk in response: - print(chunk) +# # print("response", response) + +# # for chunk in response: +# # print(chunk) + + +# system_instruction = """ +# You are an expert researcher. You always stick to the facts in the sources provided, and never make up new facts. +# Now look at these research papers, and answer the following questions. +# """ + +# contents = [ +# Part.from_uri( +# "gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf", +# mime_type="application/pdf", +# ), +# Part.from_uri( +# "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf", +# mime_type="application/pdf", +# ), +# ] + +# cached_content = caching.CachedContent.create( +# model_name="gemini-1.5-pro-001", +# system_instruction=system_instruction, +# contents=contents, +# ttl=datetime.timedelta(minutes=60), +# # display_name="example-cache", +# ) + +# print(cached_content.name) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 389cc3aa2..c2948d41e 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -14,7 +14,7 @@ from datetime import datetime, timedelta from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from functools import wraps -from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union, overload import backoff import httpx @@ -51,6 +51,10 @@ from litellm.proxy._types import ( SpendLogsPayload, UserAPIKeyAuth, ) +from litellm.proxy.db.create_views import ( + create_missing_views, + should_create_missing_views, +) from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter from litellm.proxy.hooks.parallel_request_limiter import ( @@ -365,6 +369,25 @@ class ProxyLogging: return data # The actual implementation of the function + @overload + async def pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + data: None, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + "rerank", + ], + ) -> None: + pass + + @overload async def pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, @@ -380,6 +403,23 @@ class ProxyLogging: "rerank", ], ) -> dict: + pass + + async def pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + data: Optional[dict], + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + "rerank", + ], + ) -> Optional[dict]: """ Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. @@ -394,6 +434,9 @@ class ProxyLogging: self.slack_alerting_instance.response_taking_too_long(request_data=data) ) + if data is None: + return None + try: for callback in litellm.callbacks: _callback = None @@ -418,7 +461,7 @@ class ProxyLogging: response = await _callback.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], - data=data, + data=data, # type: ignore call_type=call_type, ) if response is not None: @@ -434,7 +477,7 @@ class ProxyLogging: response = await _callback.async_pre_call_hook( user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], - data=data, + data=data, # type: ignore call_type=call_type, ) if response is not None: @@ -1021,20 +1064,24 @@ class PrismaClient: "LiteLLM_VerificationTokenView Created in DB!" ) else: - # don't block execution if these views are missing - # Convert lists to sets for efficient difference calculation - ret_view_names_set = ( - set(ret[0]["view_names"]) if ret[0]["view_names"] else set() - ) - expected_views_set = set(expected_views) - # Find missing views - missing_views = expected_views_set - ret_view_names_set - - verbose_proxy_logger.warning( - "\n\n\033[93mNot all views exist in db, needed for UI 'Usage' tab. Missing={}.\nRun 'create_views.py' from https://github.com/BerriAI/litellm/tree/main/db_scripts to create missing views.\033[0m\n".format( - missing_views + should_create_views = await should_create_missing_views(db=self.db) + if should_create_views: + await create_missing_views(db=self.db) + else: + # don't block execution if these views are missing + # Convert lists to sets for efficient difference calculation + ret_view_names_set = ( + set(ret[0]["view_names"]) if ret[0]["view_names"] else set() + ) + expected_views_set = set(expected_views) + # Find missing views + missing_views = expected_views_set - ret_view_names_set + + verbose_proxy_logger.warning( + "\n\n\033[93mNot all views exist in db, needed for UI 'Usage' tab. Missing={}.\nRun 'create_views.py' from https://github.com/BerriAI/litellm/tree/main/db_scripts to create missing views.\033[0m\n".format( + missing_views + ) ) - ) except Exception as e: raise diff --git a/litellm/rerank_api/main.py b/litellm/rerank_api/main.py index d58e3c34f..1498e8b76 100644 --- a/litellm/rerank_api/main.py +++ b/litellm/rerank_api/main.py @@ -103,10 +103,20 @@ def rerank( ) ) + model_parameters = [ + "top_n", + "rank_fields", + "return_documents", + "max_chunks_per_doc", + ] + model_params_dict = {} + for k, v in optional_params.model_fields.items(): + if k in model_parameters: + model_params_dict[k] = v litellm_logging_obj.update_environment_variables( model=model, user=user, - optional_params=optional_params.model_dump(), + optional_params=model_params_dict, litellm_params={ "litellm_call_id": litellm_call_id, "proxy_server_request": proxy_server_request, @@ -114,6 +124,7 @@ def rerank( "metadata": metadata, "preset_cache_key": None, "stream_response": {}, + **optional_params.model_dump(exclude_unset=True), }, custom_llm_provider=_custom_llm_provider, ) diff --git a/litellm/router.py b/litellm/router.py index 780eeb3e7..0159a0b17 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -3532,7 +3532,8 @@ class Router: elif isinstance(id, int): id = str(id) - total_tokens = completion_response["usage"].get("total_tokens", 0) + _usage_obj = completion_response.get("usage") + total_tokens = _usage_obj.get("total_tokens", 0) if _usage_obj else 0 # ------------ # Setup values diff --git a/litellm/secret_managers/main.py b/litellm/secret_managers/main.py index 13a61c665..e98140768 100644 --- a/litellm/secret_managers/main.py +++ b/litellm/secret_managers/main.py @@ -50,6 +50,20 @@ def str_to_bool(value: str) -> Optional[bool]: return None +def get_secret_str( + secret_name: str, + default_value: Optional[Union[str, bool]] = None, +) -> Optional[str]: + """ + Guarantees response from 'get_secret' is either string or none. Used for fixing linting errors. + """ + value = get_secret(secret_name=secret_name, default_value=default_value) + if value is not None and not isinstance(value, str): + return None + + return value + + def get_secret( secret_name: str, default_value: Optional[Union[str, bool]] = None, diff --git a/litellm/tests/test_completion_cost.py b/litellm/tests/test_completion_cost.py index 82257fad8..d2ffaa4c9 100644 --- a/litellm/tests/test_completion_cost.py +++ b/litellm/tests/test_completion_cost.py @@ -1268,3 +1268,41 @@ def test_completion_cost_fireworks_ai(): print(resp) cost = completion_cost(completion_response=resp) + + +def test_completion_cost_vertex_llama3(): + os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" + litellm.model_cost = litellm.get_model_cost_map(url="") + + from litellm.utils import Choices, Message, ModelResponse, Usage + + response = ModelResponse( + id="2024-09-19|14:52:01.823070-07|3.10.13.64|-333502972", + choices=[ + Choices( + finish_reason="stop", + index=0, + message=Message( + content="My name is Litellm Bot, and I'm here to help you with any questions or tasks you may have. As for the weather, I'd be happy to provide you with the current conditions and forecast for your location. However, I'm a large language model, I don't have real-time access to your location, so I'll need you to tell me where you are or provide me with a specific location you're interested in knowing the weather for.\\n\\nOnce you provide me with that information, I can give you the current weather conditions, including temperature, humidity, wind speed, and more, as well as a forecast for the next few days. Just let me know how I can assist you!", + role="assistant", + tool_calls=None, + function_call=None, + ), + ) + ], + created=1726782721, + model="vertex_ai/meta/llama3-405b-instruct-maas", + object="chat.completion", + system_fingerprint="", + usage=Usage( + completion_tokens=152, + prompt_tokens=27, + total_tokens=179, + completion_tokens_details=None, + ), + ) + + model = "vertex_ai/meta/llama3-8b-instruct-maas" + cost = completion_cost(model=model, completion_response=response) + + assert cost == 0 diff --git a/litellm/types/utils.py b/litellm/types/utils.py index d606ffeef..54a4a920a 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -118,6 +118,8 @@ class CallTypes(Enum): transcription = "transcription" aspeech = "aspeech" speech = "speech" + rerank = "rerank" + arerank = "arerank" class PassthroughCallTypes(Enum): diff --git a/litellm/utils.py b/litellm/utils.py index 7d6d5223c..a66a7ff70 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -550,6 +550,10 @@ def function_setup( or call_type == CallTypes.text_completion.value ): messages = args[0] if len(args) > 0 else kwargs["prompt"] + elif ( + call_type == CallTypes.rerank.value or call_type == CallTypes.arerank.value + ): + messages = kwargs.get("query") elif ( call_type == CallTypes.atranscription.value or call_type == CallTypes.transcription.value diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 2ec96c7b2..d83aeaaf7 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -2350,6 +2350,26 @@ "mode": "chat", "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" }, + "vertex_ai/meta/llama3-70b-instruct-maas": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 32000, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "vertex_ai-llama_models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" + }, + "vertex_ai/meta/llama3-8b-instruct-maas": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 32000, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "vertex_ai-llama_models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" + }, "vertex_ai/mistral-large@latest": { "max_tokens": 8191, "max_input_tokens": 128000, diff --git a/tests/pass_through_tests/test_vertex_ai.py b/tests/pass_through_tests/test_vertex_ai.py index 40998dc2f..d0c5088d9 100644 --- a/tests/pass_through_tests/test_vertex_ai.py +++ b/tests/pass_through_tests/test_vertex_ai.py @@ -48,6 +48,7 @@ def load_vertex_ai_credentials(): service_account_key_data["private_key_id"] = private_key_id service_account_key_data["private_key"] = private_key + # print(f"service_account_key_data: {service_account_key_data}") # Create a temporary file with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: # Write the updated content to the temporary files @@ -151,3 +152,46 @@ async def test_basic_vertex_ai_pass_through_streaming_with_spendlog(): ) pass + + +@pytest.mark.asyncio +async def test_vertex_ai_pass_through_endpoint_context_caching(): + import vertexai + from vertexai.generative_models import Part + from vertexai.preview import caching + import datetime + + load_vertex_ai_credentials() + + vertexai.init( + project="adroit-crow-413218", + location="us-central1", + api_endpoint=f"{LITE_LLM_ENDPOINT}/vertex-ai", + api_transport="rest", + ) + + system_instruction = """ + You are an expert researcher. You always stick to the facts in the sources provided, and never make up new facts. + Now look at these research papers, and answer the following questions. + """ + + contents = [ + Part.from_uri( + "gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf", + mime_type="application/pdf", + ), + Part.from_uri( + "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf", + mime_type="application/pdf", + ), + ] + + cached_content = caching.CachedContent.create( + model_name="gemini-1.5-pro-001", + system_instruction=system_instruction, + contents=contents, + ttl=datetime.timedelta(minutes=60), + # display_name="example-cache", + ) + + print(cached_content.name) diff --git a/ui/litellm-dashboard/src/components/alerting/alerting_settings.tsx b/ui/litellm-dashboard/src/components/alerting/alerting_settings.tsx index 2941f133c..1d0ec677d 100644 --- a/ui/litellm-dashboard/src/components/alerting/alerting_settings.tsx +++ b/ui/litellm-dashboard/src/components/alerting/alerting_settings.tsx @@ -41,7 +41,6 @@ const AlertingSettings: React.FC = ({ alertingSettingsItem[] >([]); - console.log("INSIDE ALERTING SETTINGS"); useEffect(() => { // get values if (!accessToken) { @@ -59,6 +58,8 @@ const AlertingSettings: React.FC = ({ ? { ...setting, field_value: newValue } : setting ); + + console.log(`updatedSettings: ${JSON.stringify(updatedSettings)}`) setAlertingSettings(updatedSettings); }; @@ -67,6 +68,7 @@ const AlertingSettings: React.FC = ({ return; } + console.log(`formValues: ${formValues}`) let fieldValue = formValues; if (fieldValue == null || fieldValue == undefined) { @@ -74,14 +76,25 @@ const AlertingSettings: React.FC = ({ } const initialFormValues: Record = {}; + alertingSettings.forEach((setting) => { initialFormValues[setting.field_name] = setting.field_value; }); // Merge initialFormValues with actual formValues const mergedFormValues = { ...formValues, ...initialFormValues }; + console.log(`mergedFormValues: ${JSON.stringify(mergedFormValues)}`) + const { slack_alerting, ...alertingArgs } = mergedFormValues; + console.log(`slack_alerting: ${slack_alerting}, alertingArgs: ${JSON.stringify(alertingArgs)}`) try { - updateConfigFieldSetting(accessToken, "alerting_args", mergedFormValues); + updateConfigFieldSetting(accessToken, "alerting_args", alertingArgs); + if (typeof slack_alerting === "boolean") { + if (slack_alerting == true) { + updateConfigFieldSetting(accessToken, "alerting", ["slack"]); + } else { + updateConfigFieldSetting(accessToken, "alerting", []); + } + } // update value in state message.success("Wait 10s for proxy to update."); } catch (error) { @@ -107,7 +120,6 @@ const AlertingSettings: React.FC = ({ } : setting ); - console.log("INSIDE HANDLE RESET FIELD"); setAlertingSettings(updatedSettings); } catch (error) { // do something diff --git a/ui/litellm-dashboard/src/components/alerting/dynamic_form.tsx b/ui/litellm-dashboard/src/components/alerting/dynamic_form.tsx index 8804f623d..673f63b3b 100644 --- a/ui/litellm-dashboard/src/components/alerting/dynamic_form.tsx +++ b/ui/litellm-dashboard/src/components/alerting/dynamic_form.tsx @@ -1,7 +1,7 @@ import React from "react"; import { Form, Input, InputNumber, Row, Col, Button as Button2 } from "antd"; import { TrashIcon, CheckCircleIcon } from "@heroicons/react/outline"; -import { Button, Badge, Icon, Text, TableRow, TableCell } from "@tremor/react"; +import { Button, Badge, Icon, Text, TableRow, TableCell, Switch } from "@tremor/react"; import Paragraph from "antd/es/typography/Paragraph"; interface AlertingSetting { field_name: string; @@ -30,10 +30,15 @@ const DynamicForm: React.FC = ({ const [form] = Form.useForm(); const onFinish = () => { + console.log(`INSIDE ONFINISH`) const formData = form.getFieldsValue(); - const isEmpty = Object.values(formData).some( - (value) => value === "" || value === null || value === undefined - ); + const isEmpty = Object.entries(formData).every(([key, value]) => { + if (typeof value === 'boolean') { + return false; // Boolean values are never considered empty + } + return value === '' || value === null || value === undefined; + }); + console.log(`formData: ${JSON.stringify(formData)}, isEmpty: ${isEmpty}`) if (!isEmpty) { handleSubmit(formData); } else { @@ -68,6 +73,11 @@ const DynamicForm: React.FC = ({ value={value.field_value} onChange={(e) => handleInputChange(value.field_name, e)} /> + ) : value.field_type === "Boolean" ? ( + handleInputChange(value.field_name, checked)} + /> ) : ( = ({ ) ) : ( - + {value.field_type === "Integer" ? ( = ({ onChange={(e) => handleInputChange(value.field_name, e)} className="p-0" /> - ) : ( + ) : value.field_type === "Boolean" ? ( + { + handleInputChange(value.field_name, checked); + form.setFieldsValue({ [value.field_name]: checked }); + }} + + + /> + ) :( handleInputChange(value.field_name, e)}