From 234185ec13675aaa89aa58a07c28f4ae3c0d1207 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Tue, 17 Sep 2024 08:05:52 -0700 Subject: [PATCH] LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) (#5731) * LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) * coverage (#5713) Signed-off-by: dbczumar * Move (#5714) Signed-off-by: dbczumar * fix(litellm_logging.py): fix logging client re-init (#5710) Fixes https://github.com/BerriAI/litellm/issues/5695 * fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config Fixes https://github.com/BerriAI/litellm/issues/5682 * feat(o1_handler.py): fake streaming for openai o1 models Fixes https://github.com/BerriAI/litellm/issues/5694 * docs: deprecated traceloop integration in favor of native otel (#5249) * fix: fix linting errors * fix: fix linting errors * fix(main.py): fix o1 import --------- Signed-off-by: dbczumar Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730) * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it * fix(custom_logger.py): reset calltype * fix: fix linting errors * fix: fix linting error * fix: fix import * test(test_databricks.py): fix databricks tests --------- Signed-off-by: dbczumar Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit --- db_scripts/create_views.py | 33 +- .../opentelemetry_integration.md | 78 +++ .../observability/traceloop_integration.md | 36 -- docs/my-website/docs/proxy/logging.md | 92 ++-- docs/my-website/sidebars.js | 3 +- litellm/__init__.py | 4 +- .../SlackAlerting/slack_alerting.py | 3 + litellm/integrations/custom_guardrail.py | 9 +- litellm/integrations/traceloop.py | 5 + litellm/litellm_core_utils/litellm_logging.py | 211 ++++---- .../OpenAI/{ => chat}/gpt_transformation.py | 0 litellm/llms/OpenAI/chat/o1_handler.py | 95 ++++ .../OpenAI/{ => chat}/o1_transformation.py | 0 litellm/llms/databricks/chat.py | 156 +----- litellm/llms/databricks/exceptions.py | 12 + litellm/llms/databricks/streaming_utils.py | 145 +++++ litellm/main.py | 61 ++- litellm/proxy/_new_secret_config.yaml | 33 +- litellm/proxy/auth/user_api_key_auth.py | 1 - litellm/proxy/custom_callbacks.py | 16 +- .../guardrail_hooks/bedrock_guardrails.py | 3 +- .../guardrails/guardrail_hooks/lakera_ai.py | 1 + .../guardrails/guardrail_hooks/presidio.py | 20 +- litellm/proxy/guardrails/init_guardrails.py | 32 +- .../proxy/hooks/prompt_injection_detection.py | 13 +- .../internal_user_endpoints.py | 66 +-- litellm/proxy/proxy_server.py | 1 - .../spend_management_endpoints.py | 118 +++- litellm/proxy/utils.py | 42 +- litellm/tests/test_streaming.py | 52 +- litellm/types/guardrails.py | 3 +- litellm/utils.py | 42 +- tests/llm_translation/Readme.md | 1 + tests/llm_translation/test_databricks.py | 502 ++++++++++++++++++ 34 files changed, 1387 insertions(+), 502 deletions(-) create mode 100644 docs/my-website/docs/observability/opentelemetry_integration.md delete mode 100644 docs/my-website/docs/observability/traceloop_integration.md rename litellm/llms/OpenAI/{ => chat}/gpt_transformation.py (100%) create mode 100644 litellm/llms/OpenAI/chat/o1_handler.py rename litellm/llms/OpenAI/{ => chat}/o1_transformation.py (100%) create mode 100644 litellm/llms/databricks/exceptions.py create mode 100644 litellm/llms/databricks/streaming_utils.py create mode 100644 tests/llm_translation/Readme.md create mode 100644 tests/llm_translation/test_databricks.py diff --git a/db_scripts/create_views.py b/db_scripts/create_views.py index 510bd67c3..cbf00605f 100644 --- a/db_scripts/create_views.py +++ b/db_scripts/create_views.py @@ -6,10 +6,14 @@ import asyncio import os # Enter your DATABASE_URL here -os.environ["DATABASE_URL"] = "postgresql://xxxxxxx" + from prisma import Prisma -db = Prisma() +db = Prisma( + http={ + "timeout": 60000, + }, +) async def check_view_exists(): @@ -47,25 +51,22 @@ async def check_view_exists(): print("LiteLLM_VerificationTokenView Created!") # noqa - try: - await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""") - print("MonthlyGlobalSpend Exists!") # noqa - except Exception as e: - sql_query = """ - CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS + sql_query = """ + CREATE MATERIALIZED VIEW IF NOT EXISTS "MonthlyGlobalSpend" AS SELECT - DATE("startTime") AS date, - SUM("spend") AS spend + DATE_TRUNC('day', "startTime") AS date, + SUM("spend") AS spend FROM - "LiteLLM_SpendLogs" + "LiteLLM_SpendLogs" WHERE - "startTime" >= (CURRENT_DATE - INTERVAL '30 days') + "startTime" >= CURRENT_DATE - INTERVAL '30 days' GROUP BY - DATE("startTime"); - """ - await db.execute_raw(query=sql_query) + DATE_TRUNC('day', "startTime"); + """ + # Execute the queries + await db.execute_raw(query=sql_query) - print("MonthlyGlobalSpend Created!") # noqa + print("MonthlyGlobalSpend Created!") # noqa try: await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""") diff --git a/docs/my-website/docs/observability/opentelemetry_integration.md b/docs/my-website/docs/observability/opentelemetry_integration.md new file mode 100644 index 000000000..3a27ffc39 --- /dev/null +++ b/docs/my-website/docs/observability/opentelemetry_integration.md @@ -0,0 +1,78 @@ +import Image from '@theme/IdealImage'; +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# OpenTelemetry - Tracing LLMs with any observability tool + +OpenTelemetry is a CNCF standard for observability. It connects to any observability tool, such as Jaeger, Zipkin, Datadog, New Relic, Traceloop and others. + + + +## Getting Started + +Install the OpenTelemetry SDK: + +``` +pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp +``` + +Set the environment variables (different providers may require different variables): + + + + + + +```shell +OTEL_EXPORTER="otlp_http" +OTEL_ENDPOINT="https://api.traceloop.com" +OTEL_HEADERS="Authorization=Bearer%20" +``` + + + + + +```shell +OTEL_EXPORTER="otlp_http" +OTEL_ENDPOINT="http:/0.0.0.0:4317" +``` + + + + + +```shell +OTEL_EXPORTER="otlp_grpc" +OTEL_ENDPOINT="http:/0.0.0.0:4317" +``` + + + + + +Use just 2 lines of code, to instantly log your LLM responses **across all providers** with OpenTelemetry: + +```python +litellm.callbacks = ["otel"] +``` + +## Redacting Messages, Response Content from OpenTelemetry Logging + +### Redact Messages and Responses from all OpenTelemetry Logging + +Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to OpenTelemetry, but request metadata will still be logged. + +### Redact Messages and Responses from specific OpenTelemetry Logging + +In the metadata typically passed for text completion or embedding calls you can set specific keys to mask the messages and responses for this call. + +Setting `mask_input` to `True` will mask the input from being logged for this call + +Setting `mask_output` to `True` will make the output from being logged for this call. + +Be aware that if you are continuing an existing trace, and you set `update_trace_keys` to include either `input` or `output` and you set the corresponding `mask_input` or `mask_output`, then that trace will have its existing input and/or output replaced with a redacted message. + +## Support + +For any question or issue with the integration you can reach out to the OpenLLMetry maintainers on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com). diff --git a/docs/my-website/docs/observability/traceloop_integration.md b/docs/my-website/docs/observability/traceloop_integration.md deleted file mode 100644 index 6f02aa229..000000000 --- a/docs/my-website/docs/observability/traceloop_integration.md +++ /dev/null @@ -1,36 +0,0 @@ -import Image from '@theme/IdealImage'; - -# Traceloop (OpenLLMetry) - Tracing LLMs with OpenTelemetry - -[Traceloop](https://traceloop.com) is a platform for monitoring and debugging the quality of your LLM outputs. -It provides you with a way to track the performance of your LLM application; rollout changes with confidence; and debug issues in production. -It is based on [OpenTelemetry](https://opentelemetry.io), so it can provide full visibility to your LLM requests, as well vector DB usage, and other infra in your stack. - - - -## Getting Started - -Install the Traceloop SDK: - -``` -pip install traceloop-sdk -``` - -Use just 2 lines of code, to instantly log your LLM responses with OpenTelemetry: - -```python -Traceloop.init(app_name=, disable_batch=True) -litellm.success_callback = ["traceloop"] -``` - -Make sure to properly set a destination to your traces. See [OpenLLMetry docs](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for options. - -To get better visualizations on how your code behaves, you may want to annotate specific parts of your LLM chain. See [Traceloop docs on decorators](https://traceloop.com/docs/python-sdk/decorators) for more information. - -## Exporting traces to other systems (e.g. Datadog, New Relic, and others) - -Since OpenLLMetry uses OpenTelemetry to send data, you can easily export your traces to other systems, such as Datadog, New Relic, and others. See [OpenLLMetry docs on exporters](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for more information. - -## Support - -For any question or issue with integration you can reach out to the Traceloop team on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com). diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md index f7b650f7a..d20510ac7 100644 --- a/docs/my-website/docs/proxy/logging.md +++ b/docs/my-website/docs/proxy/logging.md @@ -600,6 +600,52 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ + + +#### Quick Start - Log to Traceloop + +**Step 1:** +Add the following to your env + +```shell +OTEL_EXPORTER="otlp_http" +OTEL_ENDPOINT="https://api.traceloop.com" +OTEL_HEADERS="Authorization=Bearer%20" +``` + +**Step 2:** Add `otel` as a callbacks + +```shell +litellm_settings: + callbacks: ["otel"] +``` + +**Step 3**: Start the proxy, make a test request + +Start proxy + +```shell +litellm --config config.yaml --detailed_debug +``` + +Test Request + +```shell +curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Content-Type: application/json' \ + --data ' { + "model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ] + }' +``` + + + #### Quick Start - Log to OTEL Collector @@ -694,52 +740,6 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ - - -#### Quick Start - Log to Traceloop - -**Step 1:** Install the `traceloop-sdk` SDK - -```shell -pip install traceloop-sdk==0.21.2 -``` - -**Step 2:** Add `traceloop` as a success_callback - -```shell -litellm_settings: - success_callback: ["traceloop"] - -environment_variables: - TRACELOOP_API_KEY: "XXXXX" -``` - -**Step 3**: Start the proxy, make a test request - -Start proxy - -```shell -litellm --config config.yaml --detailed_debug -``` - -Test Request - -```shell -curl --location 'http://0.0.0.0:4000/chat/completions' \ - --header 'Content-Type: application/json' \ - --data ' { - "model": "gpt-3.5-turbo", - "messages": [ - { - "role": "user", - "content": "what llm are you" - } - ] - }' -``` - - - ** 🎉 Expect to see this trace logged in your OTEL collector** diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 7e2a2050b..6dafb5478 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -255,6 +255,7 @@ const sidebars = { type: "category", label: "Logging & Observability", items: [ + "observability/opentelemetry_integration", "observability/langfuse_integration", "observability/logfire_integration", "observability/gcs_bucket_integration", @@ -271,7 +272,7 @@ const sidebars = { "observability/openmeter", "observability/promptlayer_integration", "observability/wandb_integration", - "observability/traceloop_integration", + "observability/slack_integration", "observability/athina_integration", "observability/lunary_integration", "observability/greenscale_integration", diff --git a/litellm/__init__.py b/litellm/__init__.py index 8c11ff1c4..222466162 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -948,10 +948,10 @@ from .llms.OpenAI.openai import ( AzureAIStudioConfig, ) from .llms.mistral.mistral_chat_transformation import MistralConfig -from .llms.OpenAI.o1_transformation import ( +from .llms.OpenAI.chat.o1_transformation import ( OpenAIO1Config, ) -from .llms.OpenAI.gpt_transformation import ( +from .llms.OpenAI.chat.gpt_transformation import ( OpenAIGPTConfig, ) from .llms.nvidia_nim import NvidiaNimConfig diff --git a/litellm/integrations/SlackAlerting/slack_alerting.py b/litellm/integrations/SlackAlerting/slack_alerting.py index fa7d4bc90..bc1087149 100644 --- a/litellm/integrations/SlackAlerting/slack_alerting.py +++ b/litellm/integrations/SlackAlerting/slack_alerting.py @@ -1616,6 +1616,9 @@ Model Info: :param time_range: A string specifying the time range, e.g., "1d", "7d", "30d" """ + if self.alerting is None or "spend_reports" not in self.alert_types: + return + try: from litellm.proxy.spend_tracking.spend_management_endpoints import ( _get_spend_report_for_time_range, diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py index 25512716c..39c8f2b1e 100644 --- a/litellm/integrations/custom_guardrail.py +++ b/litellm/integrations/custom_guardrail.py @@ -11,7 +11,7 @@ class CustomGuardrail(CustomLogger): self, guardrail_name: Optional[str] = None, event_hook: Optional[GuardrailEventHooks] = None, - **kwargs + **kwargs, ): self.guardrail_name = guardrail_name self.event_hook: Optional[GuardrailEventHooks] = event_hook @@ -28,10 +28,13 @@ class CustomGuardrail(CustomLogger): requested_guardrails, ) - if self.guardrail_name not in requested_guardrails: + if ( + self.guardrail_name not in requested_guardrails + and event_type.value != "logging_only" + ): return False - if self.event_hook != event_type: + if self.event_hook != event_type.value: return False return True diff --git a/litellm/integrations/traceloop.py b/litellm/integrations/traceloop.py index e1c419c6f..d2168caab 100644 --- a/litellm/integrations/traceloop.py +++ b/litellm/integrations/traceloop.py @@ -4,6 +4,11 @@ import litellm class TraceloopLogger: + """ + WARNING: DEPRECATED + Use the OpenTelemetry standard integration instead + """ + def __init__(self): try: from traceloop.sdk.tracing.tracing import TracerWrapper diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 1a624c5f8..91e9274e8 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -90,6 +90,13 @@ from ..integrations.supabase import Supabase from ..integrations.traceloop import TraceloopLogger from ..integrations.weights_biases import WeightsBiasesLogger +try: + from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import ( + GenericAPILogger, + ) +except Exception as e: + verbose_logger.debug(f"Exception import enterprise features {str(e)}") + _in_memory_loggers: List[Any] = [] ### GLOBAL VARIABLES ### @@ -145,7 +152,41 @@ class ServiceTraceIDCache: return None +import hashlib + + +class DynamicLoggingCache: + """ + Prevent memory leaks caused by initializing new logging clients on each request. + + Relevant Issue: https://github.com/BerriAI/litellm/issues/5695 + """ + + def __init__(self) -> None: + self.cache = InMemoryCache() + + def get_cache_key(self, args: dict) -> str: + args_str = json.dumps(args, sort_keys=True) + cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest() + return cache_key + + def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]: + key_name = self.get_cache_key( + args={**credentials, "service_name": service_name} + ) + response = self.cache.get_cache(key=key_name) + return response + + def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None: + key_name = self.get_cache_key( + args={**credentials, "service_name": service_name} + ) + self.cache.set_cache(key=key_name, value=logging_obj) + return None + + in_memory_trace_id_cache = ServiceTraceIDCache() +in_memory_dynamic_logger_cache = DynamicLoggingCache() class Logging: @@ -324,10 +365,10 @@ class Logging: print_verbose(f"\033[92m{curl_command}\033[0m\n", log_level="DEBUG") # log raw request to provider (like LangFuse) -- if opted in. if log_raw_request_response is True: + _litellm_params = self.model_call_details.get("litellm_params", {}) + _metadata = _litellm_params.get("metadata", {}) or {} try: # [Non-blocking Extra Debug Information in metadata] - _litellm_params = self.model_call_details.get("litellm_params", {}) - _metadata = _litellm_params.get("metadata", {}) or {} if ( turn_off_message_logging is not None and turn_off_message_logging is True @@ -362,7 +403,7 @@ class Logging: callbacks = litellm.input_callback + self.dynamic_input_callbacks for callback in callbacks: try: - if callback == "supabase": + if callback == "supabase" and supabaseClient is not None: verbose_logger.debug("reaches supabase for logging!") model = self.model_call_details["model"] messages = self.model_call_details["input"] @@ -396,7 +437,9 @@ class Logging: messages=self.messages, kwargs=self.model_call_details, ) - elif callable(callback): # custom logger functions + elif ( + callable(callback) and customLogger is not None + ): # custom logger functions customLogger.log_input_event( model=self.model, messages=self.messages, @@ -615,7 +658,7 @@ class Logging: self.model_call_details["litellm_params"]["metadata"][ "hidden_params" - ] = result._hidden_params + ] = getattr(result, "_hidden_params", {}) ## STANDARDIZED LOGGING PAYLOAD self.model_call_details["standard_logging_object"] = ( @@ -645,6 +688,7 @@ class Logging: litellm.max_budget and self.stream is False and result is not None + and isinstance(result, dict) and "content" in result ): time_diff = (end_time - start_time).total_seconds() @@ -652,7 +696,7 @@ class Logging: litellm._current_cost += litellm.completion_cost( model=self.model, prompt="", - completion=result["content"], + completion=getattr(result, "content", ""), total_time=float_diff, ) @@ -758,7 +802,7 @@ class Logging: ): print_verbose("no-log request, skipping logging") continue - if callback == "lite_debugger": + if callback == "lite_debugger" and liteDebuggerClient is not None: print_verbose("reaches lite_debugger for logging!") print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") print_verbose( @@ -774,7 +818,7 @@ class Logging: call_type=self.call_type, stream=self.stream, ) - if callback == "promptlayer": + if callback == "promptlayer" and promptLayerLogger is not None: print_verbose("reaches promptlayer for logging!") promptLayerLogger.log_event( kwargs=self.model_call_details, @@ -783,7 +827,7 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) - if callback == "supabase": + if callback == "supabase" and supabaseClient is not None: print_verbose("reaches supabase for logging!") kwargs = self.model_call_details @@ -811,7 +855,7 @@ class Logging: ), print_verbose=print_verbose, ) - if callback == "wandb": + if callback == "wandb" and weightsBiasesLogger is not None: print_verbose("reaches wandb for logging!") weightsBiasesLogger.log_event( kwargs=self.model_call_details, @@ -820,8 +864,7 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) - if callback == "logfire": - global logfireLogger + if callback == "logfire" and logfireLogger is not None: verbose_logger.debug("reaches logfire for success logging!") kwargs = {} for k, v in self.model_call_details.items(): @@ -844,10 +887,10 @@ class Logging: start_time=start_time, end_time=end_time, print_verbose=print_verbose, - level=LogfireLevel.INFO.value, + level=LogfireLevel.INFO.value, # type: ignore ) - if callback == "lunary": + if callback == "lunary" and lunaryLogger is not None: print_verbose("reaches lunary for logging!") model = self.model kwargs = self.model_call_details @@ -882,7 +925,7 @@ class Logging: run_id=self.litellm_call_id, print_verbose=print_verbose, ) - if callback == "helicone": + if callback == "helicone" and heliconeLogger is not None: print_verbose("reaches helicone for logging!") model = self.model messages = self.model_call_details["input"] @@ -924,6 +967,7 @@ class Logging: else: print_verbose("reaches langfuse for streaming logging!") result = kwargs["complete_streaming_response"] + temp_langfuse_logger = langFuseLogger if langFuseLogger is None or ( ( @@ -941,27 +985,45 @@ class Logging: and self.langfuse_host != langFuseLogger.langfuse_host ) ): - temp_langfuse_logger = LangFuseLogger( - langfuse_public_key=self.langfuse_public_key, - langfuse_secret=self.langfuse_secret, - langfuse_host=self.langfuse_host, - ) - _response = temp_langfuse_logger.log_event( - kwargs=kwargs, - response_obj=result, - start_time=start_time, - end_time=end_time, - user_id=kwargs.get("user", None), - print_verbose=print_verbose, - ) - if _response is not None and isinstance(_response, dict): - _trace_id = _response.get("trace_id", None) - if _trace_id is not None: - in_memory_trace_id_cache.set_cache( - litellm_call_id=self.litellm_call_id, - service_name="langfuse", - trace_id=_trace_id, + credentials = { + "langfuse_public_key": self.langfuse_public_key, + "langfuse_secret": self.langfuse_secret, + "langfuse_host": self.langfuse_host, + } + temp_langfuse_logger = ( + in_memory_dynamic_logger_cache.get_cache( + credentials=credentials, service_name="langfuse" ) + ) + if temp_langfuse_logger is None: + temp_langfuse_logger = LangFuseLogger( + langfuse_public_key=self.langfuse_public_key, + langfuse_secret=self.langfuse_secret, + langfuse_host=self.langfuse_host, + ) + in_memory_dynamic_logger_cache.set_cache( + credentials=credentials, + service_name="langfuse", + logging_obj=temp_langfuse_logger, + ) + + if temp_langfuse_logger is not None: + _response = temp_langfuse_logger.log_event( + kwargs=kwargs, + response_obj=result, + start_time=start_time, + end_time=end_time, + user_id=kwargs.get("user", None), + print_verbose=print_verbose, + ) + if _response is not None and isinstance(_response, dict): + _trace_id = _response.get("trace_id", None) + if _trace_id is not None: + in_memory_trace_id_cache.set_cache( + litellm_call_id=self.litellm_call_id, + service_name="langfuse", + trace_id=_trace_id, + ) if callback == "generic": global genericAPILogger verbose_logger.debug("reaches langfuse for success logging!") @@ -982,7 +1044,7 @@ class Logging: print_verbose("reaches langfuse for streaming logging!") result = kwargs["complete_streaming_response"] if genericAPILogger is None: - genericAPILogger = GenericAPILogger() + genericAPILogger = GenericAPILogger() # type: ignore genericAPILogger.log_event( kwargs=kwargs, response_obj=result, @@ -1022,7 +1084,7 @@ class Logging: user_id=kwargs.get("user", None), print_verbose=print_verbose, ) - if callback == "greenscale": + if callback == "greenscale" and greenscaleLogger is not None: kwargs = {} for k, v in self.model_call_details.items(): if ( @@ -1066,7 +1128,7 @@ class Logging: result = kwargs["complete_streaming_response"] # only add to cache once we have a complete streaming response litellm.cache.add_cache(result, **kwargs) - if callback == "athina": + if callback == "athina" and athinaLogger is not None: deep_copy = {} for k, v in self.model_call_details.items(): deep_copy[k] = v @@ -1224,6 +1286,7 @@ class Logging: "atranscription", False ) is not True + and customLogger is not None ): # custom logger functions print_verbose( f"success callbacks: Running Custom Callback Function" @@ -1423,9 +1486,8 @@ class Logging: await litellm.cache.async_add_cache(result, **kwargs) else: litellm.cache.add_cache(result, **kwargs) - if callback == "openmeter": - global openMeterLogger - if self.stream == True: + if callback == "openmeter" and openMeterLogger is not None: + if self.stream is True: if ( "async_complete_streaming_response" in self.model_call_details @@ -1645,33 +1707,9 @@ class Logging: ) for callback in callbacks: try: - if callback == "lite_debugger": - print_verbose("reaches lite_debugger for logging!") - print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") - result = { - "model": self.model, - "created": time.time(), - "error": traceback_exception, - "usage": { - "prompt_tokens": prompt_token_calculator( - self.model, messages=self.messages - ), - "completion_tokens": 0, - }, - } - liteDebuggerClient.log_event( - model=self.model, - messages=self.messages, - end_user=self.model_call_details.get("user", "default"), - response_obj=result, - start_time=start_time, - end_time=end_time, - litellm_call_id=self.litellm_call_id, - print_verbose=print_verbose, - call_type=self.call_type, - stream=self.stream, - ) - if callback == "lunary": + if callback == "lite_debugger" and liteDebuggerClient is not None: + pass + elif callback == "lunary" and lunaryLogger is not None: print_verbose("reaches lunary for logging error!") model = self.model @@ -1685,6 +1723,7 @@ class Logging: ) lunaryLogger.log_event( + kwargs=self.model_call_details, type=_type, event="error", user_id=self.model_call_details.get("user", "default"), @@ -1704,22 +1743,11 @@ class Logging: print_verbose( f"capture exception not initialized: {capture_exception}" ) - elif callback == "supabase": + elif callback == "supabase" and supabaseClient is not None: print_verbose("reaches supabase for logging!") print_verbose(f"supabaseClient: {supabaseClient}") - result = { - "model": model, - "created": time.time(), - "error": traceback_exception, - "usage": { - "prompt_tokens": prompt_token_calculator( - model, messages=self.messages - ), - "completion_tokens": 0, - }, - } supabaseClient.log_event( - model=self.model, + model=self.model if hasattr(self, "model") else "", messages=self.messages, end_user=self.model_call_details.get("user", "default"), response_obj=result, @@ -1728,7 +1756,9 @@ class Logging: litellm_call_id=self.model_call_details["litellm_call_id"], print_verbose=print_verbose, ) - if callable(callback): # custom logger functions + if ( + callable(callback) and customLogger is not None + ): # custom logger functions customLogger.log_event( kwargs=self.model_call_details, response_obj=result, @@ -1809,13 +1839,13 @@ class Logging: start_time=start_time, end_time=end_time, response_obj=None, - user_id=kwargs.get("user", None), + user_id=self.model_call_details.get("user", None), print_verbose=print_verbose, status_message=str(exception), level="ERROR", kwargs=self.model_call_details, ) - if callback == "logfire": + if callback == "logfire" and logfireLogger is not None: verbose_logger.debug("reaches logfire for failure logging!") kwargs = {} for k, v in self.model_call_details.items(): @@ -1830,7 +1860,7 @@ class Logging: response_obj=result, start_time=start_time, end_time=end_time, - level=LogfireLevel.ERROR.value, + level=LogfireLevel.ERROR.value, # type: ignore print_verbose=print_verbose, ) @@ -1873,7 +1903,9 @@ class Logging: start_time=start_time, end_time=end_time, ) # type: ignore - if callable(callback): # custom logger functions + if ( + callable(callback) and customLogger is not None + ): # custom logger functions await customLogger.async_log_event( kwargs=self.model_call_details, response_obj=result, @@ -1966,7 +1998,7 @@ def set_callbacks(callback_list, function_id=None): ) sentry_sdk_instance.init( dsn=os.environ.get("SENTRY_DSN"), - traces_sample_rate=float(sentry_trace_rate), + traces_sample_rate=float(sentry_trace_rate), # type: ignore ) capture_exception = sentry_sdk_instance.capture_exception add_breadcrumb = sentry_sdk_instance.add_breadcrumb @@ -2411,12 +2443,11 @@ def get_standard_logging_object_payload( saved_cache_cost: Optional[float] = None if cache_hit is True: - import time id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id saved_cache_cost = logging_obj._response_cost_calculator( - result=init_response_obj, cache_hit=False + result=init_response_obj, cache_hit=False # type: ignore ) ## Get model cost information ## @@ -2473,7 +2504,7 @@ def get_standard_logging_object_payload( model_id=_model_id, requester_ip_address=clean_metadata.get("requester_ip_address", None), messages=kwargs.get("messages"), - response=( + response=( # type: ignore response_obj if len(response_obj.keys()) > 0 else init_response_obj ), model_parameters=kwargs.get("optional_params", None), diff --git a/litellm/llms/OpenAI/gpt_transformation.py b/litellm/llms/OpenAI/chat/gpt_transformation.py similarity index 100% rename from litellm/llms/OpenAI/gpt_transformation.py rename to litellm/llms/OpenAI/chat/gpt_transformation.py diff --git a/litellm/llms/OpenAI/chat/o1_handler.py b/litellm/llms/OpenAI/chat/o1_handler.py new file mode 100644 index 000000000..55dfe3715 --- /dev/null +++ b/litellm/llms/OpenAI/chat/o1_handler.py @@ -0,0 +1,95 @@ +""" +Handler file for calls to OpenAI's o1 family of models + +Written separately to handle faking streaming for o1 models. +""" + +import asyncio +from typing import Any, Callable, List, Optional, Union + +from httpx._config import Timeout + +from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator +from litellm.llms.OpenAI.openai import OpenAIChatCompletion +from litellm.types.utils import ModelResponse +from litellm.utils import CustomStreamWrapper + + +class OpenAIO1ChatCompletion(OpenAIChatCompletion): + + async def mock_async_streaming( + self, + response: Any, + model: Optional[str], + logging_obj: Any, + ): + model_response = await response + completion_stream = MockResponseIterator(model_response=model_response) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="openai", + logging_obj=logging_obj, + ) + return streaming_response + + def completion( + self, + model_response: ModelResponse, + timeout: Union[float, Timeout], + optional_params: dict, + logging_obj: Any, + model: Optional[str] = None, + messages: Optional[list] = None, + print_verbose: Optional[Callable[..., Any]] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + acompletion: bool = False, + litellm_params=None, + logger_fn=None, + headers: Optional[dict] = None, + custom_prompt_dict: dict = {}, + client=None, + organization: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + drop_params: Optional[bool] = None, + ): + stream: Optional[bool] = optional_params.pop("stream", False) + response = super().completion( + model_response, + timeout, + optional_params, + logging_obj, + model, + messages, + print_verbose, + api_key, + api_base, + acompletion, + litellm_params, + logger_fn, + headers, + custom_prompt_dict, + client, + organization, + custom_llm_provider, + drop_params, + ) + + if stream is True: + if asyncio.iscoroutine(response): + return self.mock_async_streaming( + response=response, model=model, logging_obj=logging_obj # type: ignore + ) + + completion_stream = MockResponseIterator(model_response=response) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="openai", + logging_obj=logging_obj, + ) + + return streaming_response + else: + return response diff --git a/litellm/llms/OpenAI/o1_transformation.py b/litellm/llms/OpenAI/chat/o1_transformation.py similarity index 100% rename from litellm/llms/OpenAI/o1_transformation.py rename to litellm/llms/OpenAI/chat/o1_transformation.py diff --git a/litellm/llms/databricks/chat.py b/litellm/llms/databricks/chat.py index 739abb91f..343cdd3ff 100644 --- a/litellm/llms/databricks/chat.py +++ b/litellm/llms/databricks/chat.py @@ -15,6 +15,8 @@ import requests # type: ignore import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.databricks.exceptions import DatabricksError +from litellm.llms.databricks.streaming_utils import ModelResponseIterator from litellm.types.llms.openai import ( ChatCompletionDeltaChunk, ChatCompletionResponseMessage, @@ -33,17 +35,6 @@ from ..base import BaseLLM from ..prompt_templates.factory import custom_prompt, prompt_factory -class DatabricksError(Exception): - def __init__(self, status_code, message): - self.status_code = status_code - self.message = message - self.request = httpx.Request(method="POST", url="https://docs.databricks.com/") - self.response = httpx.Response(status_code=status_code, request=self.request) - super().__init__( - self.message - ) # Call the base class constructor with the parameters it needs - - class DatabricksConfig: """ Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request @@ -367,7 +358,7 @@ class DatabricksChatCompletion(BaseLLM): status_code=e.response.status_code, message=e.response.text, ) - except httpx.TimeoutException as e: + except httpx.TimeoutException: raise DatabricksError(status_code=408, message="Timeout error occurred.") except Exception as e: raise DatabricksError(status_code=500, message=str(e)) @@ -380,7 +371,7 @@ class DatabricksChatCompletion(BaseLLM): ) response = ModelResponse(**response_json) - response.model = custom_llm_provider + "/" + response.model + response.model = custom_llm_provider + "/" + (response.model or "") if base_model is not None: response._hidden_params["model"] = base_model @@ -529,7 +520,7 @@ class DatabricksChatCompletion(BaseLLM): response_json = response.json() except httpx.HTTPStatusError as e: raise DatabricksError( - status_code=e.response.status_code, message=response.text + status_code=e.response.status_code, message=e.response.text ) except httpx.TimeoutException as e: raise DatabricksError( @@ -540,7 +531,7 @@ class DatabricksChatCompletion(BaseLLM): response = ModelResponse(**response_json) - response.model = custom_llm_provider + "/" + response.model + response.model = custom_llm_provider + "/" + (response.model or "") if base_model is not None: response._hidden_params["model"] = base_model @@ -657,7 +648,7 @@ class DatabricksChatCompletion(BaseLLM): except httpx.HTTPStatusError as e: raise DatabricksError( status_code=e.response.status_code, - message=response.text if response else str(e), + message=e.response.text, ) except httpx.TimeoutException as e: raise DatabricksError(status_code=408, message="Timeout error occurred.") @@ -673,136 +664,3 @@ class DatabricksChatCompletion(BaseLLM): ) return litellm.EmbeddingResponse(**response_json) - - -class ModelResponseIterator: - def __init__(self, streaming_response, sync_stream: bool): - self.streaming_response = streaming_response - - def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: - try: - processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore - - text = "" - tool_use: Optional[ChatCompletionToolCallChunk] = None - is_finished = False - finish_reason = "" - usage: Optional[ChatCompletionUsageBlock] = None - - if processed_chunk.choices[0].delta.content is not None: # type: ignore - text = processed_chunk.choices[0].delta.content # type: ignore - - if ( - processed_chunk.choices[0].delta.tool_calls is not None # type: ignore - and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore - and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore - and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore - is not None - ): - tool_use = ChatCompletionToolCallChunk( - id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore - type="function", - function=ChatCompletionToolCallFunctionChunk( - name=processed_chunk.choices[0] - .delta.tool_calls[0] # type: ignore - .function.name, - arguments=processed_chunk.choices[0] - .delta.tool_calls[0] # type: ignore - .function.arguments, - ), - index=processed_chunk.choices[0].index, - ) - - if processed_chunk.choices[0].finish_reason is not None: - is_finished = True - finish_reason = processed_chunk.choices[0].finish_reason - - if hasattr(processed_chunk, "usage") and isinstance( - processed_chunk.usage, litellm.Usage - ): - usage_chunk: litellm.Usage = processed_chunk.usage - - usage = ChatCompletionUsageBlock( - prompt_tokens=usage_chunk.prompt_tokens, - completion_tokens=usage_chunk.completion_tokens, - total_tokens=usage_chunk.total_tokens, - ) - - return GenericStreamingChunk( - text=text, - tool_use=tool_use, - is_finished=is_finished, - finish_reason=finish_reason, - usage=usage, - index=0, - ) - except json.JSONDecodeError: - raise ValueError(f"Failed to decode JSON from chunk: {chunk}") - - # Sync iterator - def __iter__(self): - self.response_iterator = self.streaming_response - return self - - def __next__(self): - try: - chunk = self.response_iterator.__next__() - except StopIteration: - raise StopIteration - except ValueError as e: - raise RuntimeError(f"Error receiving chunk from stream: {e}") - - try: - chunk = chunk.replace("data:", "") - chunk = chunk.strip() - if len(chunk) > 0: - json_chunk = json.loads(chunk) - return self.chunk_parser(chunk=json_chunk) - else: - return GenericStreamingChunk( - text="", - is_finished=False, - finish_reason="", - usage=None, - index=0, - tool_use=None, - ) - except StopIteration: - raise StopIteration - except ValueError as e: - raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") - - # Async iterator - def __aiter__(self): - self.async_response_iterator = self.streaming_response.__aiter__() - return self - - async def __anext__(self): - try: - chunk = await self.async_response_iterator.__anext__() - except StopAsyncIteration: - raise StopAsyncIteration - except ValueError as e: - raise RuntimeError(f"Error receiving chunk from stream: {e}") - - try: - chunk = chunk.replace("data:", "") - chunk = chunk.strip() - if chunk == "[DONE]": - raise StopAsyncIteration - if len(chunk) > 0: - json_chunk = json.loads(chunk) - return self.chunk_parser(chunk=json_chunk) - else: - return GenericStreamingChunk( - text="", - is_finished=False, - finish_reason="", - usage=None, - index=0, - tool_use=None, - ) - except StopAsyncIteration: - raise StopAsyncIteration - except ValueError as e: - raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") diff --git a/litellm/llms/databricks/exceptions.py b/litellm/llms/databricks/exceptions.py new file mode 100644 index 000000000..8bb3d435d --- /dev/null +++ b/litellm/llms/databricks/exceptions.py @@ -0,0 +1,12 @@ +import httpx + + +class DatabricksError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request(method="POST", url="https://docs.databricks.com/") + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs diff --git a/litellm/llms/databricks/streaming_utils.py b/litellm/llms/databricks/streaming_utils.py new file mode 100644 index 000000000..1b342f3c9 --- /dev/null +++ b/litellm/llms/databricks/streaming_utils.py @@ -0,0 +1,145 @@ +import json +from typing import Optional + +import litellm +from litellm.types.llms.openai import ( + ChatCompletionDeltaChunk, + ChatCompletionResponseMessage, + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, + ChatCompletionUsageBlock, +) +from litellm.types.utils import GenericStreamingChunk + + +class ModelResponseIterator: + def __init__(self, streaming_response, sync_stream: bool): + self.streaming_response = streaming_response + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore + + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + is_finished = False + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + + if processed_chunk.choices[0].delta.content is not None: # type: ignore + text = processed_chunk.choices[0].delta.content # type: ignore + + if ( + processed_chunk.choices[0].delta.tool_calls is not None # type: ignore + and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore + and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore + and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore + is not None + ): + tool_use = ChatCompletionToolCallChunk( + id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore + type="function", + function=ChatCompletionToolCallFunctionChunk( + name=processed_chunk.choices[0] + .delta.tool_calls[0] # type: ignore + .function.name, + arguments=processed_chunk.choices[0] + .delta.tool_calls[0] # type: ignore + .function.arguments, + ), + index=processed_chunk.choices[0].index, + ) + + if processed_chunk.choices[0].finish_reason is not None: + is_finished = True + finish_reason = processed_chunk.choices[0].finish_reason + + if hasattr(processed_chunk, "usage") and isinstance( + processed_chunk.usage, litellm.Usage + ): + usage_chunk: litellm.Usage = processed_chunk.usage + + usage = ChatCompletionUsageBlock( + prompt_tokens=usage_chunk.prompt_tokens, + completion_tokens=usage_chunk.completion_tokens, + total_tokens=usage_chunk.total_tokens, + ) + + return GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=0, + ) + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + # Sync iterator + def __iter__(self): + self.response_iterator = self.streaming_response + return self + + def __next__(self): + try: + chunk = self.response_iterator.__next__() + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + chunk = chunk.replace("data:", "") + chunk = chunk.strip() + if len(chunk) > 0: + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + chunk = chunk.replace("data:", "") + chunk = chunk.strip() + if chunk == "[DONE]": + raise StopAsyncIteration + if len(chunk) > 0: + json_chunk = json.loads(chunk) + return self.chunk_parser(chunk=json_chunk) + else: + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") diff --git a/litellm/main.py b/litellm/main.py index 82d19a976..ee2ea3626 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -95,6 +95,7 @@ from .llms.custom_llm import CustomLLM, custom_chat_llm_router from .llms.databricks.chat import DatabricksChatCompletion from .llms.huggingface_restapi import Huggingface from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription +from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.predibase import PredibaseChatCompletion from .llms.prompt_templates.factory import ( @@ -161,6 +162,7 @@ from litellm.utils import ( ####### ENVIRONMENT VARIABLES ################### openai_chat_completions = OpenAIChatCompletion() openai_text_completions = OpenAITextCompletion() +openai_o1_chat_completions = OpenAIO1ChatCompletion() openai_audio_transcriptions = OpenAIAudioTranscription() databricks_chat_completions = DatabricksChatCompletion() anthropic_chat_completions = AnthropicChatCompletion() @@ -1366,25 +1368,46 @@ def completion( ## COMPLETION CALL try: - response = openai_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - model_response=model_response, - print_verbose=print_verbose, - api_key=api_key, - api_base=api_base, - acompletion=acompletion, - logging_obj=logging, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, # type: ignore - custom_prompt_dict=custom_prompt_dict, - client=client, # pass AsyncOpenAI, OpenAI client - organization=organization, - custom_llm_provider=custom_llm_provider, - ) + if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model): + response = openai_o1_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + organization=organization, + custom_llm_provider=custom_llm_provider, + ) + else: + response = openai_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + organization=organization, + custom_llm_provider=custom_llm_provider, + ) except Exception as e: ## LOGGING - log the original exception returned logging.post_call( diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 4533bf911..3502a786b 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -3,7 +3,6 @@ model_list: litellm_params: model: anthropic.claude-3-sonnet-20240229-v1:0 api_base: https://exampleopenaiendpoint-production.up.railway.app - # aws_session_token: "IQoJb3JpZ2luX2VjELj//////////wEaCXVzLXdlc3QtMiJHMEUCIQDatCRVkIZERLcrR6P7Qd1vNfZ8r8xB/LUeaVaTW/lBTwIgAgmHSBe41d65GVRKSkpgVonjsCmOmAS7s/yklM9NsZcq3AEI4P//////////ARABGgw4ODg2MDIyMjM0MjgiDJrio0/CHYEfyt5EqyqwAfyWO4t3bFVWAOIwTyZ1N6lszeJKfMNus2hzVc+r73hia2Anv88uwPxNg2uqnXQNJumEo0DcBt30ZwOw03Isboy0d5l05h8gjb4nl9feyeKmKAnRdcqElrEWtCC1Qcefv78jQv53AbUipH1ssa5NPvptqZZpZYDPMlBEnV3YdvJJiuE23u2yOkCt+EoUJLaOYjOryoRyrSfbWB+JaUsB68R3rNTHzReeN3Nob/9Ic4HrMMmzmLcGOpgBZxclO4w8Z7i6TcVqbCwDOskxuR6bZaiFxKFG+9tDrWS7jaQKpq/YP9HUT0YwYpZplaBEEZR5sbIndg5yb4dRZrSHplblqKz8XLaUf5tuuyRJmwr96PTpw/dyEVk9gicFX6JfLBEv0v5rN2Z0JMFLdfIP4kC1U2PjcPOWoglWO3fLmJ4Lol2a3c5XDSMwMxjcJXq+c8Ue1v0=" aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID - model_name: gemini-vision @@ -17,4 +16,34 @@ model_list: litellm_params: model: gpt-3.5-turbo api_base: https://exampleopenaiendpoint-production.up.railway.app - \ No newline at end of file + - model_name: o1-preview + litellm_params: + model: o1-preview + +litellm_settings: + drop_params: True + json_logs: True + store_audit_logs: True + log_raw_request_response: True + return_response_headers: True + num_retries: 5 + request_timeout: 200 + callbacks: ["custom_callbacks.proxy_handler_instance"] + +guardrails: + - guardrail_name: "presidio-pre-guard" + litellm_params: + guardrail: presidio # supported values: "aporia", "bedrock", "lakera", "presidio" + mode: "logging_only" + mock_redacted_text: { + "text": "My name is , who are you? Say my name in your response", + "items": [ + { + "start": 11, + "end": 19, + "entity_type": "PERSON", + "text": "", + "operator": "replace", + } + ], + } \ No newline at end of file diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index b2e63c256..36fe8ce73 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -446,7 +446,6 @@ async def user_api_key_auth( and request.headers.get(key=header_key) is not None # type: ignore ): api_key = request.headers.get(key=header_key) # type: ignore - if master_key is None: if isinstance(api_key, str): return UserAPIKeyAuth( diff --git a/litellm/proxy/custom_callbacks.py b/litellm/proxy/custom_callbacks.py index 1516bfd24..445fea1d2 100644 --- a/litellm/proxy/custom_callbacks.py +++ b/litellm/proxy/custom_callbacks.py @@ -2,9 +2,19 @@ from litellm.integrations.custom_logger import CustomLogger class MyCustomHandler(CustomLogger): - async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): - # print("Call failed") - pass + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + # input_tokens = response_obj.get("usage", {}).get("prompt_tokens", 0) + # output_tokens = response_obj.get("usage", {}).get("completion_tokens", 0) + input_tokens = ( + response_obj.usage.prompt_tokens + if hasattr(response_obj.usage, "prompt_tokens") + else 0 + ) + output_tokens = ( + response_obj.usage.completion_tokens + if hasattr(response_obj.usage, "completion_tokens") + else 0 + ) proxy_handler_instance = MyCustomHandler() diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index a18d8db0e..c28cf5ec9 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -218,7 +218,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): response = await self.async_handler.post( url=prepared_request.url, json=request_data, # type: ignore - headers=dict(prepared_request.headers), + headers=prepared_request.headers, # type: ignore ) verbose_proxy_logger.debug("Bedrock AI response: %s", response.text) if response.status_code == 200: @@ -254,7 +254,6 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): from litellm.proxy.common_utils.callback_utils import ( add_guardrail_to_applied_guardrails_header, ) - from litellm.types.guardrails import GuardrailEventHooks event_type: GuardrailEventHooks = GuardrailEventHooks.during_call if self.should_run_guardrail(data=data, event_type=event_type) is not True: diff --git a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py index d84966fd7..d15a4a7d5 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py +++ b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py @@ -189,6 +189,7 @@ class lakeraAI_Moderation(CustomGuardrail): # Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already) # Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked. # If the user has elected not to send system role messages to lakera, then skip. + if system_message is not None: if not litellm.add_function_to_prompt: content = system_message.get("content") diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index eb09e5203..e1b0d7cad 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -19,6 +19,7 @@ from fastapi import HTTPException from pydantic import BaseModel import litellm # noqa: E401 +from litellm import get_secret from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache from litellm.integrations.custom_guardrail import CustomGuardrail @@ -58,7 +59,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): self.pii_tokens: dict = ( {} ) # mapping of PII token to original text - only used with Presidio `replace` operation - self.mock_redacted_text = mock_redacted_text self.output_parse_pii = output_parse_pii or False if mock_testing is True: # for testing purposes only @@ -92,8 +92,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): presidio_anonymizer_api_base: Optional[str] = None, ): self.presidio_analyzer_api_base: Optional[str] = ( - presidio_analyzer_api_base - or litellm.get_secret("PRESIDIO_ANALYZER_API_BASE", None) + presidio_analyzer_api_base or get_secret("PRESIDIO_ANALYZER_API_BASE", None) # type: ignore ) self.presidio_anonymizer_api_base: Optional[ str @@ -198,12 +197,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): else: raise Exception(f"Invalid anonymizer response: {redacted_text}") except Exception as e: - verbose_proxy_logger.error( - "litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format( - str(e) - ) - ) - verbose_proxy_logger.debug(traceback.format_exc()) raise e async def async_pre_call_hook( @@ -254,9 +247,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): ) return data except Exception as e: - verbose_proxy_logger.info( - f"An error occurred -", - ) raise e async def async_logging_hook( @@ -300,9 +290,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): ) kwargs["messages"] = messages - return kwargs, responses + return kwargs, result - async def async_post_call_success_hook( + async def async_post_call_success_hook( # type: ignore self, data: dict, user_api_key_dict: UserAPIKeyAuth, @@ -314,7 +304,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): verbose_proxy_logger.debug( f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}" ) - if self.output_parse_pii == False: + if self.output_parse_pii is False: return response if isinstance(response, ModelResponse) and not isinstance( diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py index cff9fca05..c46300990 100644 --- a/litellm/proxy/guardrails/init_guardrails.py +++ b/litellm/proxy/guardrails/init_guardrails.py @@ -1,10 +1,11 @@ import importlib import traceback -from typing import Dict, List, Literal +from typing import Dict, List, Literal, Optional from pydantic import BaseModel, RootModel import litellm +from litellm import get_secret from litellm._logging import verbose_proxy_logger from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy @@ -16,7 +17,6 @@ from litellm.types.guardrails import ( GuardrailItemSpec, LakeraCategoryThresholds, LitellmParams, - guardrailConfig, ) all_guardrails: List[GuardrailItem] = [] @@ -98,18 +98,13 @@ def init_guardrails_v2( # Init litellm params for guardrail litellm_params_data = guardrail["litellm_params"] verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data) - litellm_params = LitellmParams( - guardrail=litellm_params_data["guardrail"], - mode=litellm_params_data["mode"], - api_key=litellm_params_data.get("api_key"), - api_base=litellm_params_data.get("api_base"), - guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"), - guardrailVersion=litellm_params_data.get("guardrailVersion"), - output_parse_pii=litellm_params_data.get("output_parse_pii"), - presidio_ad_hoc_recognizers=litellm_params_data.get( - "presidio_ad_hoc_recognizers" - ), - ) + + _litellm_params_kwargs = { + k: litellm_params_data[k] if k in litellm_params_data else None + for k in LitellmParams.__annotations__.keys() + } + + litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore if ( "category_thresholds" in litellm_params_data @@ -122,15 +117,11 @@ def init_guardrails_v2( if litellm_params["api_key"]: if litellm_params["api_key"].startswith("os.environ/"): - litellm_params["api_key"] = litellm.get_secret( - litellm_params["api_key"] - ) + litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore if litellm_params["api_base"]: if litellm_params["api_base"].startswith("os.environ/"): - litellm_params["api_base"] = litellm.get_secret( - litellm_params["api_base"] - ) + litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore # Init guardrail CustomLoggerClass if litellm_params["guardrail"] == "aporia": @@ -182,6 +173,7 @@ def init_guardrails_v2( presidio_ad_hoc_recognizers=litellm_params[ "presidio_ad_hoc_recognizers" ], + mock_redacted_text=litellm_params.get("mock_redacted_text") or None, ) if litellm_params["output_parse_pii"] is True: diff --git a/litellm/proxy/hooks/prompt_injection_detection.py b/litellm/proxy/hooks/prompt_injection_detection.py index 9c1f1eb95..73e83ec4a 100644 --- a/litellm/proxy/hooks/prompt_injection_detection.py +++ b/litellm/proxy/hooks/prompt_injection_detection.py @@ -167,11 +167,11 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger): if self.prompt_injection_params is not None: # 1. check if heuristics check turned on - if self.prompt_injection_params.heuristics_check == True: + if self.prompt_injection_params.heuristics_check is True: is_prompt_attack = self.check_user_input_similarity( user_input=formatted_prompt ) - if is_prompt_attack == True: + if is_prompt_attack is True: raise HTTPException( status_code=400, detail={ @@ -179,14 +179,14 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger): }, ) # 2. check if vector db similarity check turned on [TODO] Not Implemented yet - if self.prompt_injection_params.vector_db_check == True: + if self.prompt_injection_params.vector_db_check is True: pass else: is_prompt_attack = self.check_user_input_similarity( user_input=formatted_prompt ) - if is_prompt_attack == True: + if is_prompt_attack is True: raise HTTPException( status_code=400, detail={ @@ -201,19 +201,18 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger): if ( e.status_code == 400 and isinstance(e.detail, dict) - and "error" in e.detail + and "error" in e.detail # type: ignore and self.prompt_injection_params is not None and self.prompt_injection_params.reject_as_response ): return e.detail.get("error") raise e except Exception as e: - verbose_proxy_logger.error( + verbose_proxy_logger.exception( "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( str(e) ) ) - verbose_proxy_logger.debug(traceback.format_exc()) async def async_moderation_hook( # type: ignore self, diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index 859c8aeb8..3801d7465 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -195,7 +195,8 @@ async def user_auth(request: Request): - os.environ["SMTP_PASSWORD"] - os.environ["SMTP_SENDER_EMAIL"] """ - from litellm.proxy.proxy_server import prisma_client, send_email + from litellm.proxy.proxy_server import prisma_client + from litellm.proxy.utils import send_email data = await request.json() # type: ignore user_email = data["user_email"] @@ -212,7 +213,7 @@ async def user_auth(request: Request): ) ### if so - generate a 24 hr key with that user id if response is not None: - user_id = response.user_id + user_id = response.user_id # type: ignore response = await generate_key_helper_fn( request_type="key", **{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id}, # type: ignore @@ -345,6 +346,7 @@ async def user_info( for team in teams_1: team_id_list.append(team.team_id) + teams_2: Optional[Any] = None if user_info is not None: # *NEW* get all teams in user 'teams' field teams_2 = await prisma_client.get_data( @@ -375,7 +377,7 @@ async def user_info( ), user_api_key_dict=user_api_key_dict, ) - else: + elif caller_user_info is not None: teams_2 = await prisma_client.get_data( team_id_list=caller_user_info.teams, table_name="team", @@ -395,7 +397,7 @@ async def user_info( query_type="find_all", ) - if user_info is None: + if user_info is None and keys is not None: ## make sure we still return a total spend ## spend = 0 for k in keys: @@ -404,32 +406,35 @@ async def user_info( ## REMOVE HASHED TOKEN INFO before returning ## returned_keys = [] - for key in keys: - if ( - key.token == litellm_master_key_hash - and general_settings.get("disable_master_key_return", False) - == True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui - ): - continue + if keys is None: + pass + else: + for key in keys: + if ( + key.token == litellm_master_key_hash + and general_settings.get("disable_master_key_return", False) + == True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui + ): + continue - try: - key = key.model_dump() # noqa - except: - # if using pydantic v1 - key = key.dict() - if ( - "team_id" in key - and key["team_id"] is not None - and key["team_id"] != "litellm-dashboard" - ): - team_info = await prisma_client.get_data( - team_id=key["team_id"], table_name="team" - ) - team_alias = getattr(team_info, "team_alias", None) - key["team_alias"] = team_alias - else: - key["team_alias"] = "None" - returned_keys.append(key) + try: + key = key.model_dump() # noqa + except: + # if using pydantic v1 + key = key.dict() + if ( + "team_id" in key + and key["team_id"] is not None + and key["team_id"] != "litellm-dashboard" + ): + team_info = await prisma_client.get_data( + team_id=key["team_id"], table_name="team" + ) + team_alias = getattr(team_info, "team_alias", None) + key["team_alias"] = team_alias + else: + key["team_alias"] = "None" + returned_keys.append(key) response_data = { "user_id": user_id, @@ -539,6 +544,7 @@ async def user_update( ## ADD USER, IF NEW ## verbose_proxy_logger.debug("/user/update: Received data = %s", data) + response: Optional[Any] = None if data.user_id is not None and len(data.user_id) > 0: non_default_values["user_id"] = data.user_id # type: ignore verbose_proxy_logger.debug("In update user, user_id condition block.") @@ -573,7 +579,7 @@ async def user_update( data=non_default_values, table_name="user", ) - return response + return response # type: ignore # update based on remaining passed in values except Exception as e: verbose_proxy_logger.error( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 581db4e32..de1530071 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -226,7 +226,6 @@ from litellm.proxy.utils import ( hash_token, log_to_opentelemetry, reset_budget, - send_email, update_spend, ) from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import ( diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py index 86f959362..295546a3d 100644 --- a/litellm/proxy/spend_tracking/spend_management_endpoints.py +++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py @@ -1434,7 +1434,7 @@ async def _get_spend_report_for_time_range( except Exception as e: verbose_proxy_logger.error( "Exception in _get_daily_spend_reports {}".format(str(e)) - ) # noqa + ) @router.post( @@ -1703,26 +1703,26 @@ async def view_spend_logs( result: dict = {} for record in response: dt_object = datetime.strptime( - str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" + str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" # type: ignore ) # type: ignore date = dt_object.date() if date not in result: result[date] = {"users": {}, "models": {}} - api_key = record["api_key"] - user_id = record["user"] - model = record["model"] - result[date]["spend"] = ( - result[date].get("spend", 0) + record["_sum"]["spend"] - ) - result[date][api_key] = ( - result[date].get(api_key, 0) + record["_sum"]["spend"] - ) - result[date]["users"][user_id] = ( - result[date]["users"].get(user_id, 0) + record["_sum"]["spend"] - ) - result[date]["models"][model] = ( - result[date]["models"].get(model, 0) + record["_sum"]["spend"] - ) + api_key = record["api_key"] # type: ignore + user_id = record["user"] # type: ignore + model = record["model"] # type: ignore + result[date]["spend"] = result[date].get("spend", 0) + record.get( + "_sum", {} + ).get("spend", 0) + result[date][api_key] = result[date].get(api_key, 0) + record.get( + "_sum", {} + ).get("spend", 0) + result[date]["users"][user_id] = result[date]["users"].get( + user_id, 0 + ) + record.get("_sum", {}).get("spend", 0) + result[date]["models"][model] = result[date]["models"].get( + model, 0 + ) + record.get("_sum", {}).get("spend", 0) return_list = [] final_date = None for k, v in sorted(result.items()): @@ -1784,7 +1784,7 @@ async def view_spend_logs( table_name="spend", query_type="find_all" ) - return spend_log + return spend_logs return None @@ -1843,6 +1843,88 @@ async def global_spend_reset(): } +@router.post( + "/global/spend/refresh", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def global_spend_refresh(): + """ + ADMIN ONLY / MASTER KEY Only Endpoint + + Globally refresh spend MonthlyGlobalSpend view + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise ProxyException( + message="Prisma Client is not initialized", + type="internal_error", + param="None", + code=status.HTTP_401_UNAUTHORIZED, + ) + + ## RESET GLOBAL SPEND VIEW ### + async def is_materialized_global_spend_view() -> bool: + """ + Return True if materialized view exists + + Else False + """ + sql_query = """ + SELECT relname, relkind + FROM pg_class + WHERE relname = 'MonthlyGlobalSpend'; + """ + try: + resp = await prisma_client.db.query_raw(sql_query) + + assert resp[0]["relkind"] == "m" + return True + except Exception: + return False + + view_exists = await is_materialized_global_spend_view() + + if view_exists: + # refresh materialized view + sql_query = """ + REFRESH MATERIALIZED VIEW "MonthlyGlobalSpend"; + """ + try: + from litellm.proxy._types import CommonProxyErrors + from litellm.proxy.proxy_server import proxy_logging_obj + from litellm.proxy.utils import PrismaClient + + db_url = os.getenv("DATABASE_URL") + if db_url is None: + raise Exception(CommonProxyErrors.db_not_connected_error.value) + new_client = PrismaClient( + database_url=db_url, + proxy_logging_obj=proxy_logging_obj, + http_client={ + "timeout": 6000, + }, + ) + await new_client.db.connect() + await new_client.db.query_raw(sql_query) + verbose_proxy_logger.info("MonthlyGlobalSpend view refreshed") + return { + "message": "MonthlyGlobalSpend view refreshed", + "status": "success", + } + + except Exception as e: + verbose_proxy_logger.exception( + "Failed to refresh materialized view - {}".format(str(e)) + ) + return { + "message": "Failed to refresh materialized view", + "status": "failure", + } + + async def global_spend_for_internal_user( api_key: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8a9a6d707..f0ee7ea9f 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -92,6 +92,7 @@ def safe_deep_copy(data): if litellm.safe_memory_mode is True: return data + litellm_parent_otel_span: Optional[Any] = None # Step 1: Remove the litellm_parent_otel_span litellm_parent_otel_span = None if isinstance(data, dict): @@ -101,7 +102,7 @@ def safe_deep_copy(data): new_data = copy.deepcopy(data) # Step 2: re-add the litellm_parent_otel_span after doing a deep copy - if isinstance(data, dict): + if isinstance(data, dict) and litellm_parent_otel_span is not None: if "metadata" in data: data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span return new_data @@ -468,7 +469,7 @@ class ProxyLogging: # V1 implementation - backwards compatibility if callback.event_hook is None: - if callback.moderation_check == "pre_call": + if callback.moderation_check == "pre_call": # type: ignore return else: # Main - V2 Guardrails implementation @@ -881,7 +882,12 @@ class PrismaClient: org_list_transactons: dict = {} spend_log_transactions: List = [] - def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): + def __init__( + self, + database_url: str, + proxy_logging_obj: ProxyLogging, + http_client: Optional[Any] = None, + ): verbose_proxy_logger.debug( "LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'" ) @@ -912,7 +918,10 @@ class PrismaClient: # Now you can import the Prisma Client from prisma import Prisma # type: ignore verbose_proxy_logger.debug("Connecting Prisma Client to DB..") - self.db = Prisma() # Client to connect to Prisma db + if http_client is not None: + self.db = Prisma(http=http_client) + else: + self.db = Prisma() # Client to connect to Prisma db verbose_proxy_logger.debug("Success - Connected Prisma Client to DB") def hash_token(self, token: str): @@ -987,7 +996,7 @@ class PrismaClient: return else: ## check if required view exists ## - if required_view not in ret[0]["view_names"]: + if ret[0]["view_names"] and required_view not in ret[0]["view_names"]: await self.health_check() # make sure we can connect to db await self.db.execute_raw( """ @@ -1009,7 +1018,9 @@ class PrismaClient: else: # don't block execution if these views are missing # Convert lists to sets for efficient difference calculation - ret_view_names_set = set(ret[0]["view_names"]) + ret_view_names_set = ( + set(ret[0]["view_names"]) if ret[0]["view_names"] else set() + ) expected_views_set = set(expected_views) # Find missing views missing_views = expected_views_set - ret_view_names_set @@ -1291,13 +1302,13 @@ class PrismaClient: verbose_proxy_logger.debug( f"PrismaClient: get_data - args_passed_in: {args_passed_in}" ) + hashed_token: Optional[str] = None try: response: Any = None if (token is not None and table_name is None) or ( table_name is not None and table_name == "key" ): # check if plain text or hash - hashed_token = None if token is not None: if isinstance(token, str): hashed_token = token @@ -1306,7 +1317,7 @@ class PrismaClient: verbose_proxy_logger.debug( f"PrismaClient: find_unique for token: {hashed_token}" ) - if query_type == "find_unique": + if query_type == "find_unique" and hashed_token is not None: if token is None: raise HTTPException( status_code=400, @@ -1706,7 +1717,7 @@ class PrismaClient: updated_data = v updated_data = json.dumps(updated_data) updated_table_row = self.db.litellm_config.upsert( - where={"param_name": k}, + where={"param_name": k}, # type: ignore data={ "create": {"param_name": k, "param_value": updated_data}, # type: ignore "update": {"param_value": updated_data}, @@ -2302,7 +2313,12 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any: return instance except ImportError as e: # Re-raise the exception with a user-friendly message - raise ImportError(f"Could not import {instance_name} from {module_name}") from e + if instance_name and module_name: + raise ImportError( + f"Could not import {instance_name} from {module_name}" + ) from e + else: + raise e except Exception as e: raise e @@ -2377,12 +2393,12 @@ async def send_email(receiver_email, subject, html): try: # Establish a secure connection with the SMTP server - with smtplib.SMTP(smtp_host, smtp_port) as server: + with smtplib.SMTP(smtp_host, smtp_port) as server: # type: ignore if os.getenv("SMTP_TLS", "True") != "False": server.starttls() # Login to your email account - server.login(smtp_username, smtp_password) + server.login(smtp_username, smtp_password) # type: ignore # Send the email server.send_message(email_message) @@ -2945,7 +2961,7 @@ async def update_spend( if i >= n_retry_times: # If we've reached the maximum number of retries raise # Re-raise the last exception # Optionally, sleep for a bit before retrying - await asyncio.sleep(2**i) # Exponential backoff + await asyncio.sleep(2**i) # type: ignore except Exception as e: import traceback diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 1debf8171..c6d5ffc34 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -2110,7 +2110,6 @@ async def test_hf_completion_tgi_stream(): def test_openai_chat_completion_call(): litellm.set_verbose = False litellm.return_response_headers = True - print(f"making openai chat completion call") response = completion(model="gpt-3.5-turbo", messages=messages, stream=True) assert isinstance( response._hidden_params["additional_headers"][ @@ -2318,6 +2317,57 @@ def test_together_ai_completion_call_mistral(): pass +# # test on together ai completion call - starcoder +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_openai_o1_completion_call_streaming(sync_mode): + try: + litellm.set_verbose = False + if sync_mode: + response = completion( + model="o1-preview", + messages=messages, + stream=True, + ) + complete_response = "" + print(f"returned response object: {response}") + has_finish_reason = False + for idx, chunk in enumerate(response): + chunk, finished = streaming_format_tests(idx, chunk) + has_finish_reason = finished + if finished: + break + complete_response += chunk + if has_finish_reason is False: + raise Exception("Finish reason not set for last chunk") + if complete_response == "": + raise Exception("Empty response received") + else: + response = await acompletion( + model="o1-preview", + messages=messages, + stream=True, + ) + complete_response = "" + print(f"returned response object: {response}") + has_finish_reason = False + idx = 0 + async for chunk in response: + chunk, finished = streaming_format_tests(idx, chunk) + has_finish_reason = finished + if finished: + break + complete_response += chunk + idx += 1 + if has_finish_reason is False: + raise Exception("Finish reason not set for last chunk") + if complete_response == "": + raise Exception("Empty response received") + print(f"complete response: {complete_response}") + except Exception: + pytest.fail(f"error occurred: {traceback.format_exc()}") + + def test_together_ai_completion_call_starcoder_bad_key(): try: api_key = "bad-key" diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index cb70de505..57be6b0c4 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -71,7 +71,7 @@ class LakeraCategoryThresholds(TypedDict, total=False): jailbreak: float -class LitellmParams(TypedDict, total=False): +class LitellmParams(TypedDict): guardrail: str mode: str api_key: str @@ -87,6 +87,7 @@ class LitellmParams(TypedDict, total=False): # Presidio params output_parse_pii: Optional[bool] presidio_ad_hoc_recognizers: Optional[str] + mock_redacted_text: Optional[dict] class Guardrail(TypedDict): diff --git a/litellm/utils.py b/litellm/utils.py index 280691c8a..58ef4d49f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -120,11 +120,26 @@ with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") # Convert to str (if necessary) claude_json_str = json.dumps(json_data) import importlib.metadata +from concurrent.futures import ThreadPoolExecutor +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Literal, + Optional, + Tuple, + Type, + Union, + cast, + get_args, +) from openai import OpenAIError as OriginalError from ._logging import verbose_logger -from .caching import QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache +from .caching import Cache, QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache from .exceptions import ( APIConnectionError, APIError, @@ -150,31 +165,6 @@ from .types.llms.openai import ( ) from .types.router import LiteLLM_Params -try: - from .proxy.enterprise.enterprise_callbacks.generic_api_callback import ( - GenericAPILogger, - ) -except Exception as e: - verbose_logger.debug(f"Exception import enterprise features {str(e)}") - -from concurrent.futures import ThreadPoolExecutor -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Literal, - Optional, - Tuple, - Type, - Union, - cast, - get_args, -) - -from .caching import Cache - ####### ENVIRONMENT VARIABLES #################### # Adjust to your specific application needs / system capabilities. MAX_THREADS = 100 diff --git a/tests/llm_translation/Readme.md b/tests/llm_translation/Readme.md new file mode 100644 index 000000000..174c81b4e --- /dev/null +++ b/tests/llm_translation/Readme.md @@ -0,0 +1 @@ +More tests under `litellm/litellm/tests/*`. \ No newline at end of file diff --git a/tests/llm_translation/test_databricks.py b/tests/llm_translation/test_databricks.py new file mode 100644 index 000000000..067b188ed --- /dev/null +++ b/tests/llm_translation/test_databricks.py @@ -0,0 +1,502 @@ +import asyncio +import httpx +import json +import pytest +import sys +from typing import Any, Dict, List +from unittest.mock import MagicMock, Mock, patch + +import litellm +from litellm.exceptions import BadRequestError, InternalServerError +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.utils import CustomStreamWrapper + + +def mock_chat_response() -> Dict[str, Any]: + return { + "id": "chatcmpl_3f78f09a-489c-4b8d-a587-f162c7497891", + "object": "chat.completion", + "created": 1726285449, + "model": "dbrx-instruct-071224", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I'm an AI assistant. I'm doing well. How can I help?", + "function_call": None, + "tool_calls": None, + }, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 230, + "completion_tokens": 38, + "total_tokens": 268, + "completion_tokens_details": None, + }, + "system_fingerprint": None, + } + + +def mock_chat_streaming_response_chunks() -> List[str]: + return [ + json.dumps( + { + "id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3", + "object": "chat.completion.chunk", + "created": 1726469651, + "model": "dbrx-instruct-071224", + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "Hello"}, + "finish_reason": None, + "logprobs": None, + } + ], + "usage": { + "prompt_tokens": 230, + "completion_tokens": 1, + "total_tokens": 231, + }, + } + ), + json.dumps( + { + "id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3", + "object": "chat.completion.chunk", + "created": 1726469651, + "model": "dbrx-instruct-071224", + "choices": [ + { + "index": 0, + "delta": {"content": " world"}, + "finish_reason": None, + "logprobs": None, + } + ], + "usage": { + "prompt_tokens": 230, + "completion_tokens": 1, + "total_tokens": 231, + }, + } + ), + json.dumps( + { + "id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3", + "object": "chat.completion.chunk", + "created": 1726469651, + "model": "dbrx-instruct-071224", + "choices": [ + { + "index": 0, + "delta": {"content": "!"}, + "finish_reason": "stop", + "logprobs": None, + } + ], + "usage": { + "prompt_tokens": 230, + "completion_tokens": 1, + "total_tokens": 231, + }, + } + ), + ] + + +def mock_chat_streaming_response_chunks_bytes() -> List[bytes]: + string_chunks = mock_chat_streaming_response_chunks() + bytes_chunks = [chunk.encode("utf-8") + b"\n" for chunk in string_chunks] + # Simulate the end of the stream + bytes_chunks.append(b"") + return bytes_chunks + + +def mock_http_handler_chat_streaming_response() -> MagicMock: + mock_stream_chunks = mock_chat_streaming_response_chunks() + + def mock_iter_lines(): + for chunk in mock_stream_chunks: + for line in chunk.splitlines(): + yield line + + mock_response = MagicMock() + mock_response.iter_lines.side_effect = mock_iter_lines + mock_response.status_code = 200 + + return mock_response + + +def mock_http_handler_chat_async_streaming_response() -> MagicMock: + mock_stream_chunks = mock_chat_streaming_response_chunks() + + async def mock_iter_lines(): + for chunk in mock_stream_chunks: + for line in chunk.splitlines(): + yield line + + mock_response = MagicMock() + mock_response.aiter_lines.return_value = mock_iter_lines() + mock_response.status_code = 200 + + return mock_response + + +def mock_databricks_client_chat_streaming_response() -> MagicMock: + mock_stream_chunks = mock_chat_streaming_response_chunks_bytes() + + def mock_read_from_stream(size=-1): + if mock_stream_chunks: + return mock_stream_chunks.pop(0) + return b"" + + mock_response = MagicMock() + streaming_response_mock = MagicMock() + streaming_response_iterator_mock = MagicMock() + # Mock the __getitem__("content") method to return the streaming response + mock_response.__getitem__.return_value = streaming_response_mock + # Mock the streaming response __enter__ method to return the streaming response iterator + streaming_response_mock.__enter__.return_value = streaming_response_iterator_mock + + streaming_response_iterator_mock.read1.side_effect = mock_read_from_stream + streaming_response_iterator_mock.closed = False + + return mock_response + + +def mock_embedding_response() -> Dict[str, Any]: + return { + "object": "list", + "model": "bge-large-en-v1.5", + "data": [ + { + "index": 0, + "object": "embedding", + "embedding": [ + 0.06768798828125, + -0.01291656494140625, + -0.0501708984375, + 0.0245361328125, + -0.030364990234375, + ], + } + ], + "usage": { + "prompt_tokens": 8, + "total_tokens": 8, + "completion_tokens": 0, + "completion_tokens_details": None, + }, + } + + +@pytest.mark.parametrize("set_base", [True, False]) +def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base): + if set_base: + monkeypatch.setenv( + "DATABRICKS_API_BASE", + "https://my.workspace.cloud.databricks.com/serving-endpoints", + ) + monkeypatch.delenv( + "DATABRICKS_API_KEY", + ) + err_msg = "A call is being made to LLM Provider but no key is set" + else: + monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey") + monkeypatch.delenv("DATABRICKS_API_BASE") + err_msg = "A call is being made to LLM Provider but no api base is set" + + with pytest.raises(BadRequestError) as exc: + litellm.completion( + model="databricks/dbrx-instruct-071224", + messages={"role": "user", "content": "How are you?"}, + ) + assert err_msg in str(exc) + + with pytest.raises(BadRequestError) as exc: + litellm.embedding( + model="databricks/bge-12312", + input=["Hello", "World"], + ) + assert err_msg in str(exc) + + +def test_completions_with_sync_http_handler(monkeypatch): + base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" + api_key = "dapimykey" + monkeypatch.setenv("DATABRICKS_API_BASE", base_url) + monkeypatch.setenv("DATABRICKS_API_KEY", api_key) + + sync_handler = HTTPHandler() + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = mock_chat_response() + + expected_response_json = { + **mock_chat_response(), + **{ + "model": "databricks/dbrx-instruct-071224", + }, + } + + messages = [{"role": "user", "content": "How are you?"}] + + with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post: + response = litellm.completion( + model="databricks/dbrx-instruct-071224", + messages=messages, + client=sync_handler, + temperature=0.5, + extraparam="testpassingextraparam", + ) + assert response.to_dict() == expected_response_json + + mock_post.assert_called_once_with( + f"{base_url}/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + data=json.dumps( + { + "model": "dbrx-instruct-071224", + "messages": messages, + "temperature": 0.5, + "extraparam": "testpassingextraparam", + "stream": False, + } + ), + ) + + +def test_completions_with_async_http_handler(monkeypatch): + base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" + api_key = "dapimykey" + monkeypatch.setenv("DATABRICKS_API_BASE", base_url) + monkeypatch.setenv("DATABRICKS_API_KEY", api_key) + + async_handler = AsyncHTTPHandler() + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = mock_chat_response() + + expected_response_json = { + **mock_chat_response(), + **{ + "model": "databricks/dbrx-instruct-071224", + }, + } + + messages = [{"role": "user", "content": "How are you?"}] + + with patch.object( + AsyncHTTPHandler, "post", return_value=mock_response + ) as mock_post: + response = asyncio.run( + litellm.acompletion( + model="databricks/dbrx-instruct-071224", + messages=messages, + client=async_handler, + temperature=0.5, + extraparam="testpassingextraparam", + ) + ) + assert response.to_dict() == expected_response_json + + mock_post.assert_called_once_with( + f"{base_url}/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + data=json.dumps( + { + "model": "dbrx-instruct-071224", + "messages": messages, + "temperature": 0.5, + "extraparam": "testpassingextraparam", + "stream": False, + } + ), + ) + + +def test_completions_streaming_with_sync_http_handler(monkeypatch): + base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" + api_key = "dapimykey" + monkeypatch.setenv("DATABRICKS_API_BASE", base_url) + monkeypatch.setenv("DATABRICKS_API_KEY", api_key) + + sync_handler = HTTPHandler() + + messages = [{"role": "user", "content": "How are you?"}] + mock_response = mock_http_handler_chat_streaming_response() + + with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post: + response_stream: CustomStreamWrapper = litellm.completion( + model="databricks/dbrx-instruct-071224", + messages=messages, + client=sync_handler, + temperature=0.5, + extraparam="testpassingextraparam", + stream=True, + ) + response = list(response_stream) + assert "dbrx-instruct-071224" in str(response) + assert "chatcmpl" in str(response) + assert len(response) == 4 + + mock_post.assert_called_once_with( + f"{base_url}/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + data=json.dumps( + { + "model": "dbrx-instruct-071224", + "messages": messages, + "temperature": 0.5, + "stream": True, + "extraparam": "testpassingextraparam", + } + ), + stream=True, + ) + + +def test_completions_streaming_with_async_http_handler(monkeypatch): + base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" + api_key = "dapimykey" + monkeypatch.setenv("DATABRICKS_API_BASE", base_url) + monkeypatch.setenv("DATABRICKS_API_KEY", api_key) + + async_handler = AsyncHTTPHandler() + + messages = [{"role": "user", "content": "How are you?"}] + mock_response = mock_http_handler_chat_async_streaming_response() + + with patch.object( + AsyncHTTPHandler, "post", return_value=mock_response + ) as mock_post: + response_stream: CustomStreamWrapper = asyncio.run( + litellm.acompletion( + model="databricks/dbrx-instruct-071224", + messages=messages, + client=async_handler, + temperature=0.5, + extraparam="testpassingextraparam", + stream=True, + ) + ) + + # Use async list gathering for the response + async def gather_responses(): + return [item async for item in response_stream] + + response = asyncio.run(gather_responses()) + assert "dbrx-instruct-071224" in str(response) + assert "chatcmpl" in str(response) + assert len(response) == 4 + + mock_post.assert_called_once_with( + f"{base_url}/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + data=json.dumps( + { + "model": "dbrx-instruct-071224", + "messages": messages, + "temperature": 0.5, + "stream": True, + "extraparam": "testpassingextraparam", + } + ), + stream=True, + ) + + +def test_embeddings_with_sync_http_handler(monkeypatch): + base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" + api_key = "dapimykey" + monkeypatch.setenv("DATABRICKS_API_BASE", base_url) + monkeypatch.setenv("DATABRICKS_API_KEY", api_key) + + sync_handler = HTTPHandler() + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = mock_embedding_response() + + inputs = ["Hello", "World"] + + with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post: + response = litellm.embedding( + model="databricks/bge-large-en-v1.5", + input=inputs, + client=sync_handler, + extraparam="testpassingextraparam", + ) + assert response.to_dict() == mock_embedding_response() + + mock_post.assert_called_once_with( + f"{base_url}/embeddings", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + data=json.dumps( + { + "model": "bge-large-en-v1.5", + "input": inputs, + "extraparam": "testpassingextraparam", + } + ), + ) + + +def test_embeddings_with_async_http_handler(monkeypatch): + base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints" + api_key = "dapimykey" + monkeypatch.setenv("DATABRICKS_API_BASE", base_url) + monkeypatch.setenv("DATABRICKS_API_KEY", api_key) + + async_handler = AsyncHTTPHandler() + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = mock_embedding_response() + + inputs = ["Hello", "World"] + + with patch.object( + AsyncHTTPHandler, "post", return_value=mock_response + ) as mock_post: + response = asyncio.run( + litellm.aembedding( + model="databricks/bge-large-en-v1.5", + input=inputs, + client=async_handler, + extraparam="testpassingextraparam", + ) + ) + assert response.to_dict() == mock_embedding_response() + + mock_post.assert_called_once_with( + f"{base_url}/embeddings", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + data=json.dumps( + { + "model": "bge-large-en-v1.5", + "input": inputs, + "extraparam": "testpassingextraparam", + } + ), + )