LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) (#5731)

* LiteLLM Minor Fixes & Improvements (09/16/2024)  (#5723)

* coverage (#5713)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Move (#5714)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix(litellm_logging.py): fix logging client re-init (#5710)

Fixes https://github.com/BerriAI/litellm/issues/5695

* fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config

Fixes https://github.com/BerriAI/litellm/issues/5682

* feat(o1_handler.py): fake streaming for openai o1 models

Fixes https://github.com/BerriAI/litellm/issues/5694

* docs: deprecated traceloop integration in favor of native otel (#5249)

* fix: fix linting errors

* fix: fix linting errors

* fix(main.py): fix o1 import

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>

* feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730)

* feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view

Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it

* fix(custom_logger.py): reset calltype

* fix: fix linting errors

* fix: fix linting error

* fix: fix import

* test(test_databricks.py): fix databricks tests

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-09-17 08:05:52 -07:00 committed by GitHub
parent 1e59395280
commit 234185ec13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1387 additions and 502 deletions

View file

@ -6,10 +6,14 @@ import asyncio
import os import os
# Enter your DATABASE_URL here # Enter your DATABASE_URL here
os.environ["DATABASE_URL"] = "postgresql://xxxxxxx"
from prisma import Prisma from prisma import Prisma
db = Prisma() db = Prisma(
http={
"timeout": 60000,
},
)
async def check_view_exists(): async def check_view_exists():
@ -47,25 +51,22 @@ async def check_view_exists():
print("LiteLLM_VerificationTokenView Created!") # noqa print("LiteLLM_VerificationTokenView Created!") # noqa
try: sql_query = """
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""") CREATE MATERIALIZED VIEW IF NOT EXISTS "MonthlyGlobalSpend" AS
print("MonthlyGlobalSpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
SELECT SELECT
DATE("startTime") AS date, DATE_TRUNC('day', "startTime") AS date,
SUM("spend") AS spend SUM("spend") AS spend
FROM FROM
"LiteLLM_SpendLogs" "LiteLLM_SpendLogs"
WHERE WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days') "startTime" >= CURRENT_DATE - INTERVAL '30 days'
GROUP BY GROUP BY
DATE("startTime"); DATE_TRUNC('day', "startTime");
""" """
await db.execute_raw(query=sql_query) # Execute the queries
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpend Created!") # noqa print("MonthlyGlobalSpend Created!") # noqa
try: try:
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""") await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")

View file

@ -0,0 +1,78 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# OpenTelemetry - Tracing LLMs with any observability tool
OpenTelemetry is a CNCF standard for observability. It connects to any observability tool, such as Jaeger, Zipkin, Datadog, New Relic, Traceloop and others.
<Image img={require('../../img/traceloop_dash.png')} />
## Getting Started
Install the OpenTelemetry SDK:
```
pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp
```
Set the environment variables (different providers may require different variables):
<Tabs>
<TabItem value="traceloop" label="Log to Traceloop Cloud">
```shell
OTEL_EXPORTER="otlp_http"
OTEL_ENDPOINT="https://api.traceloop.com"
OTEL_HEADERS="Authorization=Bearer%20<your-api-key>"
```
</TabItem>
<TabItem value="otel-col" label="Log to OTEL HTTP Collector">
```shell
OTEL_EXPORTER="otlp_http"
OTEL_ENDPOINT="http:/0.0.0.0:4317"
```
</TabItem>
<TabItem value="otel-col-grpc" label="Log to OTEL GRPC Collector">
```shell
OTEL_EXPORTER="otlp_grpc"
OTEL_ENDPOINT="http:/0.0.0.0:4317"
```
</TabItem>
</Tabs>
Use just 2 lines of code, to instantly log your LLM responses **across all providers** with OpenTelemetry:
```python
litellm.callbacks = ["otel"]
```
## Redacting Messages, Response Content from OpenTelemetry Logging
### Redact Messages and Responses from all OpenTelemetry Logging
Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to OpenTelemetry, but request metadata will still be logged.
### Redact Messages and Responses from specific OpenTelemetry Logging
In the metadata typically passed for text completion or embedding calls you can set specific keys to mask the messages and responses for this call.
Setting `mask_input` to `True` will mask the input from being logged for this call
Setting `mask_output` to `True` will make the output from being logged for this call.
Be aware that if you are continuing an existing trace, and you set `update_trace_keys` to include either `input` or `output` and you set the corresponding `mask_input` or `mask_output`, then that trace will have its existing input and/or output replaced with a redacted message.
## Support
For any question or issue with the integration you can reach out to the OpenLLMetry maintainers on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).

View file

@ -1,36 +0,0 @@
import Image from '@theme/IdealImage';
# Traceloop (OpenLLMetry) - Tracing LLMs with OpenTelemetry
[Traceloop](https://traceloop.com) is a platform for monitoring and debugging the quality of your LLM outputs.
It provides you with a way to track the performance of your LLM application; rollout changes with confidence; and debug issues in production.
It is based on [OpenTelemetry](https://opentelemetry.io), so it can provide full visibility to your LLM requests, as well vector DB usage, and other infra in your stack.
<Image img={require('../../img/traceloop_dash.png')} />
## Getting Started
Install the Traceloop SDK:
```
pip install traceloop-sdk
```
Use just 2 lines of code, to instantly log your LLM responses with OpenTelemetry:
```python
Traceloop.init(app_name=<YOUR APP NAME>, disable_batch=True)
litellm.success_callback = ["traceloop"]
```
Make sure to properly set a destination to your traces. See [OpenLLMetry docs](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for options.
To get better visualizations on how your code behaves, you may want to annotate specific parts of your LLM chain. See [Traceloop docs on decorators](https://traceloop.com/docs/python-sdk/decorators) for more information.
## Exporting traces to other systems (e.g. Datadog, New Relic, and others)
Since OpenLLMetry uses OpenTelemetry to send data, you can easily export your traces to other systems, such as Datadog, New Relic, and others. See [OpenLLMetry docs on exporters](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for more information.
## Support
For any question or issue with integration you can reach out to the Traceloop team on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).

View file

@ -600,6 +600,52 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
</TabItem> </TabItem>
<TabItem value="traceloop" label="Log to Traceloop Cloud">
#### Quick Start - Log to Traceloop
**Step 1:**
Add the following to your env
```shell
OTEL_EXPORTER="otlp_http"
OTEL_ENDPOINT="https://api.traceloop.com"
OTEL_HEADERS="Authorization=Bearer%20<your-api-key>"
```
**Step 2:** Add `otel` as a callbacks
```shell
litellm_settings:
callbacks: ["otel"]
```
**Step 3**: Start the proxy, make a test request
Start proxy
```shell
litellm --config config.yaml --detailed_debug
```
Test Request
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}'
```
</TabItem>
<TabItem value="otel-col" label="Log to OTEL HTTP Collector"> <TabItem value="otel-col" label="Log to OTEL HTTP Collector">
#### Quick Start - Log to OTEL Collector #### Quick Start - Log to OTEL Collector
@ -694,52 +740,6 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
</TabItem> </TabItem>
<TabItem value="traceloop" label="Log to Traceloop Cloud">
#### Quick Start - Log to Traceloop
**Step 1:** Install the `traceloop-sdk` SDK
```shell
pip install traceloop-sdk==0.21.2
```
**Step 2:** Add `traceloop` as a success_callback
```shell
litellm_settings:
success_callback: ["traceloop"]
environment_variables:
TRACELOOP_API_KEY: "XXXXX"
```
**Step 3**: Start the proxy, make a test request
Start proxy
```shell
litellm --config config.yaml --detailed_debug
```
Test Request
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}'
```
</TabItem>
</Tabs> </Tabs>
** 🎉 Expect to see this trace logged in your OTEL collector** ** 🎉 Expect to see this trace logged in your OTEL collector**

View file

@ -255,6 +255,7 @@ const sidebars = {
type: "category", type: "category",
label: "Logging & Observability", label: "Logging & Observability",
items: [ items: [
"observability/opentelemetry_integration",
"observability/langfuse_integration", "observability/langfuse_integration",
"observability/logfire_integration", "observability/logfire_integration",
"observability/gcs_bucket_integration", "observability/gcs_bucket_integration",
@ -271,7 +272,7 @@ const sidebars = {
"observability/openmeter", "observability/openmeter",
"observability/promptlayer_integration", "observability/promptlayer_integration",
"observability/wandb_integration", "observability/wandb_integration",
"observability/traceloop_integration", "observability/slack_integration",
"observability/athina_integration", "observability/athina_integration",
"observability/lunary_integration", "observability/lunary_integration",
"observability/greenscale_integration", "observability/greenscale_integration",

View file

@ -948,10 +948,10 @@ from .llms.OpenAI.openai import (
AzureAIStudioConfig, AzureAIStudioConfig,
) )
from .llms.mistral.mistral_chat_transformation import MistralConfig from .llms.mistral.mistral_chat_transformation import MistralConfig
from .llms.OpenAI.o1_transformation import ( from .llms.OpenAI.chat.o1_transformation import (
OpenAIO1Config, OpenAIO1Config,
) )
from .llms.OpenAI.gpt_transformation import ( from .llms.OpenAI.chat.gpt_transformation import (
OpenAIGPTConfig, OpenAIGPTConfig,
) )
from .llms.nvidia_nim import NvidiaNimConfig from .llms.nvidia_nim import NvidiaNimConfig

View file

@ -1616,6 +1616,9 @@ Model Info:
:param time_range: A string specifying the time range, e.g., "1d", "7d", "30d" :param time_range: A string specifying the time range, e.g., "1d", "7d", "30d"
""" """
if self.alerting is None or "spend_reports" not in self.alert_types:
return
try: try:
from litellm.proxy.spend_tracking.spend_management_endpoints import ( from litellm.proxy.spend_tracking.spend_management_endpoints import (
_get_spend_report_for_time_range, _get_spend_report_for_time_range,

View file

@ -11,7 +11,7 @@ class CustomGuardrail(CustomLogger):
self, self,
guardrail_name: Optional[str] = None, guardrail_name: Optional[str] = None,
event_hook: Optional[GuardrailEventHooks] = None, event_hook: Optional[GuardrailEventHooks] = None,
**kwargs **kwargs,
): ):
self.guardrail_name = guardrail_name self.guardrail_name = guardrail_name
self.event_hook: Optional[GuardrailEventHooks] = event_hook self.event_hook: Optional[GuardrailEventHooks] = event_hook
@ -28,10 +28,13 @@ class CustomGuardrail(CustomLogger):
requested_guardrails, requested_guardrails,
) )
if self.guardrail_name not in requested_guardrails: if (
self.guardrail_name not in requested_guardrails
and event_type.value != "logging_only"
):
return False return False
if self.event_hook != event_type: if self.event_hook != event_type.value:
return False return False
return True return True

View file

@ -4,6 +4,11 @@ import litellm
class TraceloopLogger: class TraceloopLogger:
"""
WARNING: DEPRECATED
Use the OpenTelemetry standard integration instead
"""
def __init__(self): def __init__(self):
try: try:
from traceloop.sdk.tracing.tracing import TracerWrapper from traceloop.sdk.tracing.tracing import TracerWrapper

View file

@ -90,6 +90,13 @@ from ..integrations.supabase import Supabase
from ..integrations.traceloop import TraceloopLogger from ..integrations.traceloop import TraceloopLogger
from ..integrations.weights_biases import WeightsBiasesLogger from ..integrations.weights_biases import WeightsBiasesLogger
try:
from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import (
GenericAPILogger,
)
except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
_in_memory_loggers: List[Any] = [] _in_memory_loggers: List[Any] = []
### GLOBAL VARIABLES ### ### GLOBAL VARIABLES ###
@ -145,7 +152,41 @@ class ServiceTraceIDCache:
return None return None
import hashlib
class DynamicLoggingCache:
"""
Prevent memory leaks caused by initializing new logging clients on each request.
Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
"""
def __init__(self) -> None:
self.cache = InMemoryCache()
def get_cache_key(self, args: dict) -> str:
args_str = json.dumps(args, sort_keys=True)
cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
return cache_key
def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
key_name = self.get_cache_key(
args={**credentials, "service_name": service_name}
)
response = self.cache.get_cache(key=key_name)
return response
def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
key_name = self.get_cache_key(
args={**credentials, "service_name": service_name}
)
self.cache.set_cache(key=key_name, value=logging_obj)
return None
in_memory_trace_id_cache = ServiceTraceIDCache() in_memory_trace_id_cache = ServiceTraceIDCache()
in_memory_dynamic_logger_cache = DynamicLoggingCache()
class Logging: class Logging:
@ -324,10 +365,10 @@ class Logging:
print_verbose(f"\033[92m{curl_command}\033[0m\n", log_level="DEBUG") print_verbose(f"\033[92m{curl_command}\033[0m\n", log_level="DEBUG")
# log raw request to provider (like LangFuse) -- if opted in. # log raw request to provider (like LangFuse) -- if opted in.
if log_raw_request_response is True: if log_raw_request_response is True:
_litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {}
try: try:
# [Non-blocking Extra Debug Information in metadata] # [Non-blocking Extra Debug Information in metadata]
_litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {}
if ( if (
turn_off_message_logging is not None turn_off_message_logging is not None
and turn_off_message_logging is True and turn_off_message_logging is True
@ -362,7 +403,7 @@ class Logging:
callbacks = litellm.input_callback + self.dynamic_input_callbacks callbacks = litellm.input_callback + self.dynamic_input_callbacks
for callback in callbacks: for callback in callbacks:
try: try:
if callback == "supabase": if callback == "supabase" and supabaseClient is not None:
verbose_logger.debug("reaches supabase for logging!") verbose_logger.debug("reaches supabase for logging!")
model = self.model_call_details["model"] model = self.model_call_details["model"]
messages = self.model_call_details["input"] messages = self.model_call_details["input"]
@ -396,7 +437,9 @@ class Logging:
messages=self.messages, messages=self.messages,
kwargs=self.model_call_details, kwargs=self.model_call_details,
) )
elif callable(callback): # custom logger functions elif (
callable(callback) and customLogger is not None
): # custom logger functions
customLogger.log_input_event( customLogger.log_input_event(
model=self.model, model=self.model,
messages=self.messages, messages=self.messages,
@ -615,7 +658,7 @@ class Logging:
self.model_call_details["litellm_params"]["metadata"][ self.model_call_details["litellm_params"]["metadata"][
"hidden_params" "hidden_params"
] = result._hidden_params ] = getattr(result, "_hidden_params", {})
## STANDARDIZED LOGGING PAYLOAD ## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = ( self.model_call_details["standard_logging_object"] = (
@ -645,6 +688,7 @@ class Logging:
litellm.max_budget litellm.max_budget
and self.stream is False and self.stream is False
and result is not None and result is not None
and isinstance(result, dict)
and "content" in result and "content" in result
): ):
time_diff = (end_time - start_time).total_seconds() time_diff = (end_time - start_time).total_seconds()
@ -652,7 +696,7 @@ class Logging:
litellm._current_cost += litellm.completion_cost( litellm._current_cost += litellm.completion_cost(
model=self.model, model=self.model,
prompt="", prompt="",
completion=result["content"], completion=getattr(result, "content", ""),
total_time=float_diff, total_time=float_diff,
) )
@ -758,7 +802,7 @@ class Logging:
): ):
print_verbose("no-log request, skipping logging") print_verbose("no-log request, skipping logging")
continue continue
if callback == "lite_debugger": if callback == "lite_debugger" and liteDebuggerClient is not None:
print_verbose("reaches lite_debugger for logging!") print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
print_verbose( print_verbose(
@ -774,7 +818,7 @@ class Logging:
call_type=self.call_type, call_type=self.call_type,
stream=self.stream, stream=self.stream,
) )
if callback == "promptlayer": if callback == "promptlayer" and promptLayerLogger is not None:
print_verbose("reaches promptlayer for logging!") print_verbose("reaches promptlayer for logging!")
promptLayerLogger.log_event( promptLayerLogger.log_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
@ -783,7 +827,7 @@ class Logging:
end_time=end_time, end_time=end_time,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if callback == "supabase": if callback == "supabase" and supabaseClient is not None:
print_verbose("reaches supabase for logging!") print_verbose("reaches supabase for logging!")
kwargs = self.model_call_details kwargs = self.model_call_details
@ -811,7 +855,7 @@ class Logging:
), ),
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if callback == "wandb": if callback == "wandb" and weightsBiasesLogger is not None:
print_verbose("reaches wandb for logging!") print_verbose("reaches wandb for logging!")
weightsBiasesLogger.log_event( weightsBiasesLogger.log_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
@ -820,8 +864,7 @@ class Logging:
end_time=end_time, end_time=end_time,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if callback == "logfire": if callback == "logfire" and logfireLogger is not None:
global logfireLogger
verbose_logger.debug("reaches logfire for success logging!") verbose_logger.debug("reaches logfire for success logging!")
kwargs = {} kwargs = {}
for k, v in self.model_call_details.items(): for k, v in self.model_call_details.items():
@ -844,10 +887,10 @@ class Logging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
print_verbose=print_verbose, print_verbose=print_verbose,
level=LogfireLevel.INFO.value, level=LogfireLevel.INFO.value, # type: ignore
) )
if callback == "lunary": if callback == "lunary" and lunaryLogger is not None:
print_verbose("reaches lunary for logging!") print_verbose("reaches lunary for logging!")
model = self.model model = self.model
kwargs = self.model_call_details kwargs = self.model_call_details
@ -882,7 +925,7 @@ class Logging:
run_id=self.litellm_call_id, run_id=self.litellm_call_id,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if callback == "helicone": if callback == "helicone" and heliconeLogger is not None:
print_verbose("reaches helicone for logging!") print_verbose("reaches helicone for logging!")
model = self.model model = self.model
messages = self.model_call_details["input"] messages = self.model_call_details["input"]
@ -924,6 +967,7 @@ class Logging:
else: else:
print_verbose("reaches langfuse for streaming logging!") print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"] result = kwargs["complete_streaming_response"]
temp_langfuse_logger = langFuseLogger temp_langfuse_logger = langFuseLogger
if langFuseLogger is None or ( if langFuseLogger is None or (
( (
@ -941,27 +985,45 @@ class Logging:
and self.langfuse_host != langFuseLogger.langfuse_host and self.langfuse_host != langFuseLogger.langfuse_host
) )
): ):
temp_langfuse_logger = LangFuseLogger( credentials = {
langfuse_public_key=self.langfuse_public_key, "langfuse_public_key": self.langfuse_public_key,
langfuse_secret=self.langfuse_secret, "langfuse_secret": self.langfuse_secret,
langfuse_host=self.langfuse_host, "langfuse_host": self.langfuse_host,
) }
_response = temp_langfuse_logger.log_event( temp_langfuse_logger = (
kwargs=kwargs, in_memory_dynamic_logger_cache.get_cache(
response_obj=result, credentials=credentials, service_name="langfuse"
start_time=start_time,
end_time=end_time,
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if _response is not None and isinstance(_response, dict):
_trace_id = _response.get("trace_id", None)
if _trace_id is not None:
in_memory_trace_id_cache.set_cache(
litellm_call_id=self.litellm_call_id,
service_name="langfuse",
trace_id=_trace_id,
) )
)
if temp_langfuse_logger is None:
temp_langfuse_logger = LangFuseLogger(
langfuse_public_key=self.langfuse_public_key,
langfuse_secret=self.langfuse_secret,
langfuse_host=self.langfuse_host,
)
in_memory_dynamic_logger_cache.set_cache(
credentials=credentials,
service_name="langfuse",
logging_obj=temp_langfuse_logger,
)
if temp_langfuse_logger is not None:
_response = temp_langfuse_logger.log_event(
kwargs=kwargs,
response_obj=result,
start_time=start_time,
end_time=end_time,
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if _response is not None and isinstance(_response, dict):
_trace_id = _response.get("trace_id", None)
if _trace_id is not None:
in_memory_trace_id_cache.set_cache(
litellm_call_id=self.litellm_call_id,
service_name="langfuse",
trace_id=_trace_id,
)
if callback == "generic": if callback == "generic":
global genericAPILogger global genericAPILogger
verbose_logger.debug("reaches langfuse for success logging!") verbose_logger.debug("reaches langfuse for success logging!")
@ -982,7 +1044,7 @@ class Logging:
print_verbose("reaches langfuse for streaming logging!") print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"] result = kwargs["complete_streaming_response"]
if genericAPILogger is None: if genericAPILogger is None:
genericAPILogger = GenericAPILogger() genericAPILogger = GenericAPILogger() # type: ignore
genericAPILogger.log_event( genericAPILogger.log_event(
kwargs=kwargs, kwargs=kwargs,
response_obj=result, response_obj=result,
@ -1022,7 +1084,7 @@ class Logging:
user_id=kwargs.get("user", None), user_id=kwargs.get("user", None),
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if callback == "greenscale": if callback == "greenscale" and greenscaleLogger is not None:
kwargs = {} kwargs = {}
for k, v in self.model_call_details.items(): for k, v in self.model_call_details.items():
if ( if (
@ -1066,7 +1128,7 @@ class Logging:
result = kwargs["complete_streaming_response"] result = kwargs["complete_streaming_response"]
# only add to cache once we have a complete streaming response # only add to cache once we have a complete streaming response
litellm.cache.add_cache(result, **kwargs) litellm.cache.add_cache(result, **kwargs)
if callback == "athina": if callback == "athina" and athinaLogger is not None:
deep_copy = {} deep_copy = {}
for k, v in self.model_call_details.items(): for k, v in self.model_call_details.items():
deep_copy[k] = v deep_copy[k] = v
@ -1224,6 +1286,7 @@ class Logging:
"atranscription", False "atranscription", False
) )
is not True is not True
and customLogger is not None
): # custom logger functions ): # custom logger functions
print_verbose( print_verbose(
f"success callbacks: Running Custom Callback Function" f"success callbacks: Running Custom Callback Function"
@ -1423,9 +1486,8 @@ class Logging:
await litellm.cache.async_add_cache(result, **kwargs) await litellm.cache.async_add_cache(result, **kwargs)
else: else:
litellm.cache.add_cache(result, **kwargs) litellm.cache.add_cache(result, **kwargs)
if callback == "openmeter": if callback == "openmeter" and openMeterLogger is not None:
global openMeterLogger if self.stream is True:
if self.stream == True:
if ( if (
"async_complete_streaming_response" "async_complete_streaming_response"
in self.model_call_details in self.model_call_details
@ -1645,33 +1707,9 @@ class Logging:
) )
for callback in callbacks: for callback in callbacks:
try: try:
if callback == "lite_debugger": if callback == "lite_debugger" and liteDebuggerClient is not None:
print_verbose("reaches lite_debugger for logging!") pass
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}") elif callback == "lunary" and lunaryLogger is not None:
result = {
"model": self.model,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens": prompt_token_calculator(
self.model, messages=self.messages
),
"completion_tokens": 0,
},
}
liteDebuggerClient.log_event(
model=self.model,
messages=self.messages,
end_user=self.model_call_details.get("user", "default"),
response_obj=result,
start_time=start_time,
end_time=end_time,
litellm_call_id=self.litellm_call_id,
print_verbose=print_verbose,
call_type=self.call_type,
stream=self.stream,
)
if callback == "lunary":
print_verbose("reaches lunary for logging error!") print_verbose("reaches lunary for logging error!")
model = self.model model = self.model
@ -1685,6 +1723,7 @@ class Logging:
) )
lunaryLogger.log_event( lunaryLogger.log_event(
kwargs=self.model_call_details,
type=_type, type=_type,
event="error", event="error",
user_id=self.model_call_details.get("user", "default"), user_id=self.model_call_details.get("user", "default"),
@ -1704,22 +1743,11 @@ class Logging:
print_verbose( print_verbose(
f"capture exception not initialized: {capture_exception}" f"capture exception not initialized: {capture_exception}"
) )
elif callback == "supabase": elif callback == "supabase" and supabaseClient is not None:
print_verbose("reaches supabase for logging!") print_verbose("reaches supabase for logging!")
print_verbose(f"supabaseClient: {supabaseClient}") print_verbose(f"supabaseClient: {supabaseClient}")
result = {
"model": model,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens": prompt_token_calculator(
model, messages=self.messages
),
"completion_tokens": 0,
},
}
supabaseClient.log_event( supabaseClient.log_event(
model=self.model, model=self.model if hasattr(self, "model") else "",
messages=self.messages, messages=self.messages,
end_user=self.model_call_details.get("user", "default"), end_user=self.model_call_details.get("user", "default"),
response_obj=result, response_obj=result,
@ -1728,7 +1756,9 @@ class Logging:
litellm_call_id=self.model_call_details["litellm_call_id"], litellm_call_id=self.model_call_details["litellm_call_id"],
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if callable(callback): # custom logger functions if (
callable(callback) and customLogger is not None
): # custom logger functions
customLogger.log_event( customLogger.log_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
response_obj=result, response_obj=result,
@ -1809,13 +1839,13 @@ class Logging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
response_obj=None, response_obj=None,
user_id=kwargs.get("user", None), user_id=self.model_call_details.get("user", None),
print_verbose=print_verbose, print_verbose=print_verbose,
status_message=str(exception), status_message=str(exception),
level="ERROR", level="ERROR",
kwargs=self.model_call_details, kwargs=self.model_call_details,
) )
if callback == "logfire": if callback == "logfire" and logfireLogger is not None:
verbose_logger.debug("reaches logfire for failure logging!") verbose_logger.debug("reaches logfire for failure logging!")
kwargs = {} kwargs = {}
for k, v in self.model_call_details.items(): for k, v in self.model_call_details.items():
@ -1830,7 +1860,7 @@ class Logging:
response_obj=result, response_obj=result,
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
level=LogfireLevel.ERROR.value, level=LogfireLevel.ERROR.value, # type: ignore
print_verbose=print_verbose, print_verbose=print_verbose,
) )
@ -1873,7 +1903,9 @@ class Logging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) # type: ignore ) # type: ignore
if callable(callback): # custom logger functions if (
callable(callback) and customLogger is not None
): # custom logger functions
await customLogger.async_log_event( await customLogger.async_log_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
response_obj=result, response_obj=result,
@ -1966,7 +1998,7 @@ def set_callbacks(callback_list, function_id=None):
) )
sentry_sdk_instance.init( sentry_sdk_instance.init(
dsn=os.environ.get("SENTRY_DSN"), dsn=os.environ.get("SENTRY_DSN"),
traces_sample_rate=float(sentry_trace_rate), traces_sample_rate=float(sentry_trace_rate), # type: ignore
) )
capture_exception = sentry_sdk_instance.capture_exception capture_exception = sentry_sdk_instance.capture_exception
add_breadcrumb = sentry_sdk_instance.add_breadcrumb add_breadcrumb = sentry_sdk_instance.add_breadcrumb
@ -2411,12 +2443,11 @@ def get_standard_logging_object_payload(
saved_cache_cost: Optional[float] = None saved_cache_cost: Optional[float] = None
if cache_hit is True: if cache_hit is True:
import time
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
saved_cache_cost = logging_obj._response_cost_calculator( saved_cache_cost = logging_obj._response_cost_calculator(
result=init_response_obj, cache_hit=False result=init_response_obj, cache_hit=False # type: ignore
) )
## Get model cost information ## ## Get model cost information ##
@ -2473,7 +2504,7 @@ def get_standard_logging_object_payload(
model_id=_model_id, model_id=_model_id,
requester_ip_address=clean_metadata.get("requester_ip_address", None), requester_ip_address=clean_metadata.get("requester_ip_address", None),
messages=kwargs.get("messages"), messages=kwargs.get("messages"),
response=( response=( # type: ignore
response_obj if len(response_obj.keys()) > 0 else init_response_obj response_obj if len(response_obj.keys()) > 0 else init_response_obj
), ),
model_parameters=kwargs.get("optional_params", None), model_parameters=kwargs.get("optional_params", None),

View file

@ -0,0 +1,95 @@
"""
Handler file for calls to OpenAI's o1 family of models
Written separately to handle faking streaming for o1 models.
"""
import asyncio
from typing import Any, Callable, List, Optional, Union
from httpx._config import Timeout
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
class OpenAIO1ChatCompletion(OpenAIChatCompletion):
async def mock_async_streaming(
self,
response: Any,
model: Optional[str],
logging_obj: Any,
):
model_response = await response
completion_stream = MockResponseIterator(model_response=model_response)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
)
return streaming_response
def completion(
self,
model_response: ModelResponse,
timeout: Union[float, Timeout],
optional_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable[..., Any]] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
stream: Optional[bool] = optional_params.pop("stream", False)
response = super().completion(
model_response,
timeout,
optional_params,
logging_obj,
model,
messages,
print_verbose,
api_key,
api_base,
acompletion,
litellm_params,
logger_fn,
headers,
custom_prompt_dict,
client,
organization,
custom_llm_provider,
drop_params,
)
if stream is True:
if asyncio.iscoroutine(response):
return self.mock_async_streaming(
response=response, model=model, logging_obj=logging_obj # type: ignore
)
completion_stream = MockResponseIterator(model_response=response)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
)
return streaming_response
else:
return response

View file

@ -15,6 +15,8 @@ import requests # type: ignore
import litellm import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.databricks.exceptions import DatabricksError
from litellm.llms.databricks.streaming_utils import ModelResponseIterator
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
ChatCompletionDeltaChunk, ChatCompletionDeltaChunk,
ChatCompletionResponseMessage, ChatCompletionResponseMessage,
@ -33,17 +35,6 @@ from ..base import BaseLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory from ..prompt_templates.factory import custom_prompt, prompt_factory
class DatabricksError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class DatabricksConfig: class DatabricksConfig:
""" """
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
@ -367,7 +358,7 @@ class DatabricksChatCompletion(BaseLLM):
status_code=e.response.status_code, status_code=e.response.status_code,
message=e.response.text, message=e.response.text,
) )
except httpx.TimeoutException as e: except httpx.TimeoutException:
raise DatabricksError(status_code=408, message="Timeout error occurred.") raise DatabricksError(status_code=408, message="Timeout error occurred.")
except Exception as e: except Exception as e:
raise DatabricksError(status_code=500, message=str(e)) raise DatabricksError(status_code=500, message=str(e))
@ -380,7 +371,7 @@ class DatabricksChatCompletion(BaseLLM):
) )
response = ModelResponse(**response_json) response = ModelResponse(**response_json)
response.model = custom_llm_provider + "/" + response.model response.model = custom_llm_provider + "/" + (response.model or "")
if base_model is not None: if base_model is not None:
response._hidden_params["model"] = base_model response._hidden_params["model"] = base_model
@ -529,7 +520,7 @@ class DatabricksChatCompletion(BaseLLM):
response_json = response.json() response_json = response.json()
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
raise DatabricksError( raise DatabricksError(
status_code=e.response.status_code, message=response.text status_code=e.response.status_code, message=e.response.text
) )
except httpx.TimeoutException as e: except httpx.TimeoutException as e:
raise DatabricksError( raise DatabricksError(
@ -540,7 +531,7 @@ class DatabricksChatCompletion(BaseLLM):
response = ModelResponse(**response_json) response = ModelResponse(**response_json)
response.model = custom_llm_provider + "/" + response.model response.model = custom_llm_provider + "/" + (response.model or "")
if base_model is not None: if base_model is not None:
response._hidden_params["model"] = base_model response._hidden_params["model"] = base_model
@ -657,7 +648,7 @@ class DatabricksChatCompletion(BaseLLM):
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
raise DatabricksError( raise DatabricksError(
status_code=e.response.status_code, status_code=e.response.status_code,
message=response.text if response else str(e), message=e.response.text,
) )
except httpx.TimeoutException as e: except httpx.TimeoutException as e:
raise DatabricksError(status_code=408, message="Timeout error occurred.") raise DatabricksError(status_code=408, message="Timeout error occurred.")
@ -673,136 +664,3 @@ class DatabricksChatCompletion(BaseLLM):
) )
return litellm.EmbeddingResponse(**response_json) return litellm.EmbeddingResponse(**response_json)
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
if processed_chunk.choices[0].delta.content is not None: # type: ignore
text = processed_chunk.choices[0].delta.content # type: ignore
if (
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
is not None
):
tool_use = ChatCompletionToolCallChunk(
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.name,
arguments=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.arguments,
),
index=processed_chunk.choices[0].index,
)
if processed_chunk.choices[0].finish_reason is not None:
is_finished = True
finish_reason = processed_chunk.choices[0].finish_reason
if hasattr(processed_chunk, "usage") and isinstance(
processed_chunk.usage, litellm.Usage
):
usage_chunk: litellm.Usage = processed_chunk.usage
usage = ChatCompletionUsageBlock(
prompt_tokens=usage_chunk.prompt_tokens,
completion_tokens=usage_chunk.completion_tokens,
total_tokens=usage_chunk.total_tokens,
)
return GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=0,
)
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
self.response_iterator = self.streaming_response
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if chunk == "[DONE]":
raise StopAsyncIteration
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -0,0 +1,12 @@
import httpx
class DatabricksError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs

View file

@ -0,0 +1,145 @@
import json
from typing import Optional
import litellm
from litellm.types.llms.openai import (
ChatCompletionDeltaChunk,
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
if processed_chunk.choices[0].delta.content is not None: # type: ignore
text = processed_chunk.choices[0].delta.content # type: ignore
if (
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
is not None
):
tool_use = ChatCompletionToolCallChunk(
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.name,
arguments=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.arguments,
),
index=processed_chunk.choices[0].index,
)
if processed_chunk.choices[0].finish_reason is not None:
is_finished = True
finish_reason = processed_chunk.choices[0].finish_reason
if hasattr(processed_chunk, "usage") and isinstance(
processed_chunk.usage, litellm.Usage
):
usage_chunk: litellm.Usage = processed_chunk.usage
usage = ChatCompletionUsageBlock(
prompt_tokens=usage_chunk.prompt_tokens,
completion_tokens=usage_chunk.completion_tokens,
total_tokens=usage_chunk.total_tokens,
)
return GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=0,
)
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
self.response_iterator = self.streaming_response
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if chunk == "[DONE]":
raise StopAsyncIteration
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -95,6 +95,7 @@ from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.chat import DatabricksChatCompletion from .llms.databricks.chat import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface from .llms.huggingface_restapi import Huggingface
from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion
from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion from .llms.predibase import PredibaseChatCompletion
from .llms.prompt_templates.factory import ( from .llms.prompt_templates.factory import (
@ -161,6 +162,7 @@ from litellm.utils import (
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
openai_chat_completions = OpenAIChatCompletion() openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion() openai_text_completions = OpenAITextCompletion()
openai_o1_chat_completions = OpenAIO1ChatCompletion()
openai_audio_transcriptions = OpenAIAudioTranscription() openai_audio_transcriptions = OpenAIAudioTranscription()
databricks_chat_completions = DatabricksChatCompletion() databricks_chat_completions = DatabricksChatCompletion()
anthropic_chat_completions = AnthropicChatCompletion() anthropic_chat_completions = AnthropicChatCompletion()
@ -1366,25 +1368,46 @@ def completion(
## COMPLETION CALL ## COMPLETION CALL
try: try:
response = openai_chat_completions.completion( if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
model=model, response = openai_o1_chat_completions.completion(
messages=messages, model=model,
headers=headers, messages=messages,
model_response=model_response, headers=headers,
print_verbose=print_verbose, model_response=model_response,
api_key=api_key, print_verbose=print_verbose,
api_base=api_base, api_key=api_key,
acompletion=acompletion, api_base=api_base,
logging_obj=logging, acompletion=acompletion,
optional_params=optional_params, logging_obj=logging,
litellm_params=litellm_params, optional_params=optional_params,
logger_fn=logger_fn, litellm_params=litellm_params,
timeout=timeout, # type: ignore logger_fn=logger_fn,
custom_prompt_dict=custom_prompt_dict, timeout=timeout, # type: ignore
client=client, # pass AsyncOpenAI, OpenAI client custom_prompt_dict=custom_prompt_dict,
organization=organization, client=client, # pass AsyncOpenAI, OpenAI client
custom_llm_provider=custom_llm_provider, organization=organization,
) custom_llm_provider=custom_llm_provider,
)
else:
response = openai_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
)
except Exception as e: except Exception as e:
## LOGGING - log the original exception returned ## LOGGING - log the original exception returned
logging.post_call( logging.post_call(

View file

@ -3,7 +3,6 @@ model_list:
litellm_params: litellm_params:
model: anthropic.claude-3-sonnet-20240229-v1:0 model: anthropic.claude-3-sonnet-20240229-v1:0
api_base: https://exampleopenaiendpoint-production.up.railway.app api_base: https://exampleopenaiendpoint-production.up.railway.app
# aws_session_token: "IQoJb3JpZ2luX2VjELj//////////wEaCXVzLXdlc3QtMiJHMEUCIQDatCRVkIZERLcrR6P7Qd1vNfZ8r8xB/LUeaVaTW/lBTwIgAgmHSBe41d65GVRKSkpgVonjsCmOmAS7s/yklM9NsZcq3AEI4P//////////ARABGgw4ODg2MDIyMjM0MjgiDJrio0/CHYEfyt5EqyqwAfyWO4t3bFVWAOIwTyZ1N6lszeJKfMNus2hzVc+r73hia2Anv88uwPxNg2uqnXQNJumEo0DcBt30ZwOw03Isboy0d5l05h8gjb4nl9feyeKmKAnRdcqElrEWtCC1Qcefv78jQv53AbUipH1ssa5NPvptqZZpZYDPMlBEnV3YdvJJiuE23u2yOkCt+EoUJLaOYjOryoRyrSfbWB+JaUsB68R3rNTHzReeN3Nob/9Ic4HrMMmzmLcGOpgBZxclO4w8Z7i6TcVqbCwDOskxuR6bZaiFxKFG+9tDrWS7jaQKpq/YP9HUT0YwYpZplaBEEZR5sbIndg5yb4dRZrSHplblqKz8XLaUf5tuuyRJmwr96PTpw/dyEVk9gicFX6JfLBEv0v5rN2Z0JMFLdfIP4kC1U2PjcPOWoglWO3fLmJ4Lol2a3c5XDSMwMxjcJXq+c8Ue1v0="
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
- model_name: gemini-vision - model_name: gemini-vision
@ -17,4 +16,34 @@ model_list:
litellm_params: litellm_params:
model: gpt-3.5-turbo model: gpt-3.5-turbo
api_base: https://exampleopenaiendpoint-production.up.railway.app api_base: https://exampleopenaiendpoint-production.up.railway.app
- model_name: o1-preview
litellm_params:
model: o1-preview
litellm_settings:
drop_params: True
json_logs: True
store_audit_logs: True
log_raw_request_response: True
return_response_headers: True
num_retries: 5
request_timeout: 200
callbacks: ["custom_callbacks.proxy_handler_instance"]
guardrails:
- guardrail_name: "presidio-pre-guard"
litellm_params:
guardrail: presidio # supported values: "aporia", "bedrock", "lakera", "presidio"
mode: "logging_only"
mock_redacted_text: {
"text": "My name is <PERSON>, who are you? Say my name in your response",
"items": [
{
"start": 11,
"end": 19,
"entity_type": "PERSON",
"text": "<PERSON>",
"operator": "replace",
}
],
}

View file

@ -446,7 +446,6 @@ async def user_api_key_auth(
and request.headers.get(key=header_key) is not None # type: ignore and request.headers.get(key=header_key) is not None # type: ignore
): ):
api_key = request.headers.get(key=header_key) # type: ignore api_key = request.headers.get(key=header_key) # type: ignore
if master_key is None: if master_key is None:
if isinstance(api_key, str): if isinstance(api_key, str):
return UserAPIKeyAuth( return UserAPIKeyAuth(

View file

@ -2,9 +2,19 @@ from litellm.integrations.custom_logger import CustomLogger
class MyCustomHandler(CustomLogger): class MyCustomHandler(CustomLogger):
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
# print("Call failed") # input_tokens = response_obj.get("usage", {}).get("prompt_tokens", 0)
pass # output_tokens = response_obj.get("usage", {}).get("completion_tokens", 0)
input_tokens = (
response_obj.usage.prompt_tokens
if hasattr(response_obj.usage, "prompt_tokens")
else 0
)
output_tokens = (
response_obj.usage.completion_tokens
if hasattr(response_obj.usage, "completion_tokens")
else 0
)
proxy_handler_instance = MyCustomHandler() proxy_handler_instance = MyCustomHandler()

View file

@ -218,7 +218,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
response = await self.async_handler.post( response = await self.async_handler.post(
url=prepared_request.url, url=prepared_request.url,
json=request_data, # type: ignore json=request_data, # type: ignore
headers=dict(prepared_request.headers), headers=prepared_request.headers, # type: ignore
) )
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text) verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
if response.status_code == 200: if response.status_code == 200:
@ -254,7 +254,6 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
from litellm.proxy.common_utils.callback_utils import ( from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header, add_guardrail_to_applied_guardrails_header,
) )
from litellm.types.guardrails import GuardrailEventHooks
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True: if self.should_run_guardrail(data=data, event_type=event_type) is not True:

View file

@ -189,6 +189,7 @@ class lakeraAI_Moderation(CustomGuardrail):
# Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already) # Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already)
# Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked. # Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked.
# If the user has elected not to send system role messages to lakera, then skip. # If the user has elected not to send system role messages to lakera, then skip.
if system_message is not None: if system_message is not None:
if not litellm.add_function_to_prompt: if not litellm.add_function_to_prompt:
content = system_message.get("content") content = system_message.get("content")

View file

@ -19,6 +19,7 @@ from fastapi import HTTPException
from pydantic import BaseModel from pydantic import BaseModel
import litellm # noqa: E401 import litellm # noqa: E401
from litellm import get_secret
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_guardrail import CustomGuardrail
@ -58,7 +59,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
self.pii_tokens: dict = ( self.pii_tokens: dict = (
{} {}
) # mapping of PII token to original text - only used with Presidio `replace` operation ) # mapping of PII token to original text - only used with Presidio `replace` operation
self.mock_redacted_text = mock_redacted_text self.mock_redacted_text = mock_redacted_text
self.output_parse_pii = output_parse_pii or False self.output_parse_pii = output_parse_pii or False
if mock_testing is True: # for testing purposes only if mock_testing is True: # for testing purposes only
@ -92,8 +92,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
presidio_anonymizer_api_base: Optional[str] = None, presidio_anonymizer_api_base: Optional[str] = None,
): ):
self.presidio_analyzer_api_base: Optional[str] = ( self.presidio_analyzer_api_base: Optional[str] = (
presidio_analyzer_api_base presidio_analyzer_api_base or get_secret("PRESIDIO_ANALYZER_API_BASE", None) # type: ignore
or litellm.get_secret("PRESIDIO_ANALYZER_API_BASE", None)
) )
self.presidio_anonymizer_api_base: Optional[ self.presidio_anonymizer_api_base: Optional[
str str
@ -198,12 +197,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
else: else:
raise Exception(f"Invalid anonymizer response: {redacted_text}") raise Exception(f"Invalid anonymizer response: {redacted_text}")
except Exception as e: except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e raise e
async def async_pre_call_hook( async def async_pre_call_hook(
@ -254,9 +247,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
) )
return data return data
except Exception as e: except Exception as e:
verbose_proxy_logger.info(
f"An error occurred -",
)
raise e raise e
async def async_logging_hook( async def async_logging_hook(
@ -300,9 +290,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
) )
kwargs["messages"] = messages kwargs["messages"] = messages
return kwargs, responses return kwargs, result
async def async_post_call_success_hook( async def async_post_call_success_hook( # type: ignore
self, self,
data: dict, data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
@ -314,7 +304,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}" f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
) )
if self.output_parse_pii == False: if self.output_parse_pii is False:
return response return response
if isinstance(response, ModelResponse) and not isinstance( if isinstance(response, ModelResponse) and not isinstance(

View file

@ -1,10 +1,11 @@
import importlib import importlib
import traceback import traceback
from typing import Dict, List, Literal from typing import Dict, List, Literal, Optional
from pydantic import BaseModel, RootModel from pydantic import BaseModel, RootModel
import litellm import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
@ -16,7 +17,6 @@ from litellm.types.guardrails import (
GuardrailItemSpec, GuardrailItemSpec,
LakeraCategoryThresholds, LakeraCategoryThresholds,
LitellmParams, LitellmParams,
guardrailConfig,
) )
all_guardrails: List[GuardrailItem] = [] all_guardrails: List[GuardrailItem] = []
@ -98,18 +98,13 @@ def init_guardrails_v2(
# Init litellm params for guardrail # Init litellm params for guardrail
litellm_params_data = guardrail["litellm_params"] litellm_params_data = guardrail["litellm_params"]
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data) verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
litellm_params = LitellmParams(
guardrail=litellm_params_data["guardrail"], _litellm_params_kwargs = {
mode=litellm_params_data["mode"], k: litellm_params_data[k] if k in litellm_params_data else None
api_key=litellm_params_data.get("api_key"), for k in LitellmParams.__annotations__.keys()
api_base=litellm_params_data.get("api_base"), }
guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"),
guardrailVersion=litellm_params_data.get("guardrailVersion"), litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore
output_parse_pii=litellm_params_data.get("output_parse_pii"),
presidio_ad_hoc_recognizers=litellm_params_data.get(
"presidio_ad_hoc_recognizers"
),
)
if ( if (
"category_thresholds" in litellm_params_data "category_thresholds" in litellm_params_data
@ -122,15 +117,11 @@ def init_guardrails_v2(
if litellm_params["api_key"]: if litellm_params["api_key"]:
if litellm_params["api_key"].startswith("os.environ/"): if litellm_params["api_key"].startswith("os.environ/"):
litellm_params["api_key"] = litellm.get_secret( litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore
litellm_params["api_key"]
)
if litellm_params["api_base"]: if litellm_params["api_base"]:
if litellm_params["api_base"].startswith("os.environ/"): if litellm_params["api_base"].startswith("os.environ/"):
litellm_params["api_base"] = litellm.get_secret( litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
litellm_params["api_base"]
)
# Init guardrail CustomLoggerClass # Init guardrail CustomLoggerClass
if litellm_params["guardrail"] == "aporia": if litellm_params["guardrail"] == "aporia":
@ -182,6 +173,7 @@ def init_guardrails_v2(
presidio_ad_hoc_recognizers=litellm_params[ presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers" "presidio_ad_hoc_recognizers"
], ],
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
) )
if litellm_params["output_parse_pii"] is True: if litellm_params["output_parse_pii"] is True:

View file

@ -167,11 +167,11 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
if self.prompt_injection_params is not None: if self.prompt_injection_params is not None:
# 1. check if heuristics check turned on # 1. check if heuristics check turned on
if self.prompt_injection_params.heuristics_check == True: if self.prompt_injection_params.heuristics_check is True:
is_prompt_attack = self.check_user_input_similarity( is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt user_input=formatted_prompt
) )
if is_prompt_attack == True: if is_prompt_attack is True:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail={ detail={
@ -179,14 +179,14 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
}, },
) )
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet # 2. check if vector db similarity check turned on [TODO] Not Implemented yet
if self.prompt_injection_params.vector_db_check == True: if self.prompt_injection_params.vector_db_check is True:
pass pass
else: else:
is_prompt_attack = self.check_user_input_similarity( is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt user_input=formatted_prompt
) )
if is_prompt_attack == True: if is_prompt_attack is True:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail={ detail={
@ -201,19 +201,18 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
if ( if (
e.status_code == 400 e.status_code == 400
and isinstance(e.detail, dict) and isinstance(e.detail, dict)
and "error" in e.detail and "error" in e.detail # type: ignore
and self.prompt_injection_params is not None and self.prompt_injection_params is not None
and self.prompt_injection_params.reject_as_response and self.prompt_injection_params.reject_as_response
): ):
return e.detail.get("error") return e.detail.get("error")
raise e raise e
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e) str(e)
) )
) )
verbose_proxy_logger.debug(traceback.format_exc())
async def async_moderation_hook( # type: ignore async def async_moderation_hook( # type: ignore
self, self,

View file

@ -195,7 +195,8 @@ async def user_auth(request: Request):
- os.environ["SMTP_PASSWORD"] - os.environ["SMTP_PASSWORD"]
- os.environ["SMTP_SENDER_EMAIL"] - os.environ["SMTP_SENDER_EMAIL"]
""" """
from litellm.proxy.proxy_server import prisma_client, send_email from litellm.proxy.proxy_server import prisma_client
from litellm.proxy.utils import send_email
data = await request.json() # type: ignore data = await request.json() # type: ignore
user_email = data["user_email"] user_email = data["user_email"]
@ -212,7 +213,7 @@ async def user_auth(request: Request):
) )
### if so - generate a 24 hr key with that user id ### if so - generate a 24 hr key with that user id
if response is not None: if response is not None:
user_id = response.user_id user_id = response.user_id # type: ignore
response = await generate_key_helper_fn( response = await generate_key_helper_fn(
request_type="key", request_type="key",
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id}, # type: ignore **{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id}, # type: ignore
@ -345,6 +346,7 @@ async def user_info(
for team in teams_1: for team in teams_1:
team_id_list.append(team.team_id) team_id_list.append(team.team_id)
teams_2: Optional[Any] = None
if user_info is not None: if user_info is not None:
# *NEW* get all teams in user 'teams' field # *NEW* get all teams in user 'teams' field
teams_2 = await prisma_client.get_data( teams_2 = await prisma_client.get_data(
@ -375,7 +377,7 @@ async def user_info(
), ),
user_api_key_dict=user_api_key_dict, user_api_key_dict=user_api_key_dict,
) )
else: elif caller_user_info is not None:
teams_2 = await prisma_client.get_data( teams_2 = await prisma_client.get_data(
team_id_list=caller_user_info.teams, team_id_list=caller_user_info.teams,
table_name="team", table_name="team",
@ -395,7 +397,7 @@ async def user_info(
query_type="find_all", query_type="find_all",
) )
if user_info is None: if user_info is None and keys is not None:
## make sure we still return a total spend ## ## make sure we still return a total spend ##
spend = 0 spend = 0
for k in keys: for k in keys:
@ -404,32 +406,35 @@ async def user_info(
## REMOVE HASHED TOKEN INFO before returning ## ## REMOVE HASHED TOKEN INFO before returning ##
returned_keys = [] returned_keys = []
for key in keys: if keys is None:
if ( pass
key.token == litellm_master_key_hash else:
and general_settings.get("disable_master_key_return", False) for key in keys:
== True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui if (
): key.token == litellm_master_key_hash
continue and general_settings.get("disable_master_key_return", False)
== True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
):
continue
try: try:
key = key.model_dump() # noqa key = key.model_dump() # noqa
except: except:
# if using pydantic v1 # if using pydantic v1
key = key.dict() key = key.dict()
if ( if (
"team_id" in key "team_id" in key
and key["team_id"] is not None and key["team_id"] is not None
and key["team_id"] != "litellm-dashboard" and key["team_id"] != "litellm-dashboard"
): ):
team_info = await prisma_client.get_data( team_info = await prisma_client.get_data(
team_id=key["team_id"], table_name="team" team_id=key["team_id"], table_name="team"
) )
team_alias = getattr(team_info, "team_alias", None) team_alias = getattr(team_info, "team_alias", None)
key["team_alias"] = team_alias key["team_alias"] = team_alias
else: else:
key["team_alias"] = "None" key["team_alias"] = "None"
returned_keys.append(key) returned_keys.append(key)
response_data = { response_data = {
"user_id": user_id, "user_id": user_id,
@ -539,6 +544,7 @@ async def user_update(
## ADD USER, IF NEW ## ## ADD USER, IF NEW ##
verbose_proxy_logger.debug("/user/update: Received data = %s", data) verbose_proxy_logger.debug("/user/update: Received data = %s", data)
response: Optional[Any] = None
if data.user_id is not None and len(data.user_id) > 0: if data.user_id is not None and len(data.user_id) > 0:
non_default_values["user_id"] = data.user_id # type: ignore non_default_values["user_id"] = data.user_id # type: ignore
verbose_proxy_logger.debug("In update user, user_id condition block.") verbose_proxy_logger.debug("In update user, user_id condition block.")
@ -573,7 +579,7 @@ async def user_update(
data=non_default_values, data=non_default_values,
table_name="user", table_name="user",
) )
return response return response # type: ignore
# update based on remaining passed in values # update based on remaining passed in values
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(

View file

@ -226,7 +226,6 @@ from litellm.proxy.utils import (
hash_token, hash_token,
log_to_opentelemetry, log_to_opentelemetry,
reset_budget, reset_budget,
send_email,
update_spend, update_spend,
) )
from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import ( from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import (

View file

@ -1434,7 +1434,7 @@ async def _get_spend_report_for_time_range(
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
"Exception in _get_daily_spend_reports {}".format(str(e)) "Exception in _get_daily_spend_reports {}".format(str(e))
) # noqa )
@router.post( @router.post(
@ -1703,26 +1703,26 @@ async def view_spend_logs(
result: dict = {} result: dict = {}
for record in response: for record in response:
dt_object = datetime.strptime( dt_object = datetime.strptime(
str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" # type: ignore
) # type: ignore ) # type: ignore
date = dt_object.date() date = dt_object.date()
if date not in result: if date not in result:
result[date] = {"users": {}, "models": {}} result[date] = {"users": {}, "models": {}}
api_key = record["api_key"] api_key = record["api_key"] # type: ignore
user_id = record["user"] user_id = record["user"] # type: ignore
model = record["model"] model = record["model"] # type: ignore
result[date]["spend"] = ( result[date]["spend"] = result[date].get("spend", 0) + record.get(
result[date].get("spend", 0) + record["_sum"]["spend"] "_sum", {}
) ).get("spend", 0)
result[date][api_key] = ( result[date][api_key] = result[date].get(api_key, 0) + record.get(
result[date].get(api_key, 0) + record["_sum"]["spend"] "_sum", {}
) ).get("spend", 0)
result[date]["users"][user_id] = ( result[date]["users"][user_id] = result[date]["users"].get(
result[date]["users"].get(user_id, 0) + record["_sum"]["spend"] user_id, 0
) ) + record.get("_sum", {}).get("spend", 0)
result[date]["models"][model] = ( result[date]["models"][model] = result[date]["models"].get(
result[date]["models"].get(model, 0) + record["_sum"]["spend"] model, 0
) ) + record.get("_sum", {}).get("spend", 0)
return_list = [] return_list = []
final_date = None final_date = None
for k, v in sorted(result.items()): for k, v in sorted(result.items()):
@ -1784,7 +1784,7 @@ async def view_spend_logs(
table_name="spend", query_type="find_all" table_name="spend", query_type="find_all"
) )
return spend_log return spend_logs
return None return None
@ -1843,6 +1843,88 @@ async def global_spend_reset():
} }
@router.post(
"/global/spend/refresh",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
async def global_spend_refresh():
"""
ADMIN ONLY / MASTER KEY Only Endpoint
Globally refresh spend MonthlyGlobalSpend view
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise ProxyException(
message="Prisma Client is not initialized",
type="internal_error",
param="None",
code=status.HTTP_401_UNAUTHORIZED,
)
## RESET GLOBAL SPEND VIEW ###
async def is_materialized_global_spend_view() -> bool:
"""
Return True if materialized view exists
Else False
"""
sql_query = """
SELECT relname, relkind
FROM pg_class
WHERE relname = 'MonthlyGlobalSpend';
"""
try:
resp = await prisma_client.db.query_raw(sql_query)
assert resp[0]["relkind"] == "m"
return True
except Exception:
return False
view_exists = await is_materialized_global_spend_view()
if view_exists:
# refresh materialized view
sql_query = """
REFRESH MATERIALIZED VIEW "MonthlyGlobalSpend";
"""
try:
from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.proxy_server import proxy_logging_obj
from litellm.proxy.utils import PrismaClient
db_url = os.getenv("DATABASE_URL")
if db_url is None:
raise Exception(CommonProxyErrors.db_not_connected_error.value)
new_client = PrismaClient(
database_url=db_url,
proxy_logging_obj=proxy_logging_obj,
http_client={
"timeout": 6000,
},
)
await new_client.db.connect()
await new_client.db.query_raw(sql_query)
verbose_proxy_logger.info("MonthlyGlobalSpend view refreshed")
return {
"message": "MonthlyGlobalSpend view refreshed",
"status": "success",
}
except Exception as e:
verbose_proxy_logger.exception(
"Failed to refresh materialized view - {}".format(str(e))
)
return {
"message": "Failed to refresh materialized view",
"status": "failure",
}
async def global_spend_for_internal_user( async def global_spend_for_internal_user(
api_key: Optional[str] = None, api_key: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),

View file

@ -92,6 +92,7 @@ def safe_deep_copy(data):
if litellm.safe_memory_mode is True: if litellm.safe_memory_mode is True:
return data return data
litellm_parent_otel_span: Optional[Any] = None
# Step 1: Remove the litellm_parent_otel_span # Step 1: Remove the litellm_parent_otel_span
litellm_parent_otel_span = None litellm_parent_otel_span = None
if isinstance(data, dict): if isinstance(data, dict):
@ -101,7 +102,7 @@ def safe_deep_copy(data):
new_data = copy.deepcopy(data) new_data = copy.deepcopy(data)
# Step 2: re-add the litellm_parent_otel_span after doing a deep copy # Step 2: re-add the litellm_parent_otel_span after doing a deep copy
if isinstance(data, dict): if isinstance(data, dict) and litellm_parent_otel_span is not None:
if "metadata" in data: if "metadata" in data:
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
return new_data return new_data
@ -468,7 +469,7 @@ class ProxyLogging:
# V1 implementation - backwards compatibility # V1 implementation - backwards compatibility
if callback.event_hook is None: if callback.event_hook is None:
if callback.moderation_check == "pre_call": if callback.moderation_check == "pre_call": # type: ignore
return return
else: else:
# Main - V2 Guardrails implementation # Main - V2 Guardrails implementation
@ -881,7 +882,12 @@ class PrismaClient:
org_list_transactons: dict = {} org_list_transactons: dict = {}
spend_log_transactions: List = [] spend_log_transactions: List = []
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging): def __init__(
self,
database_url: str,
proxy_logging_obj: ProxyLogging,
http_client: Optional[Any] = None,
):
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'" "LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
) )
@ -912,7 +918,10 @@ class PrismaClient:
# Now you can import the Prisma Client # Now you can import the Prisma Client
from prisma import Prisma # type: ignore from prisma import Prisma # type: ignore
verbose_proxy_logger.debug("Connecting Prisma Client to DB..") verbose_proxy_logger.debug("Connecting Prisma Client to DB..")
self.db = Prisma() # Client to connect to Prisma db if http_client is not None:
self.db = Prisma(http=http_client)
else:
self.db = Prisma() # Client to connect to Prisma db
verbose_proxy_logger.debug("Success - Connected Prisma Client to DB") verbose_proxy_logger.debug("Success - Connected Prisma Client to DB")
def hash_token(self, token: str): def hash_token(self, token: str):
@ -987,7 +996,7 @@ class PrismaClient:
return return
else: else:
## check if required view exists ## ## check if required view exists ##
if required_view not in ret[0]["view_names"]: if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
await self.health_check() # make sure we can connect to db await self.health_check() # make sure we can connect to db
await self.db.execute_raw( await self.db.execute_raw(
""" """
@ -1009,7 +1018,9 @@ class PrismaClient:
else: else:
# don't block execution if these views are missing # don't block execution if these views are missing
# Convert lists to sets for efficient difference calculation # Convert lists to sets for efficient difference calculation
ret_view_names_set = set(ret[0]["view_names"]) ret_view_names_set = (
set(ret[0]["view_names"]) if ret[0]["view_names"] else set()
)
expected_views_set = set(expected_views) expected_views_set = set(expected_views)
# Find missing views # Find missing views
missing_views = expected_views_set - ret_view_names_set missing_views = expected_views_set - ret_view_names_set
@ -1291,13 +1302,13 @@ class PrismaClient:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"PrismaClient: get_data - args_passed_in: {args_passed_in}" f"PrismaClient: get_data - args_passed_in: {args_passed_in}"
) )
hashed_token: Optional[str] = None
try: try:
response: Any = None response: Any = None
if (token is not None and table_name is None) or ( if (token is not None and table_name is None) or (
table_name is not None and table_name == "key" table_name is not None and table_name == "key"
): ):
# check if plain text or hash # check if plain text or hash
hashed_token = None
if token is not None: if token is not None:
if isinstance(token, str): if isinstance(token, str):
hashed_token = token hashed_token = token
@ -1306,7 +1317,7 @@ class PrismaClient:
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}" f"PrismaClient: find_unique for token: {hashed_token}"
) )
if query_type == "find_unique": if query_type == "find_unique" and hashed_token is not None:
if token is None: if token is None:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -1706,7 +1717,7 @@ class PrismaClient:
updated_data = v updated_data = v
updated_data = json.dumps(updated_data) updated_data = json.dumps(updated_data)
updated_table_row = self.db.litellm_config.upsert( updated_table_row = self.db.litellm_config.upsert(
where={"param_name": k}, where={"param_name": k}, # type: ignore
data={ data={
"create": {"param_name": k, "param_value": updated_data}, # type: ignore "create": {"param_name": k, "param_value": updated_data}, # type: ignore
"update": {"param_value": updated_data}, "update": {"param_value": updated_data},
@ -2302,7 +2313,12 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
return instance return instance
except ImportError as e: except ImportError as e:
# Re-raise the exception with a user-friendly message # Re-raise the exception with a user-friendly message
raise ImportError(f"Could not import {instance_name} from {module_name}") from e if instance_name and module_name:
raise ImportError(
f"Could not import {instance_name} from {module_name}"
) from e
else:
raise e
except Exception as e: except Exception as e:
raise e raise e
@ -2377,12 +2393,12 @@ async def send_email(receiver_email, subject, html):
try: try:
# Establish a secure connection with the SMTP server # Establish a secure connection with the SMTP server
with smtplib.SMTP(smtp_host, smtp_port) as server: with smtplib.SMTP(smtp_host, smtp_port) as server: # type: ignore
if os.getenv("SMTP_TLS", "True") != "False": if os.getenv("SMTP_TLS", "True") != "False":
server.starttls() server.starttls()
# Login to your email account # Login to your email account
server.login(smtp_username, smtp_password) server.login(smtp_username, smtp_password) # type: ignore
# Send the email # Send the email
server.send_message(email_message) server.send_message(email_message)
@ -2945,7 +2961,7 @@ async def update_spend(
if i >= n_retry_times: # If we've reached the maximum number of retries if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying # Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff await asyncio.sleep(2**i) # type: ignore
except Exception as e: except Exception as e:
import traceback import traceback

View file

@ -2110,7 +2110,6 @@ async def test_hf_completion_tgi_stream():
def test_openai_chat_completion_call(): def test_openai_chat_completion_call():
litellm.set_verbose = False litellm.set_verbose = False
litellm.return_response_headers = True litellm.return_response_headers = True
print(f"making openai chat completion call")
response = completion(model="gpt-3.5-turbo", messages=messages, stream=True) response = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
assert isinstance( assert isinstance(
response._hidden_params["additional_headers"][ response._hidden_params["additional_headers"][
@ -2318,6 +2317,57 @@ def test_together_ai_completion_call_mistral():
pass pass
# # test on together ai completion call - starcoder
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_openai_o1_completion_call_streaming(sync_mode):
try:
litellm.set_verbose = False
if sync_mode:
response = completion(
model="o1-preview",
messages=messages,
stream=True,
)
complete_response = ""
print(f"returned response object: {response}")
has_finish_reason = False
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
if finished:
break
complete_response += chunk
if has_finish_reason is False:
raise Exception("Finish reason not set for last chunk")
if complete_response == "":
raise Exception("Empty response received")
else:
response = await acompletion(
model="o1-preview",
messages=messages,
stream=True,
)
complete_response = ""
print(f"returned response object: {response}")
has_finish_reason = False
idx = 0
async for chunk in response:
chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
if finished:
break
complete_response += chunk
idx += 1
if has_finish_reason is False:
raise Exception("Finish reason not set for last chunk")
if complete_response == "":
raise Exception("Empty response received")
print(f"complete response: {complete_response}")
except Exception:
pytest.fail(f"error occurred: {traceback.format_exc()}")
def test_together_ai_completion_call_starcoder_bad_key(): def test_together_ai_completion_call_starcoder_bad_key():
try: try:
api_key = "bad-key" api_key = "bad-key"

View file

@ -71,7 +71,7 @@ class LakeraCategoryThresholds(TypedDict, total=False):
jailbreak: float jailbreak: float
class LitellmParams(TypedDict, total=False): class LitellmParams(TypedDict):
guardrail: str guardrail: str
mode: str mode: str
api_key: str api_key: str
@ -87,6 +87,7 @@ class LitellmParams(TypedDict, total=False):
# Presidio params # Presidio params
output_parse_pii: Optional[bool] output_parse_pii: Optional[bool]
presidio_ad_hoc_recognizers: Optional[str] presidio_ad_hoc_recognizers: Optional[str]
mock_redacted_text: Optional[dict]
class Guardrail(TypedDict): class Guardrail(TypedDict):

View file

@ -120,11 +120,26 @@ with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json")
# Convert to str (if necessary) # Convert to str (if necessary)
claude_json_str = json.dumps(json_data) claude_json_str = json.dumps(json_data)
import importlib.metadata import importlib.metadata
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
Type,
Union,
cast,
get_args,
)
from openai import OpenAIError as OriginalError from openai import OpenAIError as OriginalError
from ._logging import verbose_logger from ._logging import verbose_logger
from .caching import QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache from .caching import Cache, QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
from .exceptions import ( from .exceptions import (
APIConnectionError, APIConnectionError,
APIError, APIError,
@ -150,31 +165,6 @@ from .types.llms.openai import (
) )
from .types.router import LiteLLM_Params from .types.router import LiteLLM_Params
try:
from .proxy.enterprise.enterprise_callbacks.generic_api_callback import (
GenericAPILogger,
)
except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
Type,
Union,
cast,
get_args,
)
from .caching import Cache
####### ENVIRONMENT VARIABLES #################### ####### ENVIRONMENT VARIABLES ####################
# Adjust to your specific application needs / system capabilities. # Adjust to your specific application needs / system capabilities.
MAX_THREADS = 100 MAX_THREADS = 100

View file

@ -0,0 +1 @@
More tests under `litellm/litellm/tests/*`.

View file

@ -0,0 +1,502 @@
import asyncio
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import litellm
from litellm.exceptions import BadRequestError, InternalServerError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
def mock_chat_response() -> Dict[str, Any]:
return {
"id": "chatcmpl_3f78f09a-489c-4b8d-a587-f162c7497891",
"object": "chat.completion",
"created": 1726285449,
"model": "dbrx-instruct-071224",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! I'm an AI assistant. I'm doing well. How can I help?",
"function_call": None,
"tool_calls": None,
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 230,
"completion_tokens": 38,
"total_tokens": 268,
"completion_tokens_details": None,
},
"system_fingerprint": None,
}
def mock_chat_streaming_response_chunks() -> List[str]:
return [
json.dumps(
{
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
"object": "chat.completion.chunk",
"created": 1726469651,
"model": "dbrx-instruct-071224",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "Hello"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {
"prompt_tokens": 230,
"completion_tokens": 1,
"total_tokens": 231,
},
}
),
json.dumps(
{
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
"object": "chat.completion.chunk",
"created": 1726469651,
"model": "dbrx-instruct-071224",
"choices": [
{
"index": 0,
"delta": {"content": " world"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {
"prompt_tokens": 230,
"completion_tokens": 1,
"total_tokens": 231,
},
}
),
json.dumps(
{
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
"object": "chat.completion.chunk",
"created": 1726469651,
"model": "dbrx-instruct-071224",
"choices": [
{
"index": 0,
"delta": {"content": "!"},
"finish_reason": "stop",
"logprobs": None,
}
],
"usage": {
"prompt_tokens": 230,
"completion_tokens": 1,
"total_tokens": 231,
},
}
),
]
def mock_chat_streaming_response_chunks_bytes() -> List[bytes]:
string_chunks = mock_chat_streaming_response_chunks()
bytes_chunks = [chunk.encode("utf-8") + b"\n" for chunk in string_chunks]
# Simulate the end of the stream
bytes_chunks.append(b"")
return bytes_chunks
def mock_http_handler_chat_streaming_response() -> MagicMock:
mock_stream_chunks = mock_chat_streaming_response_chunks()
def mock_iter_lines():
for chunk in mock_stream_chunks:
for line in chunk.splitlines():
yield line
mock_response = MagicMock()
mock_response.iter_lines.side_effect = mock_iter_lines
mock_response.status_code = 200
return mock_response
def mock_http_handler_chat_async_streaming_response() -> MagicMock:
mock_stream_chunks = mock_chat_streaming_response_chunks()
async def mock_iter_lines():
for chunk in mock_stream_chunks:
for line in chunk.splitlines():
yield line
mock_response = MagicMock()
mock_response.aiter_lines.return_value = mock_iter_lines()
mock_response.status_code = 200
return mock_response
def mock_databricks_client_chat_streaming_response() -> MagicMock:
mock_stream_chunks = mock_chat_streaming_response_chunks_bytes()
def mock_read_from_stream(size=-1):
if mock_stream_chunks:
return mock_stream_chunks.pop(0)
return b""
mock_response = MagicMock()
streaming_response_mock = MagicMock()
streaming_response_iterator_mock = MagicMock()
# Mock the __getitem__("content") method to return the streaming response
mock_response.__getitem__.return_value = streaming_response_mock
# Mock the streaming response __enter__ method to return the streaming response iterator
streaming_response_mock.__enter__.return_value = streaming_response_iterator_mock
streaming_response_iterator_mock.read1.side_effect = mock_read_from_stream
streaming_response_iterator_mock.closed = False
return mock_response
def mock_embedding_response() -> Dict[str, Any]:
return {
"object": "list",
"model": "bge-large-en-v1.5",
"data": [
{
"index": 0,
"object": "embedding",
"embedding": [
0.06768798828125,
-0.01291656494140625,
-0.0501708984375,
0.0245361328125,
-0.030364990234375,
],
}
],
"usage": {
"prompt_tokens": 8,
"total_tokens": 8,
"completion_tokens": 0,
"completion_tokens_details": None,
},
}
@pytest.mark.parametrize("set_base", [True, False])
def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
if set_base:
monkeypatch.setenv(
"DATABRICKS_API_BASE",
"https://my.workspace.cloud.databricks.com/serving-endpoints",
)
monkeypatch.delenv(
"DATABRICKS_API_KEY",
)
err_msg = "A call is being made to LLM Provider but no key is set"
else:
monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey")
monkeypatch.delenv("DATABRICKS_API_BASE")
err_msg = "A call is being made to LLM Provider but no api base is set"
with pytest.raises(BadRequestError) as exc:
litellm.completion(
model="databricks/dbrx-instruct-071224",
messages={"role": "user", "content": "How are you?"},
)
assert err_msg in str(exc)
with pytest.raises(BadRequestError) as exc:
litellm.embedding(
model="databricks/bge-12312",
input=["Hello", "World"],
)
assert err_msg in str(exc)
def test_completions_with_sync_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
sync_handler = HTTPHandler()
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = mock_chat_response()
expected_response_json = {
**mock_chat_response(),
**{
"model": "databricks/dbrx-instruct-071224",
},
}
messages = [{"role": "user", "content": "How are you?"}]
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
response = litellm.completion(
model="databricks/dbrx-instruct-071224",
messages=messages,
client=sync_handler,
temperature=0.5,
extraparam="testpassingextraparam",
)
assert response.to_dict() == expected_response_json
mock_post.assert_called_once_with(
f"{base_url}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"extraparam": "testpassingextraparam",
"stream": False,
}
),
)
def test_completions_with_async_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
async_handler = AsyncHTTPHandler()
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = mock_chat_response()
expected_response_json = {
**mock_chat_response(),
**{
"model": "databricks/dbrx-instruct-071224",
},
}
messages = [{"role": "user", "content": "How are you?"}]
with patch.object(
AsyncHTTPHandler, "post", return_value=mock_response
) as mock_post:
response = asyncio.run(
litellm.acompletion(
model="databricks/dbrx-instruct-071224",
messages=messages,
client=async_handler,
temperature=0.5,
extraparam="testpassingextraparam",
)
)
assert response.to_dict() == expected_response_json
mock_post.assert_called_once_with(
f"{base_url}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"extraparam": "testpassingextraparam",
"stream": False,
}
),
)
def test_completions_streaming_with_sync_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
sync_handler = HTTPHandler()
messages = [{"role": "user", "content": "How are you?"}]
mock_response = mock_http_handler_chat_streaming_response()
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
response_stream: CustomStreamWrapper = litellm.completion(
model="databricks/dbrx-instruct-071224",
messages=messages,
client=sync_handler,
temperature=0.5,
extraparam="testpassingextraparam",
stream=True,
)
response = list(response_stream)
assert "dbrx-instruct-071224" in str(response)
assert "chatcmpl" in str(response)
assert len(response) == 4
mock_post.assert_called_once_with(
f"{base_url}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"stream": True,
"extraparam": "testpassingextraparam",
}
),
stream=True,
)
def test_completions_streaming_with_async_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
async_handler = AsyncHTTPHandler()
messages = [{"role": "user", "content": "How are you?"}]
mock_response = mock_http_handler_chat_async_streaming_response()
with patch.object(
AsyncHTTPHandler, "post", return_value=mock_response
) as mock_post:
response_stream: CustomStreamWrapper = asyncio.run(
litellm.acompletion(
model="databricks/dbrx-instruct-071224",
messages=messages,
client=async_handler,
temperature=0.5,
extraparam="testpassingextraparam",
stream=True,
)
)
# Use async list gathering for the response
async def gather_responses():
return [item async for item in response_stream]
response = asyncio.run(gather_responses())
assert "dbrx-instruct-071224" in str(response)
assert "chatcmpl" in str(response)
assert len(response) == 4
mock_post.assert_called_once_with(
f"{base_url}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"stream": True,
"extraparam": "testpassingextraparam",
}
),
stream=True,
)
def test_embeddings_with_sync_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
sync_handler = HTTPHandler()
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = mock_embedding_response()
inputs = ["Hello", "World"]
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
response = litellm.embedding(
model="databricks/bge-large-en-v1.5",
input=inputs,
client=sync_handler,
extraparam="testpassingextraparam",
)
assert response.to_dict() == mock_embedding_response()
mock_post.assert_called_once_with(
f"{base_url}/embeddings",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "bge-large-en-v1.5",
"input": inputs,
"extraparam": "testpassingextraparam",
}
),
)
def test_embeddings_with_async_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
async_handler = AsyncHTTPHandler()
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = mock_embedding_response()
inputs = ["Hello", "World"]
with patch.object(
AsyncHTTPHandler, "post", return_value=mock_response
) as mock_post:
response = asyncio.run(
litellm.aembedding(
model="databricks/bge-large-en-v1.5",
input=inputs,
client=async_handler,
extraparam="testpassingextraparam",
)
)
assert response.to_dict() == mock_embedding_response()
mock_post.assert_called_once_with(
f"{base_url}/embeddings",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "bge-large-en-v1.5",
"input": inputs,
"extraparam": "testpassingextraparam",
}
),
)