LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) (#5731)

* LiteLLM Minor Fixes & Improvements (09/16/2024)  (#5723)

* coverage (#5713)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* Move (#5714)

Signed-off-by: dbczumar <corey.zumar@databricks.com>

* fix(litellm_logging.py): fix logging client re-init (#5710)

Fixes https://github.com/BerriAI/litellm/issues/5695

* fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config

Fixes https://github.com/BerriAI/litellm/issues/5682

* feat(o1_handler.py): fake streaming for openai o1 models

Fixes https://github.com/BerriAI/litellm/issues/5694

* docs: deprecated traceloop integration in favor of native otel (#5249)

* fix: fix linting errors

* fix: fix linting errors

* fix(main.py): fix o1 import

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>

* feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730)

* feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view

Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it

* fix(custom_logger.py): reset calltype

* fix: fix linting errors

* fix: fix linting error

* fix: fix import

* test(test_databricks.py): fix databricks tests

---------

Signed-off-by: dbczumar <corey.zumar@databricks.com>
Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com>
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>
This commit is contained in:
Krish Dholakia 2024-09-17 08:05:52 -07:00 committed by GitHub
parent 1e59395280
commit 234185ec13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1387 additions and 502 deletions

View file

@ -6,10 +6,14 @@ import asyncio
import os
# Enter your DATABASE_URL here
os.environ["DATABASE_URL"] = "postgresql://xxxxxxx"
from prisma import Prisma
db = Prisma()
db = Prisma(
http={
"timeout": 60000,
},
)
async def check_view_exists():
@ -47,22 +51,19 @@ async def check_view_exists():
print("LiteLLM_VerificationTokenView Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
print("MonthlyGlobalSpend Exists!") # noqa
except Exception as e:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
CREATE MATERIALIZED VIEW IF NOT EXISTS "MonthlyGlobalSpend" AS
SELECT
DATE("startTime") AS date,
DATE_TRUNC('day', "startTime") AS date,
SUM("spend") AS spend
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
"startTime" >= CURRENT_DATE - INTERVAL '30 days'
GROUP BY
DATE("startTime");
DATE_TRUNC('day', "startTime");
"""
# Execute the queries
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpend Created!") # noqa

View file

@ -0,0 +1,78 @@
import Image from '@theme/IdealImage';
import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';
# OpenTelemetry - Tracing LLMs with any observability tool
OpenTelemetry is a CNCF standard for observability. It connects to any observability tool, such as Jaeger, Zipkin, Datadog, New Relic, Traceloop and others.
<Image img={require('../../img/traceloop_dash.png')} />
## Getting Started
Install the OpenTelemetry SDK:
```
pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp
```
Set the environment variables (different providers may require different variables):
<Tabs>
<TabItem value="traceloop" label="Log to Traceloop Cloud">
```shell
OTEL_EXPORTER="otlp_http"
OTEL_ENDPOINT="https://api.traceloop.com"
OTEL_HEADERS="Authorization=Bearer%20<your-api-key>"
```
</TabItem>
<TabItem value="otel-col" label="Log to OTEL HTTP Collector">
```shell
OTEL_EXPORTER="otlp_http"
OTEL_ENDPOINT="http:/0.0.0.0:4317"
```
</TabItem>
<TabItem value="otel-col-grpc" label="Log to OTEL GRPC Collector">
```shell
OTEL_EXPORTER="otlp_grpc"
OTEL_ENDPOINT="http:/0.0.0.0:4317"
```
</TabItem>
</Tabs>
Use just 2 lines of code, to instantly log your LLM responses **across all providers** with OpenTelemetry:
```python
litellm.callbacks = ["otel"]
```
## Redacting Messages, Response Content from OpenTelemetry Logging
### Redact Messages and Responses from all OpenTelemetry Logging
Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to OpenTelemetry, but request metadata will still be logged.
### Redact Messages and Responses from specific OpenTelemetry Logging
In the metadata typically passed for text completion or embedding calls you can set specific keys to mask the messages and responses for this call.
Setting `mask_input` to `True` will mask the input from being logged for this call
Setting `mask_output` to `True` will make the output from being logged for this call.
Be aware that if you are continuing an existing trace, and you set `update_trace_keys` to include either `input` or `output` and you set the corresponding `mask_input` or `mask_output`, then that trace will have its existing input and/or output replaced with a redacted message.
## Support
For any question or issue with the integration you can reach out to the OpenLLMetry maintainers on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).

View file

@ -1,36 +0,0 @@
import Image from '@theme/IdealImage';
# Traceloop (OpenLLMetry) - Tracing LLMs with OpenTelemetry
[Traceloop](https://traceloop.com) is a platform for monitoring and debugging the quality of your LLM outputs.
It provides you with a way to track the performance of your LLM application; rollout changes with confidence; and debug issues in production.
It is based on [OpenTelemetry](https://opentelemetry.io), so it can provide full visibility to your LLM requests, as well vector DB usage, and other infra in your stack.
<Image img={require('../../img/traceloop_dash.png')} />
## Getting Started
Install the Traceloop SDK:
```
pip install traceloop-sdk
```
Use just 2 lines of code, to instantly log your LLM responses with OpenTelemetry:
```python
Traceloop.init(app_name=<YOUR APP NAME>, disable_batch=True)
litellm.success_callback = ["traceloop"]
```
Make sure to properly set a destination to your traces. See [OpenLLMetry docs](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for options.
To get better visualizations on how your code behaves, you may want to annotate specific parts of your LLM chain. See [Traceloop docs on decorators](https://traceloop.com/docs/python-sdk/decorators) for more information.
## Exporting traces to other systems (e.g. Datadog, New Relic, and others)
Since OpenLLMetry uses OpenTelemetry to send data, you can easily export your traces to other systems, such as Datadog, New Relic, and others. See [OpenLLMetry docs on exporters](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for more information.
## Support
For any question or issue with integration you can reach out to the Traceloop team on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).

View file

@ -600,6 +600,52 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
</TabItem>
<TabItem value="traceloop" label="Log to Traceloop Cloud">
#### Quick Start - Log to Traceloop
**Step 1:**
Add the following to your env
```shell
OTEL_EXPORTER="otlp_http"
OTEL_ENDPOINT="https://api.traceloop.com"
OTEL_HEADERS="Authorization=Bearer%20<your-api-key>"
```
**Step 2:** Add `otel` as a callbacks
```shell
litellm_settings:
callbacks: ["otel"]
```
**Step 3**: Start the proxy, make a test request
Start proxy
```shell
litellm --config config.yaml --detailed_debug
```
Test Request
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}'
```
</TabItem>
<TabItem value="otel-col" label="Log to OTEL HTTP Collector">
#### Quick Start - Log to OTEL Collector
@ -694,52 +740,6 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
</TabItem>
<TabItem value="traceloop" label="Log to Traceloop Cloud">
#### Quick Start - Log to Traceloop
**Step 1:** Install the `traceloop-sdk` SDK
```shell
pip install traceloop-sdk==0.21.2
```
**Step 2:** Add `traceloop` as a success_callback
```shell
litellm_settings:
success_callback: ["traceloop"]
environment_variables:
TRACELOOP_API_KEY: "XXXXX"
```
**Step 3**: Start the proxy, make a test request
Start proxy
```shell
litellm --config config.yaml --detailed_debug
```
Test Request
```shell
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--data ' {
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}'
```
</TabItem>
</Tabs>
** 🎉 Expect to see this trace logged in your OTEL collector**

View file

@ -255,6 +255,7 @@ const sidebars = {
type: "category",
label: "Logging & Observability",
items: [
"observability/opentelemetry_integration",
"observability/langfuse_integration",
"observability/logfire_integration",
"observability/gcs_bucket_integration",
@ -271,7 +272,7 @@ const sidebars = {
"observability/openmeter",
"observability/promptlayer_integration",
"observability/wandb_integration",
"observability/traceloop_integration",
"observability/slack_integration",
"observability/athina_integration",
"observability/lunary_integration",
"observability/greenscale_integration",

View file

@ -948,10 +948,10 @@ from .llms.OpenAI.openai import (
AzureAIStudioConfig,
)
from .llms.mistral.mistral_chat_transformation import MistralConfig
from .llms.OpenAI.o1_transformation import (
from .llms.OpenAI.chat.o1_transformation import (
OpenAIO1Config,
)
from .llms.OpenAI.gpt_transformation import (
from .llms.OpenAI.chat.gpt_transformation import (
OpenAIGPTConfig,
)
from .llms.nvidia_nim import NvidiaNimConfig

View file

@ -1616,6 +1616,9 @@ Model Info:
:param time_range: A string specifying the time range, e.g., "1d", "7d", "30d"
"""
if self.alerting is None or "spend_reports" not in self.alert_types:
return
try:
from litellm.proxy.spend_tracking.spend_management_endpoints import (
_get_spend_report_for_time_range,

View file

@ -11,7 +11,7 @@ class CustomGuardrail(CustomLogger):
self,
guardrail_name: Optional[str] = None,
event_hook: Optional[GuardrailEventHooks] = None,
**kwargs
**kwargs,
):
self.guardrail_name = guardrail_name
self.event_hook: Optional[GuardrailEventHooks] = event_hook
@ -28,10 +28,13 @@ class CustomGuardrail(CustomLogger):
requested_guardrails,
)
if self.guardrail_name not in requested_guardrails:
if (
self.guardrail_name not in requested_guardrails
and event_type.value != "logging_only"
):
return False
if self.event_hook != event_type:
if self.event_hook != event_type.value:
return False
return True

View file

@ -4,6 +4,11 @@ import litellm
class TraceloopLogger:
"""
WARNING: DEPRECATED
Use the OpenTelemetry standard integration instead
"""
def __init__(self):
try:
from traceloop.sdk.tracing.tracing import TracerWrapper

View file

@ -90,6 +90,13 @@ from ..integrations.supabase import Supabase
from ..integrations.traceloop import TraceloopLogger
from ..integrations.weights_biases import WeightsBiasesLogger
try:
from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import (
GenericAPILogger,
)
except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
_in_memory_loggers: List[Any] = []
### GLOBAL VARIABLES ###
@ -145,7 +152,41 @@ class ServiceTraceIDCache:
return None
import hashlib
class DynamicLoggingCache:
"""
Prevent memory leaks caused by initializing new logging clients on each request.
Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
"""
def __init__(self) -> None:
self.cache = InMemoryCache()
def get_cache_key(self, args: dict) -> str:
args_str = json.dumps(args, sort_keys=True)
cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
return cache_key
def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
key_name = self.get_cache_key(
args={**credentials, "service_name": service_name}
)
response = self.cache.get_cache(key=key_name)
return response
def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
key_name = self.get_cache_key(
args={**credentials, "service_name": service_name}
)
self.cache.set_cache(key=key_name, value=logging_obj)
return None
in_memory_trace_id_cache = ServiceTraceIDCache()
in_memory_dynamic_logger_cache = DynamicLoggingCache()
class Logging:
@ -324,10 +365,10 @@ class Logging:
print_verbose(f"\033[92m{curl_command}\033[0m\n", log_level="DEBUG")
# log raw request to provider (like LangFuse) -- if opted in.
if log_raw_request_response is True:
try:
# [Non-blocking Extra Debug Information in metadata]
_litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {}
try:
# [Non-blocking Extra Debug Information in metadata]
if (
turn_off_message_logging is not None
and turn_off_message_logging is True
@ -362,7 +403,7 @@ class Logging:
callbacks = litellm.input_callback + self.dynamic_input_callbacks
for callback in callbacks:
try:
if callback == "supabase":
if callback == "supabase" and supabaseClient is not None:
verbose_logger.debug("reaches supabase for logging!")
model = self.model_call_details["model"]
messages = self.model_call_details["input"]
@ -396,7 +437,9 @@ class Logging:
messages=self.messages,
kwargs=self.model_call_details,
)
elif callable(callback): # custom logger functions
elif (
callable(callback) and customLogger is not None
): # custom logger functions
customLogger.log_input_event(
model=self.model,
messages=self.messages,
@ -615,7 +658,7 @@ class Logging:
self.model_call_details["litellm_params"]["metadata"][
"hidden_params"
] = result._hidden_params
] = getattr(result, "_hidden_params", {})
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
@ -645,6 +688,7 @@ class Logging:
litellm.max_budget
and self.stream is False
and result is not None
and isinstance(result, dict)
and "content" in result
):
time_diff = (end_time - start_time).total_seconds()
@ -652,7 +696,7 @@ class Logging:
litellm._current_cost += litellm.completion_cost(
model=self.model,
prompt="",
completion=result["content"],
completion=getattr(result, "content", ""),
total_time=float_diff,
)
@ -758,7 +802,7 @@ class Logging:
):
print_verbose("no-log request, skipping logging")
continue
if callback == "lite_debugger":
if callback == "lite_debugger" and liteDebuggerClient is not None:
print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
print_verbose(
@ -774,7 +818,7 @@ class Logging:
call_type=self.call_type,
stream=self.stream,
)
if callback == "promptlayer":
if callback == "promptlayer" and promptLayerLogger is not None:
print_verbose("reaches promptlayer for logging!")
promptLayerLogger.log_event(
kwargs=self.model_call_details,
@ -783,7 +827,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
if callback == "supabase":
if callback == "supabase" and supabaseClient is not None:
print_verbose("reaches supabase for logging!")
kwargs = self.model_call_details
@ -811,7 +855,7 @@ class Logging:
),
print_verbose=print_verbose,
)
if callback == "wandb":
if callback == "wandb" and weightsBiasesLogger is not None:
print_verbose("reaches wandb for logging!")
weightsBiasesLogger.log_event(
kwargs=self.model_call_details,
@ -820,8 +864,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
if callback == "logfire":
global logfireLogger
if callback == "logfire" and logfireLogger is not None:
verbose_logger.debug("reaches logfire for success logging!")
kwargs = {}
for k, v in self.model_call_details.items():
@ -844,10 +887,10 @@ class Logging:
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
level=LogfireLevel.INFO.value,
level=LogfireLevel.INFO.value, # type: ignore
)
if callback == "lunary":
if callback == "lunary" and lunaryLogger is not None:
print_verbose("reaches lunary for logging!")
model = self.model
kwargs = self.model_call_details
@ -882,7 +925,7 @@ class Logging:
run_id=self.litellm_call_id,
print_verbose=print_verbose,
)
if callback == "helicone":
if callback == "helicone" and heliconeLogger is not None:
print_verbose("reaches helicone for logging!")
model = self.model
messages = self.model_call_details["input"]
@ -924,6 +967,7 @@ class Logging:
else:
print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"]
temp_langfuse_logger = langFuseLogger
if langFuseLogger is None or (
(
@ -941,11 +985,29 @@ class Logging:
and self.langfuse_host != langFuseLogger.langfuse_host
)
):
credentials = {
"langfuse_public_key": self.langfuse_public_key,
"langfuse_secret": self.langfuse_secret,
"langfuse_host": self.langfuse_host,
}
temp_langfuse_logger = (
in_memory_dynamic_logger_cache.get_cache(
credentials=credentials, service_name="langfuse"
)
)
if temp_langfuse_logger is None:
temp_langfuse_logger = LangFuseLogger(
langfuse_public_key=self.langfuse_public_key,
langfuse_secret=self.langfuse_secret,
langfuse_host=self.langfuse_host,
)
in_memory_dynamic_logger_cache.set_cache(
credentials=credentials,
service_name="langfuse",
logging_obj=temp_langfuse_logger,
)
if temp_langfuse_logger is not None:
_response = temp_langfuse_logger.log_event(
kwargs=kwargs,
response_obj=result,
@ -982,7 +1044,7 @@ class Logging:
print_verbose("reaches langfuse for streaming logging!")
result = kwargs["complete_streaming_response"]
if genericAPILogger is None:
genericAPILogger = GenericAPILogger()
genericAPILogger = GenericAPILogger() # type: ignore
genericAPILogger.log_event(
kwargs=kwargs,
response_obj=result,
@ -1022,7 +1084,7 @@ class Logging:
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if callback == "greenscale":
if callback == "greenscale" and greenscaleLogger is not None:
kwargs = {}
for k, v in self.model_call_details.items():
if (
@ -1066,7 +1128,7 @@ class Logging:
result = kwargs["complete_streaming_response"]
# only add to cache once we have a complete streaming response
litellm.cache.add_cache(result, **kwargs)
if callback == "athina":
if callback == "athina" and athinaLogger is not None:
deep_copy = {}
for k, v in self.model_call_details.items():
deep_copy[k] = v
@ -1224,6 +1286,7 @@ class Logging:
"atranscription", False
)
is not True
and customLogger is not None
): # custom logger functions
print_verbose(
f"success callbacks: Running Custom Callback Function"
@ -1423,9 +1486,8 @@ class Logging:
await litellm.cache.async_add_cache(result, **kwargs)
else:
litellm.cache.add_cache(result, **kwargs)
if callback == "openmeter":
global openMeterLogger
if self.stream == True:
if callback == "openmeter" and openMeterLogger is not None:
if self.stream is True:
if (
"async_complete_streaming_response"
in self.model_call_details
@ -1645,33 +1707,9 @@ class Logging:
)
for callback in callbacks:
try:
if callback == "lite_debugger":
print_verbose("reaches lite_debugger for logging!")
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
result = {
"model": self.model,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens": prompt_token_calculator(
self.model, messages=self.messages
),
"completion_tokens": 0,
},
}
liteDebuggerClient.log_event(
model=self.model,
messages=self.messages,
end_user=self.model_call_details.get("user", "default"),
response_obj=result,
start_time=start_time,
end_time=end_time,
litellm_call_id=self.litellm_call_id,
print_verbose=print_verbose,
call_type=self.call_type,
stream=self.stream,
)
if callback == "lunary":
if callback == "lite_debugger" and liteDebuggerClient is not None:
pass
elif callback == "lunary" and lunaryLogger is not None:
print_verbose("reaches lunary for logging error!")
model = self.model
@ -1685,6 +1723,7 @@ class Logging:
)
lunaryLogger.log_event(
kwargs=self.model_call_details,
type=_type,
event="error",
user_id=self.model_call_details.get("user", "default"),
@ -1704,22 +1743,11 @@ class Logging:
print_verbose(
f"capture exception not initialized: {capture_exception}"
)
elif callback == "supabase":
elif callback == "supabase" and supabaseClient is not None:
print_verbose("reaches supabase for logging!")
print_verbose(f"supabaseClient: {supabaseClient}")
result = {
"model": model,
"created": time.time(),
"error": traceback_exception,
"usage": {
"prompt_tokens": prompt_token_calculator(
model, messages=self.messages
),
"completion_tokens": 0,
},
}
supabaseClient.log_event(
model=self.model,
model=self.model if hasattr(self, "model") else "",
messages=self.messages,
end_user=self.model_call_details.get("user", "default"),
response_obj=result,
@ -1728,7 +1756,9 @@ class Logging:
litellm_call_id=self.model_call_details["litellm_call_id"],
print_verbose=print_verbose,
)
if callable(callback): # custom logger functions
if (
callable(callback) and customLogger is not None
): # custom logger functions
customLogger.log_event(
kwargs=self.model_call_details,
response_obj=result,
@ -1809,13 +1839,13 @@ class Logging:
start_time=start_time,
end_time=end_time,
response_obj=None,
user_id=kwargs.get("user", None),
user_id=self.model_call_details.get("user", None),
print_verbose=print_verbose,
status_message=str(exception),
level="ERROR",
kwargs=self.model_call_details,
)
if callback == "logfire":
if callback == "logfire" and logfireLogger is not None:
verbose_logger.debug("reaches logfire for failure logging!")
kwargs = {}
for k, v in self.model_call_details.items():
@ -1830,7 +1860,7 @@ class Logging:
response_obj=result,
start_time=start_time,
end_time=end_time,
level=LogfireLevel.ERROR.value,
level=LogfireLevel.ERROR.value, # type: ignore
print_verbose=print_verbose,
)
@ -1873,7 +1903,9 @@ class Logging:
start_time=start_time,
end_time=end_time,
) # type: ignore
if callable(callback): # custom logger functions
if (
callable(callback) and customLogger is not None
): # custom logger functions
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=result,
@ -1966,7 +1998,7 @@ def set_callbacks(callback_list, function_id=None):
)
sentry_sdk_instance.init(
dsn=os.environ.get("SENTRY_DSN"),
traces_sample_rate=float(sentry_trace_rate),
traces_sample_rate=float(sentry_trace_rate), # type: ignore
)
capture_exception = sentry_sdk_instance.capture_exception
add_breadcrumb = sentry_sdk_instance.add_breadcrumb
@ -2411,12 +2443,11 @@ def get_standard_logging_object_payload(
saved_cache_cost: Optional[float] = None
if cache_hit is True:
import time
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
saved_cache_cost = logging_obj._response_cost_calculator(
result=init_response_obj, cache_hit=False
result=init_response_obj, cache_hit=False # type: ignore
)
## Get model cost information ##
@ -2473,7 +2504,7 @@ def get_standard_logging_object_payload(
model_id=_model_id,
requester_ip_address=clean_metadata.get("requester_ip_address", None),
messages=kwargs.get("messages"),
response=(
response=( # type: ignore
response_obj if len(response_obj.keys()) > 0 else init_response_obj
),
model_parameters=kwargs.get("optional_params", None),

View file

@ -0,0 +1,95 @@
"""
Handler file for calls to OpenAI's o1 family of models
Written separately to handle faking streaming for o1 models.
"""
import asyncio
from typing import Any, Callable, List, Optional, Union
from httpx._config import Timeout
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
class OpenAIO1ChatCompletion(OpenAIChatCompletion):
async def mock_async_streaming(
self,
response: Any,
model: Optional[str],
logging_obj: Any,
):
model_response = await response
completion_stream = MockResponseIterator(model_response=model_response)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
)
return streaming_response
def completion(
self,
model_response: ModelResponse,
timeout: Union[float, Timeout],
optional_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable[..., Any]] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
acompletion: bool = False,
litellm_params=None,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
):
stream: Optional[bool] = optional_params.pop("stream", False)
response = super().completion(
model_response,
timeout,
optional_params,
logging_obj,
model,
messages,
print_verbose,
api_key,
api_base,
acompletion,
litellm_params,
logger_fn,
headers,
custom_prompt_dict,
client,
organization,
custom_llm_provider,
drop_params,
)
if stream is True:
if asyncio.iscoroutine(response):
return self.mock_async_streaming(
response=response, model=model, logging_obj=logging_obj # type: ignore
)
completion_stream = MockResponseIterator(model_response=response)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="openai",
logging_obj=logging_obj,
)
return streaming_response
else:
return response

