LiteLLM Minor Fixes & Improvements (10/10/2024) (#6158)

* refactor(vertex_ai_partner_models/anthropic): refactor anthropic to use partner model logic

* fix(vertex_ai/): support passing custom api base to partner models

Fixes https://github.com/BerriAI/litellm/issues/4317

* fix(proxy_server.py): Fix prometheus premium user check logic

* docs(prometheus.md): update quick start docs

* fix(custom_llm.py): support passing dynamic api key + api base

* fix(realtime_api/main.py): Add request/response logging for realtime api endpoints

Closes https://github.com/BerriAI/litellm/issues/6081

* feat(openai/realtime): add openai realtime api logging

Closes https://github.com/BerriAI/litellm/issues/6081

* fix(realtime_streaming.py): fix linting errors

* fix(realtime_streaming.py): fix linting errors

* fix: fix linting errors

* fix pattern match router

* Add literalai in the sidebar observability category (#6163)

* fix: add literalai in the sidebar

* fix: typo

* update (#6160)

* Feat: Add Langtrace integration (#5341)

* Feat: Add Langtrace integration

* add langtrace service name

* fix timestamps for traces

* add tests

* Discard Callback + use existing otel logger

* cleanup

* remove print statments

* remove callback

* add docs

* docs

* add logging docs

* format logging

* remove emoji and add litellm proxy example

* format logging

* format `logging.md`

* add langtrace docs to logging.md

* sync conflict

* docs fix

* (perf) move s3 logging to Batch logging + async [94% faster perf under 100 RPS on 1 litellm instance] (#6165)

* fix move s3 to use customLogger

* add basic s3 logging test

* add s3 to custom logger compatible

* use batch logger for s3

* s3 set flush interval and batch size

* fix s3 logging

* add notes on s3 logging

* fix s3 logging

* add basic s3 logging test

* fix s3 type errors

* add test for sync logging on s3

* fix: fix to debug log

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: Willy Douhard <willy.douhard@gmail.com>
Co-authored-by: yujonglee <yujonglee.dev@gmail.com>
Co-authored-by: Ali Waleed <ali@scale3labs.com>
This commit is contained in:
Krish Dholakia 2024-10-11 23:04:36 -07:00 committed by GitHub
parent 9db4ccca9f
commit 11f9df923a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 966 additions and 760 deletions

View file

@ -27,8 +27,7 @@ model_list:
litellm_params: litellm_params:
model: gpt-3.5-turbo model: gpt-3.5-turbo
litellm_settings: litellm_settings:
success_callback: ["prometheus"] callbacks: ["prometheus"]
failure_callback: ["prometheus"]
``` ```
Start the proxy Start the proxy

View file

@ -53,6 +53,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
"arize", "arize",
"langtrace", "langtrace",
"gcs_bucket", "gcs_bucket",
"s3",
"opik", "opik",
] ]
_known_custom_logger_compatible_callbacks: List = list( _known_custom_logger_compatible_callbacks: List = list(
@ -931,7 +932,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation impor
vertexAITextEmbeddingConfig = VertexAITextEmbeddingConfig() vertexAITextEmbeddingConfig = VertexAITextEmbeddingConfig()
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.anthropic.transformation import (
VertexAIAnthropicConfig, VertexAIAnthropicConfig,
) )
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import ( from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (

View file

@ -21,6 +21,7 @@ class CustomBatchLogger(CustomLogger):
self, self,
flush_lock: Optional[asyncio.Lock] = None, flush_lock: Optional[asyncio.Lock] = None,
batch_size: Optional[int] = DEFAULT_BATCH_SIZE, batch_size: Optional[int] = DEFAULT_BATCH_SIZE,
flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
@ -28,7 +29,7 @@ class CustomBatchLogger(CustomLogger):
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
""" """
self.log_queue: List = [] self.log_queue: List = []
self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
self.last_flush_time = time.time() self.last_flush_time = time.time()
self.flush_lock = flush_lock self.flush_lock = flush_lock

View file

@ -235,6 +235,14 @@ class LangFuseLogger:
): ):
input = prompt input = prompt
output = response_obj.results output = response_obj.results
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "_arealtime"
and response_obj is not None
and isinstance(response_obj, list)
):
input = kwargs.get("input")
output = response_obj
elif ( elif (
kwargs.get("call_type") is not None kwargs.get("call_type") is not None
and kwargs.get("call_type") == "pass_through_endpoint" and kwargs.get("call_type") == "pass_through_endpoint"

View file

@ -1,43 +1,67 @@
#### What this does #### """
# On success + failure, log events to Supabase s3 Bucket Logging Integration
import datetime async_log_success_event: Processes the event, stores it in memory for 10 seconds or until MAX_BATCH_SIZE and then flushes to s3
import os
import subprocess NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually
import sys NOTE 2: We create a httpx client with a concurrent limit of 1 to upload to s3. Files should get uploaded BUT they should not impact latency of LLM calling logic
import traceback """
import uuid
from typing import Optional import asyncio
import json
from datetime import datetime
from typing import Dict, List, Optional
import litellm import litellm
from litellm._logging import print_verbose, verbose_logger from litellm._logging import print_verbose, verbose_logger
from litellm.llms.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.integrations.s3 import s3BatchLoggingElement
from litellm.types.utils import StandardLoggingPayload from litellm.types.utils import StandardLoggingPayload
from .custom_batch_logger import CustomBatchLogger
class S3Logger: # Default Flush interval and batch size for s3
# Flush to s3 every 10 seconds OR every 1K requests in memory
DEFAULT_S3_FLUSH_INTERVAL_SECONDS = 10
DEFAULT_S3_BATCH_SIZE = 1000
class S3Logger(CustomBatchLogger, BaseAWSLLM):
# Class variables or attributes # Class variables or attributes
def __init__( def __init__(
self, self,
s3_bucket_name=None, s3_bucket_name: Optional[str] = None,
s3_path=None, s3_path: Optional[str] = None,
s3_region_name=None, s3_region_name: Optional[str] = None,
s3_api_version=None, s3_api_version: Optional[str] = None,
s3_use_ssl=True, s3_use_ssl: bool = True,
s3_verify=None, s3_verify: Optional[bool] = None,
s3_endpoint_url=None, s3_endpoint_url: Optional[str] = None,
s3_aws_access_key_id=None, s3_aws_access_key_id: Optional[str] = None,
s3_aws_secret_access_key=None, s3_aws_secret_access_key: Optional[str] = None,
s3_aws_session_token=None, s3_aws_session_token: Optional[str] = None,
s3_flush_interval: Optional[int] = DEFAULT_S3_FLUSH_INTERVAL_SECONDS,
s3_batch_size: Optional[int] = DEFAULT_S3_BATCH_SIZE,
s3_config=None, s3_config=None,
**kwargs, **kwargs,
): ):
import boto3
try: try:
verbose_logger.debug( verbose_logger.debug(
f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}" f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}"
) )
# IMPORTANT: We use a concurrent limit of 1 to upload to s3
# Files should get uploaded BUT they should not impact latency of LLM calling logic
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback,
params={"concurrent_limit": 1},
)
if litellm.s3_callback_params is not None: if litellm.s3_callback_params is not None:
# read in .env variables - example os.environ/AWS_BUCKET_NAME # read in .env variables - example os.environ/AWS_BUCKET_NAME
for key, value in litellm.s3_callback_params.items(): for key, value in litellm.s3_callback_params.items():
@ -63,107 +87,282 @@ class S3Logger:
s3_path = litellm.s3_callback_params.get("s3_path") s3_path = litellm.s3_callback_params.get("s3_path")
# done reading litellm.s3_callback_params # done reading litellm.s3_callback_params
s3_flush_interval = litellm.s3_callback_params.get(
"s3_flush_interval", DEFAULT_S3_FLUSH_INTERVAL_SECONDS
)
s3_batch_size = litellm.s3_callback_params.get(
"s3_batch_size", DEFAULT_S3_BATCH_SIZE
)
self.bucket_name = s3_bucket_name self.bucket_name = s3_bucket_name
self.s3_path = s3_path self.s3_path = s3_path
verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}") verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}")
# Create an S3 client with custom endpoint URL self.s3_bucket_name = s3_bucket_name
self.s3_client = boto3.client( self.s3_region_name = s3_region_name
"s3", self.s3_api_version = s3_api_version
region_name=s3_region_name, self.s3_use_ssl = s3_use_ssl
endpoint_url=s3_endpoint_url, self.s3_verify = s3_verify
api_version=s3_api_version, self.s3_endpoint_url = s3_endpoint_url
use_ssl=s3_use_ssl, self.s3_aws_access_key_id = s3_aws_access_key_id
verify=s3_verify, self.s3_aws_secret_access_key = s3_aws_secret_access_key
aws_access_key_id=s3_aws_access_key_id, self.s3_aws_session_token = s3_aws_session_token
aws_secret_access_key=s3_aws_secret_access_key, self.s3_config = s3_config
aws_session_token=s3_aws_session_token, self.init_kwargs = kwargs
config=s3_config,
**kwargs, asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
verbose_logger.debug(
f"s3 flush interval: {s3_flush_interval}, s3 batch size: {s3_batch_size}"
) )
# Call CustomLogger's __init__
CustomBatchLogger.__init__(
self,
flush_lock=self.flush_lock,
flush_interval=s3_flush_interval,
batch_size=s3_batch_size,
)
self.log_queue: List[s3BatchLoggingElement] = []
# Call BaseAWSLLM's __init__
BaseAWSLLM.__init__(self)
except Exception as e: except Exception as e:
print_verbose(f"Got exception on init s3 client {str(e)}") print_verbose(f"Got exception on init s3 client {str(e)}")
raise e raise e
async def _async_log_event( async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
self, kwargs, response_obj, start_time, end_time, print_verbose
):
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
try: try:
verbose_logger.debug( verbose_logger.debug(
f"s3 Logging - Enters logging function for model {kwargs}" f"s3 Logging - Enters logging function for model {kwargs}"
) )
# construct payload to send to s3 s3_batch_logging_element = self.create_s3_batch_logging_element(
# follows the same params as langfuse.py start_time=start_time,
litellm_params = kwargs.get("litellm_params", {}) standard_logging_payload=kwargs.get("standard_logging_object", None),
metadata = ( s3_path=self.s3_path,
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
# Clean Metadata before logging - never log raw metadata
# the raw metadata can contain circular references which leads to infinite recursion
# we clean out all extra litellm metadata params before logging
clean_metadata = {}
if isinstance(metadata, dict):
for key, value in metadata.items():
# clean litellm metadata before logging
if key in [
"headers",
"endpoint",
"caching_groups",
"previous_models",
]:
continue
else:
clean_metadata[key] = value
# Ensure everything in the payload is converted to str
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
) )
if payload is None: if s3_batch_logging_element is None:
return raise ValueError("s3_batch_logging_element is None")
s3_file_name = litellm.utils.get_logging_id(start_time, payload) or "" verbose_logger.debug(
s3_object_key = ( "\ns3 Logger - Logging payload = %s", s3_batch_logging_element
(self.s3_path.rstrip("/") + "/" if self.s3_path else "")
+ start_time.strftime("%Y-%m-%d")
+ "/"
+ s3_file_name
) # we need the s3 key to include the time, so we log cache hits too
s3_object_key += ".json"
s3_object_download_filename = (
"time-"
+ start_time.strftime("%Y-%m-%dT%H-%M-%S-%f")
+ "_"
+ payload["id"]
+ ".json"
) )
import json self.log_queue.append(s3_batch_logging_element)
verbose_logger.debug(
payload_str = json.dumps(payload) "s3 logging: queue length %s, batch size %s",
len(self.log_queue),
print_verbose(f"\ns3 Logger - Logging payload = {payload_str}") self.batch_size,
response = self.s3_client.put_object(
Bucket=self.bucket_name,
Key=s3_object_key,
Body=payload_str,
ContentType="application/json",
ContentLanguage="en",
ContentDisposition=f'inline; filename="{s3_object_download_filename}"',
CacheControl="private, immutable, max-age=31536000, s-maxage=0",
) )
if len(self.log_queue) >= self.batch_size:
print_verbose(f"Response from s3:{str(response)}") await self.flush_queue()
print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
return response
except Exception as e: except Exception as e:
verbose_logger.exception(f"s3 Layer Error - {str(e)}") verbose_logger.exception(f"s3 Layer Error - {str(e)}")
pass pass
def log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Synchronous logging function to log to s3
Does not batch logging requests, instantly logs on s3 Bucket
"""
try:
s3_batch_logging_element = self.create_s3_batch_logging_element(
start_time=start_time,
standard_logging_payload=kwargs.get("standard_logging_object", None),
s3_path=self.s3_path,
)
if s3_batch_logging_element is None:
raise ValueError("s3_batch_logging_element is None")
verbose_logger.debug(
"\ns3 Logger - Logging payload = %s", s3_batch_logging_element
)
# log the element sync httpx client
self.upload_data_to_s3(s3_batch_logging_element)
except Exception as e:
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
pass
async def async_upload_data_to_s3(
self, batch_logging_element: s3BatchLoggingElement
):
try:
import hashlib
import boto3
import requests
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
try:
credentials: Credentials = self.get_credentials(
aws_access_key_id=self.s3_aws_access_key_id,
aws_secret_access_key=self.s3_aws_secret_access_key,
aws_session_token=self.s3_aws_session_token,
aws_region_name=self.s3_region_name,
)
# Prepare the URL
url = f"https://{self.bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}"
if self.s3_endpoint_url:
url = self.s3_endpoint_url + "/" + batch_logging_element.s3_object_key
# Convert JSON to string
json_string = json.dumps(batch_logging_element.payload)
# Calculate SHA256 hash of the content
content_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest()
# Prepare the request
headers = {
"Content-Type": "application/json",
"x-amz-content-sha256": content_hash,
"Content-Language": "en",
"Content-Disposition": f'inline; filename="{batch_logging_element.s3_object_download_filename}"',
"Cache-Control": "private, immutable, max-age=31536000, s-maxage=0",
}
req = requests.Request("PUT", url, data=json_string, headers=headers)
prepped = req.prepare()
# Sign the request
aws_request = AWSRequest(
method=prepped.method,
url=prepped.url,
data=prepped.body,
headers=prepped.headers,
)
SigV4Auth(credentials, "s3", self.s3_region_name).add_auth(aws_request)
# Prepare the signed headers
signed_headers = dict(aws_request.headers.items())
# Make the request
response = await self.async_httpx_client.put(
url, data=json_string, headers=signed_headers
)
response.raise_for_status()
except Exception as e:
verbose_logger.exception(f"Error uploading to s3: {str(e)}")
async def async_send_batch(self):
"""
Sends runs from self.log_queue
Returns: None
Raises: Does not raise an exception, will only verbose_logger.exception()
"""
if not self.log_queue:
return
for payload in self.log_queue:
asyncio.create_task(self.async_upload_data_to_s3(payload))
def create_s3_batch_logging_element(
self,
start_time: datetime,
standard_logging_payload: Optional[StandardLoggingPayload],
s3_path: Optional[str],
) -> Optional[s3BatchLoggingElement]:
"""
Helper function to create an s3BatchLoggingElement.
Args:
start_time (datetime): The start time of the logging event.
standard_logging_payload (Optional[StandardLoggingPayload]): The payload to be logged.
s3_path (Optional[str]): The S3 path prefix.
Returns:
Optional[s3BatchLoggingElement]: The created s3BatchLoggingElement, or None if payload is None.
"""
if standard_logging_payload is None:
return None
s3_file_name = (
litellm.utils.get_logging_id(start_time, standard_logging_payload) or ""
)
s3_object_key = (
(s3_path.rstrip("/") + "/" if s3_path else "")
+ start_time.strftime("%Y-%m-%d")
+ "/"
+ s3_file_name
+ ".json"
)
s3_object_download_filename = f"time-{start_time.strftime('%Y-%m-%dT%H-%M-%S-%f')}_{standard_logging_payload['id']}.json"
return s3BatchLoggingElement(
payload=standard_logging_payload, # type: ignore
s3_object_key=s3_object_key,
s3_object_download_filename=s3_object_download_filename,
)
def upload_data_to_s3(self, batch_logging_element: s3BatchLoggingElement):
try:
import hashlib
import boto3
import requests
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
try:
credentials: Credentials = self.get_credentials(
aws_access_key_id=self.s3_aws_access_key_id,
aws_secret_access_key=self.s3_aws_secret_access_key,
aws_session_token=self.s3_aws_session_token,
aws_region_name=self.s3_region_name,
)
# Prepare the URL
url = f"https://{self.bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}"
if self.s3_endpoint_url:
url = self.s3_endpoint_url + "/" + batch_logging_element.s3_object_key
# Convert JSON to string
json_string = json.dumps(batch_logging_element.payload)
# Calculate SHA256 hash of the content
content_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest()
# Prepare the request
headers = {
"Content-Type": "application/json",
"x-amz-content-sha256": content_hash,
"Content-Language": "en",
"Content-Disposition": f'inline; filename="{batch_logging_element.s3_object_download_filename}"',
"Cache-Control": "private, immutable, max-age=31536000, s-maxage=0",
}
req = requests.Request("PUT", url, data=json_string, headers=headers)
prepped = req.prepare()
# Sign the request
aws_request = AWSRequest(
method=prepped.method,
url=prepped.url,
data=prepped.body,
headers=prepped.headers,
)
SigV4Auth(credentials, "s3", self.s3_region_name).add_auth(aws_request)
# Prepare the signed headers
signed_headers = dict(aws_request.headers.items())
httpx_client = _get_httpx_client()
# Make the request
response = httpx_client.put(url, data=json_string, headers=signed_headers)
response.raise_for_status()
except Exception as e:
verbose_logger.exception(f"Error uploading to s3: {str(e)}")

View file

@ -116,7 +116,6 @@ lagoLogger = None
dataDogLogger = None dataDogLogger = None
prometheusLogger = None prometheusLogger = None
dynamoLogger = None dynamoLogger = None
s3Logger = None
genericAPILogger = None genericAPILogger = None
clickHouseLogger = None clickHouseLogger = None
greenscaleLogger = None greenscaleLogger = None
@ -1346,36 +1345,6 @@ class Logging:
user_id=kwargs.get("user", None), user_id=kwargs.get("user", None),
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if callback == "s3":
global s3Logger
if s3Logger is None:
s3Logger = S3Logger()
if self.stream:
if "complete_streaming_response" in self.model_call_details:
print_verbose(
"S3Logger Logger: Got Stream Event - Completed Stream Response"
)
s3Logger.log_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
"complete_streaming_response"
],
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
)
else:
print_verbose(
"S3Logger Logger: Got Stream Event - No complete stream response as yet"
)
else:
s3Logger.log_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
)
if ( if (
callback == "openmeter" callback == "openmeter"
and self.model_call_details.get("litellm_params", {}).get( and self.model_call_details.get("litellm_params", {}).get(
@ -2245,7 +2214,7 @@ def set_callbacks(callback_list, function_id=None):
""" """
Globally sets the callback client Globally sets the callback client
""" """
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
try: try:
for callback in callback_list: for callback in callback_list:
@ -2319,8 +2288,6 @@ def set_callbacks(callback_list, function_id=None):
dataDogLogger = DataDogLogger() dataDogLogger = DataDogLogger()
elif callback == "dynamodb": elif callback == "dynamodb":
dynamoLogger = DyanmoDBLogger() dynamoLogger = DyanmoDBLogger()
elif callback == "s3":
s3Logger = S3Logger()
elif callback == "wandb": elif callback == "wandb":
weightsBiasesLogger = WeightsBiasesLogger() weightsBiasesLogger = WeightsBiasesLogger()
elif callback == "logfire": elif callback == "logfire":
@ -2357,7 +2324,6 @@ def _init_custom_logger_compatible_class(
llm_router: Optional[ llm_router: Optional[
Any Any
], # expect litellm.Router, but typing errors due to circular import ], # expect litellm.Router, but typing errors due to circular import
premium_user: Optional[bool] = None,
) -> Optional[CustomLogger]: ) -> Optional[CustomLogger]:
if logging_integration == "lago": if logging_integration == "lago":
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
@ -2404,17 +2370,9 @@ def _init_custom_logger_compatible_class(
if isinstance(callback, PrometheusLogger): if isinstance(callback, PrometheusLogger):
return callback # type: ignore return callback # type: ignore
if premium_user: _prometheus_logger = PrometheusLogger()
_prometheus_logger = PrometheusLogger() _in_memory_loggers.append(_prometheus_logger)
_in_memory_loggers.append(_prometheus_logger) return _prometheus_logger # type: ignore
return _prometheus_logger # type: ignore
elif premium_user is False:
verbose_logger.warning(
f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}"
)
return None
else:
return None
elif logging_integration == "datadog": elif logging_integration == "datadog":
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
if isinstance(callback, DataDogLogger): if isinstance(callback, DataDogLogger):
@ -2423,6 +2381,14 @@ def _init_custom_logger_compatible_class(
_datadog_logger = DataDogLogger() _datadog_logger = DataDogLogger()
_in_memory_loggers.append(_datadog_logger) _in_memory_loggers.append(_datadog_logger)
return _datadog_logger # type: ignore return _datadog_logger # type: ignore
elif logging_integration == "s3":
for callback in _in_memory_loggers:
if isinstance(callback, S3Logger):
return callback # type: ignore
_s3_logger = S3Logger()
_in_memory_loggers.append(_s3_logger)
return _s3_logger # type: ignore
elif logging_integration == "gcs_bucket": elif logging_integration == "gcs_bucket":
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
if isinstance(callback, GCSBucketLogger): if isinstance(callback, GCSBucketLogger):
@ -2589,6 +2555,10 @@ def get_custom_logger_compatible_class(
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
if isinstance(callback, PrometheusLogger): if isinstance(callback, PrometheusLogger):
return callback return callback
elif logging_integration == "s3":
for callback in _in_memory_loggers:
if isinstance(callback, S3Logger):
return callback
elif logging_integration == "datadog": elif logging_integration == "datadog":
for callback in _in_memory_loggers: for callback in _in_memory_loggers:
if isinstance(callback, DataDogLogger): if isinstance(callback, DataDogLogger):

View file

@ -0,0 +1,112 @@
"""
async with websockets.connect( # type: ignore
url,
extra_headers={
"api-key": api_key, # type: ignore
},
) as backend_ws:
forward_task = asyncio.create_task(
forward_messages(websocket, backend_ws)
)
try:
while True:
message = await websocket.receive_text()
await backend_ws.send(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
forward_task.cancel()
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass
"""
import asyncio
import concurrent.futures
import traceback
from asyncio import Task
from typing import Any, Dict, List, Optional, Union
from .litellm_logging import Logging as LiteLLMLogging
# Create a thread pool with a maximum of 10 threads
executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
class RealTimeStreaming:
def __init__(
self,
websocket: Any,
backend_ws: Any,
logging_obj: Optional[LiteLLMLogging] = None,
):
self.websocket = websocket
self.backend_ws = backend_ws
self.logging_obj = logging_obj
self.messages: List = []
self.input_message: Dict = {}
def store_message(self, message: Union[str, bytes]):
"""Store message in list"""
self.messages.append(message)
def store_input(self, message: dict):
"""Store input message"""
self.input_message = message
if self.logging_obj:
self.logging_obj.pre_call(input=message, api_key="")
async def log_messages(self):
"""Log messages in list"""
if self.logging_obj:
## ASYNC LOGGING
# Create an event loop for the new thread
asyncio.create_task(self.logging_obj.async_success_handler(self.messages))
## SYNC LOGGING
executor.submit(self.logging_obj.success_handler(self.messages))
async def backend_to_client_send_messages(self):
import websockets
try:
while True:
message = await self.backend_ws.recv()
await self.websocket.send_text(message)
## LOGGING
self.store_message(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
pass
except Exception:
pass
finally:
await self.log_messages()
async def client_ack_messages(self):
try:
while True:
message = await self.websocket.receive_text()
## LOGGING
self.store_input(message=message)
## FORWARD TO BACKEND
await self.backend_ws.send(message)
except self.websockets.exceptions.ConnectionClosed: # type: ignore
pass
async def bidirectional_forward(self):
forward_task = asyncio.create_task(self.backend_to_client_send_messages())
try:
await self.client_ack_messages()
except self.websockets.exceptions.ConnectionClosed: # type: ignore
forward_task.cancel()
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass

View file

@ -7,6 +7,8 @@ This requires websockets, and is currently only supported on LiteLLM Proxy.
import asyncio import asyncio
from typing import Any, Optional from typing import Any, Optional
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
from ..azure import AzureChatCompletion from ..azure import AzureChatCompletion
# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01" # BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
@ -44,6 +46,7 @@ class AzureOpenAIRealtime(AzureChatCompletion):
api_version: Optional[str] = None, api_version: Optional[str] = None,
azure_ad_token: Optional[str] = None, azure_ad_token: Optional[str] = None,
client: Optional[Any] = None, client: Optional[Any] = None,
logging_obj: Optional[LiteLLMLogging] = None,
timeout: Optional[float] = None, timeout: Optional[float] = None,
): ):
import websockets import websockets
@ -62,23 +65,10 @@ class AzureOpenAIRealtime(AzureChatCompletion):
"api-key": api_key, # type: ignore "api-key": api_key, # type: ignore
}, },
) as backend_ws: ) as backend_ws:
forward_task = asyncio.create_task( realtime_streaming = RealTimeStreaming(
forward_messages(websocket, backend_ws) websocket, backend_ws, logging_obj
) )
await realtime_streaming.bidirectional_forward()
try:
while True:
message = await websocket.receive_text()
await backend_ws.send(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
forward_task.cancel()
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass
except websockets.exceptions.InvalidStatusCode as e: # type: ignore except websockets.exceptions.InvalidStatusCode as e: # type: ignore
await websocket.close(code=e.status_code, reason=str(e)) await websocket.close(code=e.status_code, reason=str(e))

View file

@ -7,20 +7,11 @@ This requires websockets, and is currently only supported on LiteLLM Proxy.
import asyncio import asyncio
from typing import Any, Optional from typing import Any, Optional
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
from ..openai import OpenAIChatCompletion from ..openai import OpenAIChatCompletion
async def forward_messages(client_ws: Any, backend_ws: Any):
import websockets
try:
while True:
message = await backend_ws.recv()
await client_ws.send_text(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
pass
class OpenAIRealtime(OpenAIChatCompletion): class OpenAIRealtime(OpenAIChatCompletion):
def _construct_url(self, api_base: str, model: str) -> str: def _construct_url(self, api_base: str, model: str) -> str:
""" """
@ -35,6 +26,7 @@ class OpenAIRealtime(OpenAIChatCompletion):
self, self,
model: str, model: str,
websocket: Any, websocket: Any,
logging_obj: LiteLLMLogging,
api_base: Optional[str] = None, api_base: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
client: Optional[Any] = None, client: Optional[Any] = None,
@ -57,25 +49,26 @@ class OpenAIRealtime(OpenAIChatCompletion):
"OpenAI-Beta": "realtime=v1", "OpenAI-Beta": "realtime=v1",
}, },
) as backend_ws: ) as backend_ws:
forward_task = asyncio.create_task( realtime_streaming = RealTimeStreaming(
forward_messages(websocket, backend_ws) websocket, backend_ws, logging_obj
) )
await realtime_streaming.bidirectional_forward()
try:
while True:
message = await websocket.receive_text()
await backend_ws.send(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
forward_task.cancel()
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass
except websockets.exceptions.InvalidStatusCode as e: # type: ignore except websockets.exceptions.InvalidStatusCode as e: # type: ignore
await websocket.close(code=e.status_code, reason=str(e)) await websocket.close(code=e.status_code, reason=str(e))
except Exception as e: except Exception as e:
await websocket.close(code=1011, reason=f"Internal server error: {str(e)}") try:
await websocket.close(
code=1011, reason=f"Internal server error: {str(e)}"
)
except RuntimeError as close_error:
if "already completed" in str(close_error) or "websocket.close" in str(
close_error
):
# The WebSocket is already closed or the response is completed, so we can ignore this error
pass
else:
# If it's a different RuntimeError, we might want to log it or handle it differently
raise Exception(
f"Unexpected error while closing WebSocket: {close_error}"
)

View file

@ -153,6 +153,8 @@ class CustomLLM(BaseLLM):
self, self,
model: str, model: str,
prompt: str, prompt: str,
api_key: Optional[str],
api_base: Optional[str],
model_response: ImageResponse, model_response: ImageResponse,
optional_params: dict, optional_params: dict,
logging_obj: Any, logging_obj: Any,
@ -166,6 +168,12 @@ class CustomLLM(BaseLLM):
model: str, model: str,
prompt: str, prompt: str,
model_response: ImageResponse, model_response: ImageResponse,
api_key: Optional[
str
], # dynamically set api_key - https://docs.litellm.ai/docs/set_keys#api_key
api_base: Optional[
str
], # dynamically set api_base - https://docs.litellm.ai/docs/set_keys#api_base
optional_params: dict, optional_params: dict,
logging_obj: Any, logging_obj: Any,
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,

View file

@ -1,464 +0,0 @@
# What is this?
## Handler file for calling claude-3 on vertex ai
import copy
import json
import os
import time
import types
import uuid
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import (
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
)
from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ..prompt_templates.factory import (
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
response_schema_prompt,
)
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexAIAnthropicConfig:
"""
Reference:https://docs.anthropic.com/claude/reference/messages_post
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = (
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
)
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"max_tokens",
"max_completion_tokens",
"tools",
"tool_choice",
"stream",
"stop",
"temperature",
"top_p",
"response_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if value == "auto":
_tool_choice = {"type": "auto"}
elif value == "required":
_tool_choice = {"type": "any"}
elif isinstance(value, dict):
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "response_format" and isinstance(value, dict):
json_schema: Optional[dict] = None
if "response_schema" in value:
json_schema = value["response_schema"]
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
"""
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
_tool_choice = None
_tool_choice = {"name": "json_tool_call", "type": "tool"}
_tool = AnthropicMessagesTool(
name="json_tool_call",
input_schema={
"type": "object",
"properties": {"values": json_schema}, # type: ignore
},
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
return optional_params
"""
- Run client init
- Support async completion, streaming
"""
def refresh_auth(
credentials,
) -> str: # used when user passes in credentials as json string
from google.auth.transport.requests import Request # type: ignore[import-untyped]
if credentials.token is None:
credentials.refresh(Request())
if not credentials.token:
raise RuntimeError("Could not resolve API token from the credentials")
return credentials.token
def get_vertex_client(
client: Any,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
) -> Tuple[Any, Optional[str]]:
args = locals()
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
try:
from anthropic import AnthropicVertex
except Exception:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
)
access_token: Optional[str] = None
if client is None:
_credentials, cred_project_id = VertexLLM().load_auth(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_ai_client = AnthropicVertex(
project_id=vertex_project or cred_project_id,
region=vertex_location or "us-central1",
access_token=_credentials.token,
)
access_token = _credentials.token
else:
vertex_ai_client = client
access_token = client.access_token
return vertex_ai_client, access_token
def create_vertex_anthropic_url(
vertex_location: str, vertex_project: str, model: str, stream: bool
) -> str:
if stream is True:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
def completion(
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
client=None,
):
try:
import vertexai
except Exception:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
)
from anthropic import AnthropicVertex
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
anthropic_chat_completions = AnthropicChatCompletion()
## Load Config
config = litellm.VertexAIAnthropicConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
## CONSTRUCT API BASE
stream = optional_params.get("stream", False)
api_base = create_vertex_anthropic_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
model=model,
stream=stream,
)
if headers is not None:
vertex_headers = headers
else:
vertex_headers = {}
vertex_headers.update({"Authorization": "Bearer {}".format(access_token)})
optional_params.update(
{"anthropic_version": "vertex-2023-10-16", "is_vertex_request": True}
)
return anthropic_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=access_token,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=vertex_headers,
client=client,
timeout=timeout,
)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_completion(
model: str,
messages: list,
data: dict,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj,
vertex_project=None,
vertex_location=None,
optional_params=None,
client=None,
access_token=None,
):
from anthropic import AsyncAnthropicVertex
if client is None:
vertex_ai_client = AsyncAnthropicVertex(
project_id=vertex_project, region=vertex_location, access_token=access_token # type: ignore
)
else:
vertex_ai_client = client
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
message = await vertex_ai_client.messages.create(**data) # type: ignore
text_content = message.content[0].text
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[0].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
## CALCULATING USAGE
prompt_tokens = message.usage.input_tokens
completion_tokens = message.usage.output_tokens
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
async def async_streaming(
model: str,
messages: list,
data: dict,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj,
vertex_project=None,
vertex_location=None,
optional_params=None,
client=None,
access_token=None,
):
from anthropic import AsyncAnthropicVertex
if client is None:
vertex_ai_client = AsyncAnthropicVertex(
project_id=vertex_project, region=vertex_location, access_token=access_token # type: ignore
)
else:
vertex_ai_client = client
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
response = await vertex_ai_client.messages.create(**data, stream=True) # type: ignore
logging_obj.post_call(input=messages, api_key=None, original_response=response)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="vertex_ai",
logging_obj=logging_obj,
)
return streamwrapper

