diff --git a/db_scripts/create_views.py b/db_scripts/create_views.py
index 510bd67c3f..cbf00605f4 100644
--- a/db_scripts/create_views.py
+++ b/db_scripts/create_views.py
@@ -6,10 +6,14 @@ import asyncio
import os
# Enter your DATABASE_URL here
-os.environ["DATABASE_URL"] = "postgresql://xxxxxxx"
+
from prisma import Prisma
-db = Prisma()
+db = Prisma(
+ http={
+ "timeout": 60000,
+ },
+)
async def check_view_exists():
@@ -47,25 +51,22 @@ async def check_view_exists():
print("LiteLLM_VerificationTokenView Created!") # noqa
- try:
- await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
- print("MonthlyGlobalSpend Exists!") # noqa
- except Exception as e:
- sql_query = """
- CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
+ sql_query = """
+ CREATE MATERIALIZED VIEW IF NOT EXISTS "MonthlyGlobalSpend" AS
SELECT
- DATE("startTime") AS date,
- SUM("spend") AS spend
+ DATE_TRUNC('day', "startTime") AS date,
+ SUM("spend") AS spend
FROM
- "LiteLLM_SpendLogs"
+ "LiteLLM_SpendLogs"
WHERE
- "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
+ "startTime" >= CURRENT_DATE - INTERVAL '30 days'
GROUP BY
- DATE("startTime");
- """
- await db.execute_raw(query=sql_query)
+ DATE_TRUNC('day', "startTime");
+ """
+ # Execute the queries
+ await db.execute_raw(query=sql_query)
- print("MonthlyGlobalSpend Created!") # noqa
+ print("MonthlyGlobalSpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
diff --git a/docs/my-website/docs/observability/opentelemetry_integration.md b/docs/my-website/docs/observability/opentelemetry_integration.md
new file mode 100644
index 0000000000..3a27ffc391
--- /dev/null
+++ b/docs/my-website/docs/observability/opentelemetry_integration.md
@@ -0,0 +1,78 @@
+import Image from '@theme/IdealImage';
+import Tabs from '@theme/Tabs';
+import TabItem from '@theme/TabItem';
+
+# OpenTelemetry - Tracing LLMs with any observability tool
+
+OpenTelemetry is a CNCF standard for observability. It connects to any observability tool, such as Jaeger, Zipkin, Datadog, New Relic, Traceloop and others.
+
+
+
+## Getting Started
+
+Install the OpenTelemetry SDK:
+
+```
+pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp
+```
+
+Set the environment variables (different providers may require different variables):
+
+
+
+
+
+
+```shell
+OTEL_EXPORTER="otlp_http"
+OTEL_ENDPOINT="https://api.traceloop.com"
+OTEL_HEADERS="Authorization=Bearer%20"
+```
+
+
+
+
+
+```shell
+OTEL_EXPORTER="otlp_http"
+OTEL_ENDPOINT="http:/0.0.0.0:4317"
+```
+
+
+
+
+
+```shell
+OTEL_EXPORTER="otlp_grpc"
+OTEL_ENDPOINT="http:/0.0.0.0:4317"
+```
+
+
+
+
+
+Use just 2 lines of code, to instantly log your LLM responses **across all providers** with OpenTelemetry:
+
+```python
+litellm.callbacks = ["otel"]
+```
+
+## Redacting Messages, Response Content from OpenTelemetry Logging
+
+### Redact Messages and Responses from all OpenTelemetry Logging
+
+Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to OpenTelemetry, but request metadata will still be logged.
+
+### Redact Messages and Responses from specific OpenTelemetry Logging
+
+In the metadata typically passed for text completion or embedding calls you can set specific keys to mask the messages and responses for this call.
+
+Setting `mask_input` to `True` will mask the input from being logged for this call
+
+Setting `mask_output` to `True` will make the output from being logged for this call.
+
+Be aware that if you are continuing an existing trace, and you set `update_trace_keys` to include either `input` or `output` and you set the corresponding `mask_input` or `mask_output`, then that trace will have its existing input and/or output replaced with a redacted message.
+
+## Support
+
+For any question or issue with the integration you can reach out to the OpenLLMetry maintainers on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).
diff --git a/docs/my-website/docs/observability/traceloop_integration.md b/docs/my-website/docs/observability/traceloop_integration.md
deleted file mode 100644
index 6f02aa229d..0000000000
--- a/docs/my-website/docs/observability/traceloop_integration.md
+++ /dev/null
@@ -1,36 +0,0 @@
-import Image from '@theme/IdealImage';
-
-# Traceloop (OpenLLMetry) - Tracing LLMs with OpenTelemetry
-
-[Traceloop](https://traceloop.com) is a platform for monitoring and debugging the quality of your LLM outputs.
-It provides you with a way to track the performance of your LLM application; rollout changes with confidence; and debug issues in production.
-It is based on [OpenTelemetry](https://opentelemetry.io), so it can provide full visibility to your LLM requests, as well vector DB usage, and other infra in your stack.
-
-
-
-## Getting Started
-
-Install the Traceloop SDK:
-
-```
-pip install traceloop-sdk
-```
-
-Use just 2 lines of code, to instantly log your LLM responses with OpenTelemetry:
-
-```python
-Traceloop.init(app_name=, disable_batch=True)
-litellm.success_callback = ["traceloop"]
-```
-
-Make sure to properly set a destination to your traces. See [OpenLLMetry docs](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for options.
-
-To get better visualizations on how your code behaves, you may want to annotate specific parts of your LLM chain. See [Traceloop docs on decorators](https://traceloop.com/docs/python-sdk/decorators) for more information.
-
-## Exporting traces to other systems (e.g. Datadog, New Relic, and others)
-
-Since OpenLLMetry uses OpenTelemetry to send data, you can easily export your traces to other systems, such as Datadog, New Relic, and others. See [OpenLLMetry docs on exporters](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for more information.
-
-## Support
-
-For any question or issue with integration you can reach out to the Traceloop team on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).
diff --git a/docs/my-website/docs/proxy/logging.md b/docs/my-website/docs/proxy/logging.md
index f7b650f7a5..d20510ac72 100644
--- a/docs/my-website/docs/proxy/logging.md
+++ b/docs/my-website/docs/proxy/logging.md
@@ -600,6 +600,52 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
+
+
+#### Quick Start - Log to Traceloop
+
+**Step 1:**
+Add the following to your env
+
+```shell
+OTEL_EXPORTER="otlp_http"
+OTEL_ENDPOINT="https://api.traceloop.com"
+OTEL_HEADERS="Authorization=Bearer%20"
+```
+
+**Step 2:** Add `otel` as a callbacks
+
+```shell
+litellm_settings:
+ callbacks: ["otel"]
+```
+
+**Step 3**: Start the proxy, make a test request
+
+Start proxy
+
+```shell
+litellm --config config.yaml --detailed_debug
+```
+
+Test Request
+
+```shell
+curl --location 'http://0.0.0.0:4000/chat/completions' \
+ --header 'Content-Type: application/json' \
+ --data ' {
+ "model": "gpt-3.5-turbo",
+ "messages": [
+ {
+ "role": "user",
+ "content": "what llm are you"
+ }
+ ]
+ }'
+```
+
+
+
#### Quick Start - Log to OTEL Collector
@@ -694,52 +740,6 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
-
-
-#### Quick Start - Log to Traceloop
-
-**Step 1:** Install the `traceloop-sdk` SDK
-
-```shell
-pip install traceloop-sdk==0.21.2
-```
-
-**Step 2:** Add `traceloop` as a success_callback
-
-```shell
-litellm_settings:
- success_callback: ["traceloop"]
-
-environment_variables:
- TRACELOOP_API_KEY: "XXXXX"
-```
-
-**Step 3**: Start the proxy, make a test request
-
-Start proxy
-
-```shell
-litellm --config config.yaml --detailed_debug
-```
-
-Test Request
-
-```shell
-curl --location 'http://0.0.0.0:4000/chat/completions' \
- --header 'Content-Type: application/json' \
- --data ' {
- "model": "gpt-3.5-turbo",
- "messages": [
- {
- "role": "user",
- "content": "what llm are you"
- }
- ]
- }'
-```
-
-
-
** 🎉 Expect to see this trace logged in your OTEL collector**
diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js
index 7e2a2050b6..6dafb5478b 100644
--- a/docs/my-website/sidebars.js
+++ b/docs/my-website/sidebars.js
@@ -255,6 +255,7 @@ const sidebars = {
type: "category",
label: "Logging & Observability",
items: [
+ "observability/opentelemetry_integration",
"observability/langfuse_integration",
"observability/logfire_integration",
"observability/gcs_bucket_integration",
@@ -271,7 +272,7 @@ const sidebars = {
"observability/openmeter",
"observability/promptlayer_integration",
"observability/wandb_integration",
- "observability/traceloop_integration",
+ "observability/slack_integration",
"observability/athina_integration",
"observability/lunary_integration",
"observability/greenscale_integration",
diff --git a/litellm/__init__.py b/litellm/__init__.py
index 8c11ff1c4e..222466162b 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -948,10 +948,10 @@ from .llms.OpenAI.openai import (
AzureAIStudioConfig,
)
from .llms.mistral.mistral_chat_transformation import MistralConfig
-from .llms.OpenAI.o1_transformation import (
+from .llms.OpenAI.chat.o1_transformation import (
OpenAIO1Config,
)
-from .llms.OpenAI.gpt_transformation import (
+from .llms.OpenAI.chat.gpt_transformation import (
OpenAIGPTConfig,
)
from .llms.nvidia_nim import NvidiaNimConfig
diff --git a/litellm/integrations/SlackAlerting/slack_alerting.py b/litellm/integrations/SlackAlerting/slack_alerting.py
index fa7d4bc90d..bc10871499 100644
--- a/litellm/integrations/SlackAlerting/slack_alerting.py
+++ b/litellm/integrations/SlackAlerting/slack_alerting.py
@@ -1616,6 +1616,9 @@ Model Info:
:param time_range: A string specifying the time range, e.g., "1d", "7d", "30d"
"""
+ if self.alerting is None or "spend_reports" not in self.alert_types:
+ return
+
try:
from litellm.proxy.spend_tracking.spend_management_endpoints import (
_get_spend_report_for_time_range,
diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py
index 25512716cd..39c8f2b1e7 100644
--- a/litellm/integrations/custom_guardrail.py
+++ b/litellm/integrations/custom_guardrail.py
@@ -11,7 +11,7 @@ class CustomGuardrail(CustomLogger):
self,
guardrail_name: Optional[str] = None,
event_hook: Optional[GuardrailEventHooks] = None,
- **kwargs
+ **kwargs,
):
self.guardrail_name = guardrail_name
self.event_hook: Optional[GuardrailEventHooks] = event_hook
@@ -28,10 +28,13 @@ class CustomGuardrail(CustomLogger):
requested_guardrails,
)
- if self.guardrail_name not in requested_guardrails:
+ if (
+ self.guardrail_name not in requested_guardrails
+ and event_type.value != "logging_only"
+ ):
return False
- if self.event_hook != event_type:
+ if self.event_hook != event_type.value:
return False
return True
diff --git a/litellm/integrations/traceloop.py b/litellm/integrations/traceloop.py
index e1c419c6f7..d2168caabb 100644
--- a/litellm/integrations/traceloop.py
+++ b/litellm/integrations/traceloop.py
@@ -4,6 +4,11 @@ import litellm
class TraceloopLogger:
+ """
+ WARNING: DEPRECATED
+ Use the OpenTelemetry standard integration instead
+ """
+
def __init__(self):
try:
from traceloop.sdk.tracing.tracing import TracerWrapper
diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py
index 1a624c5f89..91e9274e80 100644
--- a/litellm/litellm_core_utils/litellm_logging.py
+++ b/litellm/litellm_core_utils/litellm_logging.py
@@ -90,6 +90,13 @@ from ..integrations.supabase import Supabase
from ..integrations.traceloop import TraceloopLogger
from ..integrations.weights_biases import WeightsBiasesLogger
+try:
+ from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import (
+ GenericAPILogger,
+ )
+except Exception as e:
+ verbose_logger.debug(f"Exception import enterprise features {str(e)}")
+
_in_memory_loggers: List[Any] = []
### GLOBAL VARIABLES ###
@@ -145,7 +152,41 @@ class ServiceTraceIDCache:
return None
+import hashlib
+
+
+class DynamicLoggingCache:
+ """
+ Prevent memory leaks caused by initializing new logging clients on each request.
+
+ Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
+ """
+
+ def __init__(self) -> None:
+ self.cache = InMemoryCache()
+
+ def get_cache_key(self, args: dict) -> str:
+ args_str = json.dumps(args, sort_keys=True)
+ cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
+ return cache_key
+
+ def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
+ key_name = self.get_cache_key(
+ args={**credentials, "service_name": service_name}
+ )
+ response = self.cache.get_cache(key=key_name)
+ return response
+
+ def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
+ key_name = self.get_cache_key(
+ args={**credentials, "service_name": service_name}
+ )
+ self.cache.set_cache(key=key_name, value=logging_obj)
+ return None
+
+
in_memory_trace_id_cache = ServiceTraceIDCache()
+in_memory_dynamic_logger_cache = DynamicLoggingCache()
class Logging:
@@ -324,10 +365,10 @@ class Logging:
print_verbose(f"\033[92m{curl_command}\033[0m\n", log_level="DEBUG")
# log raw request to provider (like LangFuse) -- if opted in.
if log_raw_request_response is True:
+ _litellm_params = self.model_call_details.get("litellm_params", {})
+ _metadata = _litellm_params.get("metadata", {}) or {}
try:
# [Non-blocking Extra Debug Information in metadata]
- _litellm_params = self.model_call_details.get("litellm_params", {})
- _metadata = _litellm_params.get("metadata", {}) or {}
if (
turn_off_message_logging is not None
and turn_off_message_logging is True
@@ -362,7 +403,7 @@ class Logging:
callbacks = litellm.input_callback + self.dynamic_input_callbacks
for callback in callbacks:
try:
- if callback == "supabase":
+ if callback == "supabase" and supabaseClient is not None:
verbose_logger.debug("reaches supabase for logging!")
model = self.model_call_details["model"]
messages = self.model_call_details["input"]
@@ -396,7 +437,9 @@ class Logging:
messages=self.messages,
kwargs=self.model_call_details,
)
- elif callable(callback): # custom logger functions
+ elif (
+ callable(callback) and customLogger is not None
+ ): # custom logger functions
customLogger.log_input_event(
model=self.model,
messages=self.messages,
@@ -615,7 +658,7 @@ class Logging:
self.model_call_details["litellm_params"]["metadata"][
"hidden_params"
- ] = result._hidden_params
+ ] = getattr(result, "_hidden_params", {})
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
@@ -645,6 +688,7 @@ class Logging:
litellm.max_budget
and self.stream is False
and result is not None
+ and isinstance(result, dict)
and "content" in result
):
time_diff = (end_time - start_time).total_seconds()
@@ -652,7 +696,7 @@ class Logging:
litellm._current_cost += litellm.completion_cost(
model=self.model,
prompt="",
- completion=result["content"],
+ completion=getattr(result, "content", ""),
total_time=float_diff,
)
@@ -758,7 +802,7 @@ class Logging:
):
print_verbose("no-log request, skipping logging")
continue
- if callback == "lite_debugger":
+ if callback == "lite_debugger" and liteDebuggerClient is not None:
print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
print_verbose(
@@ -774,7 +818,7 @@ class Logging:
call_type=self.call_type,
stream=self.stream,
)
- if callback == "promptlayer":
+ if callback == "promptlayer" and promptLayerLogger is not None:
print_verbose("reaches promptlayer for logging!")
promptLayerLogger.log_event(
kwargs=self.model_call_details,
@@ -783,7 +827,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
- if callback == "supabase":
+ if callback == "supabase" and supabaseClient is not None:
print_verbose("reaches supabase for logging!")
kwargs = self.model_call_details
@@ -811,7 +855,7 @@ class Logging:
),
print_verbose=print_verbose,
)
- if callback == "wandb":
+ if callback == "wandb" and weightsBiasesLogger is not None:
print_verbose("reaches wandb for logging!")
weightsBiasesLogger.log_event(
kwargs=self.model_call_details,
@@ -820,8 +864,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
- if callback == "logfire":
- global logfireLogger
+ if callback == "logfire" and logfireLogger is not None:
verbose_logger.debug("reaches logfire for success logging!")
kwargs = {}
for k, v in self.model_call_details.items():
@@ -844,10 +887,10 @@ class Logging:
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
- level=LogfireLevel.INFO.value,
+ level=LogfireLevel.INFO.value, # type: ignore
)
- if callback == "lunary":
+ if callback == "lunary" and lunaryLogger is not None:
print_verbose("reaches lunary for logging!")
model = self.model
kwargs = self.model_call_details
@@ -882,7 +925,7 @@ class Logging:
run_id=self.litellm_call_id,
print_verbose=print_verbose,
)
- if callback == "helicone":
+ if callback == "helicone" and heliconeLogger is not None:
print_verbose("reaches helicone for logging!")
model = self.model
messages = self.model_call_details["input"]
@@ -924,6 +967,7 @@ class Logging:
else:
print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"]
+
temp_langfuse_logger = langFuseLogger
if langFuseLogger is None or (
(
@@ -941,27 +985,45 @@ class Logging:
and self.langfuse_host != langFuseLogger.langfuse_host
)
):
- temp_langfuse_logger = LangFuseLogger(
- langfuse_public_key=self.langfuse_public_key,
- langfuse_secret=self.langfuse_secret,
- langfuse_host=self.langfuse_host,
- )
- _response = temp_langfuse_logger.log_event(
- kwargs=kwargs,
- response_obj=result,
- start_time=start_time,
- end_time=end_time,
- user_id=kwargs.get("user", None),
- print_verbose=print_verbose,
- )
- if _response is not None and isinstance(_response, dict):
- _trace_id = _response.get("trace_id", None)
- if _trace_id is not None:
- in_memory_trace_id_cache.set_cache(
- litellm_call_id=self.litellm_call_id,
- service_name="langfuse",
- trace_id=_trace_id,
+ credentials = {
+ "langfuse_public_key": self.langfuse_public_key,
+ "langfuse_secret": self.langfuse_secret,
+ "langfuse_host": self.langfuse_host,
+ }
+ temp_langfuse_logger = (
+ in_memory_dynamic_logger_cache.get_cache(
+ credentials=credentials, service_name="langfuse"
)
+ )
+ if temp_langfuse_logger is None:
+ temp_langfuse_logger = LangFuseLogger(
+ langfuse_public_key=self.langfuse_public_key,
+ langfuse_secret=self.langfuse_secret,
+ langfuse_host=self.langfuse_host,
+ )
+ in_memory_dynamic_logger_cache.set_cache(
+ credentials=credentials,
+ service_name="langfuse",
+ logging_obj=temp_langfuse_logger,
+ )
+
+ if temp_langfuse_logger is not None:
+ _response = temp_langfuse_logger.log_event(
+ kwargs=kwargs,
+ response_obj=result,
+ start_time=start_time,
+ end_time=end_time,
+ user_id=kwargs.get("user", None),
+ print_verbose=print_verbose,
+ )
+ if _response is not None and isinstance(_response, dict):
+ _trace_id = _response.get("trace_id", None)
+ if _trace_id is not None:
+ in_memory_trace_id_cache.set_cache(
+ litellm_call_id=self.litellm_call_id,
+ service_name="langfuse",
+ trace_id=_trace_id,
+ )
if callback == "generic":
global genericAPILogger
verbose_logger.debug("reaches langfuse for success logging!")
@@ -982,7 +1044,7 @@ class Logging:
print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"]
if genericAPILogger is None:
- genericAPILogger = GenericAPILogger()
+ genericAPILogger = GenericAPILogger() # type: ignore
genericAPILogger.log_event(
kwargs=kwargs,
response_obj=result,
@@ -1022,7 +1084,7 @@ class Logging:
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
- if callback == "greenscale":
+ if callback == "greenscale" and greenscaleLogger is not None:
kwargs = {}
for k, v in self.model_call_details.items():
if (
@@ -1066,7 +1128,7 @@ class Logging:
result = kwargs["complete_streaming_response"]
# only add to cache once we have a complete streaming response
litellm.cache.add_cache(result, **kwargs)
- if callback == "athina":
+ if callback == "athina" and athinaLogger is not None:
deep_copy = {}
for k, v in self.model_call_details.items():
deep_copy[k] = v
@@ -1224,6 +1286,7 @@ class Logging:
"atranscription", False
)
is not True
+ and customLogger is not None
): # custom logger functions
print_verbose(
f"success callbacks: Running Custom Callback Function"
@@ -1423,9 +1486,8 @@ class Logging:
await litellm.cache.async_add_cache(result, **kwargs)
else:
litellm.cache.add_cache(result, **kwargs)
- if callback == "openmeter":
- global openMeterLogger
- if self.stream == True:
+ if callback == "openmeter" and openMeterLogger is not None:
+ if self.stream is True:
if (
"async_complete_streaming_response"
in self.model_call_details
@@ -1645,33 +1707,9 @@ class Logging:
)
for callback in callbacks:
try:
- if callback == "lite_debugger":
- print_verbose("reaches lite_debugger for logging!")
- print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
- result = {
- "model": self.model,
- "created": time.time(),
- "error": traceback_exception,
- "usage": {
- "prompt_tokens": prompt_token_calculator(
- self.model, messages=self.messages
- ),
- "completion_tokens": 0,
- },
- }
- liteDebuggerClient.log_event(
- model=self.model,
- messages=self.messages,
- end_user=self.model_call_details.get("user", "default"),
- response_obj=result,
- start_time=start_time,
- end_time=end_time,
- litellm_call_id=self.litellm_call_id,
- print_verbose=print_verbose,
- call_type=self.call_type,
- stream=self.stream,
- )
- if callback == "lunary":
+ if callback == "lite_debugger" and liteDebuggerClient is not None:
+ pass
+ elif callback == "lunary" and lunaryLogger is not None:
print_verbose("reaches lunary for logging error!")
model = self.model
@@ -1685,6 +1723,7 @@ class Logging:
)
lunaryLogger.log_event(
+ kwargs=self.model_call_details,
type=_type,
event="error",
user_id=self.model_call_details.get("user", "default"),
@@ -1704,22 +1743,11 @@ class Logging:
print_verbose(
f"capture exception not initialized: {capture_exception}"
)
- elif callback == "supabase":
+ elif callback == "supabase" and supabaseClient is not None:
print_verbose("reaches supabase for logging!")
print_verbose(f"supabaseClient: {supabaseClient}")
- result = {
- "model": model,
- "created": time.time(),
- "error": traceback_exception,
- "usage": {
- "prompt_tokens": prompt_token_calculator(
- model, messages=self.messages
- ),
- "completion_tokens": 0,
- },
- }
supabaseClient.log_event(
- model=self.model,
+ model=self.model if hasattr(self, "model") else "",
messages=self.messages,
end_user=self.model_call_details.get("user", "default"),
response_obj=result,
@@ -1728,7 +1756,9 @@ class Logging:
litellm_call_id=self.model_call_details["litellm_call_id"],
print_verbose=print_verbose,
)
- if callable(callback): # custom logger functions
+ if (
+ callable(callback) and customLogger is not None
+ ): # custom logger functions
customLogger.log_event(
kwargs=self.model_call_details,
response_obj=result,
@@ -1809,13 +1839,13 @@ class Logging:
start_time=start_time,
end_time=end_time,
response_obj=None,
- user_id=kwargs.get("user", None),
+ user_id=self.model_call_details.get("user", None),
print_verbose=print_verbose,
status_message=str(exception),
level="ERROR",
kwargs=self.model_call_details,
)
- if callback == "logfire":
+ if callback == "logfire" and logfireLogger is not None:
verbose_logger.debug("reaches logfire for failure logging!")
kwargs = {}
for k, v in self.model_call_details.items():
@@ -1830,7 +1860,7 @@ class Logging:
response_obj=result,
start_time=start_time,
end_time=end_time,
- level=LogfireLevel.ERROR.value,
+ level=LogfireLevel.ERROR.value, # type: ignore
print_verbose=print_verbose,
)
@@ -1873,7 +1903,9 @@ class Logging:
start_time=start_time,
end_time=end_time,
) # type: ignore
- if callable(callback): # custom logger functions
+ if (
+ callable(callback) and customLogger is not None
+ ): # custom logger functions
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=result,
@@ -1966,7 +1998,7 @@ def set_callbacks(callback_list, function_id=None):
)
sentry_sdk_instance.init(
dsn=os.environ.get("SENTRY_DSN"),
- traces_sample_rate=float(sentry_trace_rate),
+ traces_sample_rate=float(sentry_trace_rate), # type: ignore
)
capture_exception = sentry_sdk_instance.capture_exception
add_breadcrumb = sentry_sdk_instance.add_breadcrumb
@@ -2411,12 +2443,11 @@ def get_standard_logging_object_payload(
saved_cache_cost: Optional[float] = None
if cache_hit is True:
- import time
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
saved_cache_cost = logging_obj._response_cost_calculator(
- result=init_response_obj, cache_hit=False
+ result=init_response_obj, cache_hit=False # type: ignore
)
## Get model cost information ##
@@ -2473,7 +2504,7 @@ def get_standard_logging_object_payload(
model_id=_model_id,
requester_ip_address=clean_metadata.get("requester_ip_address", None),
messages=kwargs.get("messages"),
- response=(
+ response=( # type: ignore
response_obj if len(response_obj.keys()) > 0 else init_response_obj
),
model_parameters=kwargs.get("optional_params", None),
diff --git a/litellm/llms/OpenAI/gpt_transformation.py b/litellm/llms/OpenAI/chat/gpt_transformation.py
similarity index 100%
rename from litellm/llms/OpenAI/gpt_transformation.py
rename to litellm/llms/OpenAI/chat/gpt_transformation.py
diff --git a/litellm/llms/OpenAI/chat/o1_handler.py b/litellm/llms/OpenAI/chat/o1_handler.py
new file mode 100644
index 0000000000..55dfe37151
--- /dev/null
+++ b/litellm/llms/OpenAI/chat/o1_handler.py
@@ -0,0 +1,95 @@
+"""
+Handler file for calls to OpenAI's o1 family of models
+
+Written separately to handle faking streaming for o1 models.
+"""
+
+import asyncio
+from typing import Any, Callable, List, Optional, Union
+
+from httpx._config import Timeout
+
+from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
+from litellm.llms.OpenAI.openai import OpenAIChatCompletion
+from litellm.types.utils import ModelResponse
+from litellm.utils import CustomStreamWrapper
+
+
+class OpenAIO1ChatCompletion(OpenAIChatCompletion):
+
+ async def mock_async_streaming(
+ self,
+ response: Any,
+ model: Optional[str],
+ logging_obj: Any,
+ ):
+ model_response = await response
+ completion_stream = MockResponseIterator(model_response=model_response)
+ streaming_response = CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider="openai",
+ logging_obj=logging_obj,
+ )
+ return streaming_response
+
+ def completion(
+ self,
+ model_response: ModelResponse,
+ timeout: Union[float, Timeout],
+ optional_params: dict,
+ logging_obj: Any,
+ model: Optional[str] = None,
+ messages: Optional[list] = None,
+ print_verbose: Optional[Callable[..., Any]] = None,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ acompletion: bool = False,
+ litellm_params=None,
+ logger_fn=None,
+ headers: Optional[dict] = None,
+ custom_prompt_dict: dict = {},
+ client=None,
+ organization: Optional[str] = None,
+ custom_llm_provider: Optional[str] = None,
+ drop_params: Optional[bool] = None,
+ ):
+ stream: Optional[bool] = optional_params.pop("stream", False)
+ response = super().completion(
+ model_response,
+ timeout,
+ optional_params,
+ logging_obj,
+ model,
+ messages,
+ print_verbose,
+ api_key,
+ api_base,
+ acompletion,
+ litellm_params,
+ logger_fn,
+ headers,
+ custom_prompt_dict,
+ client,
+ organization,
+ custom_llm_provider,
+ drop_params,
+ )
+
+ if stream is True:
+ if asyncio.iscoroutine(response):
+ return self.mock_async_streaming(
+ response=response, model=model, logging_obj=logging_obj # type: ignore
+ )
+
+ completion_stream = MockResponseIterator(model_response=response)
+ streaming_response = CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider="openai",
+ logging_obj=logging_obj,
+ )
+
+ return streaming_response
+ else:
+ return response
diff --git a/litellm/llms/OpenAI/o1_transformation.py b/litellm/llms/OpenAI/chat/o1_transformation.py
similarity index 100%
rename from litellm/llms/OpenAI/o1_transformation.py
rename to litellm/llms/OpenAI/chat/o1_transformation.py
diff --git a/litellm/llms/databricks/chat.py b/litellm/llms/databricks/chat.py
index 739abb91f7..343cdd3ff1 100644
--- a/litellm/llms/databricks/chat.py
+++ b/litellm/llms/databricks/chat.py
@@ -15,6 +15,8 @@ import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.llms.databricks.exceptions import DatabricksError
+from litellm.llms.databricks.streaming_utils import ModelResponseIterator
from litellm.types.llms.openai import (
ChatCompletionDeltaChunk,
ChatCompletionResponseMessage,
@@ -33,17 +35,6 @@ from ..base import BaseLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory
-class DatabricksError(Exception):
- def __init__(self, status_code, message):
- self.status_code = status_code
- self.message = message
- self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
- self.response = httpx.Response(status_code=status_code, request=self.request)
- super().__init__(
- self.message
- ) # Call the base class constructor with the parameters it needs
-
-
class DatabricksConfig:
"""
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
@@ -367,7 +358,7 @@ class DatabricksChatCompletion(BaseLLM):
status_code=e.response.status_code,
message=e.response.text,
)
- except httpx.TimeoutException as e:
+ except httpx.TimeoutException:
raise DatabricksError(status_code=408, message="Timeout error occurred.")
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
@@ -380,7 +371,7 @@ class DatabricksChatCompletion(BaseLLM):
)
response = ModelResponse(**response_json)
- response.model = custom_llm_provider + "/" + response.model
+ response.model = custom_llm_provider + "/" + (response.model or "")
if base_model is not None:
response._hidden_params["model"] = base_model
@@ -529,7 +520,7 @@ class DatabricksChatCompletion(BaseLLM):
response_json = response.json()
except httpx.HTTPStatusError as e:
raise DatabricksError(
- status_code=e.response.status_code, message=response.text
+ status_code=e.response.status_code, message=e.response.text
)
except httpx.TimeoutException as e:
raise DatabricksError(
@@ -540,7 +531,7 @@ class DatabricksChatCompletion(BaseLLM):
response = ModelResponse(**response_json)
- response.model = custom_llm_provider + "/" + response.model
+ response.model = custom_llm_provider + "/" + (response.model or "")
if base_model is not None:
response._hidden_params["model"] = base_model
@@ -657,7 +648,7 @@ class DatabricksChatCompletion(BaseLLM):
except httpx.HTTPStatusError as e:
raise DatabricksError(
status_code=e.response.status_code,
- message=response.text if response else str(e),
+ message=e.response.text,
)
except httpx.TimeoutException as e:
raise DatabricksError(status_code=408, message="Timeout error occurred.")
@@ -673,136 +664,3 @@ class DatabricksChatCompletion(BaseLLM):
)
return litellm.EmbeddingResponse(**response_json)
-
-
-class ModelResponseIterator:
- def __init__(self, streaming_response, sync_stream: bool):
- self.streaming_response = streaming_response
-
- def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
- try:
- processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
-
- text = ""
- tool_use: Optional[ChatCompletionToolCallChunk] = None
- is_finished = False
- finish_reason = ""
- usage: Optional[ChatCompletionUsageBlock] = None
-
- if processed_chunk.choices[0].delta.content is not None: # type: ignore
- text = processed_chunk.choices[0].delta.content # type: ignore
-
- if (
- processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
- and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
- and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
- and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
- is not None
- ):
- tool_use = ChatCompletionToolCallChunk(
- id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
- type="function",
- function=ChatCompletionToolCallFunctionChunk(
- name=processed_chunk.choices[0]
- .delta.tool_calls[0] # type: ignore
- .function.name,
- arguments=processed_chunk.choices[0]
- .delta.tool_calls[0] # type: ignore
- .function.arguments,
- ),
- index=processed_chunk.choices[0].index,
- )
-
- if processed_chunk.choices[0].finish_reason is not None:
- is_finished = True
- finish_reason = processed_chunk.choices[0].finish_reason
-
- if hasattr(processed_chunk, "usage") and isinstance(
- processed_chunk.usage, litellm.Usage
- ):
- usage_chunk: litellm.Usage = processed_chunk.usage
-
- usage = ChatCompletionUsageBlock(
- prompt_tokens=usage_chunk.prompt_tokens,
- completion_tokens=usage_chunk.completion_tokens,
- total_tokens=usage_chunk.total_tokens,
- )
-
- return GenericStreamingChunk(
- text=text,
- tool_use=tool_use,
- is_finished=is_finished,
- finish_reason=finish_reason,
- usage=usage,
- index=0,
- )
- except json.JSONDecodeError:
- raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
-
- # Sync iterator
- def __iter__(self):
- self.response_iterator = self.streaming_response
- return self
-
- def __next__(self):
- try:
- chunk = self.response_iterator.__next__()
- except StopIteration:
- raise StopIteration
- except ValueError as e:
- raise RuntimeError(f"Error receiving chunk from stream: {e}")
-
- try:
- chunk = chunk.replace("data:", "")
- chunk = chunk.strip()
- if len(chunk) > 0:
- json_chunk = json.loads(chunk)
- return self.chunk_parser(chunk=json_chunk)
- else:
- return GenericStreamingChunk(
- text="",
- is_finished=False,
- finish_reason="",
- usage=None,
- index=0,
- tool_use=None,
- )
- except StopIteration:
- raise StopIteration
- except ValueError as e:
- raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
-
- # Async iterator
- def __aiter__(self):
- self.async_response_iterator = self.streaming_response.__aiter__()
- return self
-
- async def __anext__(self):
- try:
- chunk = await self.async_response_iterator.__anext__()
- except StopAsyncIteration:
- raise StopAsyncIteration
- except ValueError as e:
- raise RuntimeError(f"Error receiving chunk from stream: {e}")
-
- try:
- chunk = chunk.replace("data:", "")
- chunk = chunk.strip()
- if chunk == "[DONE]":
- raise StopAsyncIteration
- if len(chunk) > 0:
- json_chunk = json.loads(chunk)
- return self.chunk_parser(chunk=json_chunk)
- else:
- return GenericStreamingChunk(
- text="",
- is_finished=False,
- finish_reason="",
- usage=None,
- index=0,
- tool_use=None,
- )
- except StopAsyncIteration:
- raise StopAsyncIteration
- except ValueError as e:
- raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
diff --git a/litellm/llms/databricks/exceptions.py b/litellm/llms/databricks/exceptions.py
new file mode 100644
index 0000000000..8bb3d435d0
--- /dev/null
+++ b/litellm/llms/databricks/exceptions.py
@@ -0,0 +1,12 @@
+import httpx
+
+
+class DatabricksError(Exception):
+ def __init__(self, status_code, message):
+ self.status_code = status_code
+ self.message = message
+ self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
+ self.response = httpx.Response(status_code=status_code, request=self.request)
+ super().__init__(
+ self.message
+ ) # Call the base class constructor with the parameters it needs
diff --git a/litellm/llms/databricks/streaming_utils.py b/litellm/llms/databricks/streaming_utils.py
new file mode 100644
index 0000000000..1b342f3c96
--- /dev/null
+++ b/litellm/llms/databricks/streaming_utils.py
@@ -0,0 +1,145 @@
+import json
+from typing import Optional
+
+import litellm
+from litellm.types.llms.openai import (
+ ChatCompletionDeltaChunk,
+ ChatCompletionResponseMessage,
+ ChatCompletionToolCallChunk,
+ ChatCompletionToolCallFunctionChunk,
+ ChatCompletionUsageBlock,
+)
+from litellm.types.utils import GenericStreamingChunk
+
+
+class ModelResponseIterator:
+ def __init__(self, streaming_response, sync_stream: bool):
+ self.streaming_response = streaming_response
+
+ def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
+ try:
+ processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
+
+ text = ""
+ tool_use: Optional[ChatCompletionToolCallChunk] = None
+ is_finished = False
+ finish_reason = ""
+ usage: Optional[ChatCompletionUsageBlock] = None
+
+ if processed_chunk.choices[0].delta.content is not None: # type: ignore
+ text = processed_chunk.choices[0].delta.content # type: ignore
+
+ if (
+ processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
+ and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
+ and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
+ and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
+ is not None
+ ):
+ tool_use = ChatCompletionToolCallChunk(
+ id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
+ type="function",
+ function=ChatCompletionToolCallFunctionChunk(
+ name=processed_chunk.choices[0]
+ .delta.tool_calls[0] # type: ignore
+ .function.name,
+ arguments=processed_chunk.choices[0]
+ .delta.tool_calls[0] # type: ignore
+ .function.arguments,
+ ),
+ index=processed_chunk.choices[0].index,
+ )
+
+ if processed_chunk.choices[0].finish_reason is not None:
+ is_finished = True
+ finish_reason = processed_chunk.choices[0].finish_reason
+
+ if hasattr(processed_chunk, "usage") and isinstance(
+ processed_chunk.usage, litellm.Usage
+ ):
+ usage_chunk: litellm.Usage = processed_chunk.usage
+
+ usage = ChatCompletionUsageBlock(
+ prompt_tokens=usage_chunk.prompt_tokens,
+ completion_tokens=usage_chunk.completion_tokens,
+ total_tokens=usage_chunk.total_tokens,
+ )
+
+ return GenericStreamingChunk(
+ text=text,
+ tool_use=tool_use,
+ is_finished=is_finished,
+ finish_reason=finish_reason,
+ usage=usage,
+ index=0,
+ )
+ except json.JSONDecodeError:
+ raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
+
+ # Sync iterator
+ def __iter__(self):
+ self.response_iterator = self.streaming_response
+ return self
+
+ def __next__(self):
+ try:
+ chunk = self.response_iterator.__next__()
+ except StopIteration:
+ raise StopIteration
+ except ValueError as e:
+ raise RuntimeError(f"Error receiving chunk from stream: {e}")
+
+ try:
+ chunk = chunk.replace("data:", "")
+ chunk = chunk.strip()
+ if len(chunk) > 0:
+ json_chunk = json.loads(chunk)
+ return self.chunk_parser(chunk=json_chunk)
+ else:
+ return GenericStreamingChunk(
+ text="",
+ is_finished=False,
+ finish_reason="",
+ usage=None,
+ index=0,
+ tool_use=None,
+ )
+ except StopIteration:
+ raise StopIteration
+ except ValueError as e:
+ raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
+
+ # Async iterator
+ def __aiter__(self):
+ self.async_response_iterator = self.streaming_response.__aiter__()
+ return self
+
+ async def __anext__(self):
+ try:
+ chunk = await self.async_response_iterator.__anext__()
+ except StopAsyncIteration:
+ raise StopAsyncIteration
+ except ValueError as e:
+ raise RuntimeError(f"Error receiving chunk from stream: {e}")
+
+ try:
+ chunk = chunk.replace("data:", "")
+ chunk = chunk.strip()
+ if chunk == "[DONE]":
+ raise StopAsyncIteration
+ if len(chunk) > 0:
+ json_chunk = json.loads(chunk)
+ return self.chunk_parser(chunk=json_chunk)
+ else:
+ return GenericStreamingChunk(
+ text="",
+ is_finished=False,
+ finish_reason="",
+ usage=None,
+ index=0,
+ tool_use=None,
+ )
+ except StopAsyncIteration:
+ raise StopAsyncIteration
+ except ValueError as e:
+ raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
diff --git a/litellm/main.py b/litellm/main.py
index 82d19a976e..ee2ea36263 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -95,6 +95,7 @@ from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.chat import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
+from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion
from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion
from .llms.prompt_templates.factory import (
@@ -161,6 +162,7 @@ from litellm.utils import (
####### ENVIRONMENT VARIABLES ###################
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
+openai_o1_chat_completions = OpenAIO1ChatCompletion()
openai_audio_transcriptions = OpenAIAudioTranscription()
databricks_chat_completions = DatabricksChatCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
@@ -1366,25 +1368,46 @@ def completion(
## COMPLETION CALL
try:
- response = openai_chat_completions.completion(
- model=model,
- messages=messages,
- headers=headers,
- model_response=model_response,
- print_verbose=print_verbose,
- api_key=api_key,
- api_base=api_base,
- acompletion=acompletion,
- logging_obj=logging,
- optional_params=optional_params,
- litellm_params=litellm_params,
- logger_fn=logger_fn,
- timeout=timeout, # type: ignore
- custom_prompt_dict=custom_prompt_dict,
- client=client, # pass AsyncOpenAI, OpenAI client
- organization=organization,
- custom_llm_provider=custom_llm_provider,
- )
+ if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
+ response = openai_o1_chat_completions.completion(
+ model=model,
+ messages=messages,
+ headers=headers,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ api_key=api_key,
+ api_base=api_base,
+ acompletion=acompletion,
+ logging_obj=logging,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ timeout=timeout, # type: ignore
+ custom_prompt_dict=custom_prompt_dict,
+ client=client, # pass AsyncOpenAI, OpenAI client
+ organization=organization,
+ custom_llm_provider=custom_llm_provider,
+ )
+ else:
+ response = openai_chat_completions.completion(
+ model=model,
+ messages=messages,
+ headers=headers,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ api_key=api_key,
+ api_base=api_base,
+ acompletion=acompletion,
+ logging_obj=logging,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ timeout=timeout, # type: ignore
+ custom_prompt_dict=custom_prompt_dict,
+ client=client, # pass AsyncOpenAI, OpenAI client
+ organization=organization,
+ custom_llm_provider=custom_llm_provider,
+ )
except Exception as e:
## LOGGING - log the original exception returned
logging.post_call(
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index 4533bf9114..3502a786b6 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -3,7 +3,6 @@ model_list:
litellm_params:
model: anthropic.claude-3-sonnet-20240229-v1:0
api_base: https://exampleopenaiendpoint-production.up.railway.app
- # aws_session_token: "IQoJb3JpZ2luX2VjELj//////////wEaCXVzLXdlc3QtMiJHMEUCIQDatCRVkIZERLcrR6P7Qd1vNfZ8r8xB/LUeaVaTW/lBTwIgAgmHSBe41d65GVRKSkpgVonjsCmOmAS7s/yklM9NsZcq3AEI4P//////////ARABGgw4ODg2MDIyMjM0MjgiDJrio0/CHYEfyt5EqyqwAfyWO4t3bFVWAOIwTyZ1N6lszeJKfMNus2hzVc+r73hia2Anv88uwPxNg2uqnXQNJumEo0DcBt30ZwOw03Isboy0d5l05h8gjb4nl9feyeKmKAnRdcqElrEWtCC1Qcefv78jQv53AbUipH1ssa5NPvptqZZpZYDPMlBEnV3YdvJJiuE23u2yOkCt+EoUJLaOYjOryoRyrSfbWB+JaUsB68R3rNTHzReeN3Nob/9Ic4HrMMmzmLcGOpgBZxclO4w8Z7i6TcVqbCwDOskxuR6bZaiFxKFG+9tDrWS7jaQKpq/YP9HUT0YwYpZplaBEEZR5sbIndg5yb4dRZrSHplblqKz8XLaUf5tuuyRJmwr96PTpw/dyEVk9gicFX6JfLBEv0v5rN2Z0JMFLdfIP4kC1U2PjcPOWoglWO3fLmJ4Lol2a3c5XDSMwMxjcJXq+c8Ue1v0="
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
- model_name: gemini-vision
@@ -17,4 +16,34 @@ model_list:
litellm_params:
model: gpt-3.5-turbo
api_base: https://exampleopenaiendpoint-production.up.railway.app
-
\ No newline at end of file
+ - model_name: o1-preview
+ litellm_params:
+ model: o1-preview
+
+litellm_settings:
+ drop_params: True
+ json_logs: True
+ store_audit_logs: True
+ log_raw_request_response: True
+ return_response_headers: True
+ num_retries: 5
+ request_timeout: 200
+ callbacks: ["custom_callbacks.proxy_handler_instance"]
+
+guardrails:
+ - guardrail_name: "presidio-pre-guard"
+ litellm_params:
+ guardrail: presidio # supported values: "aporia", "bedrock", "lakera", "presidio"
+ mode: "logging_only"
+ mock_redacted_text: {
+ "text": "My name is , who are you? Say my name in your response",
+ "items": [
+ {
+ "start": 11,
+ "end": 19,
+ "entity_type": "PERSON",
+ "text": "",
+ "operator": "replace",
+ }
+ ],
+ }
\ No newline at end of file
diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py
index b2e63c2564..36fe8ce733 100644
--- a/litellm/proxy/auth/user_api_key_auth.py
+++ b/litellm/proxy/auth/user_api_key_auth.py
@@ -446,7 +446,6 @@ async def user_api_key_auth(
and request.headers.get(key=header_key) is not None # type: ignore
):
api_key = request.headers.get(key=header_key) # type: ignore
-
if master_key is None:
if isinstance(api_key, str):
return UserAPIKeyAuth(
diff --git a/litellm/proxy/custom_callbacks.py b/litellm/proxy/custom_callbacks.py
index 1516bfd240..445fea1d23 100644
--- a/litellm/proxy/custom_callbacks.py
+++ b/litellm/proxy/custom_callbacks.py
@@ -2,9 +2,19 @@ from litellm.integrations.custom_logger import CustomLogger
class MyCustomHandler(CustomLogger):
- async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
- # print("Call failed")
- pass
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ # input_tokens = response_obj.get("usage", {}).get("prompt_tokens", 0)
+ # output_tokens = response_obj.get("usage", {}).get("completion_tokens", 0)
+ input_tokens = (
+ response_obj.usage.prompt_tokens
+ if hasattr(response_obj.usage, "prompt_tokens")
+ else 0
+ )
+ output_tokens = (
+ response_obj.usage.completion_tokens
+ if hasattr(response_obj.usage, "completion_tokens")
+ else 0
+ )
proxy_handler_instance = MyCustomHandler()
diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py
index a18d8db0e9..c28cf5ec9a 100644
--- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py
+++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py
@@ -218,7 +218,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
response = await self.async_handler.post(
url=prepared_request.url,
json=request_data, # type: ignore
- headers=dict(prepared_request.headers),
+ headers=prepared_request.headers, # type: ignore
)
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
if response.status_code == 200:
@@ -254,7 +254,6 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
- from litellm.types.guardrails import GuardrailEventHooks
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
diff --git a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py
index d84966fd7e..d15a4a7d54 100644
--- a/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py
+++ b/litellm/proxy/guardrails/guardrail_hooks/lakera_ai.py
@@ -189,6 +189,7 @@ class lakeraAI_Moderation(CustomGuardrail):
# Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already)
# Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked.
# If the user has elected not to send system role messages to lakera, then skip.
+
if system_message is not None:
if not litellm.add_function_to_prompt:
content = system_message.get("content")
diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py
index eb09e5203a..e1b0d7cad0 100644
--- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py
+++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py
@@ -19,6 +19,7 @@ from fastapi import HTTPException
from pydantic import BaseModel
import litellm # noqa: E401
+from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
@@ -58,7 +59,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
self.pii_tokens: dict = (
{}
) # mapping of PII token to original text - only used with Presidio `replace` operation
-
self.mock_redacted_text = mock_redacted_text
self.output_parse_pii = output_parse_pii or False
if mock_testing is True: # for testing purposes only
@@ -92,8 +92,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
presidio_anonymizer_api_base: Optional[str] = None,
):
self.presidio_analyzer_api_base: Optional[str] = (
- presidio_analyzer_api_base
- or litellm.get_secret("PRESIDIO_ANALYZER_API_BASE", None)
+ presidio_analyzer_api_base or get_secret("PRESIDIO_ANALYZER_API_BASE", None) # type: ignore
)
self.presidio_anonymizer_api_base: Optional[
str
@@ -198,12 +197,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
else:
raise Exception(f"Invalid anonymizer response: {redacted_text}")
except Exception as e:
- verbose_proxy_logger.error(
- "litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format(
- str(e)
- )
- )
- verbose_proxy_logger.debug(traceback.format_exc())
raise e
async def async_pre_call_hook(
@@ -254,9 +247,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
)
return data
except Exception as e:
- verbose_proxy_logger.info(
- f"An error occurred -",
- )
raise e
async def async_logging_hook(
@@ -300,9 +290,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
)
kwargs["messages"] = messages
- return kwargs, responses
+ return kwargs, result
- async def async_post_call_success_hook(
+ async def async_post_call_success_hook( # type: ignore
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
@@ -314,7 +304,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
verbose_proxy_logger.debug(
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
)
- if self.output_parse_pii == False:
+ if self.output_parse_pii is False:
return response
if isinstance(response, ModelResponse) and not isinstance(
diff --git a/litellm/proxy/guardrails/init_guardrails.py b/litellm/proxy/guardrails/init_guardrails.py
index cff9fca056..c463009904 100644
--- a/litellm/proxy/guardrails/init_guardrails.py
+++ b/litellm/proxy/guardrails/init_guardrails.py
@@ -1,10 +1,11 @@
import importlib
import traceback
-from typing import Dict, List, Literal
+from typing import Dict, List, Literal, Optional
from pydantic import BaseModel, RootModel
import litellm
+from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
@@ -16,7 +17,6 @@ from litellm.types.guardrails import (
GuardrailItemSpec,
LakeraCategoryThresholds,
LitellmParams,
- guardrailConfig,
)
all_guardrails: List[GuardrailItem] = []
@@ -98,18 +98,13 @@ def init_guardrails_v2(
# Init litellm params for guardrail
litellm_params_data = guardrail["litellm_params"]
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
- litellm_params = LitellmParams(
- guardrail=litellm_params_data["guardrail"],
- mode=litellm_params_data["mode"],
- api_key=litellm_params_data.get("api_key"),
- api_base=litellm_params_data.get("api_base"),
- guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"),
- guardrailVersion=litellm_params_data.get("guardrailVersion"),
- output_parse_pii=litellm_params_data.get("output_parse_pii"),
- presidio_ad_hoc_recognizers=litellm_params_data.get(
- "presidio_ad_hoc_recognizers"
- ),
- )
+
+ _litellm_params_kwargs = {
+ k: litellm_params_data[k] if k in litellm_params_data else None
+ for k in LitellmParams.__annotations__.keys()
+ }
+
+ litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore
if (
"category_thresholds" in litellm_params_data
@@ -122,15 +117,11 @@ def init_guardrails_v2(
if litellm_params["api_key"]:
if litellm_params["api_key"].startswith("os.environ/"):
- litellm_params["api_key"] = litellm.get_secret(
- litellm_params["api_key"]
- )
+ litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore
if litellm_params["api_base"]:
if litellm_params["api_base"].startswith("os.environ/"):
- litellm_params["api_base"] = litellm.get_secret(
- litellm_params["api_base"]
- )
+ litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
# Init guardrail CustomLoggerClass
if litellm_params["guardrail"] == "aporia":
@@ -182,6 +173,7 @@ def init_guardrails_v2(
presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers"
],
+ mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
)
if litellm_params["output_parse_pii"] is True:
diff --git a/litellm/proxy/hooks/prompt_injection_detection.py b/litellm/proxy/hooks/prompt_injection_detection.py
index 9c1f1eb959..73e83ec4a1 100644
--- a/litellm/proxy/hooks/prompt_injection_detection.py
+++ b/litellm/proxy/hooks/prompt_injection_detection.py
@@ -167,11 +167,11 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
if self.prompt_injection_params is not None:
# 1. check if heuristics check turned on
- if self.prompt_injection_params.heuristics_check == True:
+ if self.prompt_injection_params.heuristics_check is True:
is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt
)
- if is_prompt_attack == True:
+ if is_prompt_attack is True:
raise HTTPException(
status_code=400,
detail={
@@ -179,14 +179,14 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
},
)
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet
- if self.prompt_injection_params.vector_db_check == True:
+ if self.prompt_injection_params.vector_db_check is True:
pass
else:
is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt
)
- if is_prompt_attack == True:
+ if is_prompt_attack is True:
raise HTTPException(
status_code=400,
detail={
@@ -201,19 +201,18 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
if (
e.status_code == 400
and isinstance(e.detail, dict)
- and "error" in e.detail
+ and "error" in e.detail # type: ignore
and self.prompt_injection_params is not None
and self.prompt_injection_params.reject_as_response
):
return e.detail.get("error")
raise e
except Exception as e:
- verbose_proxy_logger.error(
+ verbose_proxy_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
- verbose_proxy_logger.debug(traceback.format_exc())
async def async_moderation_hook( # type: ignore
self,
diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py
index 859c8aeb84..3801d74659 100644
--- a/litellm/proxy/management_endpoints/internal_user_endpoints.py
+++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py
@@ -195,7 +195,8 @@ async def user_auth(request: Request):
- os.environ["SMTP_PASSWORD"]
- os.environ["SMTP_SENDER_EMAIL"]
"""
- from litellm.proxy.proxy_server import prisma_client, send_email
+ from litellm.proxy.proxy_server import prisma_client
+ from litellm.proxy.utils import send_email
data = await request.json() # type: ignore
user_email = data["user_email"]
@@ -212,7 +213,7 @@ async def user_auth(request: Request):
)
### if so - generate a 24 hr key with that user id
if response is not None:
- user_id = response.user_id
+ user_id = response.user_id # type: ignore
response = await generate_key_helper_fn(
request_type="key",
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id}, # type: ignore
@@ -345,6 +346,7 @@ async def user_info(
for team in teams_1:
team_id_list.append(team.team_id)
+ teams_2: Optional[Any] = None
if user_info is not None:
# *NEW* get all teams in user 'teams' field
teams_2 = await prisma_client.get_data(
@@ -375,7 +377,7 @@ async def user_info(
),
user_api_key_dict=user_api_key_dict,
)
- else:
+ elif caller_user_info is not None:
teams_2 = await prisma_client.get_data(
team_id_list=caller_user_info.teams,
table_name="team",
@@ -395,7 +397,7 @@ async def user_info(
query_type="find_all",
)
- if user_info is None:
+ if user_info is None and keys is not None:
## make sure we still return a total spend ##
spend = 0
for k in keys:
@@ -404,32 +406,35 @@ async def user_info(
## REMOVE HASHED TOKEN INFO before returning ##
returned_keys = []
- for key in keys:
- if (
- key.token == litellm_master_key_hash
- and general_settings.get("disable_master_key_return", False)
- == True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
- ):
- continue
+ if keys is None:
+ pass
+ else:
+ for key in keys:
+ if (
+ key.token == litellm_master_key_hash
+ and general_settings.get("disable_master_key_return", False)
+ == True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
+ ):
+ continue
- try:
- key = key.model_dump() # noqa
- except:
- # if using pydantic v1
- key = key.dict()
- if (
- "team_id" in key
- and key["team_id"] is not None
- and key["team_id"] != "litellm-dashboard"
- ):
- team_info = await prisma_client.get_data(
- team_id=key["team_id"], table_name="team"
- )
- team_alias = getattr(team_info, "team_alias", None)
- key["team_alias"] = team_alias
- else:
- key["team_alias"] = "None"
- returned_keys.append(key)
+ try:
+ key = key.model_dump() # noqa
+ except:
+ # if using pydantic v1
+ key = key.dict()
+ if (
+ "team_id" in key
+ and key["team_id"] is not None
+ and key["team_id"] != "litellm-dashboard"
+ ):
+ team_info = await prisma_client.get_data(
+ team_id=key["team_id"], table_name="team"
+ )
+ team_alias = getattr(team_info, "team_alias", None)
+ key["team_alias"] = team_alias
+ else:
+ key["team_alias"] = "None"
+ returned_keys.append(key)
response_data = {
"user_id": user_id,
@@ -539,6 +544,7 @@ async def user_update(
## ADD USER, IF NEW ##
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
+ response: Optional[Any] = None
if data.user_id is not None and len(data.user_id) > 0:
non_default_values["user_id"] = data.user_id # type: ignore
verbose_proxy_logger.debug("In update user, user_id condition block.")
@@ -573,7 +579,7 @@ async def user_update(
data=non_default_values,
table_name="user",
)
- return response
+ return response # type: ignore
# update based on remaining passed in values
except Exception as e:
verbose_proxy_logger.error(
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index 581db4e32e..de15300712 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -226,7 +226,6 @@ from litellm.proxy.utils import (
hash_token,
log_to_opentelemetry,
reset_budget,
- send_email,
update_spend,
)
from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import (
diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py
index 86f9593627..295546a3d4 100644
--- a/litellm/proxy/spend_tracking/spend_management_endpoints.py
+++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py
@@ -1434,7 +1434,7 @@ async def _get_spend_report_for_time_range(
except Exception as e:
verbose_proxy_logger.error(
"Exception in _get_daily_spend_reports {}".format(str(e))
- ) # noqa
+ )
@router.post(
@@ -1703,26 +1703,26 @@ async def view_spend_logs(
result: dict = {}
for record in response:
dt_object = datetime.strptime(
- str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ"
+ str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" # type: ignore
) # type: ignore
date = dt_object.date()
if date not in result:
result[date] = {"users": {}, "models": {}}
- api_key = record["api_key"]
- user_id = record["user"]
- model = record["model"]
- result[date]["spend"] = (
- result[date].get("spend", 0) + record["_sum"]["spend"]
- )
- result[date][api_key] = (
- result[date].get(api_key, 0) + record["_sum"]["spend"]
- )
- result[date]["users"][user_id] = (
- result[date]["users"].get(user_id, 0) + record["_sum"]["spend"]
- )
- result[date]["models"][model] = (
- result[date]["models"].get(model, 0) + record["_sum"]["spend"]
- )
+ api_key = record["api_key"] # type: ignore
+ user_id = record["user"] # type: ignore
+ model = record["model"] # type: ignore
+ result[date]["spend"] = result[date].get("spend", 0) + record.get(
+ "_sum", {}
+ ).get("spend", 0)
+ result[date][api_key] = result[date].get(api_key, 0) + record.get(
+ "_sum", {}
+ ).get("spend", 0)
+ result[date]["users"][user_id] = result[date]["users"].get(
+ user_id, 0
+ ) + record.get("_sum", {}).get("spend", 0)
+ result[date]["models"][model] = result[date]["models"].get(
+ model, 0
+ ) + record.get("_sum", {}).get("spend", 0)
return_list = []
final_date = None
for k, v in sorted(result.items()):
@@ -1784,7 +1784,7 @@ async def view_spend_logs(
table_name="spend", query_type="find_all"
)
- return spend_log
+ return spend_logs
return None
@@ -1843,6 +1843,88 @@ async def global_spend_reset():
}
+@router.post(
+ "/global/spend/refresh",
+ tags=["Budget & Spend Tracking"],
+ dependencies=[Depends(user_api_key_auth)],
+ include_in_schema=False,
+)
+async def global_spend_refresh():
+ """
+ ADMIN ONLY / MASTER KEY Only Endpoint
+
+ Globally refresh spend MonthlyGlobalSpend view
+ """
+ from litellm.proxy.proxy_server import prisma_client
+
+ if prisma_client is None:
+ raise ProxyException(
+ message="Prisma Client is not initialized",
+ type="internal_error",
+ param="None",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+ ## RESET GLOBAL SPEND VIEW ###
+ async def is_materialized_global_spend_view() -> bool:
+ """
+ Return True if materialized view exists
+
+ Else False
+ """
+ sql_query = """
+ SELECT relname, relkind
+ FROM pg_class
+ WHERE relname = 'MonthlyGlobalSpend';
+ """
+ try:
+ resp = await prisma_client.db.query_raw(sql_query)
+
+ assert resp[0]["relkind"] == "m"
+ return True
+ except Exception:
+ return False
+
+ view_exists = await is_materialized_global_spend_view()
+
+ if view_exists:
+ # refresh materialized view
+ sql_query = """
+ REFRESH MATERIALIZED VIEW "MonthlyGlobalSpend";
+ """
+ try:
+ from litellm.proxy._types import CommonProxyErrors
+ from litellm.proxy.proxy_server import proxy_logging_obj
+ from litellm.proxy.utils import PrismaClient
+
+ db_url = os.getenv("DATABASE_URL")
+ if db_url is None:
+ raise Exception(CommonProxyErrors.db_not_connected_error.value)
+ new_client = PrismaClient(
+ database_url=db_url,
+ proxy_logging_obj=proxy_logging_obj,
+ http_client={
+ "timeout": 6000,
+ },
+ )
+ await new_client.db.connect()
+ await new_client.db.query_raw(sql_query)
+ verbose_proxy_logger.info("MonthlyGlobalSpend view refreshed")
+ return {
+ "message": "MonthlyGlobalSpend view refreshed",
+ "status": "success",
+ }
+
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "Failed to refresh materialized view - {}".format(str(e))
+ )
+ return {
+ "message": "Failed to refresh materialized view",
+ "status": "failure",
+ }
+
+
async def global_spend_for_internal_user(
api_key: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py
index 8a9a6d707d..f0ee7ea9f7 100644
--- a/litellm/proxy/utils.py
+++ b/litellm/proxy/utils.py
@@ -92,6 +92,7 @@ def safe_deep_copy(data):
if litellm.safe_memory_mode is True:
return data
+ litellm_parent_otel_span: Optional[Any] = None
# Step 1: Remove the litellm_parent_otel_span
litellm_parent_otel_span = None
if isinstance(data, dict):
@@ -101,7 +102,7 @@ def safe_deep_copy(data):
new_data = copy.deepcopy(data)
# Step 2: re-add the litellm_parent_otel_span after doing a deep copy
- if isinstance(data, dict):
+ if isinstance(data, dict) and litellm_parent_otel_span is not None:
if "metadata" in data:
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
return new_data
@@ -468,7 +469,7 @@ class ProxyLogging:
# V1 implementation - backwards compatibility
if callback.event_hook is None:
- if callback.moderation_check == "pre_call":
+ if callback.moderation_check == "pre_call": # type: ignore
return
else:
# Main - V2 Guardrails implementation
@@ -881,7 +882,12 @@ class PrismaClient:
org_list_transactons: dict = {}
spend_log_transactions: List = []
- def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
+ def __init__(
+ self,
+ database_url: str,
+ proxy_logging_obj: ProxyLogging,
+ http_client: Optional[Any] = None,
+ ):
verbose_proxy_logger.debug(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
)
@@ -912,7 +918,10 @@ class PrismaClient:
# Now you can import the Prisma Client
from prisma import Prisma # type: ignore
verbose_proxy_logger.debug("Connecting Prisma Client to DB..")
- self.db = Prisma() # Client to connect to Prisma db
+ if http_client is not None:
+ self.db = Prisma(http=http_client)
+ else:
+ self.db = Prisma() # Client to connect to Prisma db
verbose_proxy_logger.debug("Success - Connected Prisma Client to DB")
def hash_token(self, token: str):
@@ -987,7 +996,7 @@ class PrismaClient:
return
else:
## check if required view exists ##
- if required_view not in ret[0]["view_names"]:
+ if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
await self.health_check() # make sure we can connect to db
await self.db.execute_raw(
"""
@@ -1009,7 +1018,9 @@ class PrismaClient:
else:
# don't block execution if these views are missing
# Convert lists to sets for efficient difference calculation
- ret_view_names_set = set(ret[0]["view_names"])
+ ret_view_names_set = (
+ set(ret[0]["view_names"]) if ret[0]["view_names"] else set()
+ )
expected_views_set = set(expected_views)
# Find missing views
missing_views = expected_views_set - ret_view_names_set
@@ -1291,13 +1302,13 @@ class PrismaClient:
verbose_proxy_logger.debug(
f"PrismaClient: get_data - args_passed_in: {args_passed_in}"
)
+ hashed_token: Optional[str] = None
try:
response: Any = None
if (token is not None and table_name is None) or (
table_name is not None and table_name == "key"
):
# check if plain text or hash
- hashed_token = None
if token is not None:
if isinstance(token, str):
hashed_token = token
@@ -1306,7 +1317,7 @@ class PrismaClient:
verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}"
)
- if query_type == "find_unique":
+ if query_type == "find_unique" and hashed_token is not None:
if token is None:
raise HTTPException(
status_code=400,
@@ -1706,7 +1717,7 @@ class PrismaClient:
updated_data = v
updated_data = json.dumps(updated_data)
updated_table_row = self.db.litellm_config.upsert(
- where={"param_name": k},
+ where={"param_name": k}, # type: ignore
data={
"create": {"param_name": k, "param_value": updated_data}, # type: ignore
"update": {"param_value": updated_data},
@@ -2302,7 +2313,12 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
return instance
except ImportError as e:
# Re-raise the exception with a user-friendly message
- raise ImportError(f"Could not import {instance_name} from {module_name}") from e
+ if instance_name and module_name:
+ raise ImportError(
+ f"Could not import {instance_name} from {module_name}"
+ ) from e
+ else:
+ raise e
except Exception as e:
raise e
@@ -2377,12 +2393,12 @@ async def send_email(receiver_email, subject, html):
try:
# Establish a secure connection with the SMTP server
- with smtplib.SMTP(smtp_host, smtp_port) as server:
+ with smtplib.SMTP(smtp_host, smtp_port) as server: # type: ignore
if os.getenv("SMTP_TLS", "True") != "False":
server.starttls()
# Login to your email account
- server.login(smtp_username, smtp_password)
+ server.login(smtp_username, smtp_password) # type: ignore
# Send the email
server.send_message(email_message)
@@ -2945,7 +2961,7 @@ async def update_spend(
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
- await asyncio.sleep(2**i) # Exponential backoff
+ await asyncio.sleep(2**i) # type: ignore
except Exception as e:
import traceback
diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py
index 1debf81717..c6d5ffc342 100644
--- a/litellm/tests/test_streaming.py
+++ b/litellm/tests/test_streaming.py
@@ -2110,7 +2110,6 @@ async def test_hf_completion_tgi_stream():
def test_openai_chat_completion_call():
litellm.set_verbose = False
litellm.return_response_headers = True
- print(f"making openai chat completion call")
response = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
assert isinstance(
response._hidden_params["additional_headers"][
@@ -2318,6 +2317,57 @@ def test_together_ai_completion_call_mistral():
pass
+# # test on together ai completion call - starcoder
+@pytest.mark.parametrize("sync_mode", [True, False])
+@pytest.mark.asyncio
+async def test_openai_o1_completion_call_streaming(sync_mode):
+ try:
+ litellm.set_verbose = False
+ if sync_mode:
+ response = completion(
+ model="o1-preview",
+ messages=messages,
+ stream=True,
+ )
+ complete_response = ""
+ print(f"returned response object: {response}")
+ has_finish_reason = False
+ for idx, chunk in enumerate(response):
+ chunk, finished = streaming_format_tests(idx, chunk)
+ has_finish_reason = finished
+ if finished:
+ break
+ complete_response += chunk
+ if has_finish_reason is False:
+ raise Exception("Finish reason not set for last chunk")
+ if complete_response == "":
+ raise Exception("Empty response received")
+ else:
+ response = await acompletion(
+ model="o1-preview",
+ messages=messages,
+ stream=True,
+ )
+ complete_response = ""
+ print(f"returned response object: {response}")
+ has_finish_reason = False
+ idx = 0
+ async for chunk in response:
+ chunk, finished = streaming_format_tests(idx, chunk)
+ has_finish_reason = finished
+ if finished:
+ break
+ complete_response += chunk
+ idx += 1
+ if has_finish_reason is False:
+ raise Exception("Finish reason not set for last chunk")
+ if complete_response == "":
+ raise Exception("Empty response received")
+ print(f"complete response: {complete_response}")
+ except Exception:
+ pytest.fail(f"error occurred: {traceback.format_exc()}")
+
+
def test_together_ai_completion_call_starcoder_bad_key():
try:
api_key = "bad-key"
diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py
index cb70de5052..57be6b0c42 100644
--- a/litellm/types/guardrails.py
+++ b/litellm/types/guardrails.py
@@ -71,7 +71,7 @@ class LakeraCategoryThresholds(TypedDict, total=False):
jailbreak: float
-class LitellmParams(TypedDict, total=False):
+class LitellmParams(TypedDict):
guardrail: str
mode: str
api_key: str
@@ -87,6 +87,7 @@ class LitellmParams(TypedDict, total=False):
# Presidio params
output_parse_pii: Optional[bool]
presidio_ad_hoc_recognizers: Optional[str]
+ mock_redacted_text: Optional[dict]
class Guardrail(TypedDict):
diff --git a/litellm/utils.py b/litellm/utils.py
index 280691c8a4..58ef4d49f1 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -120,11 +120,26 @@ with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json")
# Convert to str (if necessary)
claude_json_str = json.dumps(json_data)
import importlib.metadata
+from concurrent.futures import ThreadPoolExecutor
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+ cast,
+ get_args,
+)
from openai import OpenAIError as OriginalError
from ._logging import verbose_logger
-from .caching import QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
+from .caching import Cache, QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
from .exceptions import (
APIConnectionError,
APIError,
@@ -150,31 +165,6 @@ from .types.llms.openai import (
)
from .types.router import LiteLLM_Params
-try:
- from .proxy.enterprise.enterprise_callbacks.generic_api_callback import (
- GenericAPILogger,
- )
-except Exception as e:
- verbose_logger.debug(f"Exception import enterprise features {str(e)}")
-
-from concurrent.futures import ThreadPoolExecutor
-from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- List,
- Literal,
- Optional,
- Tuple,
- Type,
- Union,
- cast,
- get_args,
-)
-
-from .caching import Cache
-
####### ENVIRONMENT VARIABLES ####################
# Adjust to your specific application needs / system capabilities.
MAX_THREADS = 100
diff --git a/tests/llm_translation/Readme.md b/tests/llm_translation/Readme.md
new file mode 100644
index 0000000000..174c81b4e6
--- /dev/null
+++ b/tests/llm_translation/Readme.md
@@ -0,0 +1 @@
+More tests under `litellm/litellm/tests/*`.
\ No newline at end of file
diff --git a/tests/llm_translation/test_databricks.py b/tests/llm_translation/test_databricks.py
new file mode 100644
index 0000000000..067b188ed2
--- /dev/null
+++ b/tests/llm_translation/test_databricks.py
@@ -0,0 +1,502 @@
+import asyncio
+import httpx
+import json
+import pytest
+import sys
+from typing import Any, Dict, List
+from unittest.mock import MagicMock, Mock, patch
+
+import litellm
+from litellm.exceptions import BadRequestError, InternalServerError
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.utils import CustomStreamWrapper
+
+
+def mock_chat_response() -> Dict[str, Any]:
+ return {
+ "id": "chatcmpl_3f78f09a-489c-4b8d-a587-f162c7497891",
+ "object": "chat.completion",
+ "created": 1726285449,
+ "model": "dbrx-instruct-071224",
+ "choices": [
+ {
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": "Hello! I'm an AI assistant. I'm doing well. How can I help?",
+ "function_call": None,
+ "tool_calls": None,
+ },
+ "finish_reason": "stop",
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 230,
+ "completion_tokens": 38,
+ "total_tokens": 268,
+ "completion_tokens_details": None,
+ },
+ "system_fingerprint": None,
+ }
+
+
+def mock_chat_streaming_response_chunks() -> List[str]:
+ return [
+ json.dumps(
+ {
+ "id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
+ "object": "chat.completion.chunk",
+ "created": 1726469651,
+ "model": "dbrx-instruct-071224",
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"role": "assistant", "content": "Hello"},
+ "finish_reason": None,
+ "logprobs": None,
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 230,
+ "completion_tokens": 1,
+ "total_tokens": 231,
+ },
+ }
+ ),
+ json.dumps(
+ {
+ "id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
+ "object": "chat.completion.chunk",
+ "created": 1726469651,
+ "model": "dbrx-instruct-071224",
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"content": " world"},
+ "finish_reason": None,
+ "logprobs": None,
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 230,
+ "completion_tokens": 1,
+ "total_tokens": 231,
+ },
+ }
+ ),
+ json.dumps(
+ {
+ "id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
+ "object": "chat.completion.chunk",
+ "created": 1726469651,
+ "model": "dbrx-instruct-071224",
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"content": "!"},
+ "finish_reason": "stop",
+ "logprobs": None,
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 230,
+ "completion_tokens": 1,
+ "total_tokens": 231,
+ },
+ }
+ ),
+ ]
+
+
+def mock_chat_streaming_response_chunks_bytes() -> List[bytes]:
+ string_chunks = mock_chat_streaming_response_chunks()
+ bytes_chunks = [chunk.encode("utf-8") + b"\n" for chunk in string_chunks]
+ # Simulate the end of the stream
+ bytes_chunks.append(b"")
+ return bytes_chunks
+
+
+def mock_http_handler_chat_streaming_response() -> MagicMock:
+ mock_stream_chunks = mock_chat_streaming_response_chunks()
+
+ def mock_iter_lines():
+ for chunk in mock_stream_chunks:
+ for line in chunk.splitlines():
+ yield line
+
+ mock_response = MagicMock()
+ mock_response.iter_lines.side_effect = mock_iter_lines
+ mock_response.status_code = 200
+
+ return mock_response
+
+
+def mock_http_handler_chat_async_streaming_response() -> MagicMock:
+ mock_stream_chunks = mock_chat_streaming_response_chunks()
+
+ async def mock_iter_lines():
+ for chunk in mock_stream_chunks:
+ for line in chunk.splitlines():
+ yield line
+
+ mock_response = MagicMock()
+ mock_response.aiter_lines.return_value = mock_iter_lines()
+ mock_response.status_code = 200
+
+ return mock_response
+
+
+def mock_databricks_client_chat_streaming_response() -> MagicMock:
+ mock_stream_chunks = mock_chat_streaming_response_chunks_bytes()
+
+ def mock_read_from_stream(size=-1):
+ if mock_stream_chunks:
+ return mock_stream_chunks.pop(0)
+ return b""
+
+ mock_response = MagicMock()
+ streaming_response_mock = MagicMock()
+ streaming_response_iterator_mock = MagicMock()
+ # Mock the __getitem__("content") method to return the streaming response
+ mock_response.__getitem__.return_value = streaming_response_mock
+ # Mock the streaming response __enter__ method to return the streaming response iterator
+ streaming_response_mock.__enter__.return_value = streaming_response_iterator_mock
+
+ streaming_response_iterator_mock.read1.side_effect = mock_read_from_stream
+ streaming_response_iterator_mock.closed = False
+
+ return mock_response
+
+
+def mock_embedding_response() -> Dict[str, Any]:
+ return {
+ "object": "list",
+ "model": "bge-large-en-v1.5",
+ "data": [
+ {
+ "index": 0,
+ "object": "embedding",
+ "embedding": [
+ 0.06768798828125,
+ -0.01291656494140625,
+ -0.0501708984375,
+ 0.0245361328125,
+ -0.030364990234375,
+ ],
+ }
+ ],
+ "usage": {
+ "prompt_tokens": 8,
+ "total_tokens": 8,
+ "completion_tokens": 0,
+ "completion_tokens_details": None,
+ },
+ }
+
+
+@pytest.mark.parametrize("set_base", [True, False])
+def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
+ if set_base:
+ monkeypatch.setenv(
+ "DATABRICKS_API_BASE",
+ "https://my.workspace.cloud.databricks.com/serving-endpoints",
+ )
+ monkeypatch.delenv(
+ "DATABRICKS_API_KEY",
+ )
+ err_msg = "A call is being made to LLM Provider but no key is set"
+ else:
+ monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey")
+ monkeypatch.delenv("DATABRICKS_API_BASE")
+ err_msg = "A call is being made to LLM Provider but no api base is set"
+
+ with pytest.raises(BadRequestError) as exc:
+ litellm.completion(
+ model="databricks/dbrx-instruct-071224",
+ messages={"role": "user", "content": "How are you?"},
+ )
+ assert err_msg in str(exc)
+
+ with pytest.raises(BadRequestError) as exc:
+ litellm.embedding(
+ model="databricks/bge-12312",
+ input=["Hello", "World"],
+ )
+ assert err_msg in str(exc)
+
+
+def test_completions_with_sync_http_handler(monkeypatch):
+ base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
+ api_key = "dapimykey"
+ monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
+ monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
+
+ sync_handler = HTTPHandler()
+ mock_response = Mock(spec=httpx.Response)
+ mock_response.status_code = 200
+ mock_response.json.return_value = mock_chat_response()
+
+ expected_response_json = {
+ **mock_chat_response(),
+ **{
+ "model": "databricks/dbrx-instruct-071224",
+ },
+ }
+
+ messages = [{"role": "user", "content": "How are you?"}]
+
+ with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
+ response = litellm.completion(
+ model="databricks/dbrx-instruct-071224",
+ messages=messages,
+ client=sync_handler,
+ temperature=0.5,
+ extraparam="testpassingextraparam",
+ )
+ assert response.to_dict() == expected_response_json
+
+ mock_post.assert_called_once_with(
+ f"{base_url}/chat/completions",
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ },
+ data=json.dumps(
+ {
+ "model": "dbrx-instruct-071224",
+ "messages": messages,
+ "temperature": 0.5,
+ "extraparam": "testpassingextraparam",
+ "stream": False,
+ }
+ ),
+ )
+
+
+def test_completions_with_async_http_handler(monkeypatch):
+ base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
+ api_key = "dapimykey"
+ monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
+ monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
+
+ async_handler = AsyncHTTPHandler()
+ mock_response = Mock(spec=httpx.Response)
+ mock_response.status_code = 200
+ mock_response.json.return_value = mock_chat_response()
+
+ expected_response_json = {
+ **mock_chat_response(),
+ **{
+ "model": "databricks/dbrx-instruct-071224",
+ },
+ }
+
+ messages = [{"role": "user", "content": "How are you?"}]
+
+ with patch.object(
+ AsyncHTTPHandler, "post", return_value=mock_response
+ ) as mock_post:
+ response = asyncio.run(
+ litellm.acompletion(
+ model="databricks/dbrx-instruct-071224",
+ messages=messages,
+ client=async_handler,
+ temperature=0.5,
+ extraparam="testpassingextraparam",
+ )
+ )
+ assert response.to_dict() == expected_response_json
+
+ mock_post.assert_called_once_with(
+ f"{base_url}/chat/completions",
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ },
+ data=json.dumps(
+ {
+ "model": "dbrx-instruct-071224",
+ "messages": messages,
+ "temperature": 0.5,
+ "extraparam": "testpassingextraparam",
+ "stream": False,
+ }
+ ),
+ )
+
+
+def test_completions_streaming_with_sync_http_handler(monkeypatch):
+ base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
+ api_key = "dapimykey"
+ monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
+ monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
+
+ sync_handler = HTTPHandler()
+
+ messages = [{"role": "user", "content": "How are you?"}]
+ mock_response = mock_http_handler_chat_streaming_response()
+
+ with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
+ response_stream: CustomStreamWrapper = litellm.completion(
+ model="databricks/dbrx-instruct-071224",
+ messages=messages,
+ client=sync_handler,
+ temperature=0.5,
+ extraparam="testpassingextraparam",
+ stream=True,
+ )
+ response = list(response_stream)
+ assert "dbrx-instruct-071224" in str(response)
+ assert "chatcmpl" in str(response)
+ assert len(response) == 4
+
+ mock_post.assert_called_once_with(
+ f"{base_url}/chat/completions",
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ },
+ data=json.dumps(
+ {
+ "model": "dbrx-instruct-071224",
+ "messages": messages,
+ "temperature": 0.5,
+ "stream": True,
+ "extraparam": "testpassingextraparam",
+ }
+ ),
+ stream=True,
+ )
+
+
+def test_completions_streaming_with_async_http_handler(monkeypatch):
+ base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
+ api_key = "dapimykey"
+ monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
+ monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
+
+ async_handler = AsyncHTTPHandler()
+
+ messages = [{"role": "user", "content": "How are you?"}]
+ mock_response = mock_http_handler_chat_async_streaming_response()
+
+ with patch.object(
+ AsyncHTTPHandler, "post", return_value=mock_response
+ ) as mock_post:
+ response_stream: CustomStreamWrapper = asyncio.run(
+ litellm.acompletion(
+ model="databricks/dbrx-instruct-071224",
+ messages=messages,
+ client=async_handler,
+ temperature=0.5,
+ extraparam="testpassingextraparam",
+ stream=True,
+ )
+ )
+
+ # Use async list gathering for the response
+ async def gather_responses():
+ return [item async for item in response_stream]
+
+ response = asyncio.run(gather_responses())
+ assert "dbrx-instruct-071224" in str(response)
+ assert "chatcmpl" in str(response)
+ assert len(response) == 4
+
+ mock_post.assert_called_once_with(
+ f"{base_url}/chat/completions",
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ },
+ data=json.dumps(
+ {
+ "model": "dbrx-instruct-071224",
+ "messages": messages,
+ "temperature": 0.5,
+ "stream": True,
+ "extraparam": "testpassingextraparam",
+ }
+ ),
+ stream=True,
+ )
+
+
+def test_embeddings_with_sync_http_handler(monkeypatch):
+ base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
+ api_key = "dapimykey"
+ monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
+ monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
+
+ sync_handler = HTTPHandler()
+ mock_response = Mock(spec=httpx.Response)
+ mock_response.status_code = 200
+ mock_response.json.return_value = mock_embedding_response()
+
+ inputs = ["Hello", "World"]
+
+ with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
+ response = litellm.embedding(
+ model="databricks/bge-large-en-v1.5",
+ input=inputs,
+ client=sync_handler,
+ extraparam="testpassingextraparam",
+ )
+ assert response.to_dict() == mock_embedding_response()
+
+ mock_post.assert_called_once_with(
+ f"{base_url}/embeddings",
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ },
+ data=json.dumps(
+ {
+ "model": "bge-large-en-v1.5",
+ "input": inputs,
+ "extraparam": "testpassingextraparam",
+ }
+ ),
+ )
+
+
+def test_embeddings_with_async_http_handler(monkeypatch):
+ base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
+ api_key = "dapimykey"
+ monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
+ monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
+
+ async_handler = AsyncHTTPHandler()
+ mock_response = Mock(spec=httpx.Response)
+ mock_response.status_code = 200
+ mock_response.json.return_value = mock_embedding_response()
+
+ inputs = ["Hello", "World"]
+
+ with patch.object(
+ AsyncHTTPHandler, "post", return_value=mock_response
+ ) as mock_post:
+ response = asyncio.run(
+ litellm.aembedding(
+ model="databricks/bge-large-en-v1.5",
+ input=inputs,
+ client=async_handler,
+ extraparam="testpassingextraparam",
+ )
+ )
+ assert response.to_dict() == mock_embedding_response()
+
+ mock_post.assert_called_once_with(
+ f"{base_url}/embeddings",
+ headers={
+ "Authorization": f"Bearer {api_key}",
+ "Content-Type": "application/json",
+ },
+ data=json.dumps(
+ {
+ "model": "bge-large-en-v1.5",
+ "input": inputs,
+ "extraparam": "testpassingextraparam",
+ }
+ ),
+ )