View file

@ -15,6 +15,8 @@ import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.databricks.exceptions import DatabricksError
from litellm.llms.databricks.streaming_utils import ModelResponseIterator
from litellm.types.llms.openai import (
ChatCompletionDeltaChunk,
ChatCompletionResponseMessage,
@ -33,17 +35,6 @@ from ..base import BaseLLM
from ..prompt_templates.factory import custom_prompt, prompt_factory
class DatabricksError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class DatabricksConfig:
"""
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
@ -367,7 +358,7 @@ class DatabricksChatCompletion(BaseLLM):
status_code=e.response.status_code,
message=e.response.text,
)
except httpx.TimeoutException as e:
except httpx.TimeoutException:
raise DatabricksError(status_code=408, message="Timeout error occurred.")
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))
@ -380,7 +371,7 @@ class DatabricksChatCompletion(BaseLLM):
)
response = ModelResponse(**response_json)
response.model = custom_llm_provider + "/" + response.model
response.model = custom_llm_provider + "/" + (response.model or "")
if base_model is not None:
response._hidden_params["model"] = base_model
@ -529,7 +520,7 @@ class DatabricksChatCompletion(BaseLLM):
response_json = response.json()
except httpx.HTTPStatusError as e:
raise DatabricksError(
status_code=e.response.status_code, message=response.text
status_code=e.response.status_code, message=e.response.text
)
except httpx.TimeoutException as e:
raise DatabricksError(
@ -540,7 +531,7 @@ class DatabricksChatCompletion(BaseLLM):
response = ModelResponse(**response_json)
response.model = custom_llm_provider + "/" + response.model
response.model = custom_llm_provider + "/" + (response.model or "")
if base_model is not None:
response._hidden_params["model"] = base_model
@ -657,7 +648,7 @@ class DatabricksChatCompletion(BaseLLM):
except httpx.HTTPStatusError as e:
raise DatabricksError(
status_code=e.response.status_code,
message=response.text if response else str(e),
message=e.response.text,
)
except httpx.TimeoutException as e:
raise DatabricksError(status_code=408, message="Timeout error occurred.")
@ -673,136 +664,3 @@ class DatabricksChatCompletion(BaseLLM):
)
return litellm.EmbeddingResponse(**response_json)
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
if processed_chunk.choices[0].delta.content is not None: # type: ignore
text = processed_chunk.choices[0].delta.content # type: ignore
if (
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
is not None
):
tool_use = ChatCompletionToolCallChunk(
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.name,
arguments=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.arguments,
),
index=processed_chunk.choices[0].index,
)
if processed_chunk.choices[0].finish_reason is not None:
is_finished = True
finish_reason = processed_chunk.choices[0].finish_reason
if hasattr(processed_chunk, "usage") and isinstance(
processed_chunk.usage, litellm.Usage
):
usage_chunk: litellm.Usage = processed_chunk.usage
usage = ChatCompletionUsageBlock(
prompt_tokens=usage_chunk.prompt_tokens,
completion_tokens=usage_chunk.completion_tokens,
total_tokens=usage_chunk.total_tokens,
)
return GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=0,
)
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
self.response_iterator = self.streaming_response
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if chunk == "[DONE]":
raise StopAsyncIteration
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -0,0 +1,12 @@
import httpx
class DatabricksError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs

View file

@ -0,0 +1,145 @@
import json
from typing import Optional
import litellm
from litellm.types.llms.openai import (
ChatCompletionDeltaChunk,
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import GenericStreamingChunk
class ModelResponseIterator:
def __init__(self, streaming_response, sync_stream: bool):
self.streaming_response = streaming_response
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
if processed_chunk.choices[0].delta.content is not None: # type: ignore
text = processed_chunk.choices[0].delta.content # type: ignore
if (
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
is not None
):
tool_use = ChatCompletionToolCallChunk(
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.name,
arguments=processed_chunk.choices[0]
.delta.tool_calls[0] # type: ignore
.function.arguments,
),
index=processed_chunk.choices[0].index,
)
if processed_chunk.choices[0].finish_reason is not None:
is_finished = True
finish_reason = processed_chunk.choices[0].finish_reason
if hasattr(processed_chunk, "usage") and isinstance(
processed_chunk.usage, litellm.Usage
):
usage_chunk: litellm.Usage = processed_chunk.usage
usage = ChatCompletionUsageBlock(
prompt_tokens=usage_chunk.prompt_tokens,
completion_tokens=usage_chunk.completion_tokens,
total_tokens=usage_chunk.total_tokens,
)
return GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=0,
)
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
# Sync iterator
def __iter__(self):
self.response_iterator = self.streaming_response
return self
def __next__(self):
try:
chunk = self.response_iterator.__next__()
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopIteration:
raise StopIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
# Async iterator
def __aiter__(self):
self.async_response_iterator = self.streaming_response.__aiter__()
return self
async def __anext__(self):
try:
chunk = await self.async_response_iterator.__anext__()
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error receiving chunk from stream: {e}")
try:
chunk = chunk.replace("data:", "")
chunk = chunk.strip()
if chunk == "[DONE]":
raise StopAsyncIteration
if len(chunk) > 0:
json_chunk = json.loads(chunk)
return self.chunk_parser(chunk=json_chunk)
else:
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
except StopAsyncIteration:
raise StopAsyncIteration
except ValueError as e:
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")

View file

@ -95,6 +95,7 @@ from .llms.custom_llm import CustomLLM, custom_chat_llm_router
from .llms.databricks.chat import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion
from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion
from .llms.prompt_templates.factory import (
@ -161,6 +162,7 @@ from litellm.utils import (
####### ENVIRONMENT VARIABLES ###################
openai_chat_completions = OpenAIChatCompletion()
openai_text_completions = OpenAITextCompletion()
openai_o1_chat_completions = OpenAIO1ChatCompletion()
openai_audio_transcriptions = OpenAIAudioTranscription()
databricks_chat_completions = DatabricksChatCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
@ -1366,6 +1368,27 @@ def completion(
## COMPLETION CALL
try:
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
response = openai_o1_chat_completions.completion(
model=model,
messages=messages,
headers=headers,
model_response=model_response,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
acompletion=acompletion,
logging_obj=logging,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
timeout=timeout, # type: ignore
custom_prompt_dict=custom_prompt_dict,
client=client, # pass AsyncOpenAI, OpenAI client
organization=organization,
custom_llm_provider=custom_llm_provider,
)
else:
response = openai_chat_completions.completion(
model=model,
messages=messages,

View file

@ -3,7 +3,6 @@ model_list:
litellm_params:
model: anthropic.claude-3-sonnet-20240229-v1:0
api_base: https://exampleopenaiendpoint-production.up.railway.app
# aws_session_token: "IQoJb3JpZ2luX2VjELj//////////wEaCXVzLXdlc3QtMiJHMEUCIQDatCRVkIZERLcrR6P7Qd1vNfZ8r8xB/LUeaVaTW/lBTwIgAgmHSBe41d65GVRKSkpgVonjsCmOmAS7s/yklM9NsZcq3AEI4P//////////ARABGgw4ODg2MDIyMjM0MjgiDJrio0/CHYEfyt5EqyqwAfyWO4t3bFVWAOIwTyZ1N6lszeJKfMNus2hzVc+r73hia2Anv88uwPxNg2uqnXQNJumEo0DcBt30ZwOw03Isboy0d5l05h8gjb4nl9feyeKmKAnRdcqElrEWtCC1Qcefv78jQv53AbUipH1ssa5NPvptqZZpZYDPMlBEnV3YdvJJiuE23u2yOkCt+EoUJLaOYjOryoRyrSfbWB+JaUsB68R3rNTHzReeN3Nob/9Ic4HrMMmzmLcGOpgBZxclO4w8Z7i6TcVqbCwDOskxuR6bZaiFxKFG+9tDrWS7jaQKpq/YP9HUT0YwYpZplaBEEZR5sbIndg5yb4dRZrSHplblqKz8XLaUf5tuuyRJmwr96PTpw/dyEVk9gicFX6JfLBEv0v5rN2Z0JMFLdfIP4kC1U2PjcPOWoglWO3fLmJ4Lol2a3c5XDSMwMxjcJXq+c8Ue1v0="
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
- model_name: gemini-vision
@ -17,4 +16,34 @@ model_list:
litellm_params:
model: gpt-3.5-turbo
api_base: https://exampleopenaiendpoint-production.up.railway.app
- model_name: o1-preview
litellm_params:
model: o1-preview
litellm_settings:
drop_params: True
json_logs: True
store_audit_logs: True
log_raw_request_response: True
return_response_headers: True
num_retries: 5
request_timeout: 200
callbacks: ["custom_callbacks.proxy_handler_instance"]
guardrails:
- guardrail_name: "presidio-pre-guard"
litellm_params:
guardrail: presidio # supported values: "aporia", "bedrock", "lakera", "presidio"
mode: "logging_only"
mock_redacted_text: {
"text": "My name is <PERSON>, who are you? Say my name in your response",
"items": [
{
"start": 11,
"end": 19,
"entity_type": "PERSON",
"text": "<PERSON>",
"operator": "replace",
}
],
}

View file

@ -446,7 +446,6 @@ async def user_api_key_auth(
and request.headers.get(key=header_key) is not None # type: ignore
):
api_key = request.headers.get(key=header_key) # type: ignore
if master_key is None:
if isinstance(api_key, str):
return UserAPIKeyAuth(

View file

@ -2,9 +2,19 @@ from litellm.integrations.custom_logger import CustomLogger
class MyCustomHandler(CustomLogger):
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
# print("Call failed")
pass
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
# input_tokens = response_obj.get("usage", {}).get("prompt_tokens", 0)
# output_tokens = response_obj.get("usage", {}).get("completion_tokens", 0)
input_tokens = (
response_obj.usage.prompt_tokens
if hasattr(response_obj.usage, "prompt_tokens")
else 0
)
output_tokens = (
response_obj.usage.completion_tokens
if hasattr(response_obj.usage, "completion_tokens")
else 0
)
proxy_handler_instance = MyCustomHandler()

View file

@ -218,7 +218,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
response = await self.async_handler.post(
url=prepared_request.url,
json=request_data, # type: ignore
headers=dict(prepared_request.headers),
headers=prepared_request.headers, # type: ignore
)
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
if response.status_code == 200:
@ -254,7 +254,6 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_to_applied_guardrails_header,
)
from litellm.types.guardrails import GuardrailEventHooks
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
if self.should_run_guardrail(data=data, event_type=event_type) is not True:

View file

@ -189,6 +189,7 @@ class lakeraAI_Moderation(CustomGuardrail):
# Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already)
# Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked.
# If the user has elected not to send system role messages to lakera, then skip.
if system_message is not None:
if not litellm.add_function_to_prompt:
content = system_message.get("content")

View file

@ -19,6 +19,7 @@ from fastapi import HTTPException
from pydantic import BaseModel
import litellm # noqa: E401
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
@ -58,7 +59,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
self.pii_tokens: dict = (
{}
) # mapping of PII token to original text - only used with Presidio `replace` operation
self.mock_redacted_text = mock_redacted_text
self.output_parse_pii = output_parse_pii or False
if mock_testing is True: # for testing purposes only
@ -92,8 +92,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
presidio_anonymizer_api_base: Optional[str] = None,
):
self.presidio_analyzer_api_base: Optional[str] = (
presidio_analyzer_api_base
or litellm.get_secret("PRESIDIO_ANALYZER_API_BASE", None)
presidio_analyzer_api_base or get_secret("PRESIDIO_ANALYZER_API_BASE", None) # type: ignore
)
self.presidio_anonymizer_api_base: Optional[
str
@ -198,12 +197,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
else:
raise Exception(f"Invalid anonymizer response: {redacted_text}")
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
raise e
async def async_pre_call_hook(
@ -254,9 +247,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
)
return data
except Exception as e:
verbose_proxy_logger.info(
f"An error occurred -",
)
raise e
async def async_logging_hook(
@ -300,9 +290,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
)
kwargs["messages"] = messages
return kwargs, responses
return kwargs, result
async def async_post_call_success_hook(
async def async_post_call_success_hook( # type: ignore
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
@ -314,7 +304,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
verbose_proxy_logger.debug(
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
)
if self.output_parse_pii == False:
if self.output_parse_pii is False:
return response
if isinstance(response, ModelResponse) and not isinstance(

View file

@ -1,10 +1,11 @@
import importlib
import traceback
from typing import Dict, List, Literal
from typing import Dict, List, Literal, Optional
from pydantic import BaseModel, RootModel
import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
@ -16,7 +17,6 @@ from litellm.types.guardrails import (
GuardrailItemSpec,
LakeraCategoryThresholds,
LitellmParams,
guardrailConfig,
)
all_guardrails: List[GuardrailItem] = []
@ -98,18 +98,13 @@ def init_guardrails_v2(
# Init litellm params for guardrail
litellm_params_data = guardrail["litellm_params"]
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
litellm_params = LitellmParams(
guardrail=litellm_params_data["guardrail"],
mode=litellm_params_data["mode"],
api_key=litellm_params_data.get("api_key"),
api_base=litellm_params_data.get("api_base"),
guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"),
guardrailVersion=litellm_params_data.get("guardrailVersion"),
output_parse_pii=litellm_params_data.get("output_parse_pii"),
presidio_ad_hoc_recognizers=litellm_params_data.get(
"presidio_ad_hoc_recognizers"
),
)
_litellm_params_kwargs = {
k: litellm_params_data[k] if k in litellm_params_data else None
for k in LitellmParams.__annotations__.keys()
}
litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore
if (
"category_thresholds" in litellm_params_data
@ -122,15 +117,11 @@ def init_guardrails_v2(
if litellm_params["api_key"]:
if litellm_params["api_key"].startswith("os.environ/"):
litellm_params["api_key"] = litellm.get_secret(
litellm_params["api_key"]
)
litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore
if litellm_params["api_base"]:
if litellm_params["api_base"].startswith("os.environ/"):
litellm_params["api_base"] = litellm.get_secret(
litellm_params["api_base"]
)
litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
# Init guardrail CustomLoggerClass
if litellm_params["guardrail"] == "aporia":
@ -182,6 +173,7 @@ def init_guardrails_v2(
presidio_ad_hoc_recognizers=litellm_params[
"presidio_ad_hoc_recognizers"
],
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
)
if litellm_params["output_parse_pii"] is True:

View file

@ -167,11 +167,11 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
if self.prompt_injection_params is not None:
# 1. check if heuristics check turned on
if self.prompt_injection_params.heuristics_check == True:
if self.prompt_injection_params.heuristics_check is True:
is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt
)
if is_prompt_attack == True:
if is_prompt_attack is True:
raise HTTPException(
status_code=400,
detail={
@ -179,14 +179,14 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
},
)
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet
if self.prompt_injection_params.vector_db_check == True:
if self.prompt_injection_params.vector_db_check is True:
pass
else:
is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt
)
if is_prompt_attack == True:
if is_prompt_attack is True:
raise HTTPException(
status_code=400,
detail={
@ -201,19 +201,18 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
if (
e.status_code == 400
and isinstance(e.detail, dict)
and "error" in e.detail
and "error" in e.detail # type: ignore
and self.prompt_injection_params is not None
and self.prompt_injection_params.reject_as_response
):
return e.detail.get("error")
raise e
except Exception as e:
verbose_proxy_logger.error(
verbose_proxy_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_moderation_hook( # type: ignore
self,

View file

@ -195,7 +195,8 @@ async def user_auth(request: Request):
- os.environ["SMTP_PASSWORD"]
- os.environ["SMTP_SENDER_EMAIL"]
"""
from litellm.proxy.proxy_server import prisma_client, send_email
from litellm.proxy.proxy_server import prisma_client
from litellm.proxy.utils import send_email
data = await request.json() # type: ignore
user_email = data["user_email"]
@ -212,7 +213,7 @@ async def user_auth(request: Request):
)
### if so - generate a 24 hr key with that user id
if response is not None:
user_id = response.user_id
user_id = response.user_id # type: ignore
response = await generate_key_helper_fn(
request_type="key",
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id}, # type: ignore
@ -345,6 +346,7 @@ async def user_info(
for team in teams_1:
team_id_list.append(team.team_id)
teams_2: Optional[Any] = None
if user_info is not None:
# *NEW* get all teams in user 'teams' field
teams_2 = await prisma_client.get_data(
@ -375,7 +377,7 @@ async def user_info(
),
user_api_key_dict=user_api_key_dict,
)
else:
elif caller_user_info is not None:
teams_2 = await prisma_client.get_data(
team_id_list=caller_user_info.teams,
table_name="team",
@ -395,7 +397,7 @@ async def user_info(
query_type="find_all",
)
if user_info is None:
if user_info is None and keys is not None:
## make sure we still return a total spend ##
spend = 0
for k in keys:
@ -404,6 +406,9 @@ async def user_info(
## REMOVE HASHED TOKEN INFO before returning ##
returned_keys = []
if keys is None:
pass
else:
for key in keys:
if (
key.token == litellm_master_key_hash
@ -539,6 +544,7 @@ async def user_update(
## ADD USER, IF NEW ##
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
response: Optional[Any] = None
if data.user_id is not None and len(data.user_id) > 0:
non_default_values["user_id"] = data.user_id # type: ignore
verbose_proxy_logger.debug("In update user, user_id condition block.")
@ -573,7 +579,7 @@ async def user_update(
data=non_default_values,
table_name="user",
)
return response
return response # type: ignore
# update based on remaining passed in values
except Exception as e:
verbose_proxy_logger.error(

View file

@ -226,7 +226,6 @@ from litellm.proxy.utils import (
hash_token,
log_to_opentelemetry,
reset_budget,
send_email,
update_spend,
)
from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import (

View file

@ -1434,7 +1434,7 @@ async def _get_spend_report_for_time_range(
except Exception as e:
verbose_proxy_logger.error(
"Exception in _get_daily_spend_reports {}".format(str(e))
) # noqa
)
@router.post(
@ -1703,26 +1703,26 @@ async def view_spend_logs(
result: dict = {}
for record in response:
dt_object = datetime.strptime(
str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ"
str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" # type: ignore
) # type: ignore
date = dt_object.date()
if date not in result:
result[date] = {"users": {}, "models": {}}
api_key = record["api_key"]
user_id = record["user"]
model = record["model"]
result[date]["spend"] = (
result[date].get("spend", 0) + record["_sum"]["spend"]
)
result[date][api_key] = (
result[date].get(api_key, 0) + record["_sum"]["spend"]
)
result[date]["users"][user_id] = (
result[date]["users"].get(user_id, 0) + record["_sum"]["spend"]
)
result[date]["models"][model] = (
result[date]["models"].get(model, 0) + record["_sum"]["spend"]
)
api_key = record["api_key"] # type: ignore
user_id = record["user"] # type: ignore
model = record["model"] # type: ignore
result[date]["spend"] = result[date].get("spend", 0) + record.get(
"_sum", {}
).get("spend", 0)
result[date][api_key] = result[date].get(api_key, 0) + record.get(
"_sum", {}
).get("spend", 0)
result[date]["users"][user_id] = result[date]["users"].get(
user_id, 0
) + record.get("_sum", {}).get("spend", 0)
result[date]["models"][model] = result[date]["models"].get(
model, 0
) + record.get("_sum", {}).get("spend", 0)
return_list = []
final_date = None
for k, v in sorted(result.items()):
@ -1784,7 +1784,7 @@ async def view_spend_logs(
table_name="spend", query_type="find_all"
)
return spend_log
return spend_logs
return None
@ -1843,6 +1843,88 @@ async def global_spend_reset():
}
@router.post(
"/global/spend/refresh",
tags=["Budget & Spend Tracking"],
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
async def global_spend_refresh():
"""
ADMIN ONLY / MASTER KEY Only Endpoint
Globally refresh spend MonthlyGlobalSpend view
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise ProxyException(
message="Prisma Client is not initialized",
type="internal_error",
param="None",
code=status.HTTP_401_UNAUTHORIZED,
)
## RESET GLOBAL SPEND VIEW ###
async def is_materialized_global_spend_view() -> bool:
"""
Return True if materialized view exists
Else False
"""
sql_query = """
SELECT relname, relkind
FROM pg_class
WHERE relname = 'MonthlyGlobalSpend';
"""
try:
resp = await prisma_client.db.query_raw(sql_query)
assert resp[0]["relkind"] == "m"
return True
except Exception:
return False
view_exists = await is_materialized_global_spend_view()
if view_exists:
# refresh materialized view
sql_query = """
REFRESH MATERIALIZED VIEW "MonthlyGlobalSpend";
"""
try:
from litellm.proxy._types import CommonProxyErrors
from litellm.proxy.proxy_server import proxy_logging_obj
from litellm.proxy.utils import PrismaClient
db_url = os.getenv("DATABASE_URL")
if db_url is None:
raise Exception(CommonProxyErrors.db_not_connected_error.value)
new_client = PrismaClient(
database_url=db_url,
proxy_logging_obj=proxy_logging_obj,
http_client={
"timeout": 6000,
},
)
await new_client.db.connect()
await new_client.db.query_raw(sql_query)
verbose_proxy_logger.info("MonthlyGlobalSpend view refreshed")
return {
"message": "MonthlyGlobalSpend view refreshed",
"status": "success",
}
except Exception as e:
verbose_proxy_logger.exception(
"Failed to refresh materialized view - {}".format(str(e))
)
return {
"message": "Failed to refresh materialized view",
"status": "failure",
}
async def global_spend_for_internal_user(
api_key: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),

View file

@ -92,6 +92,7 @@ def safe_deep_copy(data):
if litellm.safe_memory_mode is True:
return data
litellm_parent_otel_span: Optional[Any] = None
# Step 1: Remove the litellm_parent_otel_span
litellm_parent_otel_span = None
if isinstance(data, dict):
@ -101,7 +102,7 @@ def safe_deep_copy(data):
new_data = copy.deepcopy(data)
# Step 2: re-add the litellm_parent_otel_span after doing a deep copy
if isinstance(data, dict):
if isinstance(data, dict) and litellm_parent_otel_span is not None:
if "metadata" in data:
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
return new_data
@ -468,7 +469,7 @@ class ProxyLogging:
# V1 implementation - backwards compatibility
if callback.event_hook is None:
if callback.moderation_check == "pre_call":
if callback.moderation_check == "pre_call": # type: ignore
return
else:
# Main - V2 Guardrails implementation
@ -881,7 +882,12 @@ class PrismaClient:
org_list_transactons: dict = {}
spend_log_transactions: List = []
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
def __init__(
self,
database_url: str,
proxy_logging_obj: ProxyLogging,
http_client: Optional[Any] = None,
):
verbose_proxy_logger.debug(
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
)
@ -912,6 +918,9 @@ class PrismaClient:
# Now you can import the Prisma Client
from prisma import Prisma # type: ignore
verbose_proxy_logger.debug("Connecting Prisma Client to DB..")
if http_client is not None:
self.db = Prisma(http=http_client)
else:
self.db = Prisma() # Client to connect to Prisma db
verbose_proxy_logger.debug("Success - Connected Prisma Client to DB")
@ -987,7 +996,7 @@ class PrismaClient:
return
else:
## check if required view exists ##
if required_view not in ret[0]["view_names"]:
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
await self.health_check() # make sure we can connect to db
await self.db.execute_raw(
"""
@ -1009,7 +1018,9 @@ class PrismaClient:
else:
# don't block execution if these views are missing
# Convert lists to sets for efficient difference calculation
ret_view_names_set = set(ret[0]["view_names"])
ret_view_names_set = (
set(ret[0]["view_names"]) if ret[0]["view_names"] else set()
)
expected_views_set = set(expected_views)
# Find missing views
missing_views = expected_views_set - ret_view_names_set
@ -1291,13 +1302,13 @@ class PrismaClient:
verbose_proxy_logger.debug(
f"PrismaClient: get_data - args_passed_in: {args_passed_in}"
)
hashed_token: Optional[str] = None
try:
response: Any = None
if (token is not None and table_name is None) or (
table_name is not None and table_name == "key"
):
# check if plain text or hash
hashed_token = None
if token is not None:
if isinstance(token, str):
hashed_token = token
@ -1306,7 +1317,7 @@ class PrismaClient:
verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}"
)
if query_type == "find_unique":
if query_type == "find_unique" and hashed_token is not None:
if token is None:
raise HTTPException(
status_code=400,
@ -1706,7 +1717,7 @@ class PrismaClient:
updated_data = v
updated_data = json.dumps(updated_data)
updated_table_row = self.db.litellm_config.upsert(
where={"param_name": k},
where={"param_name": k}, # type: ignore
data={
"create": {"param_name": k, "param_value": updated_data}, # type: ignore
"update": {"param_value": updated_data},
@ -2302,7 +2313,12 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
return instance
except ImportError as e:
# Re-raise the exception with a user-friendly message
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
if instance_name and module_name:
raise ImportError(
f"Could not import {instance_name} from {module_name}"
) from e
else:
raise e
except Exception as e:
raise e
@ -2377,12 +2393,12 @@ async def send_email(receiver_email, subject, html):
try:
# Establish a secure connection with the SMTP server
with smtplib.SMTP(smtp_host, smtp_port) as server:
with smtplib.SMTP(smtp_host, smtp_port) as server: # type: ignore
if os.getenv("SMTP_TLS", "True") != "False":
server.starttls()
# Login to your email account
server.login(smtp_username, smtp_password)
server.login(smtp_username, smtp_password) # type: ignore
# Send the email
server.send_message(email_message)
@ -2945,7 +2961,7 @@ async def update_spend(
if i >= n_retry_times: # If we've reached the maximum number of retries
raise # Re-raise the last exception
# Optionally, sleep for a bit before retrying
await asyncio.sleep(2**i) # Exponential backoff
await asyncio.sleep(2**i) # type: ignore
except Exception as e:
import traceback

View file

@ -2110,7 +2110,6 @@ async def test_hf_completion_tgi_stream():
def test_openai_chat_completion_call():
litellm.set_verbose = False
litellm.return_response_headers = True
print(f"making openai chat completion call")
response = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
assert isinstance(
response._hidden_params["additional_headers"][
@ -2318,6 +2317,57 @@ def test_together_ai_completion_call_mistral():
pass
# # test on together ai completion call - starcoder
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_openai_o1_completion_call_streaming(sync_mode):
try:
litellm.set_verbose = False
if sync_mode:
response = completion(
model="o1-preview",
messages=messages,
stream=True,
)
complete_response = ""
print(f"returned response object: {response}")
has_finish_reason = False
for idx, chunk in enumerate(response):
chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
if finished:
break
complete_response += chunk
if has_finish_reason is False:
raise Exception("Finish reason not set for last chunk")
if complete_response == "":
raise Exception("Empty response received")
else:
response = await acompletion(
model="o1-preview",
messages=messages,
stream=True,
)
complete_response = ""
print(f"returned response object: {response}")
has_finish_reason = False
idx = 0
async for chunk in response:
chunk, finished = streaming_format_tests(idx, chunk)
has_finish_reason = finished
if finished:
break
complete_response += chunk
idx += 1
if has_finish_reason is False:
raise Exception("Finish reason not set for last chunk")
if complete_response == "":
raise Exception("Empty response received")
print(f"complete response: {complete_response}")
except Exception:
pytest.fail(f"error occurred: {traceback.format_exc()}")
def test_together_ai_completion_call_starcoder_bad_key():
try:
api_key = "bad-key"

View file

@ -71,7 +71,7 @@ class LakeraCategoryThresholds(TypedDict, total=False):
jailbreak: float
class LitellmParams(TypedDict, total=False):
class LitellmParams(TypedDict):
guardrail: str
mode: str
api_key: str
@ -87,6 +87,7 @@ class LitellmParams(TypedDict, total=False):
# Presidio params
output_parse_pii: Optional[bool]
presidio_ad_hoc_recognizers: Optional[str]
mock_redacted_text: Optional[dict]
class Guardrail(TypedDict):

View file

@ -120,11 +120,26 @@ with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json")
# Convert to str (if necessary)
claude_json_str = json.dumps(json_data)
import importlib.metadata
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
Type,
Union,
cast,
get_args,
)
from openai import OpenAIError as OriginalError
from ._logging import verbose_logger
from .caching import QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
from .caching import Cache, QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
from .exceptions import (
APIConnectionError,
APIError,
@ -150,31 +165,6 @@ from .types.llms.openai import (
)
from .types.router import LiteLLM_Params
try:
from .proxy.enterprise.enterprise_callbacks.generic_api_callback import (
GenericAPILogger,
)
except Exception as e:
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Tuple,
Type,
Union,
cast,
get_args,
)
from .caching import Cache
####### ENVIRONMENT VARIABLES ####################
# Adjust to your specific application needs / system capabilities.
MAX_THREADS = 100

View file

@ -0,0 +1 @@
More tests under `litellm/litellm/tests/*`.

View file

@ -0,0 +1,502 @@
import asyncio
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch
import litellm
from litellm.exceptions import BadRequestError, InternalServerError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
def mock_chat_response() -> Dict[str, Any]:
return {
"id": "chatcmpl_3f78f09a-489c-4b8d-a587-f162c7497891",
"object": "chat.completion",
"created": 1726285449,
"model": "dbrx-instruct-071224",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! I'm an AI assistant. I'm doing well. How can I help?",
"function_call": None,
"tool_calls": None,
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 230,
"completion_tokens": 38,
"total_tokens": 268,
"completion_tokens_details": None,
},
"system_fingerprint": None,
}
def mock_chat_streaming_response_chunks() -> List[str]:
return [
json.dumps(
{
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
"object": "chat.completion.chunk",
"created": 1726469651,
"model": "dbrx-instruct-071224",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "Hello"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {
"prompt_tokens": 230,
"completion_tokens": 1,
"total_tokens": 231,
},
}
),
json.dumps(
{
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
"object": "chat.completion.chunk",
"created": 1726469651,
"model": "dbrx-instruct-071224",
"choices": [
{
"index": 0,
"delta": {"content": " world"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {
"prompt_tokens": 230,
"completion_tokens": 1,
"total_tokens": 231,
},
}
),
json.dumps(
{
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
"object": "chat.completion.chunk",
"created": 1726469651,
"model": "dbrx-instruct-071224",
"choices": [
{
"index": 0,
"delta": {"content": "!"},
"finish_reason": "stop",
"logprobs": None,
}
],
"usage": {
"prompt_tokens": 230,
"completion_tokens": 1,
"total_tokens": 231,
},
}
),
]
def mock_chat_streaming_response_chunks_bytes() -> List[bytes]:
string_chunks = mock_chat_streaming_response_chunks()
bytes_chunks = [chunk.encode("utf-8") + b"\n" for chunk in string_chunks]
# Simulate the end of the stream
bytes_chunks.append(b"")
return bytes_chunks
def mock_http_handler_chat_streaming_response() -> MagicMock:
mock_stream_chunks = mock_chat_streaming_response_chunks()
def mock_iter_lines():
for chunk in mock_stream_chunks:
for line in chunk.splitlines():
yield line
mock_response = MagicMock()
mock_response.iter_lines.side_effect = mock_iter_lines
mock_response.status_code = 200
return mock_response
def mock_http_handler_chat_async_streaming_response() -> MagicMock:
mock_stream_chunks = mock_chat_streaming_response_chunks()
async def mock_iter_lines():
for chunk in mock_stream_chunks:
for line in chunk.splitlines():
yield line
mock_response = MagicMock()
mock_response.aiter_lines.return_value = mock_iter_lines()
mock_response.status_code = 200
return mock_response
def mock_databricks_client_chat_streaming_response() -> MagicMock:
mock_stream_chunks = mock_chat_streaming_response_chunks_bytes()
def mock_read_from_stream(size=-1):
if mock_stream_chunks:
return mock_stream_chunks.pop(0)
return b""
mock_response = MagicMock()
streaming_response_mock = MagicMock()
streaming_response_iterator_mock = MagicMock()
# Mock the __getitem__("content") method to return the streaming response
mock_response.__getitem__.return_value = streaming_response_mock
# Mock the streaming response __enter__ method to return the streaming response iterator
streaming_response_mock.__enter__.return_value = streaming_response_iterator_mock
streaming_response_iterator_mock.read1.side_effect = mock_read_from_stream
streaming_response_iterator_mock.closed = False
return mock_response
def mock_embedding_response() -> Dict[str, Any]:
return {
"object": "list",
"model": "bge-large-en-v1.5",
"data": [
{
"index": 0,
"object": "embedding",
"embedding": [
0.06768798828125,
-0.01291656494140625,
-0.0501708984375,
0.0245361328125,
-0.030364990234375,
],
}
],
"usage": {
"prompt_tokens": 8,
"total_tokens": 8,
"completion_tokens": 0,
"completion_tokens_details": None,
},
}
@pytest.mark.parametrize("set_base", [True, False])
def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
if set_base:
monkeypatch.setenv(
"DATABRICKS_API_BASE",
"https://my.workspace.cloud.databricks.com/serving-endpoints",
)
monkeypatch.delenv(
"DATABRICKS_API_KEY",
)
err_msg = "A call is being made to LLM Provider but no key is set"
else:
monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey")
monkeypatch.delenv("DATABRICKS_API_BASE")
err_msg = "A call is being made to LLM Provider but no api base is set"
with pytest.raises(BadRequestError) as exc:
litellm.completion(
model="databricks/dbrx-instruct-071224",
messages={"role": "user", "content": "How are you?"},
)
assert err_msg in str(exc)
with pytest.raises(BadRequestError) as exc:
litellm.embedding(
model="databricks/bge-12312",
input=["Hello", "World"],
)
assert err_msg in str(exc)
def test_completions_with_sync_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
sync_handler = HTTPHandler()
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = mock_chat_response()
expected_response_json = {
**mock_chat_response(),
**{
"model": "databricks/dbrx-instruct-071224",
},
}
messages = [{"role": "user", "content": "How are you?"}]
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
response = litellm.completion(
model="databricks/dbrx-instruct-071224",
messages=messages,
client=sync_handler,
temperature=0.5,
extraparam="testpassingextraparam",
)
assert response.to_dict() == expected_response_json
mock_post.assert_called_once_with(
f"{base_url}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"extraparam": "testpassingextraparam",
"stream": False,
}
),
)
def test_completions_with_async_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
async_handler = AsyncHTTPHandler()
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = mock_chat_response()
expected_response_json = {
**mock_chat_response(),
**{
"model": "databricks/dbrx-instruct-071224",
},
}
messages = [{"role": "user", "content": "How are you?"}]
with patch.object(
AsyncHTTPHandler, "post", return_value=mock_response
) as mock_post:
response = asyncio.run(
litellm.acompletion(
model="databricks/dbrx-instruct-071224",
messages=messages,
client=async_handler,
temperature=0.5,
extraparam="testpassingextraparam",
)
)
assert response.to_dict() == expected_response_json
mock_post.assert_called_once_with(
f"{base_url}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"extraparam": "testpassingextraparam",
"stream": False,
}
),
)
def test_completions_streaming_with_sync_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
sync_handler = HTTPHandler()
messages = [{"role": "user", "content": "How are you?"}]
mock_response = mock_http_handler_chat_streaming_response()
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
response_stream: CustomStreamWrapper = litellm.completion(
model="databricks/dbrx-instruct-071224",
messages=messages,
client=sync_handler,
temperature=0.5,
extraparam="testpassingextraparam",
stream=True,
)
response = list(response_stream)
assert "dbrx-instruct-071224" in str(response)
assert "chatcmpl" in str(response)
assert len(response) == 4
mock_post.assert_called_once_with(
f"{base_url}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"stream": True,
"extraparam": "testpassingextraparam",
}
),
stream=True,
)
def test_completions_streaming_with_async_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
async_handler = AsyncHTTPHandler()
messages = [{"role": "user", "content": "How are you?"}]
mock_response = mock_http_handler_chat_async_streaming_response()
with patch.object(
AsyncHTTPHandler, "post", return_value=mock_response
) as mock_post:
response_stream: CustomStreamWrapper = asyncio.run(
litellm.acompletion(
model="databricks/dbrx-instruct-071224",
messages=messages,
client=async_handler,
temperature=0.5,
extraparam="testpassingextraparam",
stream=True,
)
)
# Use async list gathering for the response
async def gather_responses():
return [item async for item in response_stream]
response = asyncio.run(gather_responses())
assert "dbrx-instruct-071224" in str(response)
assert "chatcmpl" in str(response)
assert len(response) == 4
mock_post.assert_called_once_with(
f"{base_url}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "dbrx-instruct-071224",
"messages": messages,
"temperature": 0.5,
"stream": True,
"extraparam": "testpassingextraparam",
}
),
stream=True,
)
def test_embeddings_with_sync_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
sync_handler = HTTPHandler()
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = mock_embedding_response()
inputs = ["Hello", "World"]
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
response = litellm.embedding(
model="databricks/bge-large-en-v1.5",
input=inputs,
client=sync_handler,
extraparam="testpassingextraparam",
)
assert response.to_dict() == mock_embedding_response()
mock_post.assert_called_once_with(
f"{base_url}/embeddings",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "bge-large-en-v1.5",
"input": inputs,
"extraparam": "testpassingextraparam",
}
),
)
def test_embeddings_with_async_http_handler(monkeypatch):
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
api_key = "dapimykey"
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
async_handler = AsyncHTTPHandler()
mock_response = Mock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.json.return_value = mock_embedding_response()
inputs = ["Hello", "World"]
with patch.object(
AsyncHTTPHandler, "post", return_value=mock_response
) as mock_post:
response = asyncio.run(
litellm.aembedding(
model="databricks/bge-large-en-v1.5",
input=inputs,
client=async_handler,
extraparam="testpassingextraparam",
)
)
assert response.to_dict() == mock_embedding_response()
mock_post.assert_called_once_with(
f"{base_url}/embeddings",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
data=json.dumps(
{
"model": "bge-large-en-v1.5",
"input": inputs,
"extraparam": "testpassingextraparam",
}
),
)