View file

@ -0,0 +1,179 @@
# What is this?
## Handler file for calling claude-3 on vertex ai
import copy
import json
import os
import time
import types
import uuid
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import (
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
)
from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ....prompt_templates.factory import (
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
response_schema_prompt,
)
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexAIAnthropicConfig:
"""
Reference:https://docs.anthropic.com/claude/reference/messages_post
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = (
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
)
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"max_tokens",
"max_completion_tokens",
"tools",
"tool_choice",
"stream",
"stop",
"temperature",
"top_p",
"response_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if value == "auto":
_tool_choice = {"type": "auto"}
elif value == "required":
_tool_choice = {"type": "any"}
elif isinstance(value, dict):
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "response_format" and isinstance(value, dict):
json_schema: Optional[dict] = None
if "response_schema" in value:
json_schema = value["response_schema"]
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
"""
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
_tool_choice = None
_tool_choice = {"name": "json_tool_call", "type": "tool"}
_tool = AnthropicMessagesTool(
name="json_tool_call",
input_schema={
"type": "object",
"properties": {"values": json_schema}, # type: ignore
},
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
return optional_params

View file

@ -9,13 +9,14 @@ import httpx # type: ignore
import litellm import litellm
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
from ...base import BaseLLM from ..vertex_llm_base import VertexBase
class VertexPartnerProvider(str, Enum): class VertexPartnerProvider(str, Enum):
mistralai = "mistralai" mistralai = "mistralai"
llama = "llama" llama = "llama"
ai21 = "ai21" ai21 = "ai21"
claude = "claude"
class VertexAIError(Exception): class VertexAIError(Exception):
@ -31,31 +32,38 @@ class VertexAIError(Exception):
) # Call the base class constructor with the parameters it needs ) # Call the base class constructor with the parameters it needs
class VertexAIPartnerModels(BaseLLM): def create_vertex_url(
vertex_location: str,
vertex_project: str,
partner: VertexPartnerProvider,
stream: Optional[bool],
model: str,
api_base: Optional[str] = None,
) -> str:
"""Return the base url for the vertex partner models"""
if partner == VertexPartnerProvider.llama:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
elif partner == VertexPartnerProvider.mistralai:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.ai21:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.claude:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
class VertexAIPartnerModels(VertexBase):
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def create_vertex_url(
self,
vertex_location: str,
vertex_project: str,
partner: VertexPartnerProvider,
stream: Optional[bool],
model: str,
) -> str:
if partner == VertexPartnerProvider.llama:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
elif partner == VertexPartnerProvider.mistralai:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.ai21:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
def completion( def completion(
self, self,
model: str, model: str,
@ -64,6 +72,7 @@ class VertexAIPartnerModels(BaseLLM):
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj,
api_base: Optional[str],
optional_params: dict, optional_params: dict,
custom_prompt_dict: dict, custom_prompt_dict: dict,
headers: Optional[dict], headers: Optional[dict],
@ -80,6 +89,7 @@ class VertexAIPartnerModels(BaseLLM):
import vertexai import vertexai
from google.cloud import aiplatform from google.cloud import aiplatform
from litellm.llms.anthropic.chat import AnthropicChatCompletion
from litellm.llms.databricks.chat import DatabricksChatCompletion from litellm.llms.databricks.chat import DatabricksChatCompletion
from litellm.llms.OpenAI.openai import OpenAIChatCompletion from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.llms.text_completion_codestral import CodestralTextCompletion from litellm.llms.text_completion_codestral import CodestralTextCompletion
@ -112,6 +122,7 @@ class VertexAIPartnerModels(BaseLLM):
openai_like_chat_completions = DatabricksChatCompletion() openai_like_chat_completions = DatabricksChatCompletion()
codestral_fim_completions = CodestralTextCompletion() codestral_fim_completions = CodestralTextCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
## CONSTRUCT API BASE ## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False stream: bool = optional_params.get("stream", False) or False
@ -126,8 +137,10 @@ class VertexAIPartnerModels(BaseLLM):
elif "jamba" in model: elif "jamba" in model:
partner = VertexPartnerProvider.ai21 partner = VertexPartnerProvider.ai21
optional_params["custom_endpoint"] = True optional_params["custom_endpoint"] = True
elif "claude" in model:
partner = VertexPartnerProvider.claude
api_base = self.create_vertex_url( default_api_base = create_vertex_url(
vertex_location=vertex_location or "us-central1", vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id, vertex_project=vertex_project or project_id,
partner=partner, # type: ignore partner=partner, # type: ignore
@ -135,6 +148,21 @@ class VertexAIPartnerModels(BaseLLM):
model=model, model=model,
) )
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=stream,
auth_header=None,
url=default_api_base,
)
model = model.split("@")[0] model = model.split("@")[0]
if "codestral" in model and litellm_params.get("text_completion") is True: if "codestral" in model and litellm_params.get("text_completion") is True:
@ -158,6 +186,35 @@ class VertexAIPartnerModels(BaseLLM):
timeout=timeout, timeout=timeout,
encoding=encoding, encoding=encoding,
) )
elif "claude" in model:
if headers is None:
headers = {}
headers.update({"Authorization": "Bearer {}".format(access_token)})
optional_params.update(
{
"anthropic_version": "vertex-2023-10-16",
"is_vertex_request": True,
}
)
return anthropic_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
acompletion=acompletion,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=access_token,
logging_obj=logging_obj,
headers=headers,
timeout=timeout,
client=client,
)
return openai_like_chat_completions.completion( return openai_like_chat_completions.completion(
model=model, model=model,

View file

@ -117,10 +117,7 @@ from .llms.sagemaker.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_and_google_ai_studio import ( from .llms.vertex_ai_and_google_ai_studio import vertex_ai_non_gemini
vertex_ai_anthropic,
vertex_ai_non_gemini,
)
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM, VertexLLM,
) )
@ -747,6 +744,11 @@ def completion( # type: ignore
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
fallbacks = kwargs.get("fallbacks", None) fallbacks = kwargs.get("fallbacks", None)
headers = kwargs.get("headers", None) or extra_headers headers = kwargs.get("headers", None) or extra_headers
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
num_retries = kwargs.get( num_retries = kwargs.get(
"num_retries", None "num_retries", None
) ## alt. param for 'max_retries'. Use this to pass retries w/ instructor. ) ## alt. param for 'max_retries'. Use this to pass retries w/ instructor.
@ -964,7 +966,6 @@ def completion( # type: ignore
max_retries=max_retries, max_retries=max_retries,
logprobs=logprobs, logprobs=logprobs,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
extra_headers=extra_headers,
api_version=api_version, api_version=api_version,
parallel_tool_calls=parallel_tool_calls, parallel_tool_calls=parallel_tool_calls,
messages=messages, messages=messages,
@ -1067,6 +1068,9 @@ def completion( # type: ignore
headers = headers or litellm.headers headers = headers or litellm.headers
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
if ( if (
litellm.enable_preview_features litellm.enable_preview_features
and litellm.AzureOpenAIO1Config().is_o1_model(model=model) and litellm.AzureOpenAIO1Config().is_o1_model(model=model)
@ -1166,6 +1170,9 @@ def completion( # type: ignore
headers = headers or litellm.headers headers = headers or litellm.headers
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
## LOAD CONFIG - if set ## LOAD CONFIG - if set
config = litellm.AzureOpenAIConfig.get_config() config = litellm.AzureOpenAIConfig.get_config()
for k, v in config.items(): for k, v in config.items():
@ -1223,6 +1230,9 @@ def completion( # type: ignore
headers = headers or litellm.headers headers = headers or litellm.headers
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
## LOAD CONFIG - if set ## LOAD CONFIG - if set
config = litellm.AzureAIStudioConfig.get_config() config = litellm.AzureAIStudioConfig.get_config()
for k, v in config.items(): for k, v in config.items():
@ -1304,6 +1314,9 @@ def completion( # type: ignore
headers = headers or litellm.headers headers = headers or litellm.headers
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
## LOAD CONFIG - if set ## LOAD CONFIG - if set
config = litellm.OpenAITextCompletionConfig.get_config() config = litellm.OpenAITextCompletionConfig.get_config()
for k, v in config.items(): for k, v in config.items():
@ -1466,6 +1479,9 @@ def completion( # type: ignore
headers = headers or litellm.headers headers = headers or litellm.headers
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
## LOAD CONFIG - if set ## LOAD CONFIG - if set
config = litellm.OpenAIConfig.get_config() config = litellm.OpenAIConfig.get_config()
for k, v in config.items(): for k, v in config.items():
@ -2228,31 +2244,12 @@ def completion( # type: ignore
) )
new_params = deepcopy(optional_params) new_params = deepcopy(optional_params)
if "claude-3" in model: if (
model_response = vertex_ai_anthropic.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
client=client,
)
elif (
model.startswith("meta/") model.startswith("meta/")
or model.startswith("mistral") or model.startswith("mistral")
or model.startswith("codestral") or model.startswith("codestral")
or model.startswith("jamba") or model.startswith("jamba")
or model.startswith("claude")
): ):
model_response = vertex_partner_models_chat_completion.completion( model_response = vertex_partner_models_chat_completion.completion(
model=model, model=model,
@ -2263,6 +2260,7 @@ def completion( # type: ignore
litellm_params=litellm_params, # type: ignore litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
api_base=api_base,
vertex_location=vertex_ai_location, vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project, vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials, vertex_credentials=vertex_credentials,
@ -4848,6 +4846,8 @@ def image_generation(
model_response = custom_handler.aimage_generation( # type: ignore model_response = custom_handler.aimage_generation( # type: ignore
model=model, model=model,
prompt=prompt, prompt=prompt,
api_key=api_key,
api_base=api_base,
model_response=model_response, model_response=model_response,
optional_params=optional_params, optional_params=optional_params,
logging_obj=litellm_logging_obj, logging_obj=litellm_logging_obj,
@ -4863,6 +4863,8 @@ def image_generation(
model_response = custom_handler.image_generation( model_response = custom_handler.image_generation(
model=model, model=model,
prompt=prompt, prompt=prompt,
api_key=api_key,
api_base=api_base,
model_response=model_response, model_response=model_response,
optional_params=optional_params, optional_params=optional_params,
logging_obj=litellm_logging_obj, logging_obj=litellm_logging_obj,

View file

@ -1,12 +1,15 @@
model_list: model_list:
- model_name: gpt-4o-mini # - model_name: openai-gpt-4o-realtime-audio
litellm_params: # litellm_params:
model: azure/my-gpt-4o-mini # model: azure/gpt-4o-realtime-preview
api_key: os.environ/AZURE_API_KEY # api_key: os.environ/AZURE_SWEDEN_API_KEY
api_base: os.environ/AZURE_API_BASE # api_base: os.environ/AZURE_SWEDEN_API_BASE
- model_name: openai-gpt-4o-realtime-audio
litellm_params:
model: openai/gpt-4o-realtime-preview-2024-10-01
api_key: os.environ/OPENAI_API_KEY
api_base: http://localhost:8080
litellm_settings: litellm_settings:
turn_off_message_logging: true success_callback: ["langfuse"]
cache: True
cache_params:
type: local

View file

@ -241,6 +241,8 @@ def initialize_callbacks_on_proxy(
litellm.callbacks = imported_list # type: ignore litellm.callbacks = imported_list # type: ignore
if "prometheus" in value: if "prometheus" in value:
if premium_user is not True:
raise Exception(CommonProxyErrors.not_premium_user.value)
from litellm.proxy.proxy_server import app from litellm.proxy.proxy_server import app
verbose_proxy_logger.debug("Starting Prometheus Metrics on /metrics") verbose_proxy_logger.debug("Starting Prometheus Metrics on /metrics")

View file

@ -1,16 +1,17 @@
model_list: model_list:
- model_name: db-openai-endpoint - model_name: db-openai-endpoint
litellm_params: litellm_params:
model: openai/gpt-5 model: openai/gpt-4
api_key: fake-key api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railwaz.app/ api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: db-openai-endpoint
litellm_params:
model: openai/gpt-5
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railwxaz.app/
litellm_settings: litellm_settings:
callbacks: ["prometheus"] success_callback: ["s3"]
turn_off_message_logging: true
s3_callback_params:
s3_bucket_name: load-testing-oct # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3

View file

@ -1692,6 +1692,10 @@ class ProxyConfig:
else: else:
litellm.success_callback.append(callback) litellm.success_callback.append(callback)
if "prometheus" in callback: if "prometheus" in callback:
if not premium_user:
raise Exception(
CommonProxyErrors.not_premium_user.value
)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
"Starting Prometheus Metrics on /metrics" "Starting Prometheus Metrics on /metrics"
) )
@ -5433,12 +5437,11 @@ async def moderations(
await proxy_logging_obj.post_call_failure_hook( await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
) )
verbose_proxy_logger.error( verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.moderations(): Exception occured - {}".format( "litellm.proxy.proxy_server.moderations(): Exception occured - {}".format(
str(e) str(e)
) )
) )
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
raise ProxyException( raise ProxyException(
message=getattr(e, "message", str(e)), message=getattr(e, "message", str(e)),

View file

@ -408,7 +408,6 @@ class ProxyLogging:
callback, callback,
internal_usage_cache=self.internal_usage_cache.dual_cache, internal_usage_cache=self.internal_usage_cache.dual_cache,
llm_router=llm_router, llm_router=llm_router,
premium_user=self.premium_user,
) )
if callback is None: if callback is None:
continue continue

View file

@ -8,13 +8,16 @@ from litellm import get_llm_provider
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams from litellm.types.router import GenericLiteLLMParams
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ..llms.AzureOpenAI.realtime.handler import AzureOpenAIRealtime from ..llms.AzureOpenAI.realtime.handler import AzureOpenAIRealtime
from ..llms.OpenAI.realtime.handler import OpenAIRealtime from ..llms.OpenAI.realtime.handler import OpenAIRealtime
from ..utils import client as wrapper_client
azure_realtime = AzureOpenAIRealtime() azure_realtime = AzureOpenAIRealtime()
openai_realtime = OpenAIRealtime() openai_realtime = OpenAIRealtime()
@wrapper_client
async def _arealtime( async def _arealtime(
model: str, model: str,
websocket: Any, # fastapi websocket websocket: Any, # fastapi websocket
@ -31,6 +34,12 @@ async def _arealtime(
For PROXY use only. For PROXY use only.
""" """
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
user = kwargs.get("user", None)
litellm_params = GenericLiteLLMParams(**kwargs) litellm_params = GenericLiteLLMParams(**kwargs)
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider( model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
@ -39,6 +48,21 @@ async def _arealtime(
api_key=api_key, api_key=api_key,
) )
litellm_logging_obj.update_environment_variables(
model=model,
user=user,
optional_params={},
litellm_params={
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
},
custom_llm_provider=_custom_llm_provider,
)
if _custom_llm_provider == "azure": if _custom_llm_provider == "azure":
api_base = ( api_base = (
dynamic_api_base dynamic_api_base
@ -63,6 +87,7 @@ async def _arealtime(
azure_ad_token=None, azure_ad_token=None,
client=None, client=None,
timeout=timeout, timeout=timeout,
logging_obj=litellm_logging_obj,
) )
elif _custom_llm_provider == "openai": elif _custom_llm_provider == "openai":
api_base = ( api_base = (
@ -82,6 +107,7 @@ async def _arealtime(
await openai_realtime.async_realtime( await openai_realtime.async_realtime(
model=model, model=model,
websocket=websocket, websocket=websocket,
logging_obj=litellm_logging_obj,
api_base=api_base, api_base=api_base,
api_key=api_key, api_key=api_key,
client=None, client=None,

View file

@ -60,7 +60,7 @@ class PatternMatchRouter:
regex = re.escape(regex).replace(r"\.\*", ".*") regex = re.escape(regex).replace(r"\.\*", ".*")
return f"^{regex}$" return f"^{regex}$"
def route(self, request: str) -> Optional[List[Dict]]: def route(self, request: Optional[str]) -> Optional[List[Dict]]:
""" """
Route a requested model to the corresponding llm deployments based on the regex pattern Route a requested model to the corresponding llm deployments based on the regex pattern
@ -69,17 +69,20 @@ class PatternMatchRouter:
if no pattern is found, return None if no pattern is found, return None
Args: Args:
request: str request: Optional[str]
Returns: Returns:
Optional[List[Deployment]]: llm deployments Optional[List[Deployment]]: llm deployments
""" """
try: try:
if request is None:
return None
for pattern, llm_deployments in self.patterns.items(): for pattern, llm_deployments in self.patterns.items():
if re.match(pattern, request): if re.match(pattern, request):
return llm_deployments return llm_deployments
except Exception as e: except Exception as e:
verbose_router_logger.error(f"Error in PatternMatchRouter.route: {str(e)}") verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}")
return None # No matching pattern found return None # No matching pattern found

View file

@ -0,0 +1,14 @@
from typing import Dict
from pydantic import BaseModel
class s3BatchLoggingElement(BaseModel):
"""
Type of element stored in self.log_queue in S3Logger
"""
payload: Dict
s3_object_key: str
s3_object_download_filename: str

View file

@ -129,6 +129,7 @@ class CallTypes(Enum):
speech = "speech" speech = "speech"
rerank = "rerank" rerank = "rerank"
arerank = "arerank" arerank = "arerank"
arealtime = "_arealtime"
class PassthroughCallTypes(Enum): class PassthroughCallTypes(Enum):

View file

@ -197,7 +197,6 @@ lagoLogger = None
dataDogLogger = None dataDogLogger = None
prometheusLogger = None prometheusLogger = None
dynamoLogger = None dynamoLogger = None
s3Logger = None
genericAPILogger = None genericAPILogger = None
clickHouseLogger = None clickHouseLogger = None
greenscaleLogger = None greenscaleLogger = None
@ -1410,6 +1409,8 @@ def client(original_function):
) )
else: else:
return result return result
elif call_type == CallTypes.arealtime.value:
return result
# ADD HIDDEN PARAMS - additional call metadata # ADD HIDDEN PARAMS - additional call metadata
if hasattr(result, "_hidden_params"): if hasattr(result, "_hidden_params"):
@ -1799,8 +1800,9 @@ def calculate_tiles_needed(
total_tiles = tiles_across * tiles_down total_tiles = tiles_across * tiles_down
return total_tiles return total_tiles
def get_image_type(image_data: bytes) -> Union[str, None]: def get_image_type(image_data: bytes) -> Union[str, None]:
""" take an image (really only the first ~100 bytes max are needed) """take an image (really only the first ~100 bytes max are needed)
and return 'png' 'gif' 'jpeg' 'heic' or None. method added to and return 'png' 'gif' 'jpeg' 'heic' or None. method added to
allow deprecation of imghdr in 3.13""" allow deprecation of imghdr in 3.13"""
@ -4336,16 +4338,18 @@ def get_api_base(
_optional_params.vertex_location is not None _optional_params.vertex_location is not None
and _optional_params.vertex_project is not None and _optional_params.vertex_project is not None
): ):
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
create_vertex_anthropic_url, VertexPartnerProvider,
create_vertex_url,
) )
if "claude" in model: if "claude" in model:
_api_base = create_vertex_anthropic_url( _api_base = create_vertex_url(
vertex_location=_optional_params.vertex_location, vertex_location=_optional_params.vertex_location,
vertex_project=_optional_params.vertex_project, vertex_project=_optional_params.vertex_project,
model=model, model=model,
stream=stream, stream=stream,
partner=VertexPartnerProvider.claude,
) )
else: else:
@ -4442,19 +4446,7 @@ def get_supported_openai_params(
elif custom_llm_provider == "volcengine": elif custom_llm_provider == "volcengine":
return litellm.VolcEngineConfig().get_supported_openai_params(model=model) return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "groq": elif custom_llm_provider == "groq":
return [ return litellm.GroqChatConfig().get_supported_openai_params(model=model)
"temperature",
"max_tokens",
"top_p",
"stream",
"stop",
"tools",
"tool_choice",
"response_format",
"seed",
"extra_headers",
"extra_body",
]
elif custom_llm_provider == "hosted_vllm": elif custom_llm_provider == "hosted_vllm":
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model) return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "deepseek": elif custom_llm_provider == "deepseek":
@ -4599,6 +4591,8 @@ def get_supported_openai_params(
return ( return (
litellm.MistralTextCompletionConfig().get_supported_openai_params() litellm.MistralTextCompletionConfig().get_supported_openai_params()
) )
if model.startswith("claude"):
return litellm.VertexAIAnthropicConfig().get_supported_openai_params()
return litellm.VertexAIConfig().get_supported_openai_params() return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings": elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()

View file

@ -296,7 +296,7 @@ def test_all_model_configs():
optional_params={}, optional_params={},
) == {"max_tokens": 10} ) == {"max_tokens": 10}
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.anthropic.transformation import (
VertexAIAnthropicConfig, VertexAIAnthropicConfig,
) )

View file

@ -12,7 +12,70 @@ import litellm
litellm.num_retries = 3 litellm.num_retries = 3
import time, random import time, random
from litellm._logging import verbose_logger
import logging
import pytest import pytest
import boto3
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_s3_logging(sync_mode):
verbose_logger.setLevel(level=logging.DEBUG)
litellm.success_callback = ["s3"]
litellm.s3_callback_params = {
"s3_bucket_name": "load-testing-oct",
"s3_aws_secret_access_key": "os.environ/AWS_SECRET_ACCESS_KEY",
"s3_aws_access_key_id": "os.environ/AWS_ACCESS_KEY_ID",
"s3_region_name": "us-west-2",
}
litellm.set_verbose = True
if sync_mode is True:
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "This is a test"}],
mock_response="It's simple to use and easy to get started",
)
else:
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "This is a test"}],
mock_response="It's simple to use and easy to get started",
)
print(f"response: {response}")
await asyncio.sleep(12)
total_objects, all_s3_keys = list_all_s3_objects("load-testing-oct")
# assert that atlest one key has response.id in it
assert any(response.id in key for key in all_s3_keys)
s3 = boto3.client("s3")
# delete all objects
for key in all_s3_keys:
s3.delete_object(Bucket="load-testing-oct", Key=key)
def list_all_s3_objects(bucket_name):
s3 = boto3.client("s3")
all_s3_keys = []
paginator = s3.get_paginator("list_objects_v2")
total_objects = 0
for page in paginator.paginate(Bucket=bucket_name):
if "Contents" in page:
total_objects += len(page["Contents"])
all_s3_keys.extend([obj["Key"] for obj in page["Contents"]])
print(f"Total number of objects in {bucket_name}: {total_objects}")
print(all_s3_keys)
return total_objects, all_s3_keys
list_all_s3_objects("load-testing-oct")
@pytest.mark.skip(reason="AWS Suspended Account") @pytest.mark.skip(reason="AWS Suspended Account")

View file

@ -1616,9 +1616,11 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema(
) )
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", @pytest.mark.parametrize(
"model", ["gemini-1.5-flash", "claude-3-sonnet@20240229"]
) # "vertex_ai",
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_gemini_pro_httpx_custom_api_base(provider): async def test_gemini_pro_httpx_custom_api_base(model):
load_vertex_ai_credentials() load_vertex_ai_credentials()
litellm.set_verbose = True litellm.set_verbose = True
messages = [ messages = [
@ -1634,7 +1636,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
with patch.object(client, "post", new=MagicMock()) as mock_call: with patch.object(client, "post", new=MagicMock()) as mock_call:
try: try:
response = completion( response = completion(
model="vertex_ai_beta/gemini-1.5-flash", model="vertex_ai/{}".format(model),
messages=messages, messages=messages,
response_format={"type": "json_object"}, response_format={"type": "json_object"},
client=client, client=client,
@ -1647,8 +1649,17 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
mock_call.assert_called_once() mock_call.assert_called_once()
assert "my-custom-api-base:generateContent" == mock_call.call_args.kwargs["url"] print(f"mock_call.call_args: {mock_call.call_args}")
assert "hello" in mock_call.call_args.kwargs["headers"] print(f"mock_call.call_args.kwargs: {mock_call.call_args.kwargs}")
if "url" in mock_call.call_args.kwargs:
assert (
"my-custom-api-base:generateContent"
== mock_call.call_args.kwargs["url"]
)
else:
assert "my-custom-api-base:rawPredict" == mock_call.call_args[0][0]
if "headers" in mock_call.call_args.kwargs:
assert "hello" in mock_call.call_args.kwargs["headers"]
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") # @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")

View file

@ -28,7 +28,6 @@ from typing import (
Union, Union,
) )
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import httpx import httpx
from dotenv import load_dotenv from dotenv import load_dotenv
@ -226,6 +225,8 @@ class MyCustomLLM(CustomLLM):
self, self,
model: str, model: str,
prompt: str, prompt: str,
api_key: Optional[str],
api_base: Optional[str],
model_response: ImageResponse, model_response: ImageResponse,
optional_params: dict, optional_params: dict,
logging_obj: Any, logging_obj: Any,
@ -242,6 +243,8 @@ class MyCustomLLM(CustomLLM):
self, self,
model: str, model: str,
prompt: str, prompt: str,
api_key: Optional[str],
api_base: Optional[str],
model_response: ImageResponse, model_response: ImageResponse,
optional_params: dict, optional_params: dict,
logging_obj: Any, logging_obj: Any,
@ -362,3 +365,31 @@ async def test_simple_image_generation_async():
) )
print(resp) print(resp)
@pytest.mark.asyncio
async def test_image_generation_async_with_api_key_and_api_base():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
with patch.object(
my_custom_llm, "aimage_generation", new=AsyncMock()
) as mock_client:
try:
resp = await litellm.aimage_generation(
model="custom_llm/my-fake-model",
prompt="Hello world",
api_key="my-api-key",
api_base="my-api-base",
)
print(resp)
except Exception as e:
print(e)
mock_client.assert_awaited_once()
mock_client.call_args.kwargs["api_key"] == "my-api-key"
mock_client.call_args.kwargs["api_base"] == "my-api-base"