forked from phoenix/litellm-mirror
* LiteLLM Minor Fixes & Improvements (09/16/2024) (#5723) * coverage (#5713) Signed-off-by: dbczumar <corey.zumar@databricks.com> * Move (#5714) Signed-off-by: dbczumar <corey.zumar@databricks.com> * fix(litellm_logging.py): fix logging client re-init (#5710) Fixes https://github.com/BerriAI/litellm/issues/5695 * fix(presidio.py): Fix logging_hook response and add support for additional presidio variables in guardrails config Fixes https://github.com/BerriAI/litellm/issues/5682 * feat(o1_handler.py): fake streaming for openai o1 models Fixes https://github.com/BerriAI/litellm/issues/5694 * docs: deprecated traceloop integration in favor of native otel (#5249) * fix: fix linting errors * fix: fix linting errors * fix(main.py): fix o1 import --------- Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit <nirga@users.noreply.github.com> * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view (#5730) * feat(spend_management_endpoints.py): expose `/global/spend/refresh` endpoint for updating material view Supports having `MonthlyGlobalSpend` view be a material view, and exposes an endpoint to refresh it * fix(custom_logger.py): reset calltype * fix: fix linting errors * fix: fix linting error * fix: fix import * test(test_databricks.py): fix databricks tests --------- Signed-off-by: dbczumar <corey.zumar@databricks.com> Co-authored-by: Corey Zumar <39497902+dbczumar@users.noreply.github.com> Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>
This commit is contained in:
parent
1e59395280
commit
234185ec13
34 changed files with 1387 additions and 502 deletions
|
@ -6,10 +6,14 @@ import asyncio
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Enter your DATABASE_URL here
|
# Enter your DATABASE_URL here
|
||||||
os.environ["DATABASE_URL"] = "postgresql://xxxxxxx"
|
|
||||||
from prisma import Prisma
|
from prisma import Prisma
|
||||||
|
|
||||||
db = Prisma()
|
db = Prisma(
|
||||||
|
http={
|
||||||
|
"timeout": 60000,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def check_view_exists():
|
async def check_view_exists():
|
||||||
|
@ -47,25 +51,22 @@ async def check_view_exists():
|
||||||
|
|
||||||
print("LiteLLM_VerificationTokenView Created!") # noqa
|
print("LiteLLM_VerificationTokenView Created!") # noqa
|
||||||
|
|
||||||
try:
|
sql_query = """
|
||||||
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
|
CREATE MATERIALIZED VIEW IF NOT EXISTS "MonthlyGlobalSpend" AS
|
||||||
print("MonthlyGlobalSpend Exists!") # noqa
|
|
||||||
except Exception as e:
|
|
||||||
sql_query = """
|
|
||||||
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
|
|
||||||
SELECT
|
SELECT
|
||||||
DATE("startTime") AS date,
|
DATE_TRUNC('day', "startTime") AS date,
|
||||||
SUM("spend") AS spend
|
SUM("spend") AS spend
|
||||||
FROM
|
FROM
|
||||||
"LiteLLM_SpendLogs"
|
"LiteLLM_SpendLogs"
|
||||||
WHERE
|
WHERE
|
||||||
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
"startTime" >= CURRENT_DATE - INTERVAL '30 days'
|
||||||
GROUP BY
|
GROUP BY
|
||||||
DATE("startTime");
|
DATE_TRUNC('day', "startTime");
|
||||||
"""
|
"""
|
||||||
await db.execute_raw(query=sql_query)
|
# Execute the queries
|
||||||
|
await db.execute_raw(query=sql_query)
|
||||||
|
|
||||||
print("MonthlyGlobalSpend Created!") # noqa
|
print("MonthlyGlobalSpend Created!") # noqa
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
|
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# OpenTelemetry - Tracing LLMs with any observability tool
|
||||||
|
|
||||||
|
OpenTelemetry is a CNCF standard for observability. It connects to any observability tool, such as Jaeger, Zipkin, Datadog, New Relic, Traceloop and others.
|
||||||
|
|
||||||
|
<Image img={require('../../img/traceloop_dash.png')} />
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
Install the OpenTelemetry SDK:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install opentelemetry-api opentelemetry-sdk opentelemetry-exporter-otlp
|
||||||
|
```
|
||||||
|
|
||||||
|
Set the environment variables (different providers may require different variables):
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
|
||||||
|
<TabItem value="traceloop" label="Log to Traceloop Cloud">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
OTEL_EXPORTER="otlp_http"
|
||||||
|
OTEL_ENDPOINT="https://api.traceloop.com"
|
||||||
|
OTEL_HEADERS="Authorization=Bearer%20<your-api-key>"
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="otel-col" label="Log to OTEL HTTP Collector">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
OTEL_EXPORTER="otlp_http"
|
||||||
|
OTEL_ENDPOINT="http:/0.0.0.0:4317"
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="otel-col-grpc" label="Log to OTEL GRPC Collector">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
OTEL_EXPORTER="otlp_grpc"
|
||||||
|
OTEL_ENDPOINT="http:/0.0.0.0:4317"
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
Use just 2 lines of code, to instantly log your LLM responses **across all providers** with OpenTelemetry:
|
||||||
|
|
||||||
|
```python
|
||||||
|
litellm.callbacks = ["otel"]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Redacting Messages, Response Content from OpenTelemetry Logging
|
||||||
|
|
||||||
|
### Redact Messages and Responses from all OpenTelemetry Logging
|
||||||
|
|
||||||
|
Set `litellm.turn_off_message_logging=True` This will prevent the messages and responses from being logged to OpenTelemetry, but request metadata will still be logged.
|
||||||
|
|
||||||
|
### Redact Messages and Responses from specific OpenTelemetry Logging
|
||||||
|
|
||||||
|
In the metadata typically passed for text completion or embedding calls you can set specific keys to mask the messages and responses for this call.
|
||||||
|
|
||||||
|
Setting `mask_input` to `True` will mask the input from being logged for this call
|
||||||
|
|
||||||
|
Setting `mask_output` to `True` will make the output from being logged for this call.
|
||||||
|
|
||||||
|
Be aware that if you are continuing an existing trace, and you set `update_trace_keys` to include either `input` or `output` and you set the corresponding `mask_input` or `mask_output`, then that trace will have its existing input and/or output replaced with a redacted message.
|
||||||
|
|
||||||
|
## Support
|
||||||
|
|
||||||
|
For any question or issue with the integration you can reach out to the OpenLLMetry maintainers on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).
|
|
@ -1,36 +0,0 @@
|
||||||
import Image from '@theme/IdealImage';
|
|
||||||
|
|
||||||
# Traceloop (OpenLLMetry) - Tracing LLMs with OpenTelemetry
|
|
||||||
|
|
||||||
[Traceloop](https://traceloop.com) is a platform for monitoring and debugging the quality of your LLM outputs.
|
|
||||||
It provides you with a way to track the performance of your LLM application; rollout changes with confidence; and debug issues in production.
|
|
||||||
It is based on [OpenTelemetry](https://opentelemetry.io), so it can provide full visibility to your LLM requests, as well vector DB usage, and other infra in your stack.
|
|
||||||
|
|
||||||
<Image img={require('../../img/traceloop_dash.png')} />
|
|
||||||
|
|
||||||
## Getting Started
|
|
||||||
|
|
||||||
Install the Traceloop SDK:
|
|
||||||
|
|
||||||
```
|
|
||||||
pip install traceloop-sdk
|
|
||||||
```
|
|
||||||
|
|
||||||
Use just 2 lines of code, to instantly log your LLM responses with OpenTelemetry:
|
|
||||||
|
|
||||||
```python
|
|
||||||
Traceloop.init(app_name=<YOUR APP NAME>, disable_batch=True)
|
|
||||||
litellm.success_callback = ["traceloop"]
|
|
||||||
```
|
|
||||||
|
|
||||||
Make sure to properly set a destination to your traces. See [OpenLLMetry docs](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for options.
|
|
||||||
|
|
||||||
To get better visualizations on how your code behaves, you may want to annotate specific parts of your LLM chain. See [Traceloop docs on decorators](https://traceloop.com/docs/python-sdk/decorators) for more information.
|
|
||||||
|
|
||||||
## Exporting traces to other systems (e.g. Datadog, New Relic, and others)
|
|
||||||
|
|
||||||
Since OpenLLMetry uses OpenTelemetry to send data, you can easily export your traces to other systems, such as Datadog, New Relic, and others. See [OpenLLMetry docs on exporters](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for more information.
|
|
||||||
|
|
||||||
## Support
|
|
||||||
|
|
||||||
For any question or issue with integration you can reach out to the Traceloop team on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).
|
|
|
@ -600,6 +600,52 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="traceloop" label="Log to Traceloop Cloud">
|
||||||
|
|
||||||
|
#### Quick Start - Log to Traceloop
|
||||||
|
|
||||||
|
**Step 1:**
|
||||||
|
Add the following to your env
|
||||||
|
|
||||||
|
```shell
|
||||||
|
OTEL_EXPORTER="otlp_http"
|
||||||
|
OTEL_ENDPOINT="https://api.traceloop.com"
|
||||||
|
OTEL_HEADERS="Authorization=Bearer%20<your-api-key>"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2:** Add `otel` as a callbacks
|
||||||
|
|
||||||
|
```shell
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["otel"]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3**: Start the proxy, make a test request
|
||||||
|
|
||||||
|
Start proxy
|
||||||
|
|
||||||
|
```shell
|
||||||
|
litellm --config config.yaml --detailed_debug
|
||||||
|
```
|
||||||
|
|
||||||
|
Test Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
<TabItem value="otel-col" label="Log to OTEL HTTP Collector">
|
<TabItem value="otel-col" label="Log to OTEL HTTP Collector">
|
||||||
|
|
||||||
#### Quick Start - Log to OTEL Collector
|
#### Quick Start - Log to OTEL Collector
|
||||||
|
@ -694,52 +740,6 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
<TabItem value="traceloop" label="Log to Traceloop Cloud">
|
|
||||||
|
|
||||||
#### Quick Start - Log to Traceloop
|
|
||||||
|
|
||||||
**Step 1:** Install the `traceloop-sdk` SDK
|
|
||||||
|
|
||||||
```shell
|
|
||||||
pip install traceloop-sdk==0.21.2
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2:** Add `traceloop` as a success_callback
|
|
||||||
|
|
||||||
```shell
|
|
||||||
litellm_settings:
|
|
||||||
success_callback: ["traceloop"]
|
|
||||||
|
|
||||||
environment_variables:
|
|
||||||
TRACELOOP_API_KEY: "XXXXX"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3**: Start the proxy, make a test request
|
|
||||||
|
|
||||||
Start proxy
|
|
||||||
|
|
||||||
```shell
|
|
||||||
litellm --config config.yaml --detailed_debug
|
|
||||||
```
|
|
||||||
|
|
||||||
Test Request
|
|
||||||
|
|
||||||
```shell
|
|
||||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
|
||||||
--header 'Content-Type: application/json' \
|
|
||||||
--data ' {
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "what llm are you"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
</TabItem>
|
|
||||||
|
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
** 🎉 Expect to see this trace logged in your OTEL collector**
|
** 🎉 Expect to see this trace logged in your OTEL collector**
|
||||||
|
|
|
@ -255,6 +255,7 @@ const sidebars = {
|
||||||
type: "category",
|
type: "category",
|
||||||
label: "Logging & Observability",
|
label: "Logging & Observability",
|
||||||
items: [
|
items: [
|
||||||
|
"observability/opentelemetry_integration",
|
||||||
"observability/langfuse_integration",
|
"observability/langfuse_integration",
|
||||||
"observability/logfire_integration",
|
"observability/logfire_integration",
|
||||||
"observability/gcs_bucket_integration",
|
"observability/gcs_bucket_integration",
|
||||||
|
@ -271,7 +272,7 @@ const sidebars = {
|
||||||
"observability/openmeter",
|
"observability/openmeter",
|
||||||
"observability/promptlayer_integration",
|
"observability/promptlayer_integration",
|
||||||
"observability/wandb_integration",
|
"observability/wandb_integration",
|
||||||
"observability/traceloop_integration",
|
"observability/slack_integration",
|
||||||
"observability/athina_integration",
|
"observability/athina_integration",
|
||||||
"observability/lunary_integration",
|
"observability/lunary_integration",
|
||||||
"observability/greenscale_integration",
|
"observability/greenscale_integration",
|
||||||
|
|
|
@ -948,10 +948,10 @@ from .llms.OpenAI.openai import (
|
||||||
AzureAIStudioConfig,
|
AzureAIStudioConfig,
|
||||||
)
|
)
|
||||||
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
from .llms.mistral.mistral_chat_transformation import MistralConfig
|
||||||
from .llms.OpenAI.o1_transformation import (
|
from .llms.OpenAI.chat.o1_transformation import (
|
||||||
OpenAIO1Config,
|
OpenAIO1Config,
|
||||||
)
|
)
|
||||||
from .llms.OpenAI.gpt_transformation import (
|
from .llms.OpenAI.chat.gpt_transformation import (
|
||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
)
|
)
|
||||||
from .llms.nvidia_nim import NvidiaNimConfig
|
from .llms.nvidia_nim import NvidiaNimConfig
|
||||||
|
|
|
@ -1616,6 +1616,9 @@ Model Info:
|
||||||
|
|
||||||
:param time_range: A string specifying the time range, e.g., "1d", "7d", "30d"
|
:param time_range: A string specifying the time range, e.g., "1d", "7d", "30d"
|
||||||
"""
|
"""
|
||||||
|
if self.alerting is None or "spend_reports" not in self.alert_types:
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
from litellm.proxy.spend_tracking.spend_management_endpoints import (
|
||||||
_get_spend_report_for_time_range,
|
_get_spend_report_for_time_range,
|
||||||
|
|
|
@ -11,7 +11,7 @@ class CustomGuardrail(CustomLogger):
|
||||||
self,
|
self,
|
||||||
guardrail_name: Optional[str] = None,
|
guardrail_name: Optional[str] = None,
|
||||||
event_hook: Optional[GuardrailEventHooks] = None,
|
event_hook: Optional[GuardrailEventHooks] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.guardrail_name = guardrail_name
|
self.guardrail_name = guardrail_name
|
||||||
self.event_hook: Optional[GuardrailEventHooks] = event_hook
|
self.event_hook: Optional[GuardrailEventHooks] = event_hook
|
||||||
|
@ -28,10 +28,13 @@ class CustomGuardrail(CustomLogger):
|
||||||
requested_guardrails,
|
requested_guardrails,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.guardrail_name not in requested_guardrails:
|
if (
|
||||||
|
self.guardrail_name not in requested_guardrails
|
||||||
|
and event_type.value != "logging_only"
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.event_hook != event_type:
|
if self.event_hook != event_type.value:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -4,6 +4,11 @@ import litellm
|
||||||
|
|
||||||
|
|
||||||
class TraceloopLogger:
|
class TraceloopLogger:
|
||||||
|
"""
|
||||||
|
WARNING: DEPRECATED
|
||||||
|
Use the OpenTelemetry standard integration instead
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
try:
|
||||||
from traceloop.sdk.tracing.tracing import TracerWrapper
|
from traceloop.sdk.tracing.tracing import TracerWrapper
|
||||||
|
|
|
@ -90,6 +90,13 @@ from ..integrations.supabase import Supabase
|
||||||
from ..integrations.traceloop import TraceloopLogger
|
from ..integrations.traceloop import TraceloopLogger
|
||||||
from ..integrations.weights_biases import WeightsBiasesLogger
|
from ..integrations.weights_biases import WeightsBiasesLogger
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ..proxy.enterprise.enterprise_callbacks.generic_api_callback import (
|
||||||
|
GenericAPILogger,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
|
||||||
|
|
||||||
_in_memory_loggers: List[Any] = []
|
_in_memory_loggers: List[Any] = []
|
||||||
|
|
||||||
### GLOBAL VARIABLES ###
|
### GLOBAL VARIABLES ###
|
||||||
|
@ -145,7 +152,41 @@ class ServiceTraceIDCache:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicLoggingCache:
|
||||||
|
"""
|
||||||
|
Prevent memory leaks caused by initializing new logging clients on each request.
|
||||||
|
|
||||||
|
Relevant Issue: https://github.com/BerriAI/litellm/issues/5695
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.cache = InMemoryCache()
|
||||||
|
|
||||||
|
def get_cache_key(self, args: dict) -> str:
|
||||||
|
args_str = json.dumps(args, sort_keys=True)
|
||||||
|
cache_key = hashlib.sha256(args_str.encode("utf-8")).hexdigest()
|
||||||
|
return cache_key
|
||||||
|
|
||||||
|
def get_cache(self, credentials: dict, service_name: str) -> Optional[Any]:
|
||||||
|
key_name = self.get_cache_key(
|
||||||
|
args={**credentials, "service_name": service_name}
|
||||||
|
)
|
||||||
|
response = self.cache.get_cache(key=key_name)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def set_cache(self, credentials: dict, service_name: str, logging_obj: Any) -> None:
|
||||||
|
key_name = self.get_cache_key(
|
||||||
|
args={**credentials, "service_name": service_name}
|
||||||
|
)
|
||||||
|
self.cache.set_cache(key=key_name, value=logging_obj)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
in_memory_trace_id_cache = ServiceTraceIDCache()
|
in_memory_trace_id_cache = ServiceTraceIDCache()
|
||||||
|
in_memory_dynamic_logger_cache = DynamicLoggingCache()
|
||||||
|
|
||||||
|
|
||||||
class Logging:
|
class Logging:
|
||||||
|
@ -324,10 +365,10 @@ class Logging:
|
||||||
print_verbose(f"\033[92m{curl_command}\033[0m\n", log_level="DEBUG")
|
print_verbose(f"\033[92m{curl_command}\033[0m\n", log_level="DEBUG")
|
||||||
# log raw request to provider (like LangFuse) -- if opted in.
|
# log raw request to provider (like LangFuse) -- if opted in.
|
||||||
if log_raw_request_response is True:
|
if log_raw_request_response is True:
|
||||||
|
_litellm_params = self.model_call_details.get("litellm_params", {})
|
||||||
|
_metadata = _litellm_params.get("metadata", {}) or {}
|
||||||
try:
|
try:
|
||||||
# [Non-blocking Extra Debug Information in metadata]
|
# [Non-blocking Extra Debug Information in metadata]
|
||||||
_litellm_params = self.model_call_details.get("litellm_params", {})
|
|
||||||
_metadata = _litellm_params.get("metadata", {}) or {}
|
|
||||||
if (
|
if (
|
||||||
turn_off_message_logging is not None
|
turn_off_message_logging is not None
|
||||||
and turn_off_message_logging is True
|
and turn_off_message_logging is True
|
||||||
|
@ -362,7 +403,7 @@ class Logging:
|
||||||
callbacks = litellm.input_callback + self.dynamic_input_callbacks
|
callbacks = litellm.input_callback + self.dynamic_input_callbacks
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
try:
|
try:
|
||||||
if callback == "supabase":
|
if callback == "supabase" and supabaseClient is not None:
|
||||||
verbose_logger.debug("reaches supabase for logging!")
|
verbose_logger.debug("reaches supabase for logging!")
|
||||||
model = self.model_call_details["model"]
|
model = self.model_call_details["model"]
|
||||||
messages = self.model_call_details["input"]
|
messages = self.model_call_details["input"]
|
||||||
|
@ -396,7 +437,9 @@ class Logging:
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
)
|
)
|
||||||
elif callable(callback): # custom logger functions
|
elif (
|
||||||
|
callable(callback) and customLogger is not None
|
||||||
|
): # custom logger functions
|
||||||
customLogger.log_input_event(
|
customLogger.log_input_event(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
|
@ -615,7 +658,7 @@ class Logging:
|
||||||
|
|
||||||
self.model_call_details["litellm_params"]["metadata"][
|
self.model_call_details["litellm_params"]["metadata"][
|
||||||
"hidden_params"
|
"hidden_params"
|
||||||
] = result._hidden_params
|
] = getattr(result, "_hidden_params", {})
|
||||||
## STANDARDIZED LOGGING PAYLOAD
|
## STANDARDIZED LOGGING PAYLOAD
|
||||||
|
|
||||||
self.model_call_details["standard_logging_object"] = (
|
self.model_call_details["standard_logging_object"] = (
|
||||||
|
@ -645,6 +688,7 @@ class Logging:
|
||||||
litellm.max_budget
|
litellm.max_budget
|
||||||
and self.stream is False
|
and self.stream is False
|
||||||
and result is not None
|
and result is not None
|
||||||
|
and isinstance(result, dict)
|
||||||
and "content" in result
|
and "content" in result
|
||||||
):
|
):
|
||||||
time_diff = (end_time - start_time).total_seconds()
|
time_diff = (end_time - start_time).total_seconds()
|
||||||
|
@ -652,7 +696,7 @@ class Logging:
|
||||||
litellm._current_cost += litellm.completion_cost(
|
litellm._current_cost += litellm.completion_cost(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
prompt="",
|
prompt="",
|
||||||
completion=result["content"],
|
completion=getattr(result, "content", ""),
|
||||||
total_time=float_diff,
|
total_time=float_diff,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -758,7 +802,7 @@ class Logging:
|
||||||
):
|
):
|
||||||
print_verbose("no-log request, skipping logging")
|
print_verbose("no-log request, skipping logging")
|
||||||
continue
|
continue
|
||||||
if callback == "lite_debugger":
|
if callback == "lite_debugger" and liteDebuggerClient is not None:
|
||||||
print_verbose("reaches lite_debugger for logging!")
|
print_verbose("reaches lite_debugger for logging!")
|
||||||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -774,7 +818,7 @@ class Logging:
|
||||||
call_type=self.call_type,
|
call_type=self.call_type,
|
||||||
stream=self.stream,
|
stream=self.stream,
|
||||||
)
|
)
|
||||||
if callback == "promptlayer":
|
if callback == "promptlayer" and promptLayerLogger is not None:
|
||||||
print_verbose("reaches promptlayer for logging!")
|
print_verbose("reaches promptlayer for logging!")
|
||||||
promptLayerLogger.log_event(
|
promptLayerLogger.log_event(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
|
@ -783,7 +827,7 @@ class Logging:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
if callback == "supabase":
|
if callback == "supabase" and supabaseClient is not None:
|
||||||
print_verbose("reaches supabase for logging!")
|
print_verbose("reaches supabase for logging!")
|
||||||
kwargs = self.model_call_details
|
kwargs = self.model_call_details
|
||||||
|
|
||||||
|
@ -811,7 +855,7 @@ class Logging:
|
||||||
),
|
),
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
if callback == "wandb":
|
if callback == "wandb" and weightsBiasesLogger is not None:
|
||||||
print_verbose("reaches wandb for logging!")
|
print_verbose("reaches wandb for logging!")
|
||||||
weightsBiasesLogger.log_event(
|
weightsBiasesLogger.log_event(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
|
@ -820,8 +864,7 @@ class Logging:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
if callback == "logfire":
|
if callback == "logfire" and logfireLogger is not None:
|
||||||
global logfireLogger
|
|
||||||
verbose_logger.debug("reaches logfire for success logging!")
|
verbose_logger.debug("reaches logfire for success logging!")
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for k, v in self.model_call_details.items():
|
for k, v in self.model_call_details.items():
|
||||||
|
@ -844,10 +887,10 @@ class Logging:
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
level=LogfireLevel.INFO.value,
|
level=LogfireLevel.INFO.value, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
if callback == "lunary":
|
if callback == "lunary" and lunaryLogger is not None:
|
||||||
print_verbose("reaches lunary for logging!")
|
print_verbose("reaches lunary for logging!")
|
||||||
model = self.model
|
model = self.model
|
||||||
kwargs = self.model_call_details
|
kwargs = self.model_call_details
|
||||||
|
@ -882,7 +925,7 @@ class Logging:
|
||||||
run_id=self.litellm_call_id,
|
run_id=self.litellm_call_id,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
if callback == "helicone":
|
if callback == "helicone" and heliconeLogger is not None:
|
||||||
print_verbose("reaches helicone for logging!")
|
print_verbose("reaches helicone for logging!")
|
||||||
model = self.model
|
model = self.model
|
||||||
messages = self.model_call_details["input"]
|
messages = self.model_call_details["input"]
|
||||||
|
@ -924,6 +967,7 @@ class Logging:
|
||||||
else:
|
else:
|
||||||
print_verbose("reaches langfuse for streaming logging!")
|
print_verbose("reaches langfuse for streaming logging!")
|
||||||
result = kwargs["complete_streaming_response"]
|
result = kwargs["complete_streaming_response"]
|
||||||
|
|
||||||
temp_langfuse_logger = langFuseLogger
|
temp_langfuse_logger = langFuseLogger
|
||||||
if langFuseLogger is None or (
|
if langFuseLogger is None or (
|
||||||
(
|
(
|
||||||
|
@ -941,27 +985,45 @@ class Logging:
|
||||||
and self.langfuse_host != langFuseLogger.langfuse_host
|
and self.langfuse_host != langFuseLogger.langfuse_host
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
temp_langfuse_logger = LangFuseLogger(
|
credentials = {
|
||||||
langfuse_public_key=self.langfuse_public_key,
|
"langfuse_public_key": self.langfuse_public_key,
|
||||||
langfuse_secret=self.langfuse_secret,
|
"langfuse_secret": self.langfuse_secret,
|
||||||
langfuse_host=self.langfuse_host,
|
"langfuse_host": self.langfuse_host,
|
||||||
)
|
}
|
||||||
_response = temp_langfuse_logger.log_event(
|
temp_langfuse_logger = (
|
||||||
kwargs=kwargs,
|
in_memory_dynamic_logger_cache.get_cache(
|
||||||
response_obj=result,
|
credentials=credentials, service_name="langfuse"
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
user_id=kwargs.get("user", None),
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
)
|
|
||||||
if _response is not None and isinstance(_response, dict):
|
|
||||||
_trace_id = _response.get("trace_id", None)
|
|
||||||
if _trace_id is not None:
|
|
||||||
in_memory_trace_id_cache.set_cache(
|
|
||||||
litellm_call_id=self.litellm_call_id,
|
|
||||||
service_name="langfuse",
|
|
||||||
trace_id=_trace_id,
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
if temp_langfuse_logger is None:
|
||||||
|
temp_langfuse_logger = LangFuseLogger(
|
||||||
|
langfuse_public_key=self.langfuse_public_key,
|
||||||
|
langfuse_secret=self.langfuse_secret,
|
||||||
|
langfuse_host=self.langfuse_host,
|
||||||
|
)
|
||||||
|
in_memory_dynamic_logger_cache.set_cache(
|
||||||
|
credentials=credentials,
|
||||||
|
service_name="langfuse",
|
||||||
|
logging_obj=temp_langfuse_logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
if temp_langfuse_logger is not None:
|
||||||
|
_response = temp_langfuse_logger.log_event(
|
||||||
|
kwargs=kwargs,
|
||||||
|
response_obj=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
user_id=kwargs.get("user", None),
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
)
|
||||||
|
if _response is not None and isinstance(_response, dict):
|
||||||
|
_trace_id = _response.get("trace_id", None)
|
||||||
|
if _trace_id is not None:
|
||||||
|
in_memory_trace_id_cache.set_cache(
|
||||||
|
litellm_call_id=self.litellm_call_id,
|
||||||
|
service_name="langfuse",
|
||||||
|
trace_id=_trace_id,
|
||||||
|
)
|
||||||
if callback == "generic":
|
if callback == "generic":
|
||||||
global genericAPILogger
|
global genericAPILogger
|
||||||
verbose_logger.debug("reaches langfuse for success logging!")
|
verbose_logger.debug("reaches langfuse for success logging!")
|
||||||
|
@ -982,7 +1044,7 @@ class Logging:
|
||||||
print_verbose("reaches langfuse for streaming logging!")
|
print_verbose("reaches langfuse for streaming logging!")
|
||||||
result = kwargs["complete_streaming_response"]
|
result = kwargs["complete_streaming_response"]
|
||||||
if genericAPILogger is None:
|
if genericAPILogger is None:
|
||||||
genericAPILogger = GenericAPILogger()
|
genericAPILogger = GenericAPILogger() # type: ignore
|
||||||
genericAPILogger.log_event(
|
genericAPILogger.log_event(
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
response_obj=result,
|
response_obj=result,
|
||||||
|
@ -1022,7 +1084,7 @@ class Logging:
|
||||||
user_id=kwargs.get("user", None),
|
user_id=kwargs.get("user", None),
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
if callback == "greenscale":
|
if callback == "greenscale" and greenscaleLogger is not None:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for k, v in self.model_call_details.items():
|
for k, v in self.model_call_details.items():
|
||||||
if (
|
if (
|
||||||
|
@ -1066,7 +1128,7 @@ class Logging:
|
||||||
result = kwargs["complete_streaming_response"]
|
result = kwargs["complete_streaming_response"]
|
||||||
# only add to cache once we have a complete streaming response
|
# only add to cache once we have a complete streaming response
|
||||||
litellm.cache.add_cache(result, **kwargs)
|
litellm.cache.add_cache(result, **kwargs)
|
||||||
if callback == "athina":
|
if callback == "athina" and athinaLogger is not None:
|
||||||
deep_copy = {}
|
deep_copy = {}
|
||||||
for k, v in self.model_call_details.items():
|
for k, v in self.model_call_details.items():
|
||||||
deep_copy[k] = v
|
deep_copy[k] = v
|
||||||
|
@ -1224,6 +1286,7 @@ class Logging:
|
||||||
"atranscription", False
|
"atranscription", False
|
||||||
)
|
)
|
||||||
is not True
|
is not True
|
||||||
|
and customLogger is not None
|
||||||
): # custom logger functions
|
): # custom logger functions
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"success callbacks: Running Custom Callback Function"
|
f"success callbacks: Running Custom Callback Function"
|
||||||
|
@ -1423,9 +1486,8 @@ class Logging:
|
||||||
await litellm.cache.async_add_cache(result, **kwargs)
|
await litellm.cache.async_add_cache(result, **kwargs)
|
||||||
else:
|
else:
|
||||||
litellm.cache.add_cache(result, **kwargs)
|
litellm.cache.add_cache(result, **kwargs)
|
||||||
if callback == "openmeter":
|
if callback == "openmeter" and openMeterLogger is not None:
|
||||||
global openMeterLogger
|
if self.stream is True:
|
||||||
if self.stream == True:
|
|
||||||
if (
|
if (
|
||||||
"async_complete_streaming_response"
|
"async_complete_streaming_response"
|
||||||
in self.model_call_details
|
in self.model_call_details
|
||||||
|
@ -1645,33 +1707,9 @@ class Logging:
|
||||||
)
|
)
|
||||||
for callback in callbacks:
|
for callback in callbacks:
|
||||||
try:
|
try:
|
||||||
if callback == "lite_debugger":
|
if callback == "lite_debugger" and liteDebuggerClient is not None:
|
||||||
print_verbose("reaches lite_debugger for logging!")
|
pass
|
||||||
print_verbose(f"liteDebuggerClient: {liteDebuggerClient}")
|
elif callback == "lunary" and lunaryLogger is not None:
|
||||||
result = {
|
|
||||||
"model": self.model,
|
|
||||||
"created": time.time(),
|
|
||||||
"error": traceback_exception,
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": prompt_token_calculator(
|
|
||||||
self.model, messages=self.messages
|
|
||||||
),
|
|
||||||
"completion_tokens": 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
liteDebuggerClient.log_event(
|
|
||||||
model=self.model,
|
|
||||||
messages=self.messages,
|
|
||||||
end_user=self.model_call_details.get("user", "default"),
|
|
||||||
response_obj=result,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
litellm_call_id=self.litellm_call_id,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
call_type=self.call_type,
|
|
||||||
stream=self.stream,
|
|
||||||
)
|
|
||||||
if callback == "lunary":
|
|
||||||
print_verbose("reaches lunary for logging error!")
|
print_verbose("reaches lunary for logging error!")
|
||||||
|
|
||||||
model = self.model
|
model = self.model
|
||||||
|
@ -1685,6 +1723,7 @@ class Logging:
|
||||||
)
|
)
|
||||||
|
|
||||||
lunaryLogger.log_event(
|
lunaryLogger.log_event(
|
||||||
|
kwargs=self.model_call_details,
|
||||||
type=_type,
|
type=_type,
|
||||||
event="error",
|
event="error",
|
||||||
user_id=self.model_call_details.get("user", "default"),
|
user_id=self.model_call_details.get("user", "default"),
|
||||||
|
@ -1704,22 +1743,11 @@ class Logging:
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"capture exception not initialized: {capture_exception}"
|
f"capture exception not initialized: {capture_exception}"
|
||||||
)
|
)
|
||||||
elif callback == "supabase":
|
elif callback == "supabase" and supabaseClient is not None:
|
||||||
print_verbose("reaches supabase for logging!")
|
print_verbose("reaches supabase for logging!")
|
||||||
print_verbose(f"supabaseClient: {supabaseClient}")
|
print_verbose(f"supabaseClient: {supabaseClient}")
|
||||||
result = {
|
|
||||||
"model": model,
|
|
||||||
"created": time.time(),
|
|
||||||
"error": traceback_exception,
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": prompt_token_calculator(
|
|
||||||
model, messages=self.messages
|
|
||||||
),
|
|
||||||
"completion_tokens": 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
supabaseClient.log_event(
|
supabaseClient.log_event(
|
||||||
model=self.model,
|
model=self.model if hasattr(self, "model") else "",
|
||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
end_user=self.model_call_details.get("user", "default"),
|
end_user=self.model_call_details.get("user", "default"),
|
||||||
response_obj=result,
|
response_obj=result,
|
||||||
|
@ -1728,7 +1756,9 @@ class Logging:
|
||||||
litellm_call_id=self.model_call_details["litellm_call_id"],
|
litellm_call_id=self.model_call_details["litellm_call_id"],
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
if callable(callback): # custom logger functions
|
if (
|
||||||
|
callable(callback) and customLogger is not None
|
||||||
|
): # custom logger functions
|
||||||
customLogger.log_event(
|
customLogger.log_event(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
response_obj=result,
|
response_obj=result,
|
||||||
|
@ -1809,13 +1839,13 @@ class Logging:
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
response_obj=None,
|
response_obj=None,
|
||||||
user_id=kwargs.get("user", None),
|
user_id=self.model_call_details.get("user", None),
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
status_message=str(exception),
|
status_message=str(exception),
|
||||||
level="ERROR",
|
level="ERROR",
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
)
|
)
|
||||||
if callback == "logfire":
|
if callback == "logfire" and logfireLogger is not None:
|
||||||
verbose_logger.debug("reaches logfire for failure logging!")
|
verbose_logger.debug("reaches logfire for failure logging!")
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for k, v in self.model_call_details.items():
|
for k, v in self.model_call_details.items():
|
||||||
|
@ -1830,7 +1860,7 @@ class Logging:
|
||||||
response_obj=result,
|
response_obj=result,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
level=LogfireLevel.ERROR.value,
|
level=LogfireLevel.ERROR.value, # type: ignore
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1873,7 +1903,9 @@ class Logging:
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
if callable(callback): # custom logger functions
|
if (
|
||||||
|
callable(callback) and customLogger is not None
|
||||||
|
): # custom logger functions
|
||||||
await customLogger.async_log_event(
|
await customLogger.async_log_event(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
response_obj=result,
|
response_obj=result,
|
||||||
|
@ -1966,7 +1998,7 @@ def set_callbacks(callback_list, function_id=None):
|
||||||
)
|
)
|
||||||
sentry_sdk_instance.init(
|
sentry_sdk_instance.init(
|
||||||
dsn=os.environ.get("SENTRY_DSN"),
|
dsn=os.environ.get("SENTRY_DSN"),
|
||||||
traces_sample_rate=float(sentry_trace_rate),
|
traces_sample_rate=float(sentry_trace_rate), # type: ignore
|
||||||
)
|
)
|
||||||
capture_exception = sentry_sdk_instance.capture_exception
|
capture_exception = sentry_sdk_instance.capture_exception
|
||||||
add_breadcrumb = sentry_sdk_instance.add_breadcrumb
|
add_breadcrumb = sentry_sdk_instance.add_breadcrumb
|
||||||
|
@ -2411,12 +2443,11 @@ def get_standard_logging_object_payload(
|
||||||
|
|
||||||
saved_cache_cost: Optional[float] = None
|
saved_cache_cost: Optional[float] = None
|
||||||
if cache_hit is True:
|
if cache_hit is True:
|
||||||
import time
|
|
||||||
|
|
||||||
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
id = f"{id}_cache_hit{time.time()}" # do not duplicate the request id
|
||||||
|
|
||||||
saved_cache_cost = logging_obj._response_cost_calculator(
|
saved_cache_cost = logging_obj._response_cost_calculator(
|
||||||
result=init_response_obj, cache_hit=False
|
result=init_response_obj, cache_hit=False # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
## Get model cost information ##
|
## Get model cost information ##
|
||||||
|
@ -2473,7 +2504,7 @@ def get_standard_logging_object_payload(
|
||||||
model_id=_model_id,
|
model_id=_model_id,
|
||||||
requester_ip_address=clean_metadata.get("requester_ip_address", None),
|
requester_ip_address=clean_metadata.get("requester_ip_address", None),
|
||||||
messages=kwargs.get("messages"),
|
messages=kwargs.get("messages"),
|
||||||
response=(
|
response=( # type: ignore
|
||||||
response_obj if len(response_obj.keys()) > 0 else init_response_obj
|
response_obj if len(response_obj.keys()) > 0 else init_response_obj
|
||||||
),
|
),
|
||||||
model_parameters=kwargs.get("optional_params", None),
|
model_parameters=kwargs.get("optional_params", None),
|
||||||
|
|
95
litellm/llms/OpenAI/chat/o1_handler.py
Normal file
95
litellm/llms/OpenAI/chat/o1_handler.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
"""
|
||||||
|
Handler file for calls to OpenAI's o1 family of models
|
||||||
|
|
||||||
|
Written separately to handle faking streaming for o1 models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
|
from httpx._config import Timeout
|
||||||
|
|
||||||
|
from litellm.llms.bedrock.chat.invoke_handler import MockResponseIterator
|
||||||
|
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
||||||
|
from litellm.types.utils import ModelResponse
|
||||||
|
from litellm.utils import CustomStreamWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIO1ChatCompletion(OpenAIChatCompletion):
|
||||||
|
|
||||||
|
async def mock_async_streaming(
|
||||||
|
self,
|
||||||
|
response: Any,
|
||||||
|
model: Optional[str],
|
||||||
|
logging_obj: Any,
|
||||||
|
):
|
||||||
|
model_response = await response
|
||||||
|
completion_stream = MockResponseIterator(model_response=model_response)
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
model_response: ModelResponse,
|
||||||
|
timeout: Union[float, Timeout],
|
||||||
|
optional_params: dict,
|
||||||
|
logging_obj: Any,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
messages: Optional[list] = None,
|
||||||
|
print_verbose: Optional[Callable[..., Any]] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
acompletion: bool = False,
|
||||||
|
litellm_params=None,
|
||||||
|
logger_fn=None,
|
||||||
|
headers: Optional[dict] = None,
|
||||||
|
custom_prompt_dict: dict = {},
|
||||||
|
client=None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
drop_params: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
stream: Optional[bool] = optional_params.pop("stream", False)
|
||||||
|
response = super().completion(
|
||||||
|
model_response,
|
||||||
|
timeout,
|
||||||
|
optional_params,
|
||||||
|
logging_obj,
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
print_verbose,
|
||||||
|
api_key,
|
||||||
|
api_base,
|
||||||
|
acompletion,
|
||||||
|
litellm_params,
|
||||||
|
logger_fn,
|
||||||
|
headers,
|
||||||
|
custom_prompt_dict,
|
||||||
|
client,
|
||||||
|
organization,
|
||||||
|
custom_llm_provider,
|
||||||
|
drop_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream is True:
|
||||||
|
if asyncio.iscoroutine(response):
|
||||||
|
return self.mock_async_streaming(
|
||||||
|
response=response, model=model, logging_obj=logging_obj # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_stream = MockResponseIterator(model_response=response)
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
return streaming_response
|
||||||
|
else:
|
||||||
|
return response
|
|
@ -15,6 +15,8 @@ import requests # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.llms.databricks.exceptions import DatabricksError
|
||||||
|
from litellm.llms.databricks.streaming_utils import ModelResponseIterator
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
ChatCompletionDeltaChunk,
|
ChatCompletionDeltaChunk,
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionResponseMessage,
|
||||||
|
@ -33,17 +35,6 @@ from ..base import BaseLLM
|
||||||
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
from ..prompt_templates.factory import custom_prompt, prompt_factory
|
||||||
|
|
||||||
|
|
||||||
class DatabricksError(Exception):
|
|
||||||
def __init__(self, status_code, message):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
|
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class DatabricksConfig:
|
class DatabricksConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request
|
||||||
|
@ -367,7 +358,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
status_code=e.response.status_code,
|
status_code=e.response.status_code,
|
||||||
message=e.response.text,
|
message=e.response.text,
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException:
|
||||||
raise DatabricksError(status_code=408, message="Timeout error occurred.")
|
raise DatabricksError(status_code=408, message="Timeout error occurred.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DatabricksError(status_code=500, message=str(e))
|
raise DatabricksError(status_code=500, message=str(e))
|
||||||
|
@ -380,7 +371,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
response = ModelResponse(**response_json)
|
response = ModelResponse(**response_json)
|
||||||
|
|
||||||
response.model = custom_llm_provider + "/" + response.model
|
response.model = custom_llm_provider + "/" + (response.model or "")
|
||||||
|
|
||||||
if base_model is not None:
|
if base_model is not None:
|
||||||
response._hidden_params["model"] = base_model
|
response._hidden_params["model"] = base_model
|
||||||
|
@ -529,7 +520,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
raise DatabricksError(
|
raise DatabricksError(
|
||||||
status_code=e.response.status_code, message=response.text
|
status_code=e.response.status_code, message=e.response.text
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException as e:
|
||||||
raise DatabricksError(
|
raise DatabricksError(
|
||||||
|
@ -540,7 +531,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
|
|
||||||
response = ModelResponse(**response_json)
|
response = ModelResponse(**response_json)
|
||||||
|
|
||||||
response.model = custom_llm_provider + "/" + response.model
|
response.model = custom_llm_provider + "/" + (response.model or "")
|
||||||
|
|
||||||
if base_model is not None:
|
if base_model is not None:
|
||||||
response._hidden_params["model"] = base_model
|
response._hidden_params["model"] = base_model
|
||||||
|
@ -657,7 +648,7 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
raise DatabricksError(
|
raise DatabricksError(
|
||||||
status_code=e.response.status_code,
|
status_code=e.response.status_code,
|
||||||
message=response.text if response else str(e),
|
message=e.response.text,
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException as e:
|
||||||
raise DatabricksError(status_code=408, message="Timeout error occurred.")
|
raise DatabricksError(status_code=408, message="Timeout error occurred.")
|
||||||
|
@ -673,136 +664,3 @@ class DatabricksChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
return litellm.EmbeddingResponse(**response_json)
|
return litellm.EmbeddingResponse(**response_json)
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseIterator:
|
|
||||||
def __init__(self, streaming_response, sync_stream: bool):
|
|
||||||
self.streaming_response = streaming_response
|
|
||||||
|
|
||||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
|
||||||
try:
|
|
||||||
processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
|
|
||||||
|
|
||||||
text = ""
|
|
||||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
|
||||||
is_finished = False
|
|
||||||
finish_reason = ""
|
|
||||||
usage: Optional[ChatCompletionUsageBlock] = None
|
|
||||||
|
|
||||||
if processed_chunk.choices[0].delta.content is not None: # type: ignore
|
|
||||||
text = processed_chunk.choices[0].delta.content # type: ignore
|
|
||||||
|
|
||||||
if (
|
|
||||||
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
|
|
||||||
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
|
|
||||||
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
|
|
||||||
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
|
|
||||||
is not None
|
|
||||||
):
|
|
||||||
tool_use = ChatCompletionToolCallChunk(
|
|
||||||
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
|
|
||||||
type="function",
|
|
||||||
function=ChatCompletionToolCallFunctionChunk(
|
|
||||||
name=processed_chunk.choices[0]
|
|
||||||
.delta.tool_calls[0] # type: ignore
|
|
||||||
.function.name,
|
|
||||||
arguments=processed_chunk.choices[0]
|
|
||||||
.delta.tool_calls[0] # type: ignore
|
|
||||||
.function.arguments,
|
|
||||||
),
|
|
||||||
index=processed_chunk.choices[0].index,
|
|
||||||
)
|
|
||||||
|
|
||||||
if processed_chunk.choices[0].finish_reason is not None:
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = processed_chunk.choices[0].finish_reason
|
|
||||||
|
|
||||||
if hasattr(processed_chunk, "usage") and isinstance(
|
|
||||||
processed_chunk.usage, litellm.Usage
|
|
||||||
):
|
|
||||||
usage_chunk: litellm.Usage = processed_chunk.usage
|
|
||||||
|
|
||||||
usage = ChatCompletionUsageBlock(
|
|
||||||
prompt_tokens=usage_chunk.prompt_tokens,
|
|
||||||
completion_tokens=usage_chunk.completion_tokens,
|
|
||||||
total_tokens=usage_chunk.total_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
return GenericStreamingChunk(
|
|
||||||
text=text,
|
|
||||||
tool_use=tool_use,
|
|
||||||
is_finished=is_finished,
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
usage=usage,
|
|
||||||
index=0,
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
|
||||||
|
|
||||||
# Sync iterator
|
|
||||||
def __iter__(self):
|
|
||||||
self.response_iterator = self.streaming_response
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
try:
|
|
||||||
chunk = self.response_iterator.__next__()
|
|
||||||
except StopIteration:
|
|
||||||
raise StopIteration
|
|
||||||
except ValueError as e:
|
|
||||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunk = chunk.replace("data:", "")
|
|
||||||
chunk = chunk.strip()
|
|
||||||
if len(chunk) > 0:
|
|
||||||
json_chunk = json.loads(chunk)
|
|
||||||
return self.chunk_parser(chunk=json_chunk)
|
|
||||||
else:
|
|
||||||
return GenericStreamingChunk(
|
|
||||||
text="",
|
|
||||||
is_finished=False,
|
|
||||||
finish_reason="",
|
|
||||||
usage=None,
|
|
||||||
index=0,
|
|
||||||
tool_use=None,
|
|
||||||
)
|
|
||||||
except StopIteration:
|
|
||||||
raise StopIteration
|
|
||||||
except ValueError as e:
|
|
||||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
|
||||||
|
|
||||||
# Async iterator
|
|
||||||
def __aiter__(self):
|
|
||||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self):
|
|
||||||
try:
|
|
||||||
chunk = await self.async_response_iterator.__anext__()
|
|
||||||
except StopAsyncIteration:
|
|
||||||
raise StopAsyncIteration
|
|
||||||
except ValueError as e:
|
|
||||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
chunk = chunk.replace("data:", "")
|
|
||||||
chunk = chunk.strip()
|
|
||||||
if chunk == "[DONE]":
|
|
||||||
raise StopAsyncIteration
|
|
||||||
if len(chunk) > 0:
|
|
||||||
json_chunk = json.loads(chunk)
|
|
||||||
return self.chunk_parser(chunk=json_chunk)
|
|
||||||
else:
|
|
||||||
return GenericStreamingChunk(
|
|
||||||
text="",
|
|
||||||
is_finished=False,
|
|
||||||
finish_reason="",
|
|
||||||
usage=None,
|
|
||||||
index=0,
|
|
||||||
tool_use=None,
|
|
||||||
)
|
|
||||||
except StopAsyncIteration:
|
|
||||||
raise StopAsyncIteration
|
|
||||||
except ValueError as e:
|
|
||||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
|
||||||
|
|
12
litellm/llms/databricks/exceptions.py
Normal file
12
litellm/llms/databricks/exceptions.py
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(method="POST", url="https://docs.databricks.com/")
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
145
litellm/llms/databricks/streaming_utils.py
Normal file
145
litellm/llms/databricks/streaming_utils.py
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
import json
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
ChatCompletionDeltaChunk,
|
||||||
|
ChatCompletionResponseMessage,
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import GenericStreamingChunk
|
||||||
|
|
||||||
|
|
||||||
|
class ModelResponseIterator:
|
||||||
|
def __init__(self, streaming_response, sync_stream: bool):
|
||||||
|
self.streaming_response = streaming_response
|
||||||
|
|
||||||
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
processed_chunk = litellm.ModelResponse(**chunk, stream=True) # type: ignore
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = ""
|
||||||
|
usage: Optional[ChatCompletionUsageBlock] = None
|
||||||
|
|
||||||
|
if processed_chunk.choices[0].delta.content is not None: # type: ignore
|
||||||
|
text = processed_chunk.choices[0].delta.content # type: ignore
|
||||||
|
|
||||||
|
if (
|
||||||
|
processed_chunk.choices[0].delta.tool_calls is not None # type: ignore
|
||||||
|
and len(processed_chunk.choices[0].delta.tool_calls) > 0 # type: ignore
|
||||||
|
and processed_chunk.choices[0].delta.tool_calls[0].function is not None # type: ignore
|
||||||
|
and processed_chunk.choices[0].delta.tool_calls[0].function.arguments # type: ignore
|
||||||
|
is not None
|
||||||
|
):
|
||||||
|
tool_use = ChatCompletionToolCallChunk(
|
||||||
|
id=processed_chunk.choices[0].delta.tool_calls[0].id, # type: ignore
|
||||||
|
type="function",
|
||||||
|
function=ChatCompletionToolCallFunctionChunk(
|
||||||
|
name=processed_chunk.choices[0]
|
||||||
|
.delta.tool_calls[0] # type: ignore
|
||||||
|
.function.name,
|
||||||
|
arguments=processed_chunk.choices[0]
|
||||||
|
.delta.tool_calls[0] # type: ignore
|
||||||
|
.function.arguments,
|
||||||
|
),
|
||||||
|
index=processed_chunk.choices[0].index,
|
||||||
|
)
|
||||||
|
|
||||||
|
if processed_chunk.choices[0].finish_reason is not None:
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = processed_chunk.choices[0].finish_reason
|
||||||
|
|
||||||
|
if hasattr(processed_chunk, "usage") and isinstance(
|
||||||
|
processed_chunk.usage, litellm.Usage
|
||||||
|
):
|
||||||
|
usage_chunk: litellm.Usage = processed_chunk.usage
|
||||||
|
|
||||||
|
usage = ChatCompletionUsageBlock(
|
||||||
|
prompt_tokens=usage_chunk.prompt_tokens,
|
||||||
|
completion_tokens=usage_chunk.completion_tokens,
|
||||||
|
total_tokens=usage_chunk.total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
tool_use=tool_use,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
self.response_iterator = self.streaming_response
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
try:
|
||||||
|
chunk = self.response_iterator.__next__()
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk = chunk.replace("data:", "")
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if len(chunk) > 0:
|
||||||
|
json_chunk = json.loads(chunk)
|
||||||
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
else:
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
raise StopIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
try:
|
||||||
|
chunk = await self.async_response_iterator.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk = chunk.replace("data:", "")
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if chunk == "[DONE]":
|
||||||
|
raise StopAsyncIteration
|
||||||
|
if len(chunk) > 0:
|
||||||
|
json_chunk = json.loads(chunk)
|
||||||
|
return self.chunk_parser(chunk=json_chunk)
|
||||||
|
else:
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text="",
|
||||||
|
is_finished=False,
|
||||||
|
finish_reason="",
|
||||||
|
usage=None,
|
||||||
|
index=0,
|
||||||
|
tool_use=None,
|
||||||
|
)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
except ValueError as e:
|
||||||
|
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
|
@ -95,6 +95,7 @@ from .llms.custom_llm import CustomLLM, custom_chat_llm_router
|
||||||
from .llms.databricks.chat import DatabricksChatCompletion
|
from .llms.databricks.chat import DatabricksChatCompletion
|
||||||
from .llms.huggingface_restapi import Huggingface
|
from .llms.huggingface_restapi import Huggingface
|
||||||
from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
|
from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription
|
||||||
|
from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion
|
||||||
from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion
|
from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||||
from .llms.predibase import PredibaseChatCompletion
|
from .llms.predibase import PredibaseChatCompletion
|
||||||
from .llms.prompt_templates.factory import (
|
from .llms.prompt_templates.factory import (
|
||||||
|
@ -161,6 +162,7 @@ from litellm.utils import (
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
openai_chat_completions = OpenAIChatCompletion()
|
openai_chat_completions = OpenAIChatCompletion()
|
||||||
openai_text_completions = OpenAITextCompletion()
|
openai_text_completions = OpenAITextCompletion()
|
||||||
|
openai_o1_chat_completions = OpenAIO1ChatCompletion()
|
||||||
openai_audio_transcriptions = OpenAIAudioTranscription()
|
openai_audio_transcriptions = OpenAIAudioTranscription()
|
||||||
databricks_chat_completions = DatabricksChatCompletion()
|
databricks_chat_completions = DatabricksChatCompletion()
|
||||||
anthropic_chat_completions = AnthropicChatCompletion()
|
anthropic_chat_completions = AnthropicChatCompletion()
|
||||||
|
@ -1366,25 +1368,46 @@ def completion(
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
try:
|
try:
|
||||||
response = openai_chat_completions.completion(
|
if litellm.OpenAIO1Config().is_model_o1_reasoning_model(model=model):
|
||||||
model=model,
|
response = openai_o1_chat_completions.completion(
|
||||||
messages=messages,
|
model=model,
|
||||||
headers=headers,
|
messages=messages,
|
||||||
model_response=model_response,
|
headers=headers,
|
||||||
print_verbose=print_verbose,
|
model_response=model_response,
|
||||||
api_key=api_key,
|
print_verbose=print_verbose,
|
||||||
api_base=api_base,
|
api_key=api_key,
|
||||||
acompletion=acompletion,
|
api_base=api_base,
|
||||||
logging_obj=logging,
|
acompletion=acompletion,
|
||||||
optional_params=optional_params,
|
logging_obj=logging,
|
||||||
litellm_params=litellm_params,
|
optional_params=optional_params,
|
||||||
logger_fn=logger_fn,
|
litellm_params=litellm_params,
|
||||||
timeout=timeout, # type: ignore
|
logger_fn=logger_fn,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
timeout=timeout, # type: ignore
|
||||||
client=client, # pass AsyncOpenAI, OpenAI client
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
organization=organization,
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
custom_llm_provider=custom_llm_provider,
|
organization=organization,
|
||||||
)
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = openai_chat_completions.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
headers=headers,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
acompletion=acompletion,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
timeout=timeout, # type: ignore
|
||||||
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
client=client, # pass AsyncOpenAI, OpenAI client
|
||||||
|
organization=organization,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
## LOGGING - log the original exception returned
|
## LOGGING - log the original exception returned
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
|
|
|
@ -3,7 +3,6 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: anthropic.claude-3-sonnet-20240229-v1:0
|
model: anthropic.claude-3-sonnet-20240229-v1:0
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||||
# aws_session_token: "IQoJb3JpZ2luX2VjELj//////////wEaCXVzLXdlc3QtMiJHMEUCIQDatCRVkIZERLcrR6P7Qd1vNfZ8r8xB/LUeaVaTW/lBTwIgAgmHSBe41d65GVRKSkpgVonjsCmOmAS7s/yklM9NsZcq3AEI4P//////////ARABGgw4ODg2MDIyMjM0MjgiDJrio0/CHYEfyt5EqyqwAfyWO4t3bFVWAOIwTyZ1N6lszeJKfMNus2hzVc+r73hia2Anv88uwPxNg2uqnXQNJumEo0DcBt30ZwOw03Isboy0d5l05h8gjb4nl9feyeKmKAnRdcqElrEWtCC1Qcefv78jQv53AbUipH1ssa5NPvptqZZpZYDPMlBEnV3YdvJJiuE23u2yOkCt+EoUJLaOYjOryoRyrSfbWB+JaUsB68R3rNTHzReeN3Nob/9Ic4HrMMmzmLcGOpgBZxclO4w8Z7i6TcVqbCwDOskxuR6bZaiFxKFG+9tDrWS7jaQKpq/YP9HUT0YwYpZplaBEEZR5sbIndg5yb4dRZrSHplblqKz8XLaUf5tuuyRJmwr96PTpw/dyEVk9gicFX6JfLBEv0v5rN2Z0JMFLdfIP4kC1U2PjcPOWoglWO3fLmJ4Lol2a3c5XDSMwMxjcJXq+c8Ue1v0="
|
|
||||||
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY
|
||||||
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID
|
||||||
- model_name: gemini-vision
|
- model_name: gemini-vision
|
||||||
|
@ -17,4 +16,34 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||||
|
- model_name: o1-preview
|
||||||
|
litellm_params:
|
||||||
|
model: o1-preview
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
drop_params: True
|
||||||
|
json_logs: True
|
||||||
|
store_audit_logs: True
|
||||||
|
log_raw_request_response: True
|
||||||
|
return_response_headers: True
|
||||||
|
num_retries: 5
|
||||||
|
request_timeout: 200
|
||||||
|
callbacks: ["custom_callbacks.proxy_handler_instance"]
|
||||||
|
|
||||||
|
guardrails:
|
||||||
|
- guardrail_name: "presidio-pre-guard"
|
||||||
|
litellm_params:
|
||||||
|
guardrail: presidio # supported values: "aporia", "bedrock", "lakera", "presidio"
|
||||||
|
mode: "logging_only"
|
||||||
|
mock_redacted_text: {
|
||||||
|
"text": "My name is <PERSON>, who are you? Say my name in your response",
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"start": 11,
|
||||||
|
"end": 19,
|
||||||
|
"entity_type": "PERSON",
|
||||||
|
"text": "<PERSON>",
|
||||||
|
"operator": "replace",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
|
@ -446,7 +446,6 @@ async def user_api_key_auth(
|
||||||
and request.headers.get(key=header_key) is not None # type: ignore
|
and request.headers.get(key=header_key) is not None # type: ignore
|
||||||
):
|
):
|
||||||
api_key = request.headers.get(key=header_key) # type: ignore
|
api_key = request.headers.get(key=header_key) # type: ignore
|
||||||
|
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
if isinstance(api_key, str):
|
if isinstance(api_key, str):
|
||||||
return UserAPIKeyAuth(
|
return UserAPIKeyAuth(
|
||||||
|
|
|
@ -2,9 +2,19 @@ from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
|
||||||
class MyCustomHandler(CustomLogger):
|
class MyCustomHandler(CustomLogger):
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
# print("Call failed")
|
# input_tokens = response_obj.get("usage", {}).get("prompt_tokens", 0)
|
||||||
pass
|
# output_tokens = response_obj.get("usage", {}).get("completion_tokens", 0)
|
||||||
|
input_tokens = (
|
||||||
|
response_obj.usage.prompt_tokens
|
||||||
|
if hasattr(response_obj.usage, "prompt_tokens")
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
output_tokens = (
|
||||||
|
response_obj.usage.completion_tokens
|
||||||
|
if hasattr(response_obj.usage, "completion_tokens")
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
proxy_handler_instance = MyCustomHandler()
|
proxy_handler_instance = MyCustomHandler()
|
||||||
|
|
|
@ -218,7 +218,7 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||||
response = await self.async_handler.post(
|
response = await self.async_handler.post(
|
||||||
url=prepared_request.url,
|
url=prepared_request.url,
|
||||||
json=request_data, # type: ignore
|
json=request_data, # type: ignore
|
||||||
headers=dict(prepared_request.headers),
|
headers=prepared_request.headers, # type: ignore
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
|
verbose_proxy_logger.debug("Bedrock AI response: %s", response.text)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
|
@ -254,7 +254,6 @@ class BedrockGuardrail(CustomGuardrail, BaseAWSLLM):
|
||||||
from litellm.proxy.common_utils.callback_utils import (
|
from litellm.proxy.common_utils.callback_utils import (
|
||||||
add_guardrail_to_applied_guardrails_header,
|
add_guardrail_to_applied_guardrails_header,
|
||||||
)
|
)
|
||||||
from litellm.types.guardrails import GuardrailEventHooks
|
|
||||||
|
|
||||||
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
event_type: GuardrailEventHooks = GuardrailEventHooks.during_call
|
||||||
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
if self.should_run_guardrail(data=data, event_type=event_type) is not True:
|
||||||
|
|
|
@ -189,6 +189,7 @@ class lakeraAI_Moderation(CustomGuardrail):
|
||||||
# Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already)
|
# Alternatively, a user can opt to have these messages added to the system prompt instead (ignore these, since they are in system already)
|
||||||
# Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked.
|
# Finally, if the user did not elect to add them to the system message themselves, and they are there, then add them to system so they can be checked.
|
||||||
# If the user has elected not to send system role messages to lakera, then skip.
|
# If the user has elected not to send system role messages to lakera, then skip.
|
||||||
|
|
||||||
if system_message is not None:
|
if system_message is not None:
|
||||||
if not litellm.add_function_to_prompt:
|
if not litellm.add_function_to_prompt:
|
||||||
content = system_message.get("content")
|
content = system_message.get("content")
|
||||||
|
|
|
@ -19,6 +19,7 @@ from fastapi import HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm # noqa: E401
|
import litellm # noqa: E401
|
||||||
|
from litellm import get_secret
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||||
|
@ -58,7 +59,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
self.pii_tokens: dict = (
|
self.pii_tokens: dict = (
|
||||||
{}
|
{}
|
||||||
) # mapping of PII token to original text - only used with Presidio `replace` operation
|
) # mapping of PII token to original text - only used with Presidio `replace` operation
|
||||||
|
|
||||||
self.mock_redacted_text = mock_redacted_text
|
self.mock_redacted_text = mock_redacted_text
|
||||||
self.output_parse_pii = output_parse_pii or False
|
self.output_parse_pii = output_parse_pii or False
|
||||||
if mock_testing is True: # for testing purposes only
|
if mock_testing is True: # for testing purposes only
|
||||||
|
@ -92,8 +92,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
presidio_anonymizer_api_base: Optional[str] = None,
|
presidio_anonymizer_api_base: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.presidio_analyzer_api_base: Optional[str] = (
|
self.presidio_analyzer_api_base: Optional[str] = (
|
||||||
presidio_analyzer_api_base
|
presidio_analyzer_api_base or get_secret("PRESIDIO_ANALYZER_API_BASE", None) # type: ignore
|
||||||
or litellm.get_secret("PRESIDIO_ANALYZER_API_BASE", None)
|
|
||||||
)
|
)
|
||||||
self.presidio_anonymizer_api_base: Optional[
|
self.presidio_anonymizer_api_base: Optional[
|
||||||
str
|
str
|
||||||
|
@ -198,12 +197,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid anonymizer response: {redacted_text}")
|
raise Exception(f"Invalid anonymizer response: {redacted_text}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
|
||||||
"litellm.proxy.hooks.presidio_pii_masking.py::async_pre_call_hook(): Exception occured - {}".format(
|
|
||||||
str(e)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
verbose_proxy_logger.debug(traceback.format_exc())
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def async_pre_call_hook(
|
async def async_pre_call_hook(
|
||||||
|
@ -254,9 +247,6 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.info(
|
|
||||||
f"An error occurred -",
|
|
||||||
)
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def async_logging_hook(
|
async def async_logging_hook(
|
||||||
|
@ -300,9 +290,9 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
)
|
)
|
||||||
kwargs["messages"] = messages
|
kwargs["messages"] = messages
|
||||||
|
|
||||||
return kwargs, responses
|
return kwargs, result
|
||||||
|
|
||||||
async def async_post_call_success_hook(
|
async def async_post_call_success_hook( # type: ignore
|
||||||
self,
|
self,
|
||||||
data: dict,
|
data: dict,
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
@ -314,7 +304,7 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
|
f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}"
|
||||||
)
|
)
|
||||||
if self.output_parse_pii == False:
|
if self.output_parse_pii is False:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
if isinstance(response, ModelResponse) and not isinstance(
|
if isinstance(response, ModelResponse) and not isinstance(
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import importlib
|
import importlib
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Dict, List, Literal
|
from typing import Dict, List, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, RootModel
|
from pydantic import BaseModel, RootModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm import get_secret
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||||
|
|
||||||
|
@ -16,7 +17,6 @@ from litellm.types.guardrails import (
|
||||||
GuardrailItemSpec,
|
GuardrailItemSpec,
|
||||||
LakeraCategoryThresholds,
|
LakeraCategoryThresholds,
|
||||||
LitellmParams,
|
LitellmParams,
|
||||||
guardrailConfig,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
all_guardrails: List[GuardrailItem] = []
|
all_guardrails: List[GuardrailItem] = []
|
||||||
|
@ -98,18 +98,13 @@ def init_guardrails_v2(
|
||||||
# Init litellm params for guardrail
|
# Init litellm params for guardrail
|
||||||
litellm_params_data = guardrail["litellm_params"]
|
litellm_params_data = guardrail["litellm_params"]
|
||||||
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
|
verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data)
|
||||||
litellm_params = LitellmParams(
|
|
||||||
guardrail=litellm_params_data["guardrail"],
|
_litellm_params_kwargs = {
|
||||||
mode=litellm_params_data["mode"],
|
k: litellm_params_data[k] if k in litellm_params_data else None
|
||||||
api_key=litellm_params_data.get("api_key"),
|
for k in LitellmParams.__annotations__.keys()
|
||||||
api_base=litellm_params_data.get("api_base"),
|
}
|
||||||
guardrailIdentifier=litellm_params_data.get("guardrailIdentifier"),
|
|
||||||
guardrailVersion=litellm_params_data.get("guardrailVersion"),
|
litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore
|
||||||
output_parse_pii=litellm_params_data.get("output_parse_pii"),
|
|
||||||
presidio_ad_hoc_recognizers=litellm_params_data.get(
|
|
||||||
"presidio_ad_hoc_recognizers"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
"category_thresholds" in litellm_params_data
|
"category_thresholds" in litellm_params_data
|
||||||
|
@ -122,15 +117,11 @@ def init_guardrails_v2(
|
||||||
|
|
||||||
if litellm_params["api_key"]:
|
if litellm_params["api_key"]:
|
||||||
if litellm_params["api_key"].startswith("os.environ/"):
|
if litellm_params["api_key"].startswith("os.environ/"):
|
||||||
litellm_params["api_key"] = litellm.get_secret(
|
litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore
|
||||||
litellm_params["api_key"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if litellm_params["api_base"]:
|
if litellm_params["api_base"]:
|
||||||
if litellm_params["api_base"].startswith("os.environ/"):
|
if litellm_params["api_base"].startswith("os.environ/"):
|
||||||
litellm_params["api_base"] = litellm.get_secret(
|
litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore
|
||||||
litellm_params["api_base"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Init guardrail CustomLoggerClass
|
# Init guardrail CustomLoggerClass
|
||||||
if litellm_params["guardrail"] == "aporia":
|
if litellm_params["guardrail"] == "aporia":
|
||||||
|
@ -182,6 +173,7 @@ def init_guardrails_v2(
|
||||||
presidio_ad_hoc_recognizers=litellm_params[
|
presidio_ad_hoc_recognizers=litellm_params[
|
||||||
"presidio_ad_hoc_recognizers"
|
"presidio_ad_hoc_recognizers"
|
||||||
],
|
],
|
||||||
|
mock_redacted_text=litellm_params.get("mock_redacted_text") or None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if litellm_params["output_parse_pii"] is True:
|
if litellm_params["output_parse_pii"] is True:
|
||||||
|
|
|
@ -167,11 +167,11 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
||||||
|
|
||||||
if self.prompt_injection_params is not None:
|
if self.prompt_injection_params is not None:
|
||||||
# 1. check if heuristics check turned on
|
# 1. check if heuristics check turned on
|
||||||
if self.prompt_injection_params.heuristics_check == True:
|
if self.prompt_injection_params.heuristics_check is True:
|
||||||
is_prompt_attack = self.check_user_input_similarity(
|
is_prompt_attack = self.check_user_input_similarity(
|
||||||
user_input=formatted_prompt
|
user_input=formatted_prompt
|
||||||
)
|
)
|
||||||
if is_prompt_attack == True:
|
if is_prompt_attack is True:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail={
|
detail={
|
||||||
|
@ -179,14 +179,14 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet
|
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet
|
||||||
if self.prompt_injection_params.vector_db_check == True:
|
if self.prompt_injection_params.vector_db_check is True:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
is_prompt_attack = self.check_user_input_similarity(
|
is_prompt_attack = self.check_user_input_similarity(
|
||||||
user_input=formatted_prompt
|
user_input=formatted_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_prompt_attack == True:
|
if is_prompt_attack is True:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail={
|
detail={
|
||||||
|
@ -201,19 +201,18 @@ class _OPTIONAL_PromptInjectionDetection(CustomLogger):
|
||||||
if (
|
if (
|
||||||
e.status_code == 400
|
e.status_code == 400
|
||||||
and isinstance(e.detail, dict)
|
and isinstance(e.detail, dict)
|
||||||
and "error" in e.detail
|
and "error" in e.detail # type: ignore
|
||||||
and self.prompt_injection_params is not None
|
and self.prompt_injection_params is not None
|
||||||
and self.prompt_injection_params.reject_as_response
|
and self.prompt_injection_params.reject_as_response
|
||||||
):
|
):
|
||||||
return e.detail.get("error")
|
return e.detail.get("error")
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.exception(
|
||||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(traceback.format_exc())
|
|
||||||
|
|
||||||
async def async_moderation_hook( # type: ignore
|
async def async_moderation_hook( # type: ignore
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -195,7 +195,8 @@ async def user_auth(request: Request):
|
||||||
- os.environ["SMTP_PASSWORD"]
|
- os.environ["SMTP_PASSWORD"]
|
||||||
- os.environ["SMTP_SENDER_EMAIL"]
|
- os.environ["SMTP_SENDER_EMAIL"]
|
||||||
"""
|
"""
|
||||||
from litellm.proxy.proxy_server import prisma_client, send_email
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
from litellm.proxy.utils import send_email
|
||||||
|
|
||||||
data = await request.json() # type: ignore
|
data = await request.json() # type: ignore
|
||||||
user_email = data["user_email"]
|
user_email = data["user_email"]
|
||||||
|
@ -212,7 +213,7 @@ async def user_auth(request: Request):
|
||||||
)
|
)
|
||||||
### if so - generate a 24 hr key with that user id
|
### if so - generate a 24 hr key with that user id
|
||||||
if response is not None:
|
if response is not None:
|
||||||
user_id = response.user_id
|
user_id = response.user_id # type: ignore
|
||||||
response = await generate_key_helper_fn(
|
response = await generate_key_helper_fn(
|
||||||
request_type="key",
|
request_type="key",
|
||||||
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id}, # type: ignore
|
**{"duration": "24hr", "models": [], "aliases": {}, "config": {}, "spend": 0, "user_id": user_id}, # type: ignore
|
||||||
|
@ -345,6 +346,7 @@ async def user_info(
|
||||||
for team in teams_1:
|
for team in teams_1:
|
||||||
team_id_list.append(team.team_id)
|
team_id_list.append(team.team_id)
|
||||||
|
|
||||||
|
teams_2: Optional[Any] = None
|
||||||
if user_info is not None:
|
if user_info is not None:
|
||||||
# *NEW* get all teams in user 'teams' field
|
# *NEW* get all teams in user 'teams' field
|
||||||
teams_2 = await prisma_client.get_data(
|
teams_2 = await prisma_client.get_data(
|
||||||
|
@ -375,7 +377,7 @@ async def user_info(
|
||||||
),
|
),
|
||||||
user_api_key_dict=user_api_key_dict,
|
user_api_key_dict=user_api_key_dict,
|
||||||
)
|
)
|
||||||
else:
|
elif caller_user_info is not None:
|
||||||
teams_2 = await prisma_client.get_data(
|
teams_2 = await prisma_client.get_data(
|
||||||
team_id_list=caller_user_info.teams,
|
team_id_list=caller_user_info.teams,
|
||||||
table_name="team",
|
table_name="team",
|
||||||
|
@ -395,7 +397,7 @@ async def user_info(
|
||||||
query_type="find_all",
|
query_type="find_all",
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_info is None:
|
if user_info is None and keys is not None:
|
||||||
## make sure we still return a total spend ##
|
## make sure we still return a total spend ##
|
||||||
spend = 0
|
spend = 0
|
||||||
for k in keys:
|
for k in keys:
|
||||||
|
@ -404,32 +406,35 @@ async def user_info(
|
||||||
|
|
||||||
## REMOVE HASHED TOKEN INFO before returning ##
|
## REMOVE HASHED TOKEN INFO before returning ##
|
||||||
returned_keys = []
|
returned_keys = []
|
||||||
for key in keys:
|
if keys is None:
|
||||||
if (
|
pass
|
||||||
key.token == litellm_master_key_hash
|
else:
|
||||||
and general_settings.get("disable_master_key_return", False)
|
for key in keys:
|
||||||
== True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
|
if (
|
||||||
):
|
key.token == litellm_master_key_hash
|
||||||
continue
|
and general_settings.get("disable_master_key_return", False)
|
||||||
|
== True ## [IMPORTANT] used by hosted proxy-ui to prevent sharing master key on ui
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
key = key.model_dump() # noqa
|
key = key.model_dump() # noqa
|
||||||
except:
|
except:
|
||||||
# if using pydantic v1
|
# if using pydantic v1
|
||||||
key = key.dict()
|
key = key.dict()
|
||||||
if (
|
if (
|
||||||
"team_id" in key
|
"team_id" in key
|
||||||
and key["team_id"] is not None
|
and key["team_id"] is not None
|
||||||
and key["team_id"] != "litellm-dashboard"
|
and key["team_id"] != "litellm-dashboard"
|
||||||
):
|
):
|
||||||
team_info = await prisma_client.get_data(
|
team_info = await prisma_client.get_data(
|
||||||
team_id=key["team_id"], table_name="team"
|
team_id=key["team_id"], table_name="team"
|
||||||
)
|
)
|
||||||
team_alias = getattr(team_info, "team_alias", None)
|
team_alias = getattr(team_info, "team_alias", None)
|
||||||
key["team_alias"] = team_alias
|
key["team_alias"] = team_alias
|
||||||
else:
|
else:
|
||||||
key["team_alias"] = "None"
|
key["team_alias"] = "None"
|
||||||
returned_keys.append(key)
|
returned_keys.append(key)
|
||||||
|
|
||||||
response_data = {
|
response_data = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -539,6 +544,7 @@ async def user_update(
|
||||||
|
|
||||||
## ADD USER, IF NEW ##
|
## ADD USER, IF NEW ##
|
||||||
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
|
verbose_proxy_logger.debug("/user/update: Received data = %s", data)
|
||||||
|
response: Optional[Any] = None
|
||||||
if data.user_id is not None and len(data.user_id) > 0:
|
if data.user_id is not None and len(data.user_id) > 0:
|
||||||
non_default_values["user_id"] = data.user_id # type: ignore
|
non_default_values["user_id"] = data.user_id # type: ignore
|
||||||
verbose_proxy_logger.debug("In update user, user_id condition block.")
|
verbose_proxy_logger.debug("In update user, user_id condition block.")
|
||||||
|
@ -573,7 +579,7 @@ async def user_update(
|
||||||
data=non_default_values,
|
data=non_default_values,
|
||||||
table_name="user",
|
table_name="user",
|
||||||
)
|
)
|
||||||
return response
|
return response # type: ignore
|
||||||
# update based on remaining passed in values
|
# update based on remaining passed in values
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
|
|
|
@ -226,7 +226,6 @@ from litellm.proxy.utils import (
|
||||||
hash_token,
|
hash_token,
|
||||||
log_to_opentelemetry,
|
log_to_opentelemetry,
|
||||||
reset_budget,
|
reset_budget,
|
||||||
send_email,
|
|
||||||
update_spend,
|
update_spend,
|
||||||
)
|
)
|
||||||
from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import (
|
from litellm.proxy.vertex_ai_endpoints.google_ai_studio_endpoints import (
|
||||||
|
|
|
@ -1434,7 +1434,7 @@ async def _get_spend_report_for_time_range(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
"Exception in _get_daily_spend_reports {}".format(str(e))
|
"Exception in _get_daily_spend_reports {}".format(str(e))
|
||||||
) # noqa
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
@ -1703,26 +1703,26 @@ async def view_spend_logs(
|
||||||
result: dict = {}
|
result: dict = {}
|
||||||
for record in response:
|
for record in response:
|
||||||
dt_object = datetime.strptime(
|
dt_object = datetime.strptime(
|
||||||
str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ"
|
str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" # type: ignore
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
date = dt_object.date()
|
date = dt_object.date()
|
||||||
if date not in result:
|
if date not in result:
|
||||||
result[date] = {"users": {}, "models": {}}
|
result[date] = {"users": {}, "models": {}}
|
||||||
api_key = record["api_key"]
|
api_key = record["api_key"] # type: ignore
|
||||||
user_id = record["user"]
|
user_id = record["user"] # type: ignore
|
||||||
model = record["model"]
|
model = record["model"] # type: ignore
|
||||||
result[date]["spend"] = (
|
result[date]["spend"] = result[date].get("spend", 0) + record.get(
|
||||||
result[date].get("spend", 0) + record["_sum"]["spend"]
|
"_sum", {}
|
||||||
)
|
).get("spend", 0)
|
||||||
result[date][api_key] = (
|
result[date][api_key] = result[date].get(api_key, 0) + record.get(
|
||||||
result[date].get(api_key, 0) + record["_sum"]["spend"]
|
"_sum", {}
|
||||||
)
|
).get("spend", 0)
|
||||||
result[date]["users"][user_id] = (
|
result[date]["users"][user_id] = result[date]["users"].get(
|
||||||
result[date]["users"].get(user_id, 0) + record["_sum"]["spend"]
|
user_id, 0
|
||||||
)
|
) + record.get("_sum", {}).get("spend", 0)
|
||||||
result[date]["models"][model] = (
|
result[date]["models"][model] = result[date]["models"].get(
|
||||||
result[date]["models"].get(model, 0) + record["_sum"]["spend"]
|
model, 0
|
||||||
)
|
) + record.get("_sum", {}).get("spend", 0)
|
||||||
return_list = []
|
return_list = []
|
||||||
final_date = None
|
final_date = None
|
||||||
for k, v in sorted(result.items()):
|
for k, v in sorted(result.items()):
|
||||||
|
@ -1784,7 +1784,7 @@ async def view_spend_logs(
|
||||||
table_name="spend", query_type="find_all"
|
table_name="spend", query_type="find_all"
|
||||||
)
|
)
|
||||||
|
|
||||||
return spend_log
|
return spend_logs
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -1843,6 +1843,88 @@ async def global_spend_reset():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/global/spend/refresh",
|
||||||
|
tags=["Budget & Spend Tracking"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
async def global_spend_refresh():
|
||||||
|
"""
|
||||||
|
ADMIN ONLY / MASTER KEY Only Endpoint
|
||||||
|
|
||||||
|
Globally refresh spend MonthlyGlobalSpend view
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="Prisma Client is not initialized",
|
||||||
|
type="internal_error",
|
||||||
|
param="None",
|
||||||
|
code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
## RESET GLOBAL SPEND VIEW ###
|
||||||
|
async def is_materialized_global_spend_view() -> bool:
|
||||||
|
"""
|
||||||
|
Return True if materialized view exists
|
||||||
|
|
||||||
|
Else False
|
||||||
|
"""
|
||||||
|
sql_query = """
|
||||||
|
SELECT relname, relkind
|
||||||
|
FROM pg_class
|
||||||
|
WHERE relname = 'MonthlyGlobalSpend';
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
resp = await prisma_client.db.query_raw(sql_query)
|
||||||
|
|
||||||
|
assert resp[0]["relkind"] == "m"
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
view_exists = await is_materialized_global_spend_view()
|
||||||
|
|
||||||
|
if view_exists:
|
||||||
|
# refresh materialized view
|
||||||
|
sql_query = """
|
||||||
|
REFRESH MATERIALIZED VIEW "MonthlyGlobalSpend";
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from litellm.proxy._types import CommonProxyErrors
|
||||||
|
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||||
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
|
||||||
|
db_url = os.getenv("DATABASE_URL")
|
||||||
|
if db_url is None:
|
||||||
|
raise Exception(CommonProxyErrors.db_not_connected_error.value)
|
||||||
|
new_client = PrismaClient(
|
||||||
|
database_url=db_url,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
http_client={
|
||||||
|
"timeout": 6000,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await new_client.db.connect()
|
||||||
|
await new_client.db.query_raw(sql_query)
|
||||||
|
verbose_proxy_logger.info("MonthlyGlobalSpend view refreshed")
|
||||||
|
return {
|
||||||
|
"message": "MonthlyGlobalSpend view refreshed",
|
||||||
|
"status": "success",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
"Failed to refresh materialized view - {}".format(str(e))
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"message": "Failed to refresh materialized view",
|
||||||
|
"status": "failure",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def global_spend_for_internal_user(
|
async def global_spend_for_internal_user(
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
|
|
@ -92,6 +92,7 @@ def safe_deep_copy(data):
|
||||||
if litellm.safe_memory_mode is True:
|
if litellm.safe_memory_mode is True:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
litellm_parent_otel_span: Optional[Any] = None
|
||||||
# Step 1: Remove the litellm_parent_otel_span
|
# Step 1: Remove the litellm_parent_otel_span
|
||||||
litellm_parent_otel_span = None
|
litellm_parent_otel_span = None
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
|
@ -101,7 +102,7 @@ def safe_deep_copy(data):
|
||||||
new_data = copy.deepcopy(data)
|
new_data = copy.deepcopy(data)
|
||||||
|
|
||||||
# Step 2: re-add the litellm_parent_otel_span after doing a deep copy
|
# Step 2: re-add the litellm_parent_otel_span after doing a deep copy
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict) and litellm_parent_otel_span is not None:
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
|
data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span
|
||||||
return new_data
|
return new_data
|
||||||
|
@ -468,7 +469,7 @@ class ProxyLogging:
|
||||||
|
|
||||||
# V1 implementation - backwards compatibility
|
# V1 implementation - backwards compatibility
|
||||||
if callback.event_hook is None:
|
if callback.event_hook is None:
|
||||||
if callback.moderation_check == "pre_call":
|
if callback.moderation_check == "pre_call": # type: ignore
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
# Main - V2 Guardrails implementation
|
# Main - V2 Guardrails implementation
|
||||||
|
@ -881,7 +882,12 @@ class PrismaClient:
|
||||||
org_list_transactons: dict = {}
|
org_list_transactons: dict = {}
|
||||||
spend_log_transactions: List = []
|
spend_log_transactions: List = []
|
||||||
|
|
||||||
def __init__(self, database_url: str, proxy_logging_obj: ProxyLogging):
|
def __init__(
|
||||||
|
self,
|
||||||
|
database_url: str,
|
||||||
|
proxy_logging_obj: ProxyLogging,
|
||||||
|
http_client: Optional[Any] = None,
|
||||||
|
):
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
"LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'"
|
||||||
)
|
)
|
||||||
|
@ -912,7 +918,10 @@ class PrismaClient:
|
||||||
# Now you can import the Prisma Client
|
# Now you can import the Prisma Client
|
||||||
from prisma import Prisma # type: ignore
|
from prisma import Prisma # type: ignore
|
||||||
verbose_proxy_logger.debug("Connecting Prisma Client to DB..")
|
verbose_proxy_logger.debug("Connecting Prisma Client to DB..")
|
||||||
self.db = Prisma() # Client to connect to Prisma db
|
if http_client is not None:
|
||||||
|
self.db = Prisma(http=http_client)
|
||||||
|
else:
|
||||||
|
self.db = Prisma() # Client to connect to Prisma db
|
||||||
verbose_proxy_logger.debug("Success - Connected Prisma Client to DB")
|
verbose_proxy_logger.debug("Success - Connected Prisma Client to DB")
|
||||||
|
|
||||||
def hash_token(self, token: str):
|
def hash_token(self, token: str):
|
||||||
|
@ -987,7 +996,7 @@ class PrismaClient:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
## check if required view exists ##
|
## check if required view exists ##
|
||||||
if required_view not in ret[0]["view_names"]:
|
if ret[0]["view_names"] and required_view not in ret[0]["view_names"]:
|
||||||
await self.health_check() # make sure we can connect to db
|
await self.health_check() # make sure we can connect to db
|
||||||
await self.db.execute_raw(
|
await self.db.execute_raw(
|
||||||
"""
|
"""
|
||||||
|
@ -1009,7 +1018,9 @@ class PrismaClient:
|
||||||
else:
|
else:
|
||||||
# don't block execution if these views are missing
|
# don't block execution if these views are missing
|
||||||
# Convert lists to sets for efficient difference calculation
|
# Convert lists to sets for efficient difference calculation
|
||||||
ret_view_names_set = set(ret[0]["view_names"])
|
ret_view_names_set = (
|
||||||
|
set(ret[0]["view_names"]) if ret[0]["view_names"] else set()
|
||||||
|
)
|
||||||
expected_views_set = set(expected_views)
|
expected_views_set = set(expected_views)
|
||||||
# Find missing views
|
# Find missing views
|
||||||
missing_views = expected_views_set - ret_view_names_set
|
missing_views = expected_views_set - ret_view_names_set
|
||||||
|
@ -1291,13 +1302,13 @@ class PrismaClient:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"PrismaClient: get_data - args_passed_in: {args_passed_in}"
|
f"PrismaClient: get_data - args_passed_in: {args_passed_in}"
|
||||||
)
|
)
|
||||||
|
hashed_token: Optional[str] = None
|
||||||
try:
|
try:
|
||||||
response: Any = None
|
response: Any = None
|
||||||
if (token is not None and table_name is None) or (
|
if (token is not None and table_name is None) or (
|
||||||
table_name is not None and table_name == "key"
|
table_name is not None and table_name == "key"
|
||||||
):
|
):
|
||||||
# check if plain text or hash
|
# check if plain text or hash
|
||||||
hashed_token = None
|
|
||||||
if token is not None:
|
if token is not None:
|
||||||
if isinstance(token, str):
|
if isinstance(token, str):
|
||||||
hashed_token = token
|
hashed_token = token
|
||||||
|
@ -1306,7 +1317,7 @@ class PrismaClient:
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"PrismaClient: find_unique for token: {hashed_token}"
|
f"PrismaClient: find_unique for token: {hashed_token}"
|
||||||
)
|
)
|
||||||
if query_type == "find_unique":
|
if query_type == "find_unique" and hashed_token is not None:
|
||||||
if token is None:
|
if token is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
|
@ -1706,7 +1717,7 @@ class PrismaClient:
|
||||||
updated_data = v
|
updated_data = v
|
||||||
updated_data = json.dumps(updated_data)
|
updated_data = json.dumps(updated_data)
|
||||||
updated_table_row = self.db.litellm_config.upsert(
|
updated_table_row = self.db.litellm_config.upsert(
|
||||||
where={"param_name": k},
|
where={"param_name": k}, # type: ignore
|
||||||
data={
|
data={
|
||||||
"create": {"param_name": k, "param_value": updated_data}, # type: ignore
|
"create": {"param_name": k, "param_value": updated_data}, # type: ignore
|
||||||
"update": {"param_value": updated_data},
|
"update": {"param_value": updated_data},
|
||||||
|
@ -2302,7 +2313,12 @@ def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||||
return instance
|
return instance
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
# Re-raise the exception with a user-friendly message
|
# Re-raise the exception with a user-friendly message
|
||||||
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
|
if instance_name and module_name:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import {instance_name} from {module_name}"
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -2377,12 +2393,12 @@ async def send_email(receiver_email, subject, html):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Establish a secure connection with the SMTP server
|
# Establish a secure connection with the SMTP server
|
||||||
with smtplib.SMTP(smtp_host, smtp_port) as server:
|
with smtplib.SMTP(smtp_host, smtp_port) as server: # type: ignore
|
||||||
if os.getenv("SMTP_TLS", "True") != "False":
|
if os.getenv("SMTP_TLS", "True") != "False":
|
||||||
server.starttls()
|
server.starttls()
|
||||||
|
|
||||||
# Login to your email account
|
# Login to your email account
|
||||||
server.login(smtp_username, smtp_password)
|
server.login(smtp_username, smtp_password) # type: ignore
|
||||||
|
|
||||||
# Send the email
|
# Send the email
|
||||||
server.send_message(email_message)
|
server.send_message(email_message)
|
||||||
|
@ -2945,7 +2961,7 @@ async def update_spend(
|
||||||
if i >= n_retry_times: # If we've reached the maximum number of retries
|
if i >= n_retry_times: # If we've reached the maximum number of retries
|
||||||
raise # Re-raise the last exception
|
raise # Re-raise the last exception
|
||||||
# Optionally, sleep for a bit before retrying
|
# Optionally, sleep for a bit before retrying
|
||||||
await asyncio.sleep(2**i) # Exponential backoff
|
await asyncio.sleep(2**i) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
|
@ -2110,7 +2110,6 @@ async def test_hf_completion_tgi_stream():
|
||||||
def test_openai_chat_completion_call():
|
def test_openai_chat_completion_call():
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
litellm.return_response_headers = True
|
litellm.return_response_headers = True
|
||||||
print(f"making openai chat completion call")
|
|
||||||
response = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
response = completion(model="gpt-3.5-turbo", messages=messages, stream=True)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
response._hidden_params["additional_headers"][
|
response._hidden_params["additional_headers"][
|
||||||
|
@ -2318,6 +2317,57 @@ def test_together_ai_completion_call_mistral():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# # test on together ai completion call - starcoder
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_o1_completion_call_streaming(sync_mode):
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = False
|
||||||
|
if sync_mode:
|
||||||
|
response = completion(
|
||||||
|
model="o1-preview",
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
complete_response = ""
|
||||||
|
print(f"returned response object: {response}")
|
||||||
|
has_finish_reason = False
|
||||||
|
for idx, chunk in enumerate(response):
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
has_finish_reason = finished
|
||||||
|
if finished:
|
||||||
|
break
|
||||||
|
complete_response += chunk
|
||||||
|
if has_finish_reason is False:
|
||||||
|
raise Exception("Finish reason not set for last chunk")
|
||||||
|
if complete_response == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
else:
|
||||||
|
response = await acompletion(
|
||||||
|
model="o1-preview",
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
complete_response = ""
|
||||||
|
print(f"returned response object: {response}")
|
||||||
|
has_finish_reason = False
|
||||||
|
idx = 0
|
||||||
|
async for chunk in response:
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
has_finish_reason = finished
|
||||||
|
if finished:
|
||||||
|
break
|
||||||
|
complete_response += chunk
|
||||||
|
idx += 1
|
||||||
|
if has_finish_reason is False:
|
||||||
|
raise Exception("Finish reason not set for last chunk")
|
||||||
|
if complete_response == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
print(f"complete response: {complete_response}")
|
||||||
|
except Exception:
|
||||||
|
pytest.fail(f"error occurred: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
def test_together_ai_completion_call_starcoder_bad_key():
|
def test_together_ai_completion_call_starcoder_bad_key():
|
||||||
try:
|
try:
|
||||||
api_key = "bad-key"
|
api_key = "bad-key"
|
||||||
|
|
|
@ -71,7 +71,7 @@ class LakeraCategoryThresholds(TypedDict, total=False):
|
||||||
jailbreak: float
|
jailbreak: float
|
||||||
|
|
||||||
|
|
||||||
class LitellmParams(TypedDict, total=False):
|
class LitellmParams(TypedDict):
|
||||||
guardrail: str
|
guardrail: str
|
||||||
mode: str
|
mode: str
|
||||||
api_key: str
|
api_key: str
|
||||||
|
@ -87,6 +87,7 @@ class LitellmParams(TypedDict, total=False):
|
||||||
# Presidio params
|
# Presidio params
|
||||||
output_parse_pii: Optional[bool]
|
output_parse_pii: Optional[bool]
|
||||||
presidio_ad_hoc_recognizers: Optional[str]
|
presidio_ad_hoc_recognizers: Optional[str]
|
||||||
|
mock_redacted_text: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
class Guardrail(TypedDict):
|
class Guardrail(TypedDict):
|
||||||
|
|
|
@ -120,11 +120,26 @@ with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json")
|
||||||
# Convert to str (if necessary)
|
# Convert to str (if necessary)
|
||||||
claude_json_str = json.dumps(json_data)
|
claude_json_str = json.dumps(json_data)
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
get_args,
|
||||||
|
)
|
||||||
|
|
||||||
from openai import OpenAIError as OriginalError
|
from openai import OpenAIError as OriginalError
|
||||||
|
|
||||||
from ._logging import verbose_logger
|
from ._logging import verbose_logger
|
||||||
from .caching import QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
|
from .caching import Cache, QdrantSemanticCache, RedisCache, RedisSemanticCache, S3Cache
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIError,
|
APIError,
|
||||||
|
@ -150,31 +165,6 @@ from .types.llms.openai import (
|
||||||
)
|
)
|
||||||
from .types.router import LiteLLM_Params
|
from .types.router import LiteLLM_Params
|
||||||
|
|
||||||
try:
|
|
||||||
from .proxy.enterprise.enterprise_callbacks.generic_api_callback import (
|
|
||||||
GenericAPILogger,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
verbose_logger.debug(f"Exception import enterprise features {str(e)}")
|
|
||||||
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
cast,
|
|
||||||
get_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .caching import Cache
|
|
||||||
|
|
||||||
####### ENVIRONMENT VARIABLES ####################
|
####### ENVIRONMENT VARIABLES ####################
|
||||||
# Adjust to your specific application needs / system capabilities.
|
# Adjust to your specific application needs / system capabilities.
|
||||||
MAX_THREADS = 100
|
MAX_THREADS = 100
|
||||||
|
|
1
tests/llm_translation/Readme.md
Normal file
1
tests/llm_translation/Readme.md
Normal file
|
@ -0,0 +1 @@
|
||||||
|
More tests under `litellm/litellm/tests/*`.
|
502
tests/llm_translation/test_databricks.py
Normal file
502
tests/llm_translation/test_databricks.py
Normal file
|
@ -0,0 +1,502 @@
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.exceptions import BadRequestError, InternalServerError
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.utils import CustomStreamWrapper
|
||||||
|
|
||||||
|
|
||||||
|
def mock_chat_response() -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": "chatcmpl_3f78f09a-489c-4b8d-a587-f162c7497891",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1726285449,
|
||||||
|
"model": "dbrx-instruct-071224",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! I'm an AI assistant. I'm doing well. How can I help?",
|
||||||
|
"function_call": None,
|
||||||
|
"tool_calls": None,
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 230,
|
||||||
|
"completion_tokens": 38,
|
||||||
|
"total_tokens": 268,
|
||||||
|
"completion_tokens_details": None,
|
||||||
|
},
|
||||||
|
"system_fingerprint": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def mock_chat_streaming_response_chunks() -> List[str]:
|
||||||
|
return [
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1726469651,
|
||||||
|
"model": "dbrx-instruct-071224",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant", "content": "Hello"},
|
||||||
|
"finish_reason": None,
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 230,
|
||||||
|
"completion_tokens": 1,
|
||||||
|
"total_tokens": 231,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1726469651,
|
||||||
|
"model": "dbrx-instruct-071224",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": " world"},
|
||||||
|
"finish_reason": None,
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 230,
|
||||||
|
"completion_tokens": 1,
|
||||||
|
"total_tokens": 231,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"id": "chatcmpl_8a7075d1-956e-4960-b3a6-892cd4649ff3",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1726469651,
|
||||||
|
"model": "dbrx-instruct-071224",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": "!"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"logprobs": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 230,
|
||||||
|
"completion_tokens": 1,
|
||||||
|
"total_tokens": 231,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def mock_chat_streaming_response_chunks_bytes() -> List[bytes]:
|
||||||
|
string_chunks = mock_chat_streaming_response_chunks()
|
||||||
|
bytes_chunks = [chunk.encode("utf-8") + b"\n" for chunk in string_chunks]
|
||||||
|
# Simulate the end of the stream
|
||||||
|
bytes_chunks.append(b"")
|
||||||
|
return bytes_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def mock_http_handler_chat_streaming_response() -> MagicMock:
|
||||||
|
mock_stream_chunks = mock_chat_streaming_response_chunks()
|
||||||
|
|
||||||
|
def mock_iter_lines():
|
||||||
|
for chunk in mock_stream_chunks:
|
||||||
|
for line in chunk.splitlines():
|
||||||
|
yield line
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.iter_lines.side_effect = mock_iter_lines
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def mock_http_handler_chat_async_streaming_response() -> MagicMock:
|
||||||
|
mock_stream_chunks = mock_chat_streaming_response_chunks()
|
||||||
|
|
||||||
|
async def mock_iter_lines():
|
||||||
|
for chunk in mock_stream_chunks:
|
||||||
|
for line in chunk.splitlines():
|
||||||
|
yield line
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.aiter_lines.return_value = mock_iter_lines()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def mock_databricks_client_chat_streaming_response() -> MagicMock:
|
||||||
|
mock_stream_chunks = mock_chat_streaming_response_chunks_bytes()
|
||||||
|
|
||||||
|
def mock_read_from_stream(size=-1):
|
||||||
|
if mock_stream_chunks:
|
||||||
|
return mock_stream_chunks.pop(0)
|
||||||
|
return b""
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
streaming_response_mock = MagicMock()
|
||||||
|
streaming_response_iterator_mock = MagicMock()
|
||||||
|
# Mock the __getitem__("content") method to return the streaming response
|
||||||
|
mock_response.__getitem__.return_value = streaming_response_mock
|
||||||
|
# Mock the streaming response __enter__ method to return the streaming response iterator
|
||||||
|
streaming_response_mock.__enter__.return_value = streaming_response_iterator_mock
|
||||||
|
|
||||||
|
streaming_response_iterator_mock.read1.side_effect = mock_read_from_stream
|
||||||
|
streaming_response_iterator_mock.closed = False
|
||||||
|
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def mock_embedding_response() -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"model": "bge-large-en-v1.5",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": [
|
||||||
|
0.06768798828125,
|
||||||
|
-0.01291656494140625,
|
||||||
|
-0.0501708984375,
|
||||||
|
0.0245361328125,
|
||||||
|
-0.030364990234375,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 8,
|
||||||
|
"total_tokens": 8,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"completion_tokens_details": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("set_base", [True, False])
|
||||||
|
def test_throws_if_only_one_of_api_base_or_api_key_set(monkeypatch, set_base):
|
||||||
|
if set_base:
|
||||||
|
monkeypatch.setenv(
|
||||||
|
"DATABRICKS_API_BASE",
|
||||||
|
"https://my.workspace.cloud.databricks.com/serving-endpoints",
|
||||||
|
)
|
||||||
|
monkeypatch.delenv(
|
||||||
|
"DATABRICKS_API_KEY",
|
||||||
|
)
|
||||||
|
err_msg = "A call is being made to LLM Provider but no key is set"
|
||||||
|
else:
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_KEY", "dapimykey")
|
||||||
|
monkeypatch.delenv("DATABRICKS_API_BASE")
|
||||||
|
err_msg = "A call is being made to LLM Provider but no api base is set"
|
||||||
|
|
||||||
|
with pytest.raises(BadRequestError) as exc:
|
||||||
|
litellm.completion(
|
||||||
|
model="databricks/dbrx-instruct-071224",
|
||||||
|
messages={"role": "user", "content": "How are you?"},
|
||||||
|
)
|
||||||
|
assert err_msg in str(exc)
|
||||||
|
|
||||||
|
with pytest.raises(BadRequestError) as exc:
|
||||||
|
litellm.embedding(
|
||||||
|
model="databricks/bge-12312",
|
||||||
|
input=["Hello", "World"],
|
||||||
|
)
|
||||||
|
assert err_msg in str(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_with_sync_http_handler(monkeypatch):
|
||||||
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||||
|
api_key = "dapimykey"
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||||
|
|
||||||
|
sync_handler = HTTPHandler()
|
||||||
|
mock_response = Mock(spec=httpx.Response)
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = mock_chat_response()
|
||||||
|
|
||||||
|
expected_response_json = {
|
||||||
|
**mock_chat_response(),
|
||||||
|
**{
|
||||||
|
"model": "databricks/dbrx-instruct-071224",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "How are you?"}]
|
||||||
|
|
||||||
|
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="databricks/dbrx-instruct-071224",
|
||||||
|
messages=messages,
|
||||||
|
client=sync_handler,
|
||||||
|
temperature=0.5,
|
||||||
|
extraparam="testpassingextraparam",
|
||||||
|
)
|
||||||
|
assert response.to_dict() == expected_response_json
|
||||||
|
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
f"{base_url}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"model": "dbrx-instruct-071224",
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": 0.5,
|
||||||
|
"extraparam": "testpassingextraparam",
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_with_async_http_handler(monkeypatch):
|
||||||
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||||
|
api_key = "dapimykey"
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||||
|
|
||||||
|
async_handler = AsyncHTTPHandler()
|
||||||
|
mock_response = Mock(spec=httpx.Response)
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = mock_chat_response()
|
||||||
|
|
||||||
|
expected_response_json = {
|
||||||
|
**mock_chat_response(),
|
||||||
|
**{
|
||||||
|
"model": "databricks/dbrx-instruct-071224",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "How are you?"}]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
AsyncHTTPHandler, "post", return_value=mock_response
|
||||||
|
) as mock_post:
|
||||||
|
response = asyncio.run(
|
||||||
|
litellm.acompletion(
|
||||||
|
model="databricks/dbrx-instruct-071224",
|
||||||
|
messages=messages,
|
||||||
|
client=async_handler,
|
||||||
|
temperature=0.5,
|
||||||
|
extraparam="testpassingextraparam",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert response.to_dict() == expected_response_json
|
||||||
|
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
f"{base_url}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"model": "dbrx-instruct-071224",
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": 0.5,
|
||||||
|
"extraparam": "testpassingextraparam",
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_streaming_with_sync_http_handler(monkeypatch):
|
||||||
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||||
|
api_key = "dapimykey"
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||||
|
|
||||||
|
sync_handler = HTTPHandler()
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "How are you?"}]
|
||||||
|
mock_response = mock_http_handler_chat_streaming_response()
|
||||||
|
|
||||||
|
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
||||||
|
response_stream: CustomStreamWrapper = litellm.completion(
|
||||||
|
model="databricks/dbrx-instruct-071224",
|
||||||
|
messages=messages,
|
||||||
|
client=sync_handler,
|
||||||
|
temperature=0.5,
|
||||||
|
extraparam="testpassingextraparam",
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
response = list(response_stream)
|
||||||
|
assert "dbrx-instruct-071224" in str(response)
|
||||||
|
assert "chatcmpl" in str(response)
|
||||||
|
assert len(response) == 4
|
||||||
|
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
f"{base_url}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"model": "dbrx-instruct-071224",
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": 0.5,
|
||||||
|
"stream": True,
|
||||||
|
"extraparam": "testpassingextraparam",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_streaming_with_async_http_handler(monkeypatch):
|
||||||
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||||
|
api_key = "dapimykey"
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||||
|
|
||||||
|
async_handler = AsyncHTTPHandler()
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "How are you?"}]
|
||||||
|
mock_response = mock_http_handler_chat_async_streaming_response()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
AsyncHTTPHandler, "post", return_value=mock_response
|
||||||
|
) as mock_post:
|
||||||
|
response_stream: CustomStreamWrapper = asyncio.run(
|
||||||
|
litellm.acompletion(
|
||||||
|
model="databricks/dbrx-instruct-071224",
|
||||||
|
messages=messages,
|
||||||
|
client=async_handler,
|
||||||
|
temperature=0.5,
|
||||||
|
extraparam="testpassingextraparam",
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use async list gathering for the response
|
||||||
|
async def gather_responses():
|
||||||
|
return [item async for item in response_stream]
|
||||||
|
|
||||||
|
response = asyncio.run(gather_responses())
|
||||||
|
assert "dbrx-instruct-071224" in str(response)
|
||||||
|
assert "chatcmpl" in str(response)
|
||||||
|
assert len(response) == 4
|
||||||
|
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
f"{base_url}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"model": "dbrx-instruct-071224",
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": 0.5,
|
||||||
|
"stream": True,
|
||||||
|
"extraparam": "testpassingextraparam",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embeddings_with_sync_http_handler(monkeypatch):
|
||||||
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||||
|
api_key = "dapimykey"
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||||
|
|
||||||
|
sync_handler = HTTPHandler()
|
||||||
|
mock_response = Mock(spec=httpx.Response)
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = mock_embedding_response()
|
||||||
|
|
||||||
|
inputs = ["Hello", "World"]
|
||||||
|
|
||||||
|
with patch.object(HTTPHandler, "post", return_value=mock_response) as mock_post:
|
||||||
|
response = litellm.embedding(
|
||||||
|
model="databricks/bge-large-en-v1.5",
|
||||||
|
input=inputs,
|
||||||
|
client=sync_handler,
|
||||||
|
extraparam="testpassingextraparam",
|
||||||
|
)
|
||||||
|
assert response.to_dict() == mock_embedding_response()
|
||||||
|
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
f"{base_url}/embeddings",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"model": "bge-large-en-v1.5",
|
||||||
|
"input": inputs,
|
||||||
|
"extraparam": "testpassingextraparam",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embeddings_with_async_http_handler(monkeypatch):
|
||||||
|
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
|
||||||
|
api_key = "dapimykey"
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
|
||||||
|
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
|
||||||
|
|
||||||
|
async_handler = AsyncHTTPHandler()
|
||||||
|
mock_response = Mock(spec=httpx.Response)
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = mock_embedding_response()
|
||||||
|
|
||||||
|
inputs = ["Hello", "World"]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
AsyncHTTPHandler, "post", return_value=mock_response
|
||||||
|
) as mock_post:
|
||||||
|
response = asyncio.run(
|
||||||
|
litellm.aembedding(
|
||||||
|
model="databricks/bge-large-en-v1.5",
|
||||||
|
input=inputs,
|
||||||
|
client=async_handler,
|
||||||
|
extraparam="testpassingextraparam",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert response.to_dict() == mock_embedding_response()
|
||||||
|
|
||||||
|
mock_post.assert_called_once_with(
|
||||||
|
f"{base_url}/embeddings",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"model": "bge-large-en-v1.5",
|
||||||
|
"input": inputs,
|
||||||
|
"extraparam": "testpassingextraparam",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
Loading…
Add table
Add a link
Reference in a new issue