forked from phoenix/litellm-mirror
* LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) * coverage (#5713) Signed-off-by: dbczumar <corey.zumar@databricks.com> * Move (#5714) Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix(litellm_logging.py): fix logging client re-init (#5710) Fixes https://github.com/BerriAI/litellm/issues/5695 * fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config Fixes https://github.com/BerriAI/litellm/issues/5682 * feat(o1_handler.py): fake streaming for openai o1 models Fixes https://github.com/BerriAI/litellm/issues/5694 * docs: deprecated traceloop integration in favor of native otel (#5249) * fix: fix linting errors * fix: fix linting errors * fix(main.py): fix o1 import --------- Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit <nirga@users.noreply.github.com> * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730) * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it * fix(custom_logger.py): reset calltype * fix: fix linting errors * fix: fix linting error * fix: fix import * test(test_databricks.py): fix databricks tests --------- Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>
This commit is contained in:
parent
1e59395280
commit
234185ec13
34 changed files with 1387 additions and 502 deletions
|
@ -6,10 +6,14 @@ import asyncio
|
|||
import os
|
||||
|
||||
# Enter your DATABASE_URL here
|
||||
os.environ["DATABASE_URL"] = "postgresql://xxxxxxx"
|
||||
|
||||
from prisma import Prisma
|
||||
|
||||
db = Prisma()
|
||||
db = Prisma(
|
||||
http={
|
||||
"timeout": 60000,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def check_view_exists():
|
||||
|
@ -47,25 +51,22 @@ async def check_view_exists():
|
|||
|
||||
print("LiteLLM_VerificationTokenView Created!") # noqa
|
||||
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
|
||||
print("MonthlyGlobalSpend Exists!") # noqa
|
||||
except Exception as e:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
|
||||
sql_query = """
|
||||
CREATE MATERIALIZED VIEW IF NOT EXISTS "MonthlyGlobalSpend" AS
|
||||
SELECT
|
||||
DATE("startTime") AS date,
|
||||
SUM("spend") AS spend
|
||||
DATE_TRUNC('day', "startTime") AS date,
|
||||
SUM("spend") AS spend
|
||||
FROM
|
||||
"LiteLLM_SpendLogs"
|
||||
"LiteLLM_SpendLogs"
|
||||
WHERE
|
||||
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
"startTime" >= CURRENT_DATE - INTERVAL '30 days'
|
||||
GROUP BY
|
||||
DATE("startTime");
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
DATE_TRUNC('day', "startTime");
|
||||
"""
|
||||
# Execute the queries
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("MonthlyGlobalSpend Created!") # noqa
|
||||
print("MonthlyGlobalSpend Created!") # noqa
|
||||
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# OpenTelemetry - Tracing LLMs with any observability tool
|
||||
|
||||
OpenTelemetry is a CNCF standard for observability. It connects to any observability tool, such as Jaeger, Zipkin, Datadog, New Relic, Traceloop and others.
|
||||
|
||||
<Image img={require('../../img/traceloop_dash.png')} />
|
||||
|
||||
## Getting Started
|
||||
|
||||
Install the OpenTelemetry SDK:
|
||||
|
||||
```
|
||||
pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp
|
||||
```
|
||||
|
||||
Set the environment variables (different providers may require different variables):
|
||||
|
||||
|
||||
<Tabs>
|
||||
|
||||
<TabItem value="traceloop" label="Log to Traceloop Cloud">
|
||||
|
||||
```shell
|
||||
OTEL_EXPORTER="otlp_http"
|
||||
OTEL_ENDPOINT="https://api.traceloop.com"
|
||||
OTEL_HEADERS="Authorization=Bearer%20<your-api-key>"
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="otel-col" label="Log to OTEL HTTP Collector">
|
||||
|
||||
```shell
|
||||
OTEL_EXPORTER="otlp_http"
|
||||
OTEL_ENDPOINT="http:/0.0.0.0:4317"
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="otel-col-grpc" label="Log to OTEL GRPC Collector">
|
||||
|
||||
```shell
|
||||
OTEL_EXPORTER="otlp_grpc"
|
||||
OTEL_ENDPOINT="http:/0.0.0.0:4317"
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
Use just 2 lines of code, to instantly log your LLM responses **across all providers** with OpenTelemetry:
|
||||
|
||||
```python
|
||||
litellm.callbacks = ["otel"]
|
||||
```
|
||||
|
||||
## Redacting Messages, Response Content from OpenTelemetry Logging
|
||||
|
||||
### Redact Messages and Responses from all OpenTelemetry Logging
|
||||
|
||||
Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to OpenTelemetry, but request metadata will still be logged.
|
||||
|
||||
### Redact Messages and Responses from specific OpenTelemetry Logging
|
||||
|
||||
In the metadata typically passed for text completion or embedding calls you can set specific keys to mask the messages and responses for this call.
|
||||
|
||||
Setting `mask_input` to `True` will mask the input from being logged for this call
|
||||
|
||||
Setting `mask_output` to `True` will make the output from being logged for this call.
|
||||
|
||||
Be aware that if you are continuing an existing trace, and you set `update_trace_keys` to include either `input` or `output` and you set the corresponding `mask_input` or `mask_output`, then that trace will have its existing input and/or output replaced with a redacted message.
|
||||
|
||||
## Support
|
||||
|
||||
For any question or issue with the integration you can reach out to the OpenLLMetry maintainers on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).
|
|
@ -1,36 +0,0 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
|
||||
# Traceloop (OpenLLMetry) - Tracing LLMs with OpenTelemetry
|
||||
|
||||
[Traceloop](https://traceloop.com) is a platform for monitoring and debugging the quality of your LLM outputs.
|
||||
It provides you with a way to track the performance of your LLM application; rollout changes with confidence; and debug issues in production.
|
||||
It is based on [OpenTelemetry](https://opentelemetry.io), so it can provide full visibility to your LLM requests, as well vector DB usage, and other infra in your stack.
|
||||
|
||||
<Image img={require('../../img/traceloop_dash.png')} />
|
||||
|
||||
## Getting Started
|
||||
|
||||
Install the Traceloop SDK:
|
||||
|
||||
```
|
||||
pip install traceloop-sdk
|
||||
```
|
||||
|
||||
Use just 2 lines of code, to instantly log your LLM responses with OpenTelemetry:
|
||||
|
||||
```python
|
||||
Traceloop.init(app_name=<YOUR APP NAME>, disable_batch=True)
|
||||
litellm.success_callback = ["traceloop"]
|
||||
```
|
||||
|
||||
Make sure to properly set a destination to your traces. See [OpenLLMetry docs](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for options.
|
||||
|
||||
To get better visualizations on how your code behaves, you may want to annotate specific parts of your LLM chain. See [Traceloop docs on decorators](https://traceloop.com/docs/python-sdk/decorators) for more information.
|
||||
|
||||
## Exporting traces to other systems (e.g. Datadog, New Relic, and others)
|
||||
|
||||
Since OpenLLMetry uses OpenTelemetry to send data, you can easily export your traces to other systems, such as Datadog, New Relic, and others. See [OpenLLMetry docs on exporters](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for more information.
|
||||
|
||||
## Support
|
||||
|
||||
For any question or issue with integration you can reach out to the Traceloop team on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).
|
|
@ -600,6 +600,52 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="traceloop" label="Log to Traceloop Cloud">
|
||||
|
||||
#### Quick Start - Log to Traceloop
|
||||
|
||||
**Step 1:**
|
||||
Add the following to your env
|
||||
|
||||
```shell
|
||||
OTEL_EXPORTER="otlp_http"
|
||||
OTEL_ENDPOINT="https://api.traceloop.com"
|
||||
OTEL_HEADERS="Authorization=Bearer%20<your-api-key>"
|
||||
```
|
||||
|
||||
**Step 2:** Add `otel` as a callbacks
|
||||
|
||||
```shell
|
||||
litellm_settings:
|
||||
callbacks: ["otel"]
|
||||
```
|
||||
|
||||
**Step 3**: Start the proxy, make a test request
|
||||
|
||||
Start proxy
|
||||
|
||||
```shell
|
||||
litellm --config config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
Test Request
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="otel-col" label="Log to OTEL HTTP Collector">
|
||||
|
||||
#### Quick Start - Log to OTEL Collector
|
||||
|
@ -694,52 +740,6 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="traceloop" label="Log to Traceloop Cloud">
|
||||
|
||||
#### Quick Start - Log to Traceloop
|
||||
|
||||
**Step 1:** Install the `traceloop-sdk` SDK
|
||||
|
||||
```shell
|
||||
pip install traceloop-sdk==0.21.2
|
||||
```
|
||||
|
||||
**Step 2:** Add `traceloop` as a success_callback
|
||||
|
||||
```shell
|
||||
litellm_settings:
|
||||
success_callback: ["traceloop"]
|
||||
|
||||
environment_variables:
|
||||
TRACELOOP_API_KEY: "XXXXX"
|
||||
```
|
||||
|
||||
**Step 3**: Start the proxy, make a test request
|
||||
|
||||
Start proxy
|
||||
|
||||
```shell
|
||||
litellm --config config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
Test Request
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
</Tabs>
|
||||
|
||||
** 🎉 Expect to see this trace logged in your OTEL collector**
|
||||
|
|
|
@ -255,6 +255,7 @@ const sidebars = {
|
|||
type: "category",
|
||||
label: "Logging & Observability",
|
||||
items: [
|
||||
"observability/opentelemetry_integration",
|
||||
"observability/langfuse_integration",
|
||||
"observability/logfire_integration",
|
||||
"observability/gcs_bucket_integration",
|
||||
|
@ -271,7 +272,7 @@ const sidebars = {
|
|||
"observability/openmeter",
|
||||
"observability/promptlayer_integration",
|
||||
"observability/wandb_integration",
|
||||
"observability/traceloop_integration",
|
||||
"observability/slack_integration",
|
||||
"observability/athina_integration",
|
||||
"observability/lunary_integration",
|
||||
"observability/greenscale_integration",
|
||||
|
|
|
@ -948,10 +948,10 @@ from .llms.OpenAI.openai import (
|
|||
AzureAIStudioConfig,
|
||||
)
|
||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||
from .llms.OpenAI.o1_transformation import (
|
||||
from .llms.OpenAI.chat.o1_transformation import (
|
||||
OpenAIO1Config,
|
||||
)
|
||||
from .llms.OpenAI.gpt_transformation import (
|
||||
from .llms.OpenAI.chat.gpt_transformation import (
|
||||
OpenAIGPTConfig,
|
||||
)
|
||||
from .llms.nvidia_nim import NvidiaNimConfig
|
||||
|
|
|
@ -1616,6 +1616,9 @@ Model Info:
|
|||
|
||||
:param time_range: A string specifying the time range, e.g., "1d", "7d", "30d"
|
||||
"""
|
||||
if self.alerting is None or "spend_reports" not in self.alert_types:
|
||||
return
|
||||
|
||||
try:
|
||||
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
||||
_get_spend_report_for_time_range,
|
||||
|
|
|
@ -11,7 +11,7 @@ class CustomGuardrail(CustomLogger):
|
|||
self,
|
||||
guardrail_name: Optional[str] = None,
|
||||
event_hook: Optional[GuardrailEventHooks] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.guardrail_name = guardrail_name
|
||||
self.event_hook: Optional[GuardrailEventHooks] = event_hook
|
||||
|
@ -28,10 +28,13 @@ class CustomGuardrail(CustomLogger):
|
|||
requested_guardrails,
|
||||
)
|
||||
|
||||
if self.guardrail_name not in requested_guardrails:
|
||||
if (
|
||||
self.guardrail_name not in requested_guardrails
|
||||
and event_type.value != "logging_only"
|
||||
):
|
||||
return False
|
||||
|
||||
if self.event_hook != event_type:
|
||||
if self.event_hook != event_type.value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
|
@ -4,6 +4,11 @@ import litellm
|
|||
|
||||
|
||||
class TraceloopLogger:
|
||||
"""
|
||||
WARNING: DEPRECATED
|
||||
Use the OpenTelemetry standard integration instead
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
try:
|
||||
from traceloop.sdk.tracing.tracing import TracerWrapper
|
||||
|
|
|
@ -90,6 +90,13 @@ from ..integrations.supabase import Supabase
|
|||
from ..integrations.traceloop import TraceloopLogger
|
||||
from ..integrations.weights_biases import WeightsBiasesLogger
|
||||
|
||||
try:
|
||||
from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import (
|
||||
GenericAPILogger,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
|
||||
|
||||
_in_memory_loggers: List[Any] = []
|
||||
|
||||
### GLOBAL VARIABLES ###
|
||||
|
@ -145,7 +152,41 @@ class ServiceTraceIDCache:
|
|||
return None
|
||||
|
||||
|
||||
import hashlib
|
||||
|
||||
|
||||
class DynamicLoggingCache:
|
||||
"""
|
||||
Prevent memory leaks caused by initializing new logging clients on each request.
|
||||
|
||||
Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.cache = InMemoryCache()
|
||||
|
||||
def get_cache_key(self, args: dict) -> str:
|
||||
args_str = json.dumps(args, sort_keys=True)
|
||||
cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
|
||||
return cache_key
|
||||
|
||||
def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
|
||||
key_name = self.get_cache_key(
|
||||
args={**credentials, "service_name": service_name}
|
||||
)
|
||||
response = self.cache.get_cache(key=key_name)
|
||||
return response
|
||||
|
||||
def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
|
||||
key_name = self.get_cache_key(
|
||||
args={**credentials, "service_name": service_name}
|
||||
)
|
||||
self.cache.set_cache(key=key_name, value=logging_obj)
|
||||
return None
|
||||
|
||||
|
||||
in_memory_trace_id_cache = ServiceTraceIDCache()
|
||||
in_memory_dynamic_logger_cache = DynamicLoggingCache()
|
||||
|
||||
|
||||
class Logging:
|
||||
|
@ -324,10 +365,10 @@ class Logging:
|
|||
print_verbose(f"\033[92m{curl_command}\033[0m\n", log_level="DEBUG")
|
||||
# log raw request to provider (like LangFuse) -- if opted in.
|
||||
if log_raw_request_response is True:
|
||||
_litellm_params = self.model_call_details.get("litellm_params", {})
|
||||
_metadata = _litellm_params.get("metadata", {}) or {}
|
||||
try:
|
||||
# [Non-blocking Extra Debug Information in metadata]
|
||||
_litellm_params = self.model_call_details.get("litellm_params", {})
|
||||
_metadata = _litellm_params.get("metadata", {}) or {}
|
||||
if (
|
||||
turn_off_message_logging is not None
|
||||
and turn_off_message_logging is True
|
||||
|
@ -362,7 +403,7 @@ class Logging:
|
|||
callbacks = litellm.input_callback + self.dynamic_input_callbacks
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if callback == "supabase":
|
||||
if callback == "supabase" and supabaseClient is not None:
|
||||
verbose_logger.debug("reaches supabase for logging!")
|
||||
model = self.model_call_details["model"]
|
||||
messages = self.model_call_details["input"]
|
||||
|
@ -396,7 +437,9 @@ class Logging:
|
|||
messages=self.messages,
|
||||
kwargs=self.model_call_details,
|
||||
)
|
||||
elif callable(callback): # custom logger functions
|
||||
elif (
|
||||
callable(callback) and customLogger is not None
|
||||
): # custom logger functions
|
||||
customLogger.log_input_event(
|
||||
model=self.model,
|
||||
messages=self.messages,
|
||||
|
@ -615,7 +658,7 @@ class Logging:
|
|||
|
||||
self.model_call_details["litellm_params"]["metadata"][
|
||||
"hidden_params"
|
||||
] = result._hidden_params
|
||||
] = getattr(result, "_hidden_params", {})
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
|
@ -645,6 +688,7 @@ class Logging:
|
|||
litellm.max_budget
|
||||
and self.stream is False
|
||||
and result is not None
|
||||
and isinstance(result, dict)
|
||||
and "content" in result
|
||||
):
|
||||
time_diff = (end_time - start_time).total_seconds()
|
||||
|
@ -652,7 +696,7 @@ class Logging:
|
|||
litellm._current_cost += litellm.completion_cost(
|
||||
model=self.model,
|
||||
prompt="",
|
||||
completion=result["content"],
|
||||
completion=getattr(result, "content", ""),
|
||||
total_time=float_diff,
|
||||
)
|
||||
|
||||
|
@ -758,7 +802,7 @@ class Logging:
|
|||
):
|
||||
print_verbose("no-log request, skipping logging")
|
||||
continue
|
||||
if callback == "lite_debugger":
|
||||
if callback == "lite_debugger" and liteDebuggerClient is not None:
|
||||
print_verbose("reaches lite_debugger for logging!")
|
||||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||
print_verbose(
|
||||
|
@ -774,7 +818,7 @@ class Logging:
|
|||
call_type=self.call_type,
|
||||
stream=self.stream,
|
||||
)
|
||||
if callback == "promptlayer":
|
||||
if callback == "promptlayer" and promptLayerLogger is not None:
|
||||
print_verbose("reaches promptlayer for logging!")
|
||||
promptLayerLogger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
|
@ -783,7 +827,7 @@ class Logging:
|
|||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "supabase":
|
||||
if callback == "supabase" and supabaseClient is not None:
|
||||
print_verbose("reaches supabase for logging!")
|
||||
kwargs = self.model_call_details
|
||||
|
||||
|
@ -811,7 +855,7 @@ class Logging:
|
|||
),
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "wandb":
|
||||
if callback == "wandb" and weightsBiasesLogger is not None:
|
||||
print_verbose("reaches wandb for logging!")
|
||||
weightsBiasesLogger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
|
@ -820,8 +864,7 @@ class Logging:
|
|||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "logfire":
|
||||
global logfireLogger
|
||||
if callback == "logfire" and logfireLogger is not None:
|
||||
verbose_logger.debug("reaches logfire for success logging!")
|
||||
kwargs = {}
|
||||
for k, v in self.model_call_details.items():
|
||||
|
@ -844,10 +887,10 @@ class Logging:
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
level=LogfireLevel.INFO.value,
|
||||
level=LogfireLevel.INFO.value, # type: ignore
|
||||
)
|
||||
|
||||
if callback == "lunary":
|
||||
if callback == "lunary" and lunaryLogger is not None:
|
||||
print_verbose("reaches lunary for logging!")
|
||||
model = self.model
|
||||
kwargs = self.model_call_details
|
||||
|
@ -882,7 +925,7 @@ class Logging:
|
|||
run_id=self.litellm_call_id,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "helicone":
|
||||
if callback == "helicone" and heliconeLogger is not None:
|
||||
print_verbose("reaches helicone for logging!")
|
||||
model = self.model
|
||||
messages = self.model_call_details["input"]
|
||||
|
@ -924,6 +967,7 @@ class Logging:
|
|||
else:
|
||||
print_verbose("reaches langfuse for streaming logging!")
|
||||
result = kwargs["complete_streaming_response"]
|
||||
|
||||
temp_langfuse_logger = langFuseLogger
|
||||
if langFuseLogger is None or (
|
||||
(
|
||||
|
@ -941,27 +985,45 @@ class Logging:
|
|||
and self.langfuse_host != langFuseLogger.langfuse_host
|
||||
)
|
||||
):
|
||||
temp_langfuse_logger = LangFuseLogger(
|
||||
langfuse_public_key=self.langfuse_public_key,
|
||||
langfuse_secret=self.langfuse_secret,
|
||||
langfuse_host=self.langfuse_host,
|
||||
)
|
||||
_response = temp_langfuse_logger.log_event(
|
||||
kwargs=kwargs,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
user_id=kwargs.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if _response is not None and isinstance(_response, dict):
|
||||
_trace_id = _response.get("trace_id", None)
|
||||
if _trace_id is not None:
|
||||
in_memory_trace_id_cache.set_cache(
|
||||
litellm_call_id=self.litellm_call_id,
|
||||
service_name="langfuse",
|
||||
trace_id=_trace_id,
|
||||
credentials = {
|
||||
"langfuse_public_key": self.langfuse_public_key,
|
||||
"langfuse_secret": self.langfuse_secret,
|
||||
"langfuse_host": self.langfuse_host,
|
||||
}
|
||||
temp_langfuse_logger = (
|
||||
in_memory_dynamic_logger_cache.get_cache(
|
||||
credentials=credentials, service_name="langfuse"
|
||||
)
|
||||
)
|
||||
if temp_langfuse_logger is None:
|
||||
temp_langfuse_logger = LangFuseLogger(
|
||||
langfuse_public_key=self.langfuse_public_key,
|
||||
langfuse_secret=self.langfuse_secret,
|
||||
langfuse_host=self.langfuse_host,
|
||||
)
|
||||
in_memory_dynamic_logger_cache.set_cache(
|
||||
credentials=credentials,
|
||||
service_name="langfuse",
|
||||
logging_obj=temp_langfuse_logger,
|
||||
)
|
||||
|
||||
if temp_langfuse_logger is not None:
|
||||
_response = temp_langfuse_logger.log_event(
|
||||
kwargs=kwargs,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
user_id=kwargs.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if _response is not None and isinstance(_response, dict):
|
||||
_trace_id = _response.get("trace_id", None)
|
||||
if _trace_id is not None:
|
||||
in_memory_trace_id_cache.set_cache(
|
||||
litellm_call_id=self.litellm_call_id,
|
||||
service_name="langfuse",
|
||||
trace_id=_trace_id,
|
||||
)
|
||||
if callback == "generic":
|
||||
global genericAPILogger
|
||||
verbose_logger.debug("reaches langfuse for success logging!")
|
||||
|
@ -982,7 +1044,7 @@ class Logging:
|
|||
print_verbose("reaches langfuse for streaming logging!")
|
||||
result = kwargs["complete_streaming_response"]
|
||||
if genericAPILogger is None:
|
||||
genericAPILogger = GenericAPILogger()
|
||||
genericAPILogger = GenericAPILogger() # type: ignore
|
||||
genericAPILogger.log_event(
|
||||
kwargs=kwargs,
|
||||
response_obj=result,
|
||||
|
@ -1022,7 +1084,7 @@ class Logging:
|
|||
user_id=kwargs.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "greenscale":
|
||||
if callback == "greenscale" and greenscaleLogger is not None:
|
||||
kwargs = {}
|
||||
for k, v in self.model_call_details.items():
|
||||
if (
|
||||
|
@ -1066,7 +1128,7 @@ class Logging:
|
|||
result = kwargs["complete_streaming_response"]
|
||||
# only add to cache once we have a complete streaming response
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
if callback == "athina":
|
||||
if callback == "athina" and athinaLogger is not None:
|
||||
deep_copy = {}
|
||||
for k, v in self.model_call_details.items():
|
||||
deep_copy[k] = v
|
||||
|
@ -1224,6 +1286,7 @@ class Logging:
|
|||
"atranscription", False
|
||||
)
|
||||
is not True
|
||||
and customLogger is not None
|
||||
): # custom logger functions
|
||||
print_verbose(
|
||||
f"success callbacks: Running Custom Callback Function"
|
||||
|
@ -1423,9 +1486,8 @@ class Logging:
|
|||
await litellm.cache.async_add_cache(result, **kwargs)
|
||||
else:
|
||||
litellm.cache.add_cache(result, **kwargs)
|
||||
if callback == "openmeter":
|
||||
global openMeterLogger
|
||||
if self.stream == True:
|
||||
if callback == "openmeter" and openMeterLogger is not None:
|
||||
if self.stream is True:
|
||||
if (
|
||||
"async_complete_streaming_response"
|
||||
in self.model_call_details
|
||||
|
@ -1645,33 +1707,9 @@ class Logging:
|
|||
)
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if callback == "lite_debugger":
|
||||
print_verbose("reaches lite_debugger for logging!")
|
||||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||
result = {
|
||||
"model": self.model,
|
||||
"created": time.time(),
|
||||
"error": traceback_exception,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_token_calculator(
|
||||
self.model, messages=self.messages
|
||||
),
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
}
|
||||
liteDebuggerClient.log_event(
|
||||
model=self.model,
|
||||
messages=self.messages,
|
||||
end_user=self.model_call_details.get("user", "default"),
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
litellm_call_id=self.litellm_call_id,
|
||||
print_verbose=print_verbose,
|
||||
call_type=self.call_type,
|
||||
stream=self.stream,
|
||||
)
|
||||
if callback == "lunary":
|
||||
if callback == "lite_debugger" and liteDebuggerClient is not None:
|
||||
pass
|
||||
elif callback == "lunary" and lunaryLogger is not None:
|
||||
print_verbose("reaches lunary for logging error!")
|
||||
|
||||
model = self.model
|
||||
|
@ -1685,6 +1723,7 @@ class Logging:
|
|||
)
|
||||
|
||||
lunaryLogger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
type=_type,
|
||||
event="error",
|
||||
user_id=self.model_call_details.get("user", "default"),
|
||||
|
@ -1704,22 +1743,11 @@ class Logging:
|
|||
print_verbose(
|
||||
f"capture exception not initialized: {capture_exception}"
|
||||
)
|
||||
elif callback == "supabase":
|
||||
elif callback == "supabase" and supabaseClient is not None:
|
||||
print_verbose("reaches supabase for logging!")
|
||||
print_verbose(f"supabaseClient: {supabaseClient}")
|
||||
result = {
|
||||
"model": model,
|
||||
"created": time.time(),
|
||||
"error": traceback_exception,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_token_calculator(
|
||||
model, messages=self.messages
|
||||
),
|
||||
"completion_tokens": 0,
|
||||
},
|
||||
}
|
||||
supabaseClient.log_event(
|
||||
model=self.model,
|
||||
model=self.model if hasattr(self, "model") else "",
|
||||
messages=self.messages,
|
||||
end_user=self.model_call_details.get("user", "default"),
|
||||
response_obj=result,
|
||||
|
@ -1728,7 +1756,9 @@ class Logging:
|
|||
litellm_call_id=self.model_call_details["litellm_call_id"],
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callable(callback): # custom logger functions
|
||||
if (
|
||||
callable(callback) and customLogger is not None
|
||||
): # custom logger functions
|
||||
customLogger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=result,
|
||||
|
@ -1809,13 +1839,13 @@ class Logging:
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
response_obj=None,
|
||||
user_id=kwargs.get("user", None),
|
||||
user_id=self.model_call_details.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
status_message=str(exception),
|
||||
level="ERROR",
|
||||
kwargs=self.model_call_details,
|
||||
)
|
||||
if callback == "logfire":
|
||||
if callback == "logfire" and logfireLogger is not None:
|
||||
verbose_logger.debug("reaches logfire for failure logging!")
|
||||
kwargs = {}
|
||||
for k, v in self.model_call_details.items():
|
||||
|
@ -1830,7 +1860,7 @@ class Logging:
|
|||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
level=LogfireLevel.ERROR.value,
|
||||
level=LogfireLevel.ERROR.value, # type: ignore
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
|
||||
|
@ -1873,7 +1903,9 @@ class Logging:
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
) # type: ignore
|
||||
if callable(callback): # custom logger functions
|
||||
if (
|
||||
callable(callback) and customLogger is not None
|
||||
): # custom logger functions
|
||||
await customLogger.async_log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=result,
|
||||
|
@ -1966,7 +1998,7 @@ def set_callbacks(callback_list, function_id=None):
|
|||
)
|
||||
sentry_sdk_instance.init(
|
||||
dsn=os.environ.get("SENTRY_DSN"),
|
||||
traces_sample_rate=float(sentry_trace_rate),
|
||||
traces_sample_rate=float(sentry_trace_rate), # type: ignore
|
||||
)
|
||||
capture_exception = sentry_sdk_instance.capture_exception
|
||||
add_breadcrumb = sentry_sdk_instance.add_breadcrumb
|
||||
|
@ -2411,12 +2443,11 @@ def get_standard_logging_object_payload(
|
|||
|
||||
saved_cache_cost: Optional[float] = None
|
||||
if cache_hit is True:
|
||||
import time
|
||||
|
||||
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
||||
|
||||
saved_cache_cost = logging_obj._response_cost_calculator(
|
||||
result=init_response_obj, cache_hit=False
|
||||
result=init_response_obj, cache_hit=False # type: ignore
|
||||
)
|
||||
|
||||
## Get model cost information ##
|
||||
|
@ -2473,7 +2504,7 @@ def get_standard_logging_object_payload(
|
|||
model_id=_model_id,
|
||||
requester_ip_address=clean_metadata.get("requester_ip_address", None),
|
||||
messages=kwargs.get("messages"),
|
||||
response=(
|
||||
response=( # type: ignore
|
||||
response_obj if len(response_obj.keys()) > 0 else init_response_obj
|
||||
),
|
||||
model_parameters=kwargs.get("optional_params", None),
|
||||
|
|
95
litellm/llms/OpenAI/chat/o1_handler.py
Normal file
95
litellm/llms/OpenAI/chat/o1_handler.py
Normal file
|
@ -0,0 +1,95 @@
|
|||
"""
|
||||
Handler file for calls to OpenAI's o1 family of models
|
||||
|
||||
Written separately to handle faking streaming for o1 models.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
from httpx._config import Timeout
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
|
||||
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
|
||||
class OpenAIO1ChatCompletion(OpenAIChatCompletion):
|
||||
|
||||
async def mock_async_streaming(
|
||||
self,
|
||||
response: Any,
|
||||
model: Optional[str],
|
||||
logging_obj: Any,
|
||||
):
|
||||
model_response = await response
|
||||
completion_stream = MockResponseIterator(model_response=model_response)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="openai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
timeout: Union[float, Timeout],
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
model: Optional[str] = None,
|
||||
messages: Optional[list] = None,
|
||||
print_verbose: Optional[Callable[..., Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
acompletion: bool = False,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
custom_prompt_dict: dict = {},
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
drop_params: Optional[bool] = None,
|
||||
):
|
||||
stream: Optional[bool] = optional_params.pop("stream", False)
|
||||
response = super().completion(
|
||||
model_response,
|
||||
timeout,
|
||||
optional_params,
|
||||
logging_obj,
|
||||
model,
|
||||
messages,
|
||||
print_verbose,
|
||||
api_key,
|
||||
api_base,
|
||||
acompletion,
|
||||
litellm_params,
|
||||
logger_fn,
|
||||
headers,
|
||||
custom_prompt_dict,
|
||||
client,
|
||||
organization,
|
||||
custom_llm_provider,
|
||||
drop_params,
|
||||
)
|
||||
|
||||
if stream is True:
|
||||
if asyncio.iscoroutine(response):
|
||||
return self.mock_async_streaming(
|
||||
response=response, model=model, logging_obj=logging_obj # type: ignore
|
||||
)
|
||||
|
||||
completion_stream = MockResponseIterator(model_response=response)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="openai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
else:
|
||||
return response
|
|
@ -15,6 +15,8 @@ import requests # type: ignore
|
|||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.databricks.exceptions import DatabricksError
|
||||
from litellm.llms.databricks.streaming_utils import ModelResponseIterator
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionDeltaChunk,
|
||||
ChatCompletionResponseMessage,
|
||||
|
@ -33,17 +35,6 @@ from ..base import BaseLLM
|
|||
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
|
||||
class DatabricksError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class DatabricksConfig:
|
||||
"""
|
||||
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
||||
|
@ -367,7 +358,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
except httpx.TimeoutException:
|
||||
raise DatabricksError(status_code=408, message="Timeout error occurred.")
|
||||
except Exception as e:
|
||||
raise DatabricksError(status_code=500, message=str(e))
|
||||
|
@ -380,7 +371,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
)
|
||||
response = ModelResponse(**response_json)
|
||||
|
||||
response.model = custom_llm_provider + "/" + response.model
|
||||
response.model = custom_llm_provider + "/" + (response.model or "")
|
||||
|
||||
if base_model is not None:
|
||||
response._hidden_params["model"] = base_model
|
||||
|
@ -529,7 +520,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
response_json = response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise DatabricksError(
|
||||
status_code=e.response.status_code, message=response.text
|
||||
status_code=e.response.status_code, message=e.response.text
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
raise DatabricksError(
|
||||
|
@ -540,7 +531,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
|
||||
response = ModelResponse(**response_json)
|
||||
|
||||
response.model = custom_llm_provider + "/" + response.model
|
||||
response.model = custom_llm_provider + "/" + (response.model or "")
|
||||
|
||||
if base_model is not None:
|
||||
response._hidden_params["model"] = base_model
|
||||
|
@ -657,7 +648,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
except httpx.HTTPStatusError as e:
|
||||
raise DatabricksError(
|
||||
status_code=e.response.status_code,
|
||||
message=response.text if response else str(e),
|
||||
message=e.response.text,
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
raise DatabricksError(status_code=408, message="Timeout error occurred.")
|
||||
|
@ -673,136 +664,3 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
return litellm.EmbeddingResponse(**response_json)
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, streaming_response, sync_stream: bool):
|
||||
self.streaming_response = streaming_response
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
|
||||
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
|
||||
if processed_chunk.choices[0].delta.content is not None: # type: ignore
|
||||
text = processed_chunk.choices[0].delta.content # type: ignore
|
||||
|
||||
if (
|
||||
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
|
||||
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
|
||||
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
|
||||
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
|
||||
is not None
|
||||
):
|
||||
tool_use = ChatCompletionToolCallChunk(
|
||||
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=processed_chunk.choices[0]
|
||||
.delta.tool_calls[0] # type: ignore
|
||||
.function.name,
|
||||
arguments=processed_chunk.choices[0]
|
||||
.delta.tool_calls[0] # type: ignore
|
||||
.function.arguments,
|
||||
),
|
||||
index=processed_chunk.choices[0].index,
|
||||
)
|
||||
|
||||
if processed_chunk.choices[0].finish_reason is not None:
|
||||
is_finished = True
|
||||
finish_reason = processed_chunk.choices[0].finish_reason
|
||||
|
||||
if hasattr(processed_chunk, "usage") and isinstance(
|
||||
processed_chunk.usage, litellm.Usage
|
||||
):
|
||||
usage_chunk: litellm.Usage = processed_chunk.usage
|
||||
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=usage_chunk.prompt_tokens,
|
||||
completion_tokens=usage_chunk.completion_tokens,
|
||||
total_tokens=usage_chunk.total_tokens,
|
||||
)
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=0,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
self.response_iterator = self.streaming_response
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
chunk = chunk.replace("data:", "")
|
||||
chunk = chunk.strip()
|
||||
if len(chunk) > 0:
|
||||
json_chunk = json.loads(chunk)
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
chunk = chunk.replace("data:", "")
|
||||
chunk = chunk.strip()
|
||||
if chunk == "[DONE]":
|
||||
raise StopAsyncIteration
|
||||
if len(chunk) > 0:
|
||||
json_chunk = json.loads(chunk)
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
|
12
litellm/llms/databricks/exceptions.py
Normal file
12
litellm/llms/databricks/exceptions.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
import httpx
|
||||
|
||||
|
||||
class DatabricksError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
145
litellm/llms/databricks/streaming_utils.py
Normal file
145
litellm/llms/databricks/streaming_utils.py
Normal file
|
@ -0,0 +1,145 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionDeltaChunk,
|
||||
ChatCompletionResponseMessage,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
)
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, streaming_response, sync_stream: bool):
|
||||
self.streaming_response = streaming_response
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
|
||||
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
|
||||
if processed_chunk.choices[0].delta.content is not None: # type: ignore
|
||||
text = processed_chunk.choices[0].delta.content # type: ignore
|
||||
|
||||
if (
|
||||
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
|
||||
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
|
||||
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
|
||||
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
|
||||
is not None
|
||||
):
|
||||
tool_use = ChatCompletionToolCallChunk(
|
||||
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=processed_chunk.choices[0]
|
||||
.delta.tool_calls[0] # type: ignore
|
||||
.function.name,
|
||||
arguments=processed_chunk.choices[0]
|
||||
.delta.tool_calls[0] # type: ignore
|
||||
.function.arguments,
|
||||
),
|
||||
index=processed_chunk.choices[0].index,
|
||||
)
|
||||
|
||||
if processed_chunk.choices[0].finish_reason is not None:
|
||||
is_finished = True
|
||||
finish_reason = processed_chunk.choices[0].finish_reason
|
||||
|
||||
if hasattr(processed_chunk, "usage") and isinstance(
|
||||
processed_chunk.usage, litellm.Usage
|
||||
):
|
||||
usage_chunk: litellm.Usage = processed_chunk.usage
|
||||
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=usage_chunk.prompt_tokens,
|
||||
completion_tokens=usage_chunk.completion_tokens,
|
||||
total_tokens=usage_chunk.total_tokens,
|
||||
)
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=0,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
self.response_iterator = self.streaming_response
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
chunk = chunk.replace("data:", "")
|
||||
chunk = chunk.strip()
|
||||
if len(chunk) > 0:
|
||||
json_chunk = json.loads(chunk)
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
chunk = chunk.replace("data:", "")
|
||||
chunk = chunk.strip()
|
||||
if chunk == "[DONE]":
|
||||
raise StopAsyncIteration
|
||||
if len(chunk) > 0:
|
||||
json_chunk = json.loads(chunk)
|
||||
return self.chunk_parser(chunk=json_chunk)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
|
@ -95,6 +95,7 @@ from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
|||
from .llms.databricks.chat import DatabricksChatCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
|
||||
from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion
|
||||
from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
|
@ -161,6 +162,7 @@ from litellm.utils import (
|
|||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_chat_completions = OpenAIChatCompletion()
|
||||
openai_text_completions = OpenAITextCompletion()
|
||||
openai_o1_chat_completions = OpenAIO1ChatCompletion()
|
||||
openai_audio_transcriptions = OpenAIAudioTranscription()
|
||||
databricks_chat_completions = DatabricksChatCompletion()
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
|
@ -1366,25 +1368,46 @@ def completion(
|
|||
|
||||
## COMPLETION CALL
|
||||
try:
|
||||
response = openai_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
|
||||
response = openai_o1_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
else:
|
||||
response = openai_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
timeout=timeout, # type: ignore
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client, # pass AsyncOpenAI, OpenAI client
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING - log the original exception returned
|
||||
logging.post_call(
|
||||
|
|
|
@ -3,7 +3,6 @@ model_list:
|
|||
litellm_params:
|
||||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
# aws_session_token: "IQoJb3JpZ2luX2VjELj//////////wEaCXVzLXdlc3QtMiJHMEUCIQDatCRVkIZERLcrR6P7Qd1vNfZ8r8xB/LUeaVaTW/lBTwIgAgmHSBe41d65GVRKSkpgVonjsCmOmAS7s/yklM9NsZcq3AEI4P//////////ARABGgw4ODg2MDIyMjM0MjgiDJrio0/CHYEfyt5EqyqwAfyWO4t3bFVWAOIwTyZ1N6lszeJKfMNus2hzVc+r73hia2Anv88uwPxNg2uqnXQNJumEo0DcBt30ZwOw03Isboy0d5l05h8gjb4nl9feyeKmKAnRdcqElrEWtCC1Qcefv78jQv53AbUipH1ssa5NPvptqZZpZYDPMlBEnV3YdvJJiuE23u2yOkCt+EoUJLaOYjOryoRyrSfbWB+JaUsB68R3rNTHzReeN3Nob/9Ic4HrMMmzmLcGOpgBZxclO4w8Z7i6TcVqbCwDOskxuR6bZaiFxKFG+9tDrWS7jaQKpq/YP9HUT0YwYpZplaBEEZR5sbIndg5yb4dRZrSHplblqKz8XLaUf5tuuyRJmwr96PTpw/dyEVk9gicFX6JfLBEv0v5rN2Z0JMFLdfIP4kC1U2PjcPOWoglWO3fLmJ4Lol2a3c5XDSMwMxjcJXq+c8Ue1v0="
|
||||
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
||||
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
||||
- model_name: gemini-vision
|
||||
|
@ -17,4 +16,34 @@ model_list:
|
|||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
|
||||
- model_name: o1-preview
|
||||
litellm_params:
|
||||
model: o1-preview
|
||||
|
||||
litellm_settings:
|
||||
drop_params: True
|
||||
json_logs: True
|
||||
store_audit_logs: True
|
||||
log_raw_request_response: True
|
||||
return_response_headers: True
|
||||
num_retries: 5
|
||||
request_timeout: 200
|
||||
callbacks: ["custom_callbacks.proxy_handler_instance"]
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "presidio-pre-guard"
|
||||
litellm_params:
|
||||
guardrail: presidio # supported values: "aporia", "bedrock", "lakera", "presidio"
|
||||
mode: "logging_only"
|
||||
mock_redacted_text: {
|
||||
"text": "My name is <PERSON>, who are you? Say my name in your response",
|
||||
"items": [
|
||||
{
|
||||
"start": 11,
|
||||
"end": 19,
|
||||
"entity_type": "PERSON",
|
||||
"text": "<PERSON>",
|
||||
"operator": "replace",
|
||||
}
|
||||
],
|
||||
}
|
|
@ -446,7 +446,6 @@ async def user_api_key_auth(
|
|||
and request.headers.get(key=header_key) is not None # type: ignore
|
||||
):
|
||||
api_key = request.headers.get(key=header_key) # type: ignore
|
||||
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
return UserAPIKeyAuth(
|
||||
|
|
|
@ -2,9 +2,19 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
|
||||
|
||||
class MyCustomHandler(CustomLogger):
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
# print("Call failed")
|
||||
pass
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
# input_tokens = response_obj.get("usage", {}).get("prompt_tokens", 0)
|
||||
# output_tokens = response_obj.get("usage", {}).get("completion_tokens", 0)
|
||||
input_tokens = (
|
||||
response_obj.usage.prompt_tokens
|
||||
if hasattr(response_obj.usage, "prompt_tokens")
|
||||
else 0
|
||||
)
|
||||
output_tokens = (
|
||||
response_obj.usage.completion_tokens
|
||||
if hasattr(response_obj.usage, "completion_tokens")
|
||||
else 0
|
||||
)
|
||||
|
||||
|
||||
proxy_handler_instance = MyCustomHandler()
|
||||
|
|
|
@ -218,7 +218,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
response = await self.async_handler.post(
|
||||
url=prepared_request.url,
|
||||
json=request_data, # type: ignore
|
||||
headers=dict(prepared_request.headers),
|
||||
headers=prepared_request.headers, # type: ignore
|
||||
)
|
||||
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
|
||||
if response.status_code == 200:
|
||||
|
@ -254,7 +254,6 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
|||
from litellm.proxy.common_utils.callback_utils import (
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||
|
|
|
@ -189,6 +189,7 @@ class lakeraAI_Moderation(CustomGuardrail):
|
|||
# Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already)
|
||||
# Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked.
|
||||
# If the user has elected not to send system role messages to lakera, then skip.
|
||||
|
||||
if system_message is not None:
|
||||
if not litellm.add_function_to_prompt:
|
||||
content = system_message.get("content")
|
||||
|
|
|
@ -19,6 +19,7 @@ from fastapi import HTTPException
|
|||
from pydantic import BaseModel
|
||||
|
||||
import litellm # noqa: E401
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
|
@ -58,7 +59,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
self.pii_tokens: dict = (
|
||||
{}
|
||||
) # mapping of PII token to original text - only used with Presidio `replace` operation
|
||||
|
||||
self.mock_redacted_text = mock_redacted_text
|
||||
self.output_parse_pii = output_parse_pii or False
|
||||
if mock_testing is True: # for testing purposes only
|
||||
|
@ -92,8 +92,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
presidio_anonymizer_api_base: Optional[str] = None,
|
||||
):
|
||||
self.presidio_analyzer_api_base: Optional[str] = (
|
||||
presidio_analyzer_api_base
|
||||
or litellm.get_secret("PRESIDIO_ANALYZER_API_BASE", None)
|
||||
presidio_analyzer_api_base or get_secret("PRESIDIO_ANALYZER_API_BASE", None) # type: ignore
|
||||
)
|
||||
self.presidio_anonymizer_api_base: Optional[
|
||||
str
|
||||
|
@ -198,12 +197,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
else:
|
||||
raise Exception(f"Invalid anonymizer response: {redacted_text}")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
async def async_pre_call_hook(
|
||||
|
@ -254,9 +247,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.info(
|
||||
f"An error occurred -",
|
||||
)
|
||||
raise e
|
||||
|
||||
async def async_logging_hook(
|
||||
|
@ -300,9 +290,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
)
|
||||
kwargs["messages"] = messages
|
||||
|
||||
return kwargs, responses
|
||||
return kwargs, result
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
async def async_post_call_success_hook( # type: ignore
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
@ -314,7 +304,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
|||
verbose_proxy_logger.debug(
|
||||
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
|
||||
)
|
||||
if self.output_parse_pii == False:
|
||||
if self.output_parse_pii is False:
|
||||
return response
|
||||
|
||||
if isinstance(response, ModelResponse) and not isinstance(
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import importlib
|
||||
import traceback
|
||||
from typing import Dict, List, Literal
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, RootModel
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||
|
||||
|
@ -16,7 +17,6 @@ from litellm.types.guardrails import (
|
|||
GuardrailItemSpec,
|
||||
LakeraCategoryThresholds,
|
||||
LitellmParams,
|
||||
guardrailConfig,
|
||||
)
|
||||
|
||||
all_guardrails: List[GuardrailItem] = []
|
||||
|
@ -98,18 +98,13 @@ def init_guardrails_v2(
|
|||
# Init litellm params for guardrail
|
||||
litellm_params_data = guardrail["litellm_params"]
|
||||
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
|
||||
litellm_params = LitellmParams(
|
||||
guardrail=litellm_params_data["guardrail"],
|
||||
mode=litellm_params_data["mode"],
|
||||
api_key=litellm_params_data.get("api_key"),
|
||||
api_base=litellm_params_data.get("api_base"),
|
||||
guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"),
|
||||
guardrailVersion=litellm_params_data.get("guardrailVersion"),
|
||||
output_parse_pii=litellm_params_data.get("output_parse_pii"),
|
||||
presidio_ad_hoc_recognizers=litellm_params_data.get(
|
||||
"presidio_ad_hoc_recognizers"
|
||||
),
|
||||
)
|
||||
|
||||
_litellm_params_kwargs = {
|
||||
k: litellm_params_data[k] if k in litellm_params_data else None
|
||||
for k in LitellmParams.__annotations__.keys()
|
||||
}
|
||||
|
||||
litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore
|
||||
|
||||
if (
|
||||
"category_thresholds" in litellm_params_data
|
||||
|
@ -122,15 +117,11 @@ def init_guardrails_v2(
|
|||
|
||||
if litellm_params["api_key"]:
|
||||
if litellm_params["api_key"].startswith("os.environ/"):
|
||||
litellm_params["api_key"] = litellm.get_secret(
|
||||
litellm_params["api_key"]
|
||||
)
|
||||
litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore
|
||||
|
||||
if litellm_params["api_base"]:
|
||||
if litellm_params["api_base"].startswith("os.environ/"):
|
||||
litellm_params["api_base"] = litellm.get_secret(
|
||||
litellm_params["api_base"]
|
||||
)
|
||||
litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
|
||||
|
||||
# Init guardrail CustomLoggerClass
|
||||
if litellm_params["guardrail"] == "aporia":
|
||||
|
@ -182,6 +173,7 @@ def init_guardrails_v2(
|
|||
presidio_ad_hoc_recognizers=litellm_params[
|
||||
"presidio_ad_hoc_recognizers"
|
||||
],
|
||||
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
|
||||
)
|
||||
|
||||
if litellm_params["output_parse_pii"] is True:
|
||||
|
|
|
@ -167,11 +167,11 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
|
||||
if self.prompt_injection_params is not None:
|
||||
# 1. check if heuristics check turned on
|
||||
if self.prompt_injection_params.heuristics_check == True:
|
||||
if self.prompt_injection_params.heuristics_check is True:
|
||||
is_prompt_attack = self.check_user_input_similarity(
|
||||
user_input=formatted_prompt
|
||||
)
|
||||
if is_prompt_attack == True:
|
||||
if is_prompt_attack is True:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
|
@ -179,14 +179,14 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
},
|
||||
)
|
||||
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet
|
||||
if self.prompt_injection_params.vector_db_check == True:
|
||||
if self.prompt_injection_params.vector_db_check is True:
|
||||
pass
|
||||
else:
|
||||
is_prompt_attack = self.check_user_input_similarity(
|
||||
user_input=formatted_prompt
|
||||
)
|
||||
|
||||
if is_prompt_attack == True:
|
||||
if is_prompt_attack is True:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
|
@ -201,19 +201,18 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
|||
if (
|
||||
e.status_code == 400
|
||||
and isinstance(e.detail, dict)
|
||||
and "error" in e.detail
|
||||
and "error" in e.detail # type: ignore
|
||||
and self.prompt_injection_params is not None
|
||||
and self.prompt_injection_params.reject_as_response
|
||||
):
|
||||
return e.detail.get("error")
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
|
||||
async def async_moderation_hook( # type: ignore
|
||||
self,
|
||||
|
|
|
@ -195,7 +195,8 @@ async def user_auth(request: Request):
|
|||
- os.environ["SMTP_PASSWORD"]
|
||||
- os.environ["SMTP_SENDER_EMAIL"]
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client, send_email
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
from litellm.proxy.utils import send_email
|
||||
|
||||
data = await request.json() # type: ignore
|
||||
user_email = data["user_email"]
|
||||
|
@ -212,7 +213,7 @@ async def user_auth(request: Request):
|
|||
)
|
||||
### if so - generate a 24 hr key with that user id
|
||||
if response is not None:
|
||||
user_id = response.user_id
|
||||
user_id = response.user_id # type: ignore
|
||||
response = await generate_key_helper_fn(
|
||||
request_type="key",
|
||||
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id}, # type: ignore
|
||||
|
@ -345,6 +346,7 @@ async def user_info(
|
|||
for team in teams_1:
|
||||
team_id_list.append(team.team_id)
|
||||
|
||||
teams_2: Optional[Any] = None
|
||||
if user_info is not None:
|
||||
# *NEW* get all teams in user 'teams' field
|
||||
teams_2 = await prisma_client.get_data(
|
||||
|
@ -375,7 +377,7 @@ async def user_info(
|
|||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
else:
|
||||
elif caller_user_info is not None:
|
||||
teams_2 = await prisma_client.get_data(
|
||||
team_id_list=caller_user_info.teams,
|
||||
table_name="team",
|
||||
|
@ -395,7 +397,7 @@ async def user_info(
|
|||
query_type="find_all",
|
||||
)
|
||||
|
||||
if user_info is None:
|
||||
if user_info is None and keys is not None:
|
||||
## make sure we still return a total spend ##
|
||||
spend = 0
|
||||
for k in keys:
|
||||
|
@ -404,32 +406,35 @@ async def user_info(
|
|||
|
||||
## REMOVE HASHED TOKEN INFO before returning ##
|
||||
returned_keys = []
|
||||
for key in keys:
|
||||
if (
|
||||
key.token == litellm_master_key_hash
|
||||
and general_settings.get("disable_master_key_return", False)
|
||||
== True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
|
||||
):
|
||||
continue
|
||||
if keys is None:
|
||||
pass
|
||||
else:
|
||||
for key in keys:
|
||||
if (
|
||||
key.token == litellm_master_key_hash
|
||||
and general_settings.get("disable_master_key_return", False)
|
||||
== True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
key = key.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
key = key.dict()
|
||||
if (
|
||||
"team_id" in key
|
||||
and key["team_id"] is not None
|
||||
and key["team_id"] != "litellm-dashboard"
|
||||
):
|
||||
team_info = await prisma_client.get_data(
|
||||
team_id=key["team_id"], table_name="team"
|
||||
)
|
||||
team_alias = getattr(team_info, "team_alias", None)
|
||||
key["team_alias"] = team_alias
|
||||
else:
|
||||
key["team_alias"] = "None"
|
||||
returned_keys.append(key)
|
||||
try:
|
||||
key = key.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
key = key.dict()
|
||||
if (
|
||||
"team_id" in key
|
||||
and key["team_id"] is not None
|
||||
and key["team_id"] != "litellm-dashboard"
|
||||
):
|
||||
team_info = await prisma_client.get_data(
|
||||
team_id=key["team_id"], table_name="team"
|
||||
)
|
||||
team_alias = getattr(team_info, "team_alias", None)
|
||||
key["team_alias"] = team_alias
|
||||
else:
|
||||
key["team_alias"] = "None"
|
||||
returned_keys.append(key)
|
||||
|
||||
response_data = {
|
||||
"user_id": user_id,
|
||||
|
@ -539,6 +544,7 @@ async def user_update(
|
|||
|
||||
## ADD USER, IF NEW ##
|
||||
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
|
||||
response: Optional[Any] = None
|
||||
if data.user_id is not None and len(data.user_id) > 0:
|
||||
non_default_values["user_id"] = data.user_id # type: ignore
|
||||
verbose_proxy_logger.debug("In update user, user_id condition block.")
|
||||
|
@ -573,7 +579,7 @@ async def user_update(
|
|||
data=non_default_values,
|
||||
table_name="user",
|
||||
)
|
||||
return response
|
||||
return response # type: ignore
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
|
|
|
@ -226,7 +226,6 @@ from litellm.proxy.utils import (
|
|||
hash_token,
|
||||
log_to_opentelemetry,
|
||||
reset_budget,
|
||||
send_email,
|
||||
update_spend,
|
||||
)
|
||||
from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import (
|
||||
|
|
|
@ -1434,7 +1434,7 @@ async def _get_spend_report_for_time_range(
|
|||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"Exception in _get_daily_spend_reports {}".format(str(e))
|
||||
) # noqa
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
|
@ -1703,26 +1703,26 @@ async def view_spend_logs(
|
|||
result: dict = {}
|
||||
for record in response:
|
||||
dt_object = datetime.strptime(
|
||||
str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ"
|
||||
str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" # type: ignore
|
||||
) # type: ignore
|
||||
date = dt_object.date()
|
||||
if date not in result:
|
||||
result[date] = {"users": {}, "models": {}}
|
||||
api_key = record["api_key"]
|
||||
user_id = record["user"]
|
||||
model = record["model"]
|
||||
result[date]["spend"] = (
|
||||
result[date].get("spend", 0) + record["_sum"]["spend"]
|
||||
)
|
||||
result[date][api_key] = (
|
||||
result[date].get(api_key, 0) + record["_sum"]["spend"]
|
||||
)
|
||||
result[date]["users"][user_id] = (
|
||||
result[date]["users"].get(user_id, 0) + record["_sum"]["spend"]
|
||||
)
|
||||
result[date]["models"][model] = (
|
||||
result[date]["models"].get(model, 0) + record["_sum"]["spend"]
|
||||
)
|
||||
api_key = record["api_key"] # type: ignore
|
||||
user_id = record["user"] # type: ignore
|
||||
model = record["model"] # type: ignore
|
||||
result[date]["spend"] = result[date].get("spend", 0) + record.get(
|
||||
"_sum", {}
|
||||
).get("spend", 0)
|
||||
result[date][api_key] = result[date].get(api_key, 0) + record.get(
|
||||
"_sum", {}
|
||||
).get("spend", 0)
|
||||
result[date]["users"][user_id] = result[date]["users"].get(
|
||||
user_id, 0
|
||||
) + record.get("_sum", {}).get("spend", 0)
|
||||
result[date]["models"][model] = result[date]["models"].get(
|
||||
model, 0
|
||||
) + record.get("_sum", {}).get("spend", 0)
|
||||
return_list = []
|
||||
final_date = None
|
||||
for k, v in sorted(result.items()):
|
||||
|
@ -1784,7 +1784,7 @@ async def view_spend_logs(
|
|||
table_name="spend", query_type="find_all"
|
||||
)
|
||||
|
||||
return spend_log
|
||||
return spend_logs
|
||||
|
||||
return None
|
||||
|
||||
|
@ -1843,6 +1843,88 @@ async def global_spend_reset():
|
|||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/global/spend/refresh",
|
||||
tags=["Budget & Spend Tracking"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def global_spend_refresh():
|
||||
"""
|
||||
ADMIN ONLY / MASTER KEY Only Endpoint
|
||||
|
||||
Globally refresh spend MonthlyGlobalSpend view
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise ProxyException(
|
||||
message="Prisma Client is not initialized",
|
||||
type="internal_error",
|
||||
param="None",
|
||||
code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
## RESET GLOBAL SPEND VIEW ###
|
||||
async def is_materialized_global_spend_view() -> bool:
|
||||
"""
|
||||
Return True if materialized view exists
|
||||
|
||||
Else False
|
||||
"""
|
||||
sql_query = """
|
||||
SELECT relname, relkind
|
||||
FROM pg_class
|
||||
WHERE relname = 'MonthlyGlobalSpend';
|
||||
"""
|
||||
try:
|
||||
resp = await prisma_client.db.query_raw(sql_query)
|
||||
|
||||
assert resp[0]["relkind"] == "m"
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
view_exists = await is_materialized_global_spend_view()
|
||||
|
||||
if view_exists:
|
||||
# refresh materialized view
|
||||
sql_query = """
|
||||
REFRESH MATERIALIZED VIEW "MonthlyGlobalSpend";
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
db_url = os.getenv("DATABASE_URL")
|
||||
if db_url is None:
|
||||
raise Exception(CommonProxyErrors.db_not_connected_error.value)
|
||||
new_client = PrismaClient(
|
||||
database_url=db_url,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
http_client={
|
||||
"timeout": 6000,
|
||||
},
|
||||
)
|
||||
await new_client.db.connect()
|
||||
await new_client.db.query_raw(sql_query)
|
||||
verbose_proxy_logger.info("MonthlyGlobalSpend view refreshed")
|
||||
return {
|
||||
"message": "MonthlyGlobalSpend view refreshed",
|
||||
"status": "success",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Failed to refresh materialized view - {}".format(str(e))
|
||||
)
|
||||
return {
|
||||
"message": "Failed to refresh materialized view",
|
||||
"status": "failure",
|
||||
}
|
||||
|
||||
|
||||
async def global_spend_for_internal_user(
|
||||
api_key: Optional[str] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
|
|
|
@ -92,6 +92,7 @@ def safe_deep_copy(data):
|
|||
if litellm.safe_memory_mode is True:
|
||||
return data
|
||||
|
||||
litellm_parent_otel_span: Optional[Any] = None
|
||||
# Step 1: Remove the litellm_parent_otel_span
|
||||
litellm_parent_otel_span = None
|
||||
if isinstance(data, dict):
|
||||
|
@ -101,7 +102,7 @@ def safe_deep_copy(data):
|
|||
new_data = copy.deepcopy(data)
|
||||
|
||||
# Step 2: re-add the litellm_parent_otel_span after doing a deep copy
|
||||
if isinstance(data, dict):
|
||||
if isinstance(data, dict) and litellm_parent_otel_span is not None:
|
||||
if "metadata" in data:
|
||||
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
|
||||
return new_data
|
||||
|
@ -468,7 +469,7 @@ class ProxyLogging:
|
|||
|
||||
# V1 implementation - backwards compatibility
|
||||
if callback.event_hook is None:
|
||||
if callback.moderation_check == "pre_call":
|
||||
if callback.moderation_check == "pre_call": # type: ignore
|
||||
return
|
||||
else:
|
||||
# Main - V2 Guardrails implementation
|
||||
|
@ -881,7 +882,12 @@ class PrismaClient:
|
|||
org_list_transactons: dict = {}
|
||||
spend_log_transactions: List = []
|
||||
|
||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
||||
def __init__(
|
||||
self,
|
||||
database_url: str,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
http_client: Optional[Any] = None,
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
||||
)
|
||||
|
@ -912,7 +918,10 @@ class PrismaClient:
|
|||
# Now you can import the Prisma Client
|
||||
from prisma import Prisma # type: ignore
|
||||
verbose_proxy_logger.debug("Connecting Prisma Client to DB..")
|
||||
self.db = Prisma() # Client to connect to Prisma db
|
||||
if http_client is not None:
|
||||
self.db = Prisma(http=http_client)
|
||||
else:
|
||||
self.db = Prisma() # Client to connect to Prisma db
|
||||
verbose_proxy_logger.debug("Success - Connected Prisma Client to DB")
|
||||
|
||||
def hash_token(self, token: str):
|
||||
|
@ -987,7 +996,7 @@ class PrismaClient:
|
|||
return
|
||||
else:
|
||||
## check if required view exists ##
|
||||
if required_view not in ret[0]["view_names"]:
|
||||
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
|
||||
await self.health_check() # make sure we can connect to db
|
||||
await self.db.execute_raw(
|
||||
"""
|
||||
|
@ -1009,7 +1018,9 @@ class PrismaClient:
|
|||
else:
|
||||
# don't block execution if these views are missing
|
||||
# Convert lists to sets for efficient difference calculation
|
||||
ret_view_names_set = set(ret[0]["view_names"])
|
||||
ret_view_names_set = (
|
||||
set(ret[0]["view_names"]) if ret[0]["view_names"] else set()
|
||||
)
|
||||
expected_views_set = set(expected_views)
|
||||
# Find missing views
|
||||
missing_views = expected_views_set - ret_view_names_set
|
||||
|
@ -1291,13 +1302,13 @@ class PrismaClient:
|
|||
verbose_proxy_logger.debug(
|
||||
f"PrismaClient: get_data - args_passed_in: {args_passed_in}"
|
||||
)
|
||||
hashed_token: Optional[str] = None
|
||||
try:
|
||||
response: Any = None
|
||||
if (token is not None and table_name is None) or (
|
||||
table_name is not None and table_name == "key"
|
||||
):
|
||||
# check if plain text or hash
|
||||
hashed_token = None
|
||||
if token is not None:
|
||||
if isinstance(token, str):
|
||||
hashed_token = token
|
||||
|
@ -1306,7 +1317,7 @@ class PrismaClient:
|
|||
verbose_proxy_logger.debug(
|
||||
f"PrismaClient: find_unique for token: {hashed_token}"
|
||||
)
|
||||
if query_type == "find_unique":
|
||||
if query_type == "find_unique" and hashed_token is not None:
|
||||
if token is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
@ -1706,7 +1717,7 @@ class PrismaClient:
|
|||
updated_data = v
|
||||
updated_data = json.dumps(updated_data)
|
||||
updated_table_row = self.db.litellm_config.upsert(
|
||||
where={"param_name": k},
|
||||
where={"param_name": k}, # type: ignore
|
||||
data={
|
||||
"create": {"param_name": k, "param_value": updated_data}, # type: ignore
|
||||
"update": {"param_value": updated_data},
|
||||
|
@ -2302,7 +2313,12 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
|||
return instance
|
||||
except ImportError as e:
|
||||
# Re-raise the exception with a user-friendly message
|
||||
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
|
||||
if instance_name and module_name:
|
||||
raise ImportError(
|
||||
f"Could not import {instance_name} from {module_name}"
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -2377,12 +2393,12 @@ async def send_email(receiver_email, subject, html):
|
|||
|
||||
try:
|
||||
# Establish a secure connection with the SMTP server
|
||||
with smtplib.SMTP(smtp_host, smtp_port) as server:
|
||||
with smtplib.SMTP(smtp_host, smtp_port) as server: # type: ignore
|
||||
if os.getenv("SMTP_TLS", "True") != "False":
|
||||
server.starttls()
|
||||
|
||||
# Login to your email account
|
||||
server.login(smtp_username, smtp_password)
|
||||
server.login(smtp_username, smtp_password) # type: ignore
|
||||
|
||||
# Send the email
|
||||
server.send_message(email_message)
|
||||
|
@ -2945,7 +2961,7 @@ async def update_spend(
|
|||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||
raise # Re-raise the last exception
|
||||
# Optionally, sleep for a bit before retrying
|
||||
await asyncio.sleep(2**i) # Exponential backoff
|
||||
await asyncio.sleep(2**i) # type: ignore
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
|
|
|
@ -2110,7 +2110,6 @@ async def test_hf_completion_tgi_stream():
|
|||
def test_openai_chat_completion_call():
|
||||
litellm.set_verbose = False
|
||||
litellm.return_response_headers = True
|
||||
print(f"making openai chat completion call")
|
||||
response = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
||||
assert isinstance(
|
||||
response._hidden_params["additional_headers"][
|
||||
|
@ -2318,6 +2317,57 @@ def test_together_ai_completion_call_mistral():
|
|||
pass
|
||||
|
||||
|
||||
# # test on together ai completion call - starcoder
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_o1_completion_call_streaming(sync_mode):
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
if sync_mode:
|
||||
response = completion(
|
||||
model="o1-preview",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
print(f"returned response object: {response}")
|
||||
has_finish_reason = False
|
||||
for idx, chunk in enumerate(response):
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
has_finish_reason = finished
|
||||
if finished:
|
||||
break
|
||||
complete_response += chunk
|
||||
if has_finish_reason is False:
|
||||
raise Exception("Finish reason not set for last chunk")
|
||||
if complete_response == "":
|
||||
raise Exception("Empty response received")
|
||||
else:
|
||||
response = await acompletion(
|
||||
model="o1-preview",
|
||||
messages=messages,
|
||||
stream=True,
|
||||
)
|
||||
complete_response = ""
|
||||
print(f"returned response object: {response}")
|
||||
has_finish_reason = False
|
||||
idx = 0
|
||||
async for chunk in response:
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
has_finish_reason = finished
|
||||
if finished:
|
||||
break
|
||||
complete_response += chunk
|
||||
idx += 1
|
||||
if has_finish_reason is False:
|
||||
raise Exception("Finish reason not set for last chunk")
|
||||
if complete_response == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"complete response: {complete_response}")
|
||||
except Exception:
|
||||
pytest.fail(f"error occurred: {traceback.format_exc()}")
|
||||
|
||||
|
||||
def test_together_ai_completion_call_starcoder_bad_key():
|
||||
try:
|
||||
api_key = "bad-key"
|
||||
|
|
|
@ -71,7 +71,7 @@ class LakeraCategoryThresholds(TypedDict, total=False):
|
|||
jailbreak: float
|
||||
|
||||
|
||||
class LitellmParams(TypedDict, total=False):
|
||||
class LitellmParams(TypedDict):
|
||||
guardrail: str
|
||||
mode: str
|
||||
api_key: str
|
||||
|
@ -87,6 +87,7 @@ class LitellmParams(TypedDict, total=False):
|
|||
# Presidio params
|
||||
output_parse_pii: Optional[bool]
|
||||
presidio_ad_hoc_recognizers: Optional[str]
|
||||
mock_redacted_text: Optional[dict]
|
||||
|
||||
|
||||
class Guardrail(TypedDict):
|
||||
|
|
|
@ -120,11 +120,26 @@ with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json")
|
|||
# Convert to str (if necessary)
|
||||
claude_json_str = json.dumps(json_data)
|
||||
import importlib.metadata
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
|
||||
from openai import OpenAIError as OriginalError
|
||||
|
||||
from ._logging import verbose_logger
|
||||
from .caching import QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
|
||||
from .caching import Cache, QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
|
||||
from .exceptions import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
|
@ -150,31 +165,6 @@ from .types.llms.openai import (
|
|||
)
|
||||
from .types.router import LiteLLM_Params
|
||||
|
||||
try:
|
||||
from .proxy.enterprise.enterprise_callbacks.generic_api_callback import (
|
||||
GenericAPILogger,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
|
||||
from .caching import Cache
|
||||
|
||||
####### ENVIRONMENT VARIABLES ####################
|
||||
# Adjust to your specific application needs / system capabilities.
|
||||
MAX_THREADS = 100
|
||||
|
|
1
tests/llm_translation/Readme.md
Normal file
1
tests/llm_translation/Readme.md
Normal file
|
@ -0,0 +1 @@
|
|||
More tests under `litellm/litellm/tests/*`.
|
502
tests/llm_translation/test_databricks.py
Normal file
502
tests/llm_translation/test_databricks.py
Normal file
|
@ -0,0 +1,502 @@
|
|||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
import pytest
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import litellm
|
||||
from litellm.exceptions import BadRequestError, InternalServerError
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
|
||||
def mock_chat_response() -> Dict[str, Any]:
|
||||
return {
|
||||
"id": "chatcmpl_3f78f09a-489c-4b8d-a587-f162c7497891",
|
||||
"object": "chat.completion",
|
||||
"created": 1726285449,
|
||||
"model": "dbrx-instruct-071224",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! I'm an AI assistant. I'm doing well. How can I help?",
|
||||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 230,
|
||||
"completion_tokens": 38,
|
||||
"total_tokens": 268,
|
||||
"completion_tokens_details": None,
|
||||
},
|
||||
"system_fingerprint": None,
|
||||
}
|
||||
|
||||
|
||||
def mock_chat_streaming_response_chunks() -> List[str]:
|
||||
return [
|
||||
json.dumps(
|
||||
{
|
||||
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1726469651,
|
||||
"model": "dbrx-instruct-071224",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "Hello"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 230,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 231,
|
||||
},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1726469651,
|
||||
"model": "dbrx-instruct-071224",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": " world"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 230,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 231,
|
||||
},
|
||||
}
|
||||
),
|
||||
json.dumps(
|
||||
{
|
||||
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1726469651,
|
||||
"model": "dbrx-instruct-071224",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": "!"},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 230,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 231,
|
||||
},
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def mock_chat_streaming_response_chunks_bytes() -> List[bytes]:
|
||||
string_chunks = mock_chat_streaming_response_chunks()
|
||||
bytes_chunks = [chunk.encode("utf-8") + b"\n" for chunk in string_chunks]
|
||||
# Simulate the end of the stream
|
||||
bytes_chunks.append(b"")
|
||||
return bytes_chunks
|
||||
|
||||
|
||||
def mock_http_handler_chat_streaming_response() -> MagicMock:
|
||||
mock_stream_chunks = mock_chat_streaming_response_chunks()
|
||||
|
||||
def mock_iter_lines():
|
||||
for chunk in mock_stream_chunks:
|
||||
for line in chunk.splitlines():
|
||||
yield line
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.side_effect = mock_iter_lines
|
||||
mock_response.status_code = 200
|
||||
|
||||
return mock_response
|
||||
|
||||
|
||||
def mock_http_handler_chat_async_streaming_response() -> MagicMock:
|
||||
mock_stream_chunks = mock_chat_streaming_response_chunks()
|
||||
|
||||
async def mock_iter_lines():
|
||||
for chunk in mock_stream_chunks:
|
||||
for line in chunk.splitlines():
|
||||
yield line
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.aiter_lines.return_value = mock_iter_lines()
|
||||
mock_response.status_code = 200
|
||||
|
||||
return mock_response
|
||||
|
||||
|
||||
def mock_databricks_client_chat_streaming_response() -> MagicMock:
|
||||
mock_stream_chunks = mock_chat_streaming_response_chunks_bytes()
|
||||
|
||||
def mock_read_from_stream(size=-1):
|
||||
if mock_stream_chunks:
|
||||
return mock_stream_chunks.pop(0)
|
||||
return b""
|
||||
|
||||
mock_response = MagicMock()
|
||||
streaming_response_mock = MagicMock()
|
||||
streaming_response_iterator_mock = MagicMock()
|
||||
# Mock the __getitem__("content") method to return the streaming response
|
||||
mock_response.__getitem__.return_value = streaming_response_mock
|
||||
# Mock the streaming response __enter__ method to return the streaming response iterator
|
||||
streaming_response_mock.__enter__.return_value = streaming_response_iterator_mock
|
||||
|
||||
streaming_response_iterator_mock.read1.side_effect = mock_read_from_stream
|
||||
streaming_response_iterator_mock.closed = False
|
||||
|
||||
return mock_response
|
||||
|
||||
|
||||
def mock_embedding_response() -> Dict[str, Any]:
|
||||
return {
|
||||
"object": "list",
|
||||
"model": "bge-large-en-v1.5",
|
||||
"data": [
|
||||
{
|
||||
"index": 0,
|
||||
"object": "embedding",
|
||||
"embedding": [
|
||||
0.06768798828125,
|
||||
-0.01291656494140625,
|
||||
-0.0501708984375,
|
||||
0.0245361328125,
|
||||
-0.030364990234375,
|
||||
],
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 8,
|
||||
"total_tokens": 8,
|
||||
"completion_tokens": 0,
|
||||
"completion_tokens_details": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("set_base", [True, False])
|
||||
def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
|
||||
if set_base:
|
||||
monkeypatch.setenv(
|
||||
"DATABRICKS_API_BASE",
|
||||
"https://my.workspace.cloud.databricks.com/serving-endpoints",
|
||||
)
|
||||
monkeypatch.delenv(
|
||||
"DATABRICKS_API_KEY",
|
||||
)
|
||||
err_msg = "A call is being made to LLM Provider but no key is set"
|
||||
else:
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey")
|
||||
monkeypatch.delenv("DATABRICKS_API_BASE")
|
||||
err_msg = "A call is being made to LLM Provider but no api base is set"
|
||||
|
||||
with pytest.raises(BadRequestError) as exc:
|
||||
litellm.completion(
|
||||
model="databricks/dbrx-instruct-071224",
|
||||
messages={"role": "user", "content": "How are you?"},
|
||||
)
|
||||
assert err_msg in str(exc)
|
||||
|
||||
with pytest.raises(BadRequestError) as exc:
|
||||
litellm.embedding(
|
||||
model="databricks/bge-12312",
|
||||
input=["Hello", "World"],
|
||||
)
|
||||
assert err_msg in str(exc)
|
||||
|
||||
|
||||
def test_completions_with_sync_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
sync_handler = HTTPHandler()
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_chat_response()
|
||||
|
||||
expected_response_json = {
|
||||
**mock_chat_response(),
|
||||
**{
|
||||
"model": "databricks/dbrx-instruct-071224",
|
||||
},
|
||||
}
|
||||
|
||||
messages = [{"role": "user", "content": "How are you?"}]
|
||||
|
||||
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
||||
response = litellm.completion(
|
||||
model="databricks/dbrx-instruct-071224",
|
||||
messages=messages,
|
||||
client=sync_handler,
|
||||
temperature=0.5,
|
||||
extraparam="testpassingextraparam",
|
||||
)
|
||||
assert response.to_dict() == expected_response_json
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"extraparam": "testpassingextraparam",
|
||||
"stream": False,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_completions_with_async_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
async_handler = AsyncHTTPHandler()
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_chat_response()
|
||||
|
||||
expected_response_json = {
|
||||
**mock_chat_response(),
|
||||
**{
|
||||
"model": "databricks/dbrx-instruct-071224",
|
||||
},
|
||||
}
|
||||
|
||||
messages = [{"role": "user", "content": "How are you?"}]
|
||||
|
||||
with patch.object(
|
||||
AsyncHTTPHandler, "post", return_value=mock_response
|
||||
) as mock_post:
|
||||
response = asyncio.run(
|
||||
litellm.acompletion(
|
||||
model="databricks/dbrx-instruct-071224",
|
||||
messages=messages,
|
||||
client=async_handler,
|
||||
temperature=0.5,
|
||||
extraparam="testpassingextraparam",
|
||||
)
|
||||
)
|
||||
assert response.to_dict() == expected_response_json
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"extraparam": "testpassingextraparam",
|
||||
"stream": False,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_completions_streaming_with_sync_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
sync_handler = HTTPHandler()
|
||||
|
||||
messages = [{"role": "user", "content": "How are you?"}]
|
||||
mock_response = mock_http_handler_chat_streaming_response()
|
||||
|
||||
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
||||
response_stream: CustomStreamWrapper = litellm.completion(
|
||||
model="databricks/dbrx-instruct-071224",
|
||||
messages=messages,
|
||||
client=sync_handler,
|
||||
temperature=0.5,
|
||||
extraparam="testpassingextraparam",
|
||||
stream=True,
|
||||
)
|
||||
response = list(response_stream)
|
||||
assert "dbrx-instruct-071224" in str(response)
|
||||
assert "chatcmpl" in str(response)
|
||||
assert len(response) == 4
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"stream": True,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
def test_completions_streaming_with_async_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
async_handler = AsyncHTTPHandler()
|
||||
|
||||
messages = [{"role": "user", "content": "How are you?"}]
|
||||
mock_response = mock_http_handler_chat_async_streaming_response()
|
||||
|
||||
with patch.object(
|
||||
AsyncHTTPHandler, "post", return_value=mock_response
|
||||
) as mock_post:
|
||||
response_stream: CustomStreamWrapper = asyncio.run(
|
||||
litellm.acompletion(
|
||||
model="databricks/dbrx-instruct-071224",
|
||||
messages=messages,
|
||||
client=async_handler,
|
||||
temperature=0.5,
|
||||
extraparam="testpassingextraparam",
|
||||
stream=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Use async list gathering for the response
|
||||
async def gather_responses():
|
||||
return [item async for item in response_stream]
|
||||
|
||||
response = asyncio.run(gather_responses())
|
||||
assert "dbrx-instruct-071224" in str(response)
|
||||
assert "chatcmpl" in str(response)
|
||||
assert len(response) == 4
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "dbrx-instruct-071224",
|
||||
"messages": messages,
|
||||
"temperature": 0.5,
|
||||
"stream": True,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
def test_embeddings_with_sync_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
sync_handler = HTTPHandler()
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_embedding_response()
|
||||
|
||||
inputs = ["Hello", "World"]
|
||||
|
||||
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
||||
response = litellm.embedding(
|
||||
model="databricks/bge-large-en-v1.5",
|
||||
input=inputs,
|
||||
client=sync_handler,
|
||||
extraparam="testpassingextraparam",
|
||||
)
|
||||
assert response.to_dict() == mock_embedding_response()
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/embeddings",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "bge-large-en-v1.5",
|
||||
"input": inputs,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_embeddings_with_async_http_handler(monkeypatch):
|
||||
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||
api_key = "dapimykey"
|
||||
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||
|
||||
async_handler = AsyncHTTPHandler()
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_embedding_response()
|
||||
|
||||
inputs = ["Hello", "World"]
|
||||
|
||||
with patch.object(
|
||||
AsyncHTTPHandler, "post", return_value=mock_response
|
||||
) as mock_post:
|
||||
response = asyncio.run(
|
||||
litellm.aembedding(
|
||||
model="databricks/bge-large-en-v1.5",
|
||||
input=inputs,
|
||||
client=async_handler,
|
||||
extraparam="testpassingextraparam",
|
||||
)
|
||||
)
|
||||
assert response.to_dict() == mock_embedding_response()
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
f"{base_url}/embeddings",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data=json.dumps(
|
||||
{
|
||||
"model": "bge-large-en-v1.5",
|
||||
"input": inputs,
|
||||
"extraparam": "testpassingextraparam",
|
||||
}
|
||||
),
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue