diff --git a/docs/my-website/docs/proxy/prometheus.md b/docs/my-website/docs/proxy/prometheus.md
index 207b1abe1..ef3a7b940 100644
--- a/docs/my-website/docs/proxy/prometheus.md
+++ b/docs/my-website/docs/proxy/prometheus.md
@@ -27,8 +27,7 @@ model_list:
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
- success_callback: ["prometheus"]
- failure_callback: ["prometheus"]
+ callbacks: ["prometheus"]
```
Start the proxy
diff --git a/litellm/__init__.py b/litellm/__init__.py
index 9379caac1..4a9b9c312 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -53,6 +53,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
"arize",
"langtrace",
"gcs_bucket",
+ "s3",
"opik",
]
_known_custom_logger_compatible_callbacks: List = list(
@@ -931,7 +932,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation impor
vertexAITextEmbeddingConfig = VertexAITextEmbeddingConfig()
-from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
+from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.anthropic.transformation import (
VertexAIAnthropicConfig,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
diff --git a/litellm/integrations/custom_batch_logger.py b/litellm/integrations/custom_batch_logger.py
index aa7f0bba2..7ef63d25c 100644
--- a/litellm/integrations/custom_batch_logger.py
+++ b/litellm/integrations/custom_batch_logger.py
@@ -21,6 +21,7 @@ class CustomBatchLogger(CustomLogger):
self,
flush_lock: Optional[asyncio.Lock] = None,
batch_size: Optional[int] = DEFAULT_BATCH_SIZE,
+ flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS,
**kwargs,
) -> None:
"""
@@ -28,7 +29,7 @@ class CustomBatchLogger(CustomLogger):
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
"""
self.log_queue: List = []
- self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds
+ self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
self.last_flush_time = time.time()
self.flush_lock = flush_lock
diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py
index 19740c741..177ff735c 100644
--- a/litellm/integrations/langfuse.py
+++ b/litellm/integrations/langfuse.py
@@ -235,6 +235,14 @@ class LangFuseLogger:
):
input = prompt
output = response_obj.results
+ elif (
+ kwargs.get("call_type") is not None
+ and kwargs.get("call_type") == "_arealtime"
+ and response_obj is not None
+ and isinstance(response_obj, list)
+ ):
+ input = kwargs.get("input")
+ output = response_obj
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "pass_through_endpoint"
diff --git a/litellm/integrations/s3.py b/litellm/integrations/s3.py
index 1f82406e1..397b53cd5 100644
--- a/litellm/integrations/s3.py
+++ b/litellm/integrations/s3.py
@@ -1,43 +1,67 @@
-#### What this does ####
-# On success + failure, log events to Supabase
+"""
+s3 Bucket Logging Integration
-import datetime
-import os
-import subprocess
-import sys
-import traceback
-import uuid
-from typing import Optional
+async_log_success_event: Processes the event, stores it in memory for 10 seconds or until MAX_BATCH_SIZE and then flushes to s3
+
+NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually
+NOTE 2: We create a httpx client with a concurrent limit of 1 to upload to s3. Files should get uploaded BUT they should not impact latency of LLM calling logic
+"""
+
+import asyncio
+import json
+from datetime import datetime
+from typing import Dict, List, Optional
import litellm
from litellm._logging import print_verbose, verbose_logger
+from litellm.llms.base_aws_llm import BaseAWSLLM
+from litellm.llms.custom_httpx.http_handler import (
+ _get_httpx_client,
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.types.integrations.s3 import s3BatchLoggingElement
from litellm.types.utils import StandardLoggingPayload
+from .custom_batch_logger import CustomBatchLogger
-class S3Logger:
+# Default Flush interval and batch size for s3
+# Flush to s3 every 10 seconds OR every 1K requests in memory
+DEFAULT_S3_FLUSH_INTERVAL_SECONDS = 10
+DEFAULT_S3_BATCH_SIZE = 1000
+
+
+class S3Logger(CustomBatchLogger, BaseAWSLLM):
# Class variables or attributes
def __init__(
self,
- s3_bucket_name=None,
- s3_path=None,
- s3_region_name=None,
- s3_api_version=None,
- s3_use_ssl=True,
- s3_verify=None,
- s3_endpoint_url=None,
- s3_aws_access_key_id=None,
- s3_aws_secret_access_key=None,
- s3_aws_session_token=None,
+ s3_bucket_name: Optional[str] = None,
+ s3_path: Optional[str] = None,
+ s3_region_name: Optional[str] = None,
+ s3_api_version: Optional[str] = None,
+ s3_use_ssl: bool = True,
+ s3_verify: Optional[bool] = None,
+ s3_endpoint_url: Optional[str] = None,
+ s3_aws_access_key_id: Optional[str] = None,
+ s3_aws_secret_access_key: Optional[str] = None,
+ s3_aws_session_token: Optional[str] = None,
+ s3_flush_interval: Optional[int] = DEFAULT_S3_FLUSH_INTERVAL_SECONDS,
+ s3_batch_size: Optional[int] = DEFAULT_S3_BATCH_SIZE,
s3_config=None,
**kwargs,
):
- import boto3
-
try:
verbose_logger.debug(
f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}"
)
+ # IMPORTANT: We use a concurrent limit of 1 to upload to s3
+ # Files should get uploaded BUT they should not impact latency of LLM calling logic
+ self.async_httpx_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback,
+ params={"concurrent_limit": 1},
+ )
+
if litellm.s3_callback_params is not None:
# read in .env variables - example os.environ/AWS_BUCKET_NAME
for key, value in litellm.s3_callback_params.items():
@@ -63,107 +87,282 @@ class S3Logger:
s3_path = litellm.s3_callback_params.get("s3_path")
# done reading litellm.s3_callback_params
+ s3_flush_interval = litellm.s3_callback_params.get(
+ "s3_flush_interval", DEFAULT_S3_FLUSH_INTERVAL_SECONDS
+ )
+ s3_batch_size = litellm.s3_callback_params.get(
+ "s3_batch_size", DEFAULT_S3_BATCH_SIZE
+ )
+
self.bucket_name = s3_bucket_name
self.s3_path = s3_path
verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}")
- # Create an S3 client with custom endpoint URL
- self.s3_client = boto3.client(
- "s3",
- region_name=s3_region_name,
- endpoint_url=s3_endpoint_url,
- api_version=s3_api_version,
- use_ssl=s3_use_ssl,
- verify=s3_verify,
- aws_access_key_id=s3_aws_access_key_id,
- aws_secret_access_key=s3_aws_secret_access_key,
- aws_session_token=s3_aws_session_token,
- config=s3_config,
- **kwargs,
+ self.s3_bucket_name = s3_bucket_name
+ self.s3_region_name = s3_region_name
+ self.s3_api_version = s3_api_version
+ self.s3_use_ssl = s3_use_ssl
+ self.s3_verify = s3_verify
+ self.s3_endpoint_url = s3_endpoint_url
+ self.s3_aws_access_key_id = s3_aws_access_key_id
+ self.s3_aws_secret_access_key = s3_aws_secret_access_key
+ self.s3_aws_session_token = s3_aws_session_token
+ self.s3_config = s3_config
+ self.init_kwargs = kwargs
+
+ asyncio.create_task(self.periodic_flush())
+ self.flush_lock = asyncio.Lock()
+
+ verbose_logger.debug(
+ f"s3 flush interval: {s3_flush_interval}, s3 batch size: {s3_batch_size}"
)
+ # Call CustomLogger's __init__
+ CustomBatchLogger.__init__(
+ self,
+ flush_lock=self.flush_lock,
+ flush_interval=s3_flush_interval,
+ batch_size=s3_batch_size,
+ )
+ self.log_queue: List[s3BatchLoggingElement] = []
+
+ # Call BaseAWSLLM's __init__
+ BaseAWSLLM.__init__(self)
+
except Exception as e:
print_verbose(f"Got exception on init s3 client {str(e)}")
raise e
- async def _async_log_event(
- self, kwargs, response_obj, start_time, end_time, print_verbose
- ):
- self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
-
- def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
verbose_logger.debug(
f"s3 Logging - Enters logging function for model {kwargs}"
)
- # construct payload to send to s3
- # follows the same params as langfuse.py
- litellm_params = kwargs.get("litellm_params", {})
- metadata = (
- litellm_params.get("metadata", {}) or {}
- ) # if litellm_params['metadata'] == None
-
- # Clean Metadata before logging - never log raw metadata
- # the raw metadata can contain circular references which leads to infinite recursion
- # we clean out all extra litellm metadata params before logging
- clean_metadata = {}
- if isinstance(metadata, dict):
- for key, value in metadata.items():
- # clean litellm metadata before logging
- if key in [
- "headers",
- "endpoint",
- "caching_groups",
- "previous_models",
- ]:
- continue
- else:
- clean_metadata[key] = value
-
- # Ensure everything in the payload is converted to str
- payload: Optional[StandardLoggingPayload] = kwargs.get(
- "standard_logging_object", None
+ s3_batch_logging_element = self.create_s3_batch_logging_element(
+ start_time=start_time,
+ standard_logging_payload=kwargs.get("standard_logging_object", None),
+ s3_path=self.s3_path,
)
- if payload is None:
- return
+ if s3_batch_logging_element is None:
+ raise ValueError("s3_batch_logging_element is None")
- s3_file_name = litellm.utils.get_logging_id(start_time, payload) or ""
- s3_object_key = (
- (self.s3_path.rstrip("/") + "/" if self.s3_path else "")
- + start_time.strftime("%Y-%m-%d")
- + "/"
- + s3_file_name
- ) # we need the s3 key to include the time, so we log cache hits too
- s3_object_key += ".json"
-
- s3_object_download_filename = (
- "time-"
- + start_time.strftime("%Y-%m-%dT%H-%M-%S-%f")
- + "_"
- + payload["id"]
- + ".json"
+ verbose_logger.debug(
+ "\ns3 Logger - Logging payload = %s", s3_batch_logging_element
)
- import json
-
- payload_str = json.dumps(payload)
-
- print_verbose(f"\ns3 Logger - Logging payload = {payload_str}")
-
- response = self.s3_client.put_object(
- Bucket=self.bucket_name,
- Key=s3_object_key,
- Body=payload_str,
- ContentType="application/json",
- ContentLanguage="en",
- ContentDisposition=f'inline; filename="{s3_object_download_filename}"',
- CacheControl="private, immutable, max-age=31536000, s-maxage=0",
+ self.log_queue.append(s3_batch_logging_element)
+ verbose_logger.debug(
+ "s3 logging: queue length %s, batch size %s",
+ len(self.log_queue),
+ self.batch_size,
)
-
- print_verbose(f"Response from s3:{str(response)}")
-
- print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
- return response
+ if len(self.log_queue) >= self.batch_size:
+ await self.flush_queue()
except Exception as e:
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
pass
+
+ def log_success_event(self, kwargs, response_obj, start_time, end_time):
+ """
+ Synchronous logging function to log to s3
+
+ Does not batch logging requests, instantly logs on s3 Bucket
+ """
+ try:
+ s3_batch_logging_element = self.create_s3_batch_logging_element(
+ start_time=start_time,
+ standard_logging_payload=kwargs.get("standard_logging_object", None),
+ s3_path=self.s3_path,
+ )
+
+ if s3_batch_logging_element is None:
+ raise ValueError("s3_batch_logging_element is None")
+
+ verbose_logger.debug(
+ "\ns3 Logger - Logging payload = %s", s3_batch_logging_element
+ )
+
+ # log the element sync httpx client
+ self.upload_data_to_s3(s3_batch_logging_element)
+ except Exception as e:
+ verbose_logger.exception(f"s3 Layer Error - {str(e)}")
+ pass
+
+ async def async_upload_data_to_s3(
+ self, batch_logging_element: s3BatchLoggingElement
+ ):
+ try:
+ import hashlib
+
+ import boto3
+ import requests
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ from botocore.credentials import Credentials
+ except ImportError:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+ try:
+ credentials: Credentials = self.get_credentials(
+ aws_access_key_id=self.s3_aws_access_key_id,
+ aws_secret_access_key=self.s3_aws_secret_access_key,
+ aws_session_token=self.s3_aws_session_token,
+ aws_region_name=self.s3_region_name,
+ )
+
+ # Prepare the URL
+ url = f"https://{self.bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}"
+
+ if self.s3_endpoint_url:
+ url = self.s3_endpoint_url + "/" + batch_logging_element.s3_object_key
+
+ # Convert JSON to string
+ json_string = json.dumps(batch_logging_element.payload)
+
+ # Calculate SHA256 hash of the content
+ content_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest()
+
+ # Prepare the request
+ headers = {
+ "Content-Type": "application/json",
+ "x-amz-content-sha256": content_hash,
+ "Content-Language": "en",
+ "Content-Disposition": f'inline; filename="{batch_logging_element.s3_object_download_filename}"',
+ "Cache-Control": "private, immutable, max-age=31536000, s-maxage=0",
+ }
+ req = requests.Request("PUT", url, data=json_string, headers=headers)
+ prepped = req.prepare()
+
+ # Sign the request
+ aws_request = AWSRequest(
+ method=prepped.method,
+ url=prepped.url,
+ data=prepped.body,
+ headers=prepped.headers,
+ )
+ SigV4Auth(credentials, "s3", self.s3_region_name).add_auth(aws_request)
+
+ # Prepare the signed headers
+ signed_headers = dict(aws_request.headers.items())
+
+ # Make the request
+ response = await self.async_httpx_client.put(
+ url, data=json_string, headers=signed_headers
+ )
+ response.raise_for_status()
+ except Exception as e:
+ verbose_logger.exception(f"Error uploading to s3: {str(e)}")
+
+ async def async_send_batch(self):
+ """
+
+ Sends runs from self.log_queue
+
+ Returns: None
+
+ Raises: Does not raise an exception, will only verbose_logger.exception()
+ """
+ if not self.log_queue:
+ return
+
+ for payload in self.log_queue:
+ asyncio.create_task(self.async_upload_data_to_s3(payload))
+
+ def create_s3_batch_logging_element(
+ self,
+ start_time: datetime,
+ standard_logging_payload: Optional[StandardLoggingPayload],
+ s3_path: Optional[str],
+ ) -> Optional[s3BatchLoggingElement]:
+ """
+ Helper function to create an s3BatchLoggingElement.
+
+ Args:
+ start_time (datetime): The start time of the logging event.
+ standard_logging_payload (Optional[StandardLoggingPayload]): The payload to be logged.
+ s3_path (Optional[str]): The S3 path prefix.
+
+ Returns:
+ Optional[s3BatchLoggingElement]: The created s3BatchLoggingElement, or None if payload is None.
+ """
+ if standard_logging_payload is None:
+ return None
+
+ s3_file_name = (
+ litellm.utils.get_logging_id(start_time, standard_logging_payload) or ""
+ )
+ s3_object_key = (
+ (s3_path.rstrip("/") + "/" if s3_path else "")
+ + start_time.strftime("%Y-%m-%d")
+ + "/"
+ + s3_file_name
+ + ".json"
+ )
+
+ s3_object_download_filename = f"time-{start_time.strftime('%Y-%m-%dT%H-%M-%S-%f')}_{standard_logging_payload['id']}.json"
+
+ return s3BatchLoggingElement(
+ payload=standard_logging_payload, # type: ignore
+ s3_object_key=s3_object_key,
+ s3_object_download_filename=s3_object_download_filename,
+ )
+
+ def upload_data_to_s3(self, batch_logging_element: s3BatchLoggingElement):
+ try:
+ import hashlib
+
+ import boto3
+ import requests
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ from botocore.credentials import Credentials
+ except ImportError:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+ try:
+ credentials: Credentials = self.get_credentials(
+ aws_access_key_id=self.s3_aws_access_key_id,
+ aws_secret_access_key=self.s3_aws_secret_access_key,
+ aws_session_token=self.s3_aws_session_token,
+ aws_region_name=self.s3_region_name,
+ )
+
+ # Prepare the URL
+ url = f"https://{self.bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}"
+
+ if self.s3_endpoint_url:
+ url = self.s3_endpoint_url + "/" + batch_logging_element.s3_object_key
+
+ # Convert JSON to string
+ json_string = json.dumps(batch_logging_element.payload)
+
+ # Calculate SHA256 hash of the content
+ content_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest()
+
+ # Prepare the request
+ headers = {
+ "Content-Type": "application/json",
+ "x-amz-content-sha256": content_hash,
+ "Content-Language": "en",
+ "Content-Disposition": f'inline; filename="{batch_logging_element.s3_object_download_filename}"',
+ "Cache-Control": "private, immutable, max-age=31536000, s-maxage=0",
+ }
+ req = requests.Request("PUT", url, data=json_string, headers=headers)
+ prepped = req.prepare()
+
+ # Sign the request
+ aws_request = AWSRequest(
+ method=prepped.method,
+ url=prepped.url,
+ data=prepped.body,
+ headers=prepped.headers,
+ )
+ SigV4Auth(credentials, "s3", self.s3_region_name).add_auth(aws_request)
+
+ # Prepare the signed headers
+ signed_headers = dict(aws_request.headers.items())
+
+ httpx_client = _get_httpx_client()
+ # Make the request
+ response = httpx_client.put(url, data=json_string, headers=signed_headers)
+ response.raise_for_status()
+ except Exception as e:
+ verbose_logger.exception(f"Error uploading to s3: {str(e)}")
diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py
index ce97f1c6f..a641be019 100644
--- a/litellm/litellm_core_utils/litellm_logging.py
+++ b/litellm/litellm_core_utils/litellm_logging.py
@@ -116,7 +116,6 @@ lagoLogger = None
dataDogLogger = None
prometheusLogger = None
dynamoLogger = None
-s3Logger = None
genericAPILogger = None
clickHouseLogger = None
greenscaleLogger = None
@@ -1346,36 +1345,6 @@ class Logging:
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
- if callback == "s3":
- global s3Logger
- if s3Logger is None:
- s3Logger = S3Logger()
- if self.stream:
- if "complete_streaming_response" in self.model_call_details:
- print_verbose(
- "S3Logger Logger: Got Stream Event - Completed Stream Response"
- )
- s3Logger.log_event(
- kwargs=self.model_call_details,
- response_obj=self.model_call_details[
- "complete_streaming_response"
- ],
- start_time=start_time,
- end_time=end_time,
- print_verbose=print_verbose,
- )
- else:
- print_verbose(
- "S3Logger Logger: Got Stream Event - No complete stream response as yet"
- )
- else:
- s3Logger.log_event(
- kwargs=self.model_call_details,
- response_obj=result,
- start_time=start_time,
- end_time=end_time,
- print_verbose=print_verbose,
- )
if (
callback == "openmeter"
and self.model_call_details.get("litellm_params", {}).get(
@@ -2245,7 +2214,7 @@ def set_callbacks(callback_list, function_id=None):
"""
Globally sets the callback client
"""
- global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
+ global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
try:
for callback in callback_list:
@@ -2319,8 +2288,6 @@ def set_callbacks(callback_list, function_id=None):
dataDogLogger = DataDogLogger()
elif callback == "dynamodb":
dynamoLogger = DyanmoDBLogger()
- elif callback == "s3":
- s3Logger = S3Logger()
elif callback == "wandb":
weightsBiasesLogger = WeightsBiasesLogger()
elif callback == "logfire":
@@ -2357,7 +2324,6 @@ def _init_custom_logger_compatible_class(
llm_router: Optional[
Any
], # expect litellm.Router, but typing errors due to circular import
- premium_user: Optional[bool] = None,
) -> Optional[CustomLogger]:
if logging_integration == "lago":
for callback in _in_memory_loggers:
@@ -2404,17 +2370,9 @@ def _init_custom_logger_compatible_class(
if isinstance(callback, PrometheusLogger):
return callback # type: ignore
- if premium_user:
- _prometheus_logger = PrometheusLogger()
- _in_memory_loggers.append(_prometheus_logger)
- return _prometheus_logger # type: ignore
- elif premium_user is False:
- verbose_logger.warning(
- f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}"
- )
- return None
- else:
- return None
+ _prometheus_logger = PrometheusLogger()
+ _in_memory_loggers.append(_prometheus_logger)
+ return _prometheus_logger # type: ignore
elif logging_integration == "datadog":
for callback in _in_memory_loggers:
if isinstance(callback, DataDogLogger):
@@ -2423,6 +2381,14 @@ def _init_custom_logger_compatible_class(
_datadog_logger = DataDogLogger()
_in_memory_loggers.append(_datadog_logger)
return _datadog_logger # type: ignore
+ elif logging_integration == "s3":
+ for callback in _in_memory_loggers:
+ if isinstance(callback, S3Logger):
+ return callback # type: ignore
+
+ _s3_logger = S3Logger()
+ _in_memory_loggers.append(_s3_logger)
+ return _s3_logger # type: ignore
elif logging_integration == "gcs_bucket":
for callback in _in_memory_loggers:
if isinstance(callback, GCSBucketLogger):
@@ -2589,6 +2555,10 @@ def get_custom_logger_compatible_class(
for callback in _in_memory_loggers:
if isinstance(callback, PrometheusLogger):
return callback
+ elif logging_integration == "s3":
+ for callback in _in_memory_loggers:
+ if isinstance(callback, S3Logger):
+ return callback
elif logging_integration == "datadog":
for callback in _in_memory_loggers:
if isinstance(callback, DataDogLogger):
diff --git a/litellm/litellm_core_utils/realtime_streaming.py b/litellm/litellm_core_utils/realtime_streaming.py
new file mode 100644
index 000000000..922f90e36
--- /dev/null
+++ b/litellm/litellm_core_utils/realtime_streaming.py
@@ -0,0 +1,112 @@
+"""
+async with websockets.connect( # type: ignore
+ url,
+ extra_headers={
+ "api-key": api_key, # type: ignore
+ },
+ ) as backend_ws:
+ forward_task = asyncio.create_task(
+ forward_messages(websocket, backend_ws)
+ )
+
+ try:
+ while True:
+ message = await websocket.receive_text()
+ await backend_ws.send(message)
+ except websockets.exceptions.ConnectionClosed: # type: ignore
+ forward_task.cancel()
+ finally:
+ if not forward_task.done():
+ forward_task.cancel()
+ try:
+ await forward_task
+ except asyncio.CancelledError:
+ pass
+"""
+
+import asyncio
+import concurrent.futures
+import traceback
+from asyncio import Task
+from typing import Any, Dict, List, Optional, Union
+
+from .litellm_logging import Logging as LiteLLMLogging
+
+# Create a thread pool with a maximum of 10 threads
+executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
+
+
+class RealTimeStreaming:
+ def __init__(
+ self,
+ websocket: Any,
+ backend_ws: Any,
+ logging_obj: Optional[LiteLLMLogging] = None,
+ ):
+ self.websocket = websocket
+ self.backend_ws = backend_ws
+ self.logging_obj = logging_obj
+ self.messages: List = []
+ self.input_message: Dict = {}
+
+ def store_message(self, message: Union[str, bytes]):
+ """Store message in list"""
+ self.messages.append(message)
+
+ def store_input(self, message: dict):
+ """Store input message"""
+ self.input_message = message
+ if self.logging_obj:
+ self.logging_obj.pre_call(input=message, api_key="")
+
+ async def log_messages(self):
+ """Log messages in list"""
+ if self.logging_obj:
+ ## ASYNC LOGGING
+ # Create an event loop for the new thread
+ asyncio.create_task(self.logging_obj.async_success_handler(self.messages))
+ ## SYNC LOGGING
+ executor.submit(self.logging_obj.success_handler(self.messages))
+
+ async def backend_to_client_send_messages(self):
+ import websockets
+
+ try:
+ while True:
+ message = await self.backend_ws.recv()
+ await self.websocket.send_text(message)
+
+ ## LOGGING
+ self.store_message(message)
+ except websockets.exceptions.ConnectionClosed: # type: ignore
+ pass
+ except Exception:
+ pass
+ finally:
+ await self.log_messages()
+
+ async def client_ack_messages(self):
+ try:
+ while True:
+ message = await self.websocket.receive_text()
+ ## LOGGING
+ self.store_input(message=message)
+ ## FORWARD TO BACKEND
+ await self.backend_ws.send(message)
+ except self.websockets.exceptions.ConnectionClosed: # type: ignore
+ pass
+
+ async def bidirectional_forward(self):
+
+ forward_task = asyncio.create_task(self.backend_to_client_send_messages())
+ try:
+ await self.client_ack_messages()
+ except self.websockets.exceptions.ConnectionClosed: # type: ignore
+ forward_task.cancel()
+ finally:
+ if not forward_task.done():
+ forward_task.cancel()
+ try:
+ await forward_task
+ except asyncio.CancelledError:
+ pass
diff --git a/litellm/llms/AzureOpenAI/realtime/handler.py b/litellm/llms/AzureOpenAI/realtime/handler.py
index 7d58ee78f..bf45c53fb 100644
--- a/litellm/llms/AzureOpenAI/realtime/handler.py
+++ b/litellm/llms/AzureOpenAI/realtime/handler.py
@@ -7,6 +7,8 @@ This requires websockets, and is currently only supported on LiteLLM Proxy.
import asyncio
from typing import Any, Optional
+from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
+from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
from ..azure import AzureChatCompletion
# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
@@ -44,6 +46,7 @@ class AzureOpenAIRealtime(AzureChatCompletion):
api_version: Optional[str] = None,
azure_ad_token: Optional[str] = None,
client: Optional[Any] = None,
+ logging_obj: Optional[LiteLLMLogging] = None,
timeout: Optional[float] = None,
):
import websockets
@@ -62,23 +65,10 @@ class AzureOpenAIRealtime(AzureChatCompletion):
"api-key": api_key, # type: ignore
},
) as backend_ws:
- forward_task = asyncio.create_task(
- forward_messages(websocket, backend_ws)
+ realtime_streaming = RealTimeStreaming(
+ websocket, backend_ws, logging_obj
)
-
- try:
- while True:
- message = await websocket.receive_text()
- await backend_ws.send(message)
- except websockets.exceptions.ConnectionClosed: # type: ignore
- forward_task.cancel()
- finally:
- if not forward_task.done():
- forward_task.cancel()
- try:
- await forward_task
- except asyncio.CancelledError:
- pass
+ await realtime_streaming.bidirectional_forward()
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
await websocket.close(code=e.status_code, reason=str(e))
diff --git a/litellm/llms/OpenAI/realtime/handler.py b/litellm/llms/OpenAI/realtime/handler.py
index 08e5fa0b9..a790b1800 100644
--- a/litellm/llms/OpenAI/realtime/handler.py
+++ b/litellm/llms/OpenAI/realtime/handler.py
@@ -7,20 +7,11 @@ This requires websockets, and is currently only supported on LiteLLM Proxy.
import asyncio
from typing import Any, Optional
+from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
+from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
from ..openai import OpenAIChatCompletion
-async def forward_messages(client_ws: Any, backend_ws: Any):
- import websockets
-
- try:
- while True:
- message = await backend_ws.recv()
- await client_ws.send_text(message)
- except websockets.exceptions.ConnectionClosed: # type: ignore
- pass
-
-
class OpenAIRealtime(OpenAIChatCompletion):
def _construct_url(self, api_base: str, model: str) -> str:
"""
@@ -35,6 +26,7 @@ class OpenAIRealtime(OpenAIChatCompletion):
self,
model: str,
websocket: Any,
+ logging_obj: LiteLLMLogging,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
client: Optional[Any] = None,
@@ -57,25 +49,26 @@ class OpenAIRealtime(OpenAIChatCompletion):
"OpenAI-Beta": "realtime=v1",
},
) as backend_ws:
- forward_task = asyncio.create_task(
- forward_messages(websocket, backend_ws)
+ realtime_streaming = RealTimeStreaming(
+ websocket, backend_ws, logging_obj
)
-
- try:
- while True:
- message = await websocket.receive_text()
- await backend_ws.send(message)
- except websockets.exceptions.ConnectionClosed: # type: ignore
- forward_task.cancel()
- finally:
- if not forward_task.done():
- forward_task.cancel()
- try:
- await forward_task
- except asyncio.CancelledError:
- pass
+ await realtime_streaming.bidirectional_forward()
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
await websocket.close(code=e.status_code, reason=str(e))
except Exception as e:
- await websocket.close(code=1011, reason=f"Internal server error: {str(e)}")
+ try:
+ await websocket.close(
+ code=1011, reason=f"Internal server error: {str(e)}"
+ )
+ except RuntimeError as close_error:
+ if "already completed" in str(close_error) or "websocket.close" in str(
+ close_error
+ ):
+ # The WebSocket is already closed or the response is completed, so we can ignore this error
+ pass
+ else:
+ # If it's a different RuntimeError, we might want to log it or handle it differently
+ raise Exception(
+ f"Unexpected error while closing WebSocket: {close_error}"
+ )
diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py
index 89798eef5..de09df19c 100644
--- a/litellm/llms/custom_llm.py
+++ b/litellm/llms/custom_llm.py
@@ -153,6 +153,8 @@ class CustomLLM(BaseLLM):
self,
model: str,
prompt: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
@@ -166,6 +168,12 @@ class CustomLLM(BaseLLM):
model: str,
prompt: str,
model_response: ImageResponse,
+ api_key: Optional[
+ str
+ ], # dynamically set api_key - https://docs.litellm.ai/docs/set_keys#api_key
+ api_base: Optional[
+ str
+ ], # dynamically set api_base - https://docs.litellm.ai/docs/set_keys#api_base
optional_params: dict,
logging_obj: Any,
timeout: Optional[Union[float, httpx.Timeout]] = None,
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py
deleted file mode 100644
index 0f98c34c2..000000000
--- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py
+++ /dev/null
@@ -1,464 +0,0 @@
-# What is this?
-## Handler file for calling claude-3 on vertex ai
-import copy
-import json
-import os
-import time
-import types
-import uuid
-from enum import Enum
-from typing import Any, Callable, List, Optional, Tuple, Union
-
-import httpx # type: ignore
-import requests # type: ignore
-
-import litellm
-from litellm.litellm_core_utils.core_helpers import map_finish_reason
-from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
-from litellm.types.llms.anthropic import (
- AnthropicMessagesTool,
- AnthropicMessagesToolChoice,
-)
-from litellm.types.llms.openai import (
- ChatCompletionToolParam,
- ChatCompletionToolParamFunctionChunk,
-)
-from litellm.types.utils import ResponseFormatChunk
-from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
-
-from ..prompt_templates.factory import (
- construct_tool_use_system_prompt,
- contains_tag,
- custom_prompt,
- extract_between_tags,
- parse_xml_params,
- prompt_factory,
- response_schema_prompt,
-)
-
-
-class VertexAIError(Exception):
- def __init__(self, status_code, message):
- self.status_code = status_code
- self.message = message
- self.request = httpx.Request(
- method="POST", url=" https://cloud.google.com/vertex-ai/"
- )
- self.response = httpx.Response(status_code=status_code, request=self.request)
- super().__init__(
- self.message
- ) # Call the base class constructor with the parameters it needs
-
-
-class VertexAIAnthropicConfig:
- """
- Reference:https://docs.anthropic.com/claude/reference/messages_post
-
- Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
-
- - `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
- - `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
-
- The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
-
- - `max_tokens` Required (integer) max tokens,
- - `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- - `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- - `temperature` Optional (float) The amount of randomness injected into the response
- - `top_p` Optional (float) Use nucleus sampling.
- - `top_k` Optional (int) Only sample from the top K options for each subsequent token
- - `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
-
- Note: Please make sure to modify the default parameters as required for your use case.
- """
-
- max_tokens: Optional[int] = (
- 4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
- )
- system: Optional[str] = None
- temperature: Optional[float] = None
- top_p: Optional[float] = None
- top_k: Optional[int] = None
- stop_sequences: Optional[List[str]] = None
-
- def __init__(
- self,
- max_tokens: Optional[int] = None,
- anthropic_version: Optional[str] = None,
- ) -> None:
- locals_ = locals()
- for key, value in locals_.items():
- if key == "max_tokens" and value is None:
- value = self.max_tokens
- if key != "self" and value is not None:
- setattr(self.__class__, key, value)
-
- @classmethod
- def get_config(cls):
- return {
- k: v
- for k, v in cls.__dict__.items()
- if not k.startswith("__")
- and not isinstance(
- v,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- classmethod,
- staticmethod,
- ),
- )
- and v is not None
- }
-
- def get_supported_openai_params(self):
- return [
- "max_tokens",
- "max_completion_tokens",
- "tools",
- "tool_choice",
- "stream",
- "stop",
- "temperature",
- "top_p",
- "response_format",
- ]
-
- def map_openai_params(self, non_default_params: dict, optional_params: dict):
- for param, value in non_default_params.items():
- if param == "max_tokens" or param == "max_completion_tokens":
- optional_params["max_tokens"] = value
- if param == "tools":
- optional_params["tools"] = value
- if param == "tool_choice":
- _tool_choice: Optional[AnthropicMessagesToolChoice] = None
- if value == "auto":
- _tool_choice = {"type": "auto"}
- elif value == "required":
- _tool_choice = {"type": "any"}
- elif isinstance(value, dict):
- _tool_choice = {"type": "tool", "name": value["function"]["name"]}
-
- if _tool_choice is not None:
- optional_params["tool_choice"] = _tool_choice
- if param == "stream":
- optional_params["stream"] = value
- if param == "stop":
- optional_params["stop_sequences"] = value
- if param == "temperature":
- optional_params["temperature"] = value
- if param == "top_p":
- optional_params["top_p"] = value
- if param == "response_format" and isinstance(value, dict):
- json_schema: Optional[dict] = None
- if "response_schema" in value:
- json_schema = value["response_schema"]
- elif "json_schema" in value:
- json_schema = value["json_schema"]["schema"]
- """
- When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- - You usually want to provide a single tool
- - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
- """
- _tool_choice = None
- _tool_choice = {"name": "json_tool_call", "type": "tool"}
-
- _tool = AnthropicMessagesTool(
- name="json_tool_call",
- input_schema={
- "type": "object",
- "properties": {"values": json_schema}, # type: ignore
- },
- )
-
- optional_params["tools"] = [_tool]
- optional_params["tool_choice"] = _tool_choice
- optional_params["json_mode"] = True
-
- return optional_params
-
-
-"""
-- Run client init
-- Support async completion, streaming
-"""
-
-
-def refresh_auth(
- credentials,
-) -> str: # used when user passes in credentials as json string
- from google.auth.transport.requests import Request # type: ignore[import-untyped]
-
- if credentials.token is None:
- credentials.refresh(Request())
-
- if not credentials.token:
- raise RuntimeError("Could not resolve API token from the credentials")
-
- return credentials.token
-
-
-def get_vertex_client(
- client: Any,
- vertex_project: Optional[str],
- vertex_location: Optional[str],
- vertex_credentials: Optional[str],
-) -> Tuple[Any, Optional[str]]:
- args = locals()
- from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
- VertexLLM,
- )
-
- try:
- from anthropic import AnthropicVertex
- except Exception:
- raise VertexAIError(
- status_code=400,
- message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
- )
-
- access_token: Optional[str] = None
-
- if client is None:
- _credentials, cred_project_id = VertexLLM().load_auth(
- credentials=vertex_credentials, project_id=vertex_project
- )
-
- vertex_ai_client = AnthropicVertex(
- project_id=vertex_project or cred_project_id,
- region=vertex_location or "us-central1",
- access_token=_credentials.token,
- )
- access_token = _credentials.token
- else:
- vertex_ai_client = client
- access_token = client.access_token
-
- return vertex_ai_client, access_token
-
-
-def create_vertex_anthropic_url(
- vertex_location: str, vertex_project: str, model: str, stream: bool
-) -> str:
- if stream is True:
- return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
- else:
- return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
-
-
-def completion(
- model: str,
- messages: list,
- model_response: ModelResponse,
- print_verbose: Callable,
- encoding,
- logging_obj,
- optional_params: dict,
- custom_prompt_dict: dict,
- headers: Optional[dict],
- timeout: Union[float, httpx.Timeout],
- vertex_project=None,
- vertex_location=None,
- vertex_credentials=None,
- litellm_params=None,
- logger_fn=None,
- acompletion: bool = False,
- client=None,
-):
- try:
- import vertexai
- except Exception:
- raise VertexAIError(
- status_code=400,
- message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
- )
-
- from anthropic import AnthropicVertex
-
- from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
- from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
- VertexLLM,
- )
-
- if not (
- hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
- ):
- raise VertexAIError(
- status_code=400,
- message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
- )
- try:
-
- vertex_httpx_logic = VertexLLM()
-
- access_token, project_id = vertex_httpx_logic._ensure_access_token(
- credentials=vertex_credentials,
- project_id=vertex_project,
- custom_llm_provider="vertex_ai",
- )
-
- anthropic_chat_completions = AnthropicChatCompletion()
-
- ## Load Config
- config = litellm.VertexAIAnthropicConfig.get_config()
- for k, v in config.items():
- if k not in optional_params:
- optional_params[k] = v
-
- ## CONSTRUCT API BASE
- stream = optional_params.get("stream", False)
-
- api_base = create_vertex_anthropic_url(
- vertex_location=vertex_location or "us-central1",
- vertex_project=vertex_project or project_id,
- model=model,
- stream=stream,
- )
-
- if headers is not None:
- vertex_headers = headers
- else:
- vertex_headers = {}
-
- vertex_headers.update({"Authorization": "Bearer {}".format(access_token)})
-
- optional_params.update(
- {"anthropic_version": "vertex-2023-10-16", "is_vertex_request": True}
- )
-
- return anthropic_chat_completions.completion(
- model=model,
- messages=messages,
- api_base=api_base,
- custom_prompt_dict=custom_prompt_dict,
- model_response=model_response,
- print_verbose=print_verbose,
- encoding=encoding,
- api_key=access_token,
- logging_obj=logging_obj,
- optional_params=optional_params,
- acompletion=acompletion,
- litellm_params=litellm_params,
- logger_fn=logger_fn,
- headers=vertex_headers,
- client=client,
- timeout=timeout,
- )
-
- except Exception as e:
- raise VertexAIError(status_code=500, message=str(e))
-
-
-async def async_completion(
- model: str,
- messages: list,
- data: dict,
- model_response: ModelResponse,
- print_verbose: Callable,
- logging_obj,
- vertex_project=None,
- vertex_location=None,
- optional_params=None,
- client=None,
- access_token=None,
-):
- from anthropic import AsyncAnthropicVertex
-
- if client is None:
- vertex_ai_client = AsyncAnthropicVertex(
- project_id=vertex_project, region=vertex_location, access_token=access_token # type: ignore
- )
- else:
- vertex_ai_client = client
-
- ## LOGGING
- logging_obj.pre_call(
- input=messages,
- api_key=None,
- additional_args={
- "complete_input_dict": optional_params,
- },
- )
- message = await vertex_ai_client.messages.create(**data) # type: ignore
- text_content = message.content[0].text
- ## TOOL CALLING - OUTPUT PARSE
- if text_content is not None and contains_tag("invoke", text_content):
- function_name = extract_between_tags("tool_name", text_content)[0]
- function_arguments_str = extract_between_tags("invoke", text_content)[0].strip()
- function_arguments_str = f"{function_arguments_str}"
- function_arguments = parse_xml_params(function_arguments_str)
- _message = litellm.Message(
- tool_calls=[
- {
- "id": f"call_{uuid.uuid4()}",
- "type": "function",
- "function": {
- "name": function_name,
- "arguments": json.dumps(function_arguments),
- },
- }
- ],
- content=None,
- )
- model_response.choices[0].message = _message # type: ignore
- else:
- model_response.choices[0].message.content = text_content # type: ignore
- model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
-
- ## CALCULATING USAGE
- prompt_tokens = message.usage.input_tokens
- completion_tokens = message.usage.output_tokens
-
- model_response.created = int(time.time())
- model_response.model = model
- usage = Usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=prompt_tokens + completion_tokens,
- )
- setattr(model_response, "usage", usage)
- return model_response
-
-
-async def async_streaming(
- model: str,
- messages: list,
- data: dict,
- model_response: ModelResponse,
- print_verbose: Callable,
- logging_obj,
- vertex_project=None,
- vertex_location=None,
- optional_params=None,
- client=None,
- access_token=None,
-):
- from anthropic import AsyncAnthropicVertex
-
- if client is None:
- vertex_ai_client = AsyncAnthropicVertex(
- project_id=vertex_project, region=vertex_location, access_token=access_token # type: ignore
- )
- else:
- vertex_ai_client = client
-
- ## LOGGING
- logging_obj.pre_call(
- input=messages,
- api_key=None,
- additional_args={
- "complete_input_dict": optional_params,
- },
- )
- response = await vertex_ai_client.messages.create(**data, stream=True) # type: ignore
- logging_obj.post_call(input=messages, api_key=None, original_response=response)
-
- streamwrapper = CustomStreamWrapper(
- completion_stream=response,
- model=model,
- custom_llm_provider="vertex_ai",
- logging_obj=logging_obj,
- )
-
- return streamwrapper
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/anthropic/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/anthropic/transformation.py
new file mode 100644
index 000000000..44b8af279
--- /dev/null
+++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/anthropic/transformation.py
@@ -0,0 +1,179 @@
+# What is this?
+## Handler file for calling claude-3 on vertex ai
+import copy
+import json
+import os
+import time
+import types
+import uuid
+from enum import Enum
+from typing import Any, Callable, List, Optional, Tuple, Union
+
+import httpx # type: ignore
+import requests # type: ignore
+
+import litellm
+from litellm.litellm_core_utils.core_helpers import map_finish_reason
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.types.llms.anthropic import (
+ AnthropicMessagesTool,
+ AnthropicMessagesToolChoice,
+)
+from litellm.types.llms.openai import (
+ ChatCompletionToolParam,
+ ChatCompletionToolParamFunctionChunk,
+)
+from litellm.types.utils import ResponseFormatChunk
+from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
+
+from ....prompt_templates.factory import (
+ construct_tool_use_system_prompt,
+ contains_tag,
+ custom_prompt,
+ extract_between_tags,
+ parse_xml_params,
+ prompt_factory,
+ response_schema_prompt,
+)
+
+
+class VertexAIError(Exception):
+ def __init__(self, status_code, message):
+ self.status_code = status_code
+ self.message = message
+ self.request = httpx.Request(
+ method="POST", url=" https://cloud.google.com/vertex-ai/"
+ )
+ self.response = httpx.Response(status_code=status_code, request=self.request)
+ super().__init__(
+ self.message
+ ) # Call the base class constructor with the parameters it needs
+
+
+class VertexAIAnthropicConfig:
+ """
+ Reference:https://docs.anthropic.com/claude/reference/messages_post
+
+ Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
+
+ - `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
+ - `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
+
+ The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
+
+ - `max_tokens` Required (integer) max tokens,
+ - `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
+ - `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
+ - `temperature` Optional (float) The amount of randomness injected into the response
+ - `top_p` Optional (float) Use nucleus sampling.
+ - `top_k` Optional (int) Only sample from the top K options for each subsequent token
+ - `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
+
+ Note: Please make sure to modify the default parameters as required for your use case.
+ """
+
+ max_tokens: Optional[int] = (
+ 4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
+ )
+ system: Optional[str] = None
+ temperature: Optional[float] = None
+ top_p: Optional[float] = None
+ top_k: Optional[int] = None
+ stop_sequences: Optional[List[str]] = None
+
+ def __init__(
+ self,
+ max_tokens: Optional[int] = None,
+ anthropic_version: Optional[str] = None,
+ ) -> None:
+ locals_ = locals()
+ for key, value in locals_.items():
+ if key == "max_tokens" and value is None:
+ value = self.max_tokens
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_supported_openai_params(self):
+ return [
+ "max_tokens",
+ "max_completion_tokens",
+ "tools",
+ "tool_choice",
+ "stream",
+ "stop",
+ "temperature",
+ "top_p",
+ "response_format",
+ ]
+
+ def map_openai_params(self, non_default_params: dict, optional_params: dict):
+ for param, value in non_default_params.items():
+ if param == "max_tokens" or param == "max_completion_tokens":
+ optional_params["max_tokens"] = value
+ if param == "tools":
+ optional_params["tools"] = value
+ if param == "tool_choice":
+ _tool_choice: Optional[AnthropicMessagesToolChoice] = None
+ if value == "auto":
+ _tool_choice = {"type": "auto"}
+ elif value == "required":
+ _tool_choice = {"type": "any"}
+ elif isinstance(value, dict):
+ _tool_choice = {"type": "tool", "name": value["function"]["name"]}
+
+ if _tool_choice is not None:
+ optional_params["tool_choice"] = _tool_choice
+ if param == "stream":
+ optional_params["stream"] = value
+ if param == "stop":
+ optional_params["stop_sequences"] = value
+ if param == "temperature":
+ optional_params["temperature"] = value
+ if param == "top_p":
+ optional_params["top_p"] = value
+ if param == "response_format" and isinstance(value, dict):
+ json_schema: Optional[dict] = None
+ if "response_schema" in value:
+ json_schema = value["response_schema"]
+ elif "json_schema" in value:
+ json_schema = value["json_schema"]["schema"]
+ """
+ When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
+ - You usually want to provide a single tool
+ - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
+ - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
+ """
+ _tool_choice = None
+ _tool_choice = {"name": "json_tool_call", "type": "tool"}
+
+ _tool = AnthropicMessagesTool(
+ name="json_tool_call",
+ input_schema={
+ "type": "object",
+ "properties": {"values": json_schema}, # type: ignore
+ },
+ )
+
+ optional_params["tools"] = [_tool]
+ optional_params["tool_choice"] = _tool_choice
+ optional_params["json_mode"] = True
+
+ return optional_params
diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py
index da54f6e1b..e8443e6f6 100644
--- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py
+++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py
@@ -9,13 +9,14 @@ import httpx # type: ignore
import litellm
from litellm.utils import ModelResponse
-from ...base import BaseLLM
+from ..vertex_llm_base import VertexBase
class VertexPartnerProvider(str, Enum):
mistralai = "mistralai"
llama = "llama"
ai21 = "ai21"
+ claude = "claude"
class VertexAIError(Exception):
@@ -31,31 +32,38 @@ class VertexAIError(Exception):
) # Call the base class constructor with the parameters it needs
-class VertexAIPartnerModels(BaseLLM):
+def create_vertex_url(
+ vertex_location: str,
+ vertex_project: str,
+ partner: VertexPartnerProvider,
+ stream: Optional[bool],
+ model: str,
+ api_base: Optional[str] = None,
+) -> str:
+ """Return the base url for the vertex partner models"""
+ if partner == VertexPartnerProvider.llama:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
+ elif partner == VertexPartnerProvider.mistralai:
+ if stream:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
+ else:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
+ elif partner == VertexPartnerProvider.ai21:
+ if stream:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
+ else:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
+ elif partner == VertexPartnerProvider.claude:
+ if stream:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
+ else:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
+
+
+class VertexAIPartnerModels(VertexBase):
def __init__(self) -> None:
pass
- def create_vertex_url(
- self,
- vertex_location: str,
- vertex_project: str,
- partner: VertexPartnerProvider,
- stream: Optional[bool],
- model: str,
- ) -> str:
- if partner == VertexPartnerProvider.llama:
- return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
- elif partner == VertexPartnerProvider.mistralai:
- if stream:
- return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
- else:
- return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
- elif partner == VertexPartnerProvider.ai21:
- if stream:
- return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
- else:
- return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
-
def completion(
self,
model: str,
@@ -64,6 +72,7 @@ class VertexAIPartnerModels(BaseLLM):
print_verbose: Callable,
encoding,
logging_obj,
+ api_base: Optional[str],
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
@@ -80,6 +89,7 @@ class VertexAIPartnerModels(BaseLLM):
import vertexai
from google.cloud import aiplatform
+ from litellm.llms.anthropic.chat import AnthropicChatCompletion
from litellm.llms.databricks.chat import DatabricksChatCompletion
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.llms.text_completion_codestral import CodestralTextCompletion
@@ -112,6 +122,7 @@ class VertexAIPartnerModels(BaseLLM):
openai_like_chat_completions = DatabricksChatCompletion()
codestral_fim_completions = CodestralTextCompletion()
+ anthropic_chat_completions = AnthropicChatCompletion()
## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False
@@ -126,8 +137,10 @@ class VertexAIPartnerModels(BaseLLM):
elif "jamba" in model:
partner = VertexPartnerProvider.ai21
optional_params["custom_endpoint"] = True
+ elif "claude" in model:
+ partner = VertexPartnerProvider.claude
- api_base = self.create_vertex_url(
+ default_api_base = create_vertex_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
partner=partner, # type: ignore
@@ -135,6 +148,21 @@ class VertexAIPartnerModels(BaseLLM):
model=model,
)
+ if len(default_api_base.split(":")) > 1:
+ endpoint = default_api_base.split(":")[-1]
+ else:
+ endpoint = ""
+
+ _, api_base = self._check_custom_proxy(
+ api_base=api_base,
+ custom_llm_provider="vertex_ai",
+ gemini_api_key=None,
+ endpoint=endpoint,
+ stream=stream,
+ auth_header=None,
+ url=default_api_base,
+ )
+
model = model.split("@")[0]
if "codestral" in model and litellm_params.get("text_completion") is True:
@@ -158,6 +186,35 @@ class VertexAIPartnerModels(BaseLLM):
timeout=timeout,
encoding=encoding,
)
+ elif "claude" in model:
+ if headers is None:
+ headers = {}
+ headers.update({"Authorization": "Bearer {}".format(access_token)})
+
+ optional_params.update(
+ {
+ "anthropic_version": "vertex-2023-10-16",
+ "is_vertex_request": True,
+ }
+ )
+ return anthropic_chat_completions.completion(
+ model=model,
+ messages=messages,
+ api_base=api_base,
+ acompletion=acompletion,
+ custom_prompt_dict=litellm.custom_prompt_dict,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ encoding=encoding, # for calculating input/output tokens
+ api_key=access_token,
+ logging_obj=logging_obj,
+ headers=headers,
+ timeout=timeout,
+ client=client,
+ )
return openai_like_chat_completions.completion(
model=model,
diff --git a/litellm/main.py b/litellm/main.py
index b53db67f4..e5a22b3f5 100644
--- a/litellm/main.py
+++ b/litellm/main.py
@@ -117,10 +117,7 @@ from .llms.sagemaker.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion
-from .llms.vertex_ai_and_google_ai_studio import (
- vertex_ai_anthropic,
- vertex_ai_non_gemini,
-)
+from .llms.vertex_ai_and_google_ai_studio import vertex_ai_non_gemini
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
@@ -747,6 +744,11 @@ def completion( # type: ignore
proxy_server_request = kwargs.get("proxy_server_request", None)
fallbacks = kwargs.get("fallbacks", None)
headers = kwargs.get("headers", None) or extra_headers
+ if headers is None:
+ headers = {}
+
+ if extra_headers is not None:
+ headers.update(extra_headers)
num_retries = kwargs.get(
"num_retries", None
) ## alt. param for 'max_retries'. Use this to pass retries w/ instructor.
@@ -964,7 +966,6 @@ def completion( # type: ignore
max_retries=max_retries,
logprobs=logprobs,
top_logprobs=top_logprobs,
- extra_headers=extra_headers,
api_version=api_version,
parallel_tool_calls=parallel_tool_calls,
messages=messages,
@@ -1067,6 +1068,9 @@ def completion( # type: ignore
headers = headers or litellm.headers
+ if extra_headers is not None:
+ optional_params["extra_headers"] = extra_headers
+
if (
litellm.enable_preview_features
and litellm.AzureOpenAIO1Config().is_o1_model(model=model)
@@ -1166,6 +1170,9 @@ def completion( # type: ignore
headers = headers or litellm.headers
+ if extra_headers is not None:
+ optional_params["extra_headers"] = extra_headers
+
## LOAD CONFIG - if set
config = litellm.AzureOpenAIConfig.get_config()
for k, v in config.items():
@@ -1223,6 +1230,9 @@ def completion( # type: ignore
headers = headers or litellm.headers
+ if extra_headers is not None:
+ optional_params["extra_headers"] = extra_headers
+
## LOAD CONFIG - if set
config = litellm.AzureAIStudioConfig.get_config()
for k, v in config.items():
@@ -1304,6 +1314,9 @@ def completion( # type: ignore
headers = headers or litellm.headers
+ if extra_headers is not None:
+ optional_params["extra_headers"] = extra_headers
+
## LOAD CONFIG - if set
config = litellm.OpenAITextCompletionConfig.get_config()
for k, v in config.items():
@@ -1466,6 +1479,9 @@ def completion( # type: ignore
headers = headers or litellm.headers
+ if extra_headers is not None:
+ optional_params["extra_headers"] = extra_headers
+
## LOAD CONFIG - if set
config = litellm.OpenAIConfig.get_config()
for k, v in config.items():
@@ -2228,31 +2244,12 @@ def completion( # type: ignore
)
new_params = deepcopy(optional_params)
- if "claude-3" in model:
- model_response = vertex_ai_anthropic.completion(
- model=model,
- messages=messages,
- model_response=model_response,
- print_verbose=print_verbose,
- optional_params=new_params,
- litellm_params=litellm_params,
- logger_fn=logger_fn,
- encoding=encoding,
- vertex_location=vertex_ai_location,
- vertex_project=vertex_ai_project,
- vertex_credentials=vertex_credentials,
- logging_obj=logging,
- acompletion=acompletion,
- headers=headers,
- custom_prompt_dict=custom_prompt_dict,
- timeout=timeout,
- client=client,
- )
- elif (
+ if (
model.startswith("meta/")
or model.startswith("mistral")
or model.startswith("codestral")
or model.startswith("jamba")
+ or model.startswith("claude")
):
model_response = vertex_partner_models_chat_completion.completion(
model=model,
@@ -2263,6 +2260,7 @@ def completion( # type: ignore
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
+ api_base=api_base,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
@@ -4848,6 +4846,8 @@ def image_generation(
model_response = custom_handler.aimage_generation( # type: ignore
model=model,
prompt=prompt,
+ api_key=api_key,
+ api_base=api_base,
model_response=model_response,
optional_params=optional_params,
logging_obj=litellm_logging_obj,
@@ -4863,6 +4863,8 @@ def image_generation(
model_response = custom_handler.image_generation(
model=model,
prompt=prompt,
+ api_key=api_key,
+ api_base=api_base,
model_response=model_response,
optional_params=optional_params,
logging_obj=litellm_logging_obj,
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index 536976ce4..dda23b7b4 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -1,12 +1,15 @@
model_list:
- - model_name: gpt-4o-mini
- litellm_params:
- model: azure/my-gpt-4o-mini
- api_key: os.environ/AZURE_API_KEY
- api_base: os.environ/AZURE_API_BASE
+ # - model_name: openai-gpt-4o-realtime-audio
+ # litellm_params:
+ # model: azure/gpt-4o-realtime-preview
+ # api_key: os.environ/AZURE_SWEDEN_API_KEY
+ # api_base: os.environ/AZURE_SWEDEN_API_BASE
+
+ - model_name: openai-gpt-4o-realtime-audio
+ litellm_params:
+ model: openai/gpt-4o-realtime-preview-2024-10-01
+ api_key: os.environ/OPENAI_API_KEY
+ api_base: http://localhost:8080
litellm_settings:
- turn_off_message_logging: true
- cache: True
- cache_params:
- type: local
\ No newline at end of file
+ success_callback: ["langfuse"]
\ No newline at end of file
diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py
index 59a4edb90..538a1eee1 100644
--- a/litellm/proxy/common_utils/callback_utils.py
+++ b/litellm/proxy/common_utils/callback_utils.py
@@ -241,6 +241,8 @@ def initialize_callbacks_on_proxy(
litellm.callbacks = imported_list # type: ignore
if "prometheus" in value:
+ if premium_user is not True:
+ raise Exception(CommonProxyErrors.not_premium_user.value)
from litellm.proxy.proxy_server import app
verbose_proxy_logger.debug("Starting Prometheus Metrics on /metrics")
diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml
index d611aa87b..11ccc8561 100644
--- a/litellm/proxy/proxy_config.yaml
+++ b/litellm/proxy/proxy_config.yaml
@@ -1,16 +1,17 @@
model_list:
- model_name: db-openai-endpoint
litellm_params:
- model: openai/gpt-5
+ model: openai/gpt-4
api_key: fake-key
- api_base: https://exampleopenaiendpoint-production.up.railwaz.app/
- - model_name: db-openai-endpoint
- litellm_params:
- model: openai/gpt-5
- api_key: fake-key
- api_base: https://exampleopenaiendpoint-production.up.railwxaz.app/
+ api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings:
- callbacks: ["prometheus"]
+ success_callback: ["s3"]
+ turn_off_message_logging: true
+ s3_callback_params:
+ s3_bucket_name: load-testing-oct # AWS Bucket Name for S3
+ s3_region_name: us-west-2 # AWS Region Name for S3
+ s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3
+ s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py
index d75a087da..8407b4e86 100644
--- a/litellm/proxy/proxy_server.py
+++ b/litellm/proxy/proxy_server.py
@@ -1692,6 +1692,10 @@ class ProxyConfig:
else:
litellm.success_callback.append(callback)
if "prometheus" in callback:
+ if not premium_user:
+ raise Exception(
+ CommonProxyErrors.not_premium_user.value
+ )
verbose_proxy_logger.debug(
"Starting Prometheus Metrics on /metrics"
)
@@ -5433,12 +5437,11 @@ async def moderations(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
- verbose_proxy_logger.error(
+ verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.moderations(): Exception occured - {}".format(
str(e)
)
)
- verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e)),
diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py
index 666118039..25de59c8e 100644
--- a/litellm/proxy/utils.py
+++ b/litellm/proxy/utils.py
@@ -408,7 +408,6 @@ class ProxyLogging:
callback,
internal_usage_cache=self.internal_usage_cache.dual_cache,
llm_router=llm_router,
- premium_user=self.premium_user,
)
if callback is None:
continue
diff --git a/litellm/realtime_api/main.py b/litellm/realtime_api/main.py
index 5e512795f..9088a491c 100644
--- a/litellm/realtime_api/main.py
+++ b/litellm/realtime_api/main.py
@@ -8,13 +8,16 @@ from litellm import get_llm_provider
from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams
+from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ..llms.AzureOpenAI.realtime.handler import AzureOpenAIRealtime
from ..llms.OpenAI.realtime.handler import OpenAIRealtime
+from ..utils import client as wrapper_client
azure_realtime = AzureOpenAIRealtime()
openai_realtime = OpenAIRealtime()
+@wrapper_client
async def _arealtime(
model: str,
websocket: Any, # fastapi websocket
@@ -31,6 +34,12 @@ async def _arealtime(
For PROXY use only.
"""
+ litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
+ litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
+ proxy_server_request = kwargs.get("proxy_server_request", None)
+ model_info = kwargs.get("model_info", None)
+ metadata = kwargs.get("metadata", {})
+ user = kwargs.get("user", None)
litellm_params = GenericLiteLLMParams(**kwargs)
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
@@ -39,6 +48,21 @@ async def _arealtime(
api_key=api_key,
)
+ litellm_logging_obj.update_environment_variables(
+ model=model,
+ user=user,
+ optional_params={},
+ litellm_params={
+ "litellm_call_id": litellm_call_id,
+ "proxy_server_request": proxy_server_request,
+ "model_info": model_info,
+ "metadata": metadata,
+ "preset_cache_key": None,
+ "stream_response": {},
+ },
+ custom_llm_provider=_custom_llm_provider,
+ )
+
if _custom_llm_provider == "azure":
api_base = (
dynamic_api_base
@@ -63,6 +87,7 @@ async def _arealtime(
azure_ad_token=None,
client=None,
timeout=timeout,
+ logging_obj=litellm_logging_obj,
)
elif _custom_llm_provider == "openai":
api_base = (
@@ -82,6 +107,7 @@ async def _arealtime(
await openai_realtime.async_realtime(
model=model,
websocket=websocket,
+ logging_obj=litellm_logging_obj,
api_base=api_base,
api_key=api_key,
client=None,
diff --git a/litellm/router_utils/pattern_match_deployments.py b/litellm/router_utils/pattern_match_deployments.py
index f735c965a..6c4ec8e6b 100644
--- a/litellm/router_utils/pattern_match_deployments.py
+++ b/litellm/router_utils/pattern_match_deployments.py
@@ -60,7 +60,7 @@ class PatternMatchRouter:
regex = re.escape(regex).replace(r"\.\*", ".*")
return f"^{regex}$"
- def route(self, request: str) -> Optional[List[Dict]]:
+ def route(self, request: Optional[str]) -> Optional[List[Dict]]:
"""
Route a requested model to the corresponding llm deployments based on the regex pattern
@@ -69,17 +69,20 @@ class PatternMatchRouter:
if no pattern is found, return None
Args:
- request: str
+ request: Optional[str]
Returns:
Optional[List[Deployment]]: llm deployments
"""
try:
+ if request is None:
+ return None
for pattern, llm_deployments in self.patterns.items():
if re.match(pattern, request):
return llm_deployments
except Exception as e:
- verbose_router_logger.error(f"Error in PatternMatchRouter.route: {str(e)}")
+ verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}")
+
return None # No matching pattern found
diff --git a/litellm/types/integrations/s3.py b/litellm/types/integrations/s3.py
new file mode 100644
index 000000000..d66e2c59d
--- /dev/null
+++ b/litellm/types/integrations/s3.py
@@ -0,0 +1,14 @@
+from typing import Dict
+
+from pydantic import BaseModel
+
+
+class s3BatchLoggingElement(BaseModel):
+ """
+ Type of element stored in self.log_queue in S3Logger
+
+ """
+
+ payload: Dict
+ s3_object_key: str
+ s3_object_download_filename: str
diff --git a/litellm/types/utils.py b/litellm/types/utils.py
index 6019b0e6f..c3118b453 100644
--- a/litellm/types/utils.py
+++ b/litellm/types/utils.py
@@ -129,6 +129,7 @@ class CallTypes(Enum):
speech = "speech"
rerank = "rerank"
arerank = "arerank"
+ arealtime = "_arealtime"
class PassthroughCallTypes(Enum):
diff --git a/litellm/utils.py b/litellm/utils.py
index 5afeab58e..9efde1be7 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -197,7 +197,6 @@ lagoLogger = None
dataDogLogger = None
prometheusLogger = None
dynamoLogger = None
-s3Logger = None
genericAPILogger = None
clickHouseLogger = None
greenscaleLogger = None
@@ -1410,6 +1409,8 @@ def client(original_function):
)
else:
return result
+ elif call_type == CallTypes.arealtime.value:
+ return result
# ADD HIDDEN PARAMS - additional call metadata
if hasattr(result, "_hidden_params"):
@@ -1799,8 +1800,9 @@ def calculate_tiles_needed(
total_tiles = tiles_across * tiles_down
return total_tiles
+
def get_image_type(image_data: bytes) -> Union[str, None]:
- """ take an image (really only the first ~100 bytes max are needed)
+ """take an image (really only the first ~100 bytes max are needed)
and return 'png' 'gif' 'jpeg' 'heic' or None. method added to
allow deprecation of imghdr in 3.13"""
@@ -4336,16 +4338,18 @@ def get_api_base(
_optional_params.vertex_location is not None
and _optional_params.vertex_project is not None
):
- from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
- create_vertex_anthropic_url,
+ from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
+ VertexPartnerProvider,
+ create_vertex_url,
)
if "claude" in model:
- _api_base = create_vertex_anthropic_url(
+ _api_base = create_vertex_url(
vertex_location=_optional_params.vertex_location,
vertex_project=_optional_params.vertex_project,
model=model,
stream=stream,
+ partner=VertexPartnerProvider.claude,
)
else:
@@ -4442,19 +4446,7 @@ def get_supported_openai_params(
elif custom_llm_provider == "volcengine":
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "groq":
- return [
- "temperature",
- "max_tokens",
- "top_p",
- "stream",
- "stop",
- "tools",
- "tool_choice",
- "response_format",
- "seed",
- "extra_headers",
- "extra_body",
- ]
+ return litellm.GroqChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "hosted_vllm":
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "deepseek":
@@ -4599,6 +4591,8 @@ def get_supported_openai_params(
return (
litellm.MistralTextCompletionConfig().get_supported_openai_params()
)
+ if model.startswith("claude"):
+ return litellm.VertexAIAnthropicConfig().get_supported_openai_params()
return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
diff --git a/tests/llm_translation/test_max_completion_tokens.py b/tests/llm_translation/test_max_completion_tokens.py
index 9ec148e47..3ce57380c 100644
--- a/tests/llm_translation/test_max_completion_tokens.py
+++ b/tests/llm_translation/test_max_completion_tokens.py
@@ -296,7 +296,7 @@ def test_all_model_configs():
optional_params={},
) == {"max_tokens": 10}
- from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
+ from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.anthropic.transformation import (
VertexAIAnthropicConfig,
)
diff --git a/tests/local_testing/test_amazing_s3_logs.py b/tests/local_testing/test_amazing_s3_logs.py
index c3e8a61db..5459647c1 100644
--- a/tests/local_testing/test_amazing_s3_logs.py
+++ b/tests/local_testing/test_amazing_s3_logs.py
@@ -12,7 +12,70 @@ import litellm
litellm.num_retries = 3
import time, random
+from litellm._logging import verbose_logger
+import logging
import pytest
+import boto3
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("sync_mode", [True, False])
+async def test_basic_s3_logging(sync_mode):
+ verbose_logger.setLevel(level=logging.DEBUG)
+ litellm.success_callback = ["s3"]
+ litellm.s3_callback_params = {
+ "s3_bucket_name": "load-testing-oct",
+ "s3_aws_secret_access_key": "os.environ/AWS_SECRET_ACCESS_KEY",
+ "s3_aws_access_key_id": "os.environ/AWS_ACCESS_KEY_ID",
+ "s3_region_name": "us-west-2",
+ }
+ litellm.set_verbose = True
+
+ if sync_mode is True:
+ response = litellm.completion(
+ model="gpt-3.5-turbo",
+ messages=[{"role": "user", "content": "This is a test"}],
+ mock_response="It's simple to use and easy to get started",
+ )
+ else:
+ response = await litellm.acompletion(
+ model="gpt-3.5-turbo",
+ messages=[{"role": "user", "content": "This is a test"}],
+ mock_response="It's simple to use and easy to get started",
+ )
+ print(f"response: {response}")
+
+ await asyncio.sleep(12)
+
+ total_objects, all_s3_keys = list_all_s3_objects("load-testing-oct")
+
+ # assert that atlest one key has response.id in it
+ assert any(response.id in key for key in all_s3_keys)
+ s3 = boto3.client("s3")
+ # delete all objects
+ for key in all_s3_keys:
+ s3.delete_object(Bucket="load-testing-oct", Key=key)
+
+
+def list_all_s3_objects(bucket_name):
+ s3 = boto3.client("s3")
+
+ all_s3_keys = []
+
+ paginator = s3.get_paginator("list_objects_v2")
+ total_objects = 0
+
+ for page in paginator.paginate(Bucket=bucket_name):
+ if "Contents" in page:
+ total_objects += len(page["Contents"])
+ all_s3_keys.extend([obj["Key"] for obj in page["Contents"]])
+
+ print(f"Total number of objects in {bucket_name}: {total_objects}")
+ print(all_s3_keys)
+ return total_objects, all_s3_keys
+
+
+list_all_s3_objects("load-testing-oct")
@pytest.mark.skip(reason="AWS Suspended Account")
diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py
index 604c2eed8..d8c448976 100644
--- a/tests/local_testing/test_amazing_vertex_completion.py
+++ b/tests/local_testing/test_amazing_vertex_completion.py
@@ -1616,9 +1616,11 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema(
)
-@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
+@pytest.mark.parametrize(
+ "model", ["gemini-1.5-flash", "claude-3-sonnet@20240229"]
+) # "vertex_ai",
@pytest.mark.asyncio
-async def test_gemini_pro_httpx_custom_api_base(provider):
+async def test_gemini_pro_httpx_custom_api_base(model):
load_vertex_ai_credentials()
litellm.set_verbose = True
messages = [
@@ -1634,7 +1636,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
with patch.object(client, "post", new=MagicMock()) as mock_call:
try:
response = completion(
- model="vertex_ai_beta/gemini-1.5-flash",
+ model="vertex_ai/{}".format(model),
messages=messages,
response_format={"type": "json_object"},
client=client,
@@ -1647,8 +1649,17 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
mock_call.assert_called_once()
- assert "my-custom-api-base:generateContent" == mock_call.call_args.kwargs["url"]
- assert "hello" in mock_call.call_args.kwargs["headers"]
+ print(f"mock_call.call_args: {mock_call.call_args}")
+ print(f"mock_call.call_args.kwargs: {mock_call.call_args.kwargs}")
+ if "url" in mock_call.call_args.kwargs:
+ assert (
+ "my-custom-api-base:generateContent"
+ == mock_call.call_args.kwargs["url"]
+ )
+ else:
+ assert "my-custom-api-base:rawPredict" == mock_call.call_args[0][0]
+ if "headers" in mock_call.call_args.kwargs:
+ assert "hello" in mock_call.call_args.kwargs["headers"]
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
diff --git a/tests/local_testing/test_custom_llm.py b/tests/local_testing/test_custom_llm.py
index c9edde4a8..29daef481 100644
--- a/tests/local_testing/test_custom_llm.py
+++ b/tests/local_testing/test_custom_llm.py
@@ -28,7 +28,6 @@ from typing import (
Union,
)
from unittest.mock import AsyncMock, MagicMock, patch
-
import httpx
from dotenv import load_dotenv
@@ -226,6 +225,8 @@ class MyCustomLLM(CustomLLM):
self,
model: str,
prompt: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
@@ -242,6 +243,8 @@ class MyCustomLLM(CustomLLM):
self,
model: str,
prompt: str,
+ api_key: Optional[str],
+ api_base: Optional[str],
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
@@ -362,3 +365,31 @@ async def test_simple_image_generation_async():
)
print(resp)
+
+
+@pytest.mark.asyncio
+async def test_image_generation_async_with_api_key_and_api_base():
+ my_custom_llm = MyCustomLLM()
+ litellm.custom_provider_map = [
+ {"provider": "custom_llm", "custom_handler": my_custom_llm}
+ ]
+
+ with patch.object(
+ my_custom_llm, "aimage_generation", new=AsyncMock()
+ ) as mock_client:
+ try:
+ resp = await litellm.aimage_generation(
+ model="custom_llm/my-fake-model",
+ prompt="Hello world",
+ api_key="my-api-key",
+ api_base="my-api-base",
+ )
+
+ print(resp)
+ except Exception as e:
+ print(e)
+
+ mock_client.assert_awaited_once()
+
+ mock_client.call_args.kwargs["api_key"] == "my-api-key"
+ mock_client.call_args.kwargs["api_base"] == "my-api-base"