diff --git a/docs/my-website/docs/proxy/prometheus.md b/docs/my-website/docs/proxy/prometheus.md index 207b1abe1..ef3a7b940 100644 --- a/docs/my-website/docs/proxy/prometheus.md +++ b/docs/my-website/docs/proxy/prometheus.md @@ -27,8 +27,7 @@ model_list: litellm_params: model: gpt-3.5-turbo litellm_settings: - success_callback: ["prometheus"] - failure_callback: ["prometheus"] + callbacks: ["prometheus"] ``` Start the proxy diff --git a/litellm/__init__.py b/litellm/__init__.py index 9379caac1..4a9b9c312 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -53,6 +53,7 @@ _custom_logger_compatible_callbacks_literal = Literal[ "arize", "langtrace", "gcs_bucket", + "s3", "opik", ] _known_custom_logger_compatible_callbacks: List = list( @@ -931,7 +932,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation impor vertexAITextEmbeddingConfig = VertexAITextEmbeddingConfig() -from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( +from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.anthropic.transformation import ( VertexAIAnthropicConfig, ) from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import ( diff --git a/litellm/integrations/custom_batch_logger.py b/litellm/integrations/custom_batch_logger.py index aa7f0bba2..7ef63d25c 100644 --- a/litellm/integrations/custom_batch_logger.py +++ b/litellm/integrations/custom_batch_logger.py @@ -21,6 +21,7 @@ class CustomBatchLogger(CustomLogger): self, flush_lock: Optional[asyncio.Lock] = None, batch_size: Optional[int] = DEFAULT_BATCH_SIZE, + flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS, **kwargs, ) -> None: """ @@ -28,7 +29,7 @@ class CustomBatchLogger(CustomLogger): flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching """ self.log_queue: List = [] - self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds + self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE self.last_flush_time = time.time() self.flush_lock = flush_lock diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index 19740c741..177ff735c 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -235,6 +235,14 @@ class LangFuseLogger: ): input = prompt output = response_obj.results + elif ( + kwargs.get("call_type") is not None + and kwargs.get("call_type") == "_arealtime" + and response_obj is not None + and isinstance(response_obj, list) + ): + input = kwargs.get("input") + output = response_obj elif ( kwargs.get("call_type") is not None and kwargs.get("call_type") == "pass_through_endpoint" diff --git a/litellm/integrations/s3.py b/litellm/integrations/s3.py index 1f82406e1..397b53cd5 100644 --- a/litellm/integrations/s3.py +++ b/litellm/integrations/s3.py @@ -1,43 +1,67 @@ -#### What this does #### -# On success + failure, log events to Supabase +""" +s3 Bucket Logging Integration -import datetime -import os -import subprocess -import sys -import traceback -import uuid -from typing import Optional +async_log_success_event: Processes the event, stores it in memory for 10 seconds or until MAX_BATCH_SIZE and then flushes to s3 + +NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually +NOTE 2: We create a httpx client with a concurrent limit of 1 to upload to s3. Files should get uploaded BUT they should not impact latency of LLM calling logic +""" + +import asyncio +import json +from datetime import datetime +from typing import Dict, List, Optional import litellm from litellm._logging import print_verbose, verbose_logger +from litellm.llms.base_aws_llm import BaseAWSLLM +from litellm.llms.custom_httpx.http_handler import ( + _get_httpx_client, + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.types.integrations.s3 import s3BatchLoggingElement from litellm.types.utils import StandardLoggingPayload +from .custom_batch_logger import CustomBatchLogger -class S3Logger: +# Default Flush interval and batch size for s3 +# Flush to s3 every 10 seconds OR every 1K requests in memory +DEFAULT_S3_FLUSH_INTERVAL_SECONDS = 10 +DEFAULT_S3_BATCH_SIZE = 1000 + + +class S3Logger(CustomBatchLogger, BaseAWSLLM): # Class variables or attributes def __init__( self, - s3_bucket_name=None, - s3_path=None, - s3_region_name=None, - s3_api_version=None, - s3_use_ssl=True, - s3_verify=None, - s3_endpoint_url=None, - s3_aws_access_key_id=None, - s3_aws_secret_access_key=None, - s3_aws_session_token=None, + s3_bucket_name: Optional[str] = None, + s3_path: Optional[str] = None, + s3_region_name: Optional[str] = None, + s3_api_version: Optional[str] = None, + s3_use_ssl: bool = True, + s3_verify: Optional[bool] = None, + s3_endpoint_url: Optional[str] = None, + s3_aws_access_key_id: Optional[str] = None, + s3_aws_secret_access_key: Optional[str] = None, + s3_aws_session_token: Optional[str] = None, + s3_flush_interval: Optional[int] = DEFAULT_S3_FLUSH_INTERVAL_SECONDS, + s3_batch_size: Optional[int] = DEFAULT_S3_BATCH_SIZE, s3_config=None, **kwargs, ): - import boto3 - try: verbose_logger.debug( f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}" ) + # IMPORTANT: We use a concurrent limit of 1 to upload to s3 + # Files should get uploaded BUT they should not impact latency of LLM calling logic + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback, + params={"concurrent_limit": 1}, + ) + if litellm.s3_callback_params is not None: # read in .env variables - example os.environ/AWS_BUCKET_NAME for key, value in litellm.s3_callback_params.items(): @@ -63,107 +87,282 @@ class S3Logger: s3_path = litellm.s3_callback_params.get("s3_path") # done reading litellm.s3_callback_params + s3_flush_interval = litellm.s3_callback_params.get( + "s3_flush_interval", DEFAULT_S3_FLUSH_INTERVAL_SECONDS + ) + s3_batch_size = litellm.s3_callback_params.get( + "s3_batch_size", DEFAULT_S3_BATCH_SIZE + ) + self.bucket_name = s3_bucket_name self.s3_path = s3_path verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}") - # Create an S3 client with custom endpoint URL - self.s3_client = boto3.client( - "s3", - region_name=s3_region_name, - endpoint_url=s3_endpoint_url, - api_version=s3_api_version, - use_ssl=s3_use_ssl, - verify=s3_verify, - aws_access_key_id=s3_aws_access_key_id, - aws_secret_access_key=s3_aws_secret_access_key, - aws_session_token=s3_aws_session_token, - config=s3_config, - **kwargs, + self.s3_bucket_name = s3_bucket_name + self.s3_region_name = s3_region_name + self.s3_api_version = s3_api_version + self.s3_use_ssl = s3_use_ssl + self.s3_verify = s3_verify + self.s3_endpoint_url = s3_endpoint_url + self.s3_aws_access_key_id = s3_aws_access_key_id + self.s3_aws_secret_access_key = s3_aws_secret_access_key + self.s3_aws_session_token = s3_aws_session_token + self.s3_config = s3_config + self.init_kwargs = kwargs + + asyncio.create_task(self.periodic_flush()) + self.flush_lock = asyncio.Lock() + + verbose_logger.debug( + f"s3 flush interval: {s3_flush_interval}, s3 batch size: {s3_batch_size}" ) + # Call CustomLogger's __init__ + CustomBatchLogger.__init__( + self, + flush_lock=self.flush_lock, + flush_interval=s3_flush_interval, + batch_size=s3_batch_size, + ) + self.log_queue: List[s3BatchLoggingElement] = [] + + # Call BaseAWSLLM's __init__ + BaseAWSLLM.__init__(self) + except Exception as e: print_verbose(f"Got exception on init s3 client {str(e)}") raise e - async def _async_log_event( - self, kwargs, response_obj, start_time, end_time, print_verbose - ): - self.log_event(kwargs, response_obj, start_time, end_time, print_verbose) - - def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: verbose_logger.debug( f"s3 Logging - Enters logging function for model {kwargs}" ) - # construct payload to send to s3 - # follows the same params as langfuse.py - litellm_params = kwargs.get("litellm_params", {}) - metadata = ( - litellm_params.get("metadata", {}) or {} - ) # if litellm_params['metadata'] == None - - # Clean Metadata before logging - never log raw metadata - # the raw metadata can contain circular references which leads to infinite recursion - # we clean out all extra litellm metadata params before logging - clean_metadata = {} - if isinstance(metadata, dict): - for key, value in metadata.items(): - # clean litellm metadata before logging - if key in [ - "headers", - "endpoint", - "caching_groups", - "previous_models", - ]: - continue - else: - clean_metadata[key] = value - - # Ensure everything in the payload is converted to str - payload: Optional[StandardLoggingPayload] = kwargs.get( - "standard_logging_object", None + s3_batch_logging_element = self.create_s3_batch_logging_element( + start_time=start_time, + standard_logging_payload=kwargs.get("standard_logging_object", None), + s3_path=self.s3_path, ) - if payload is None: - return + if s3_batch_logging_element is None: + raise ValueError("s3_batch_logging_element is None") - s3_file_name = litellm.utils.get_logging_id(start_time, payload) or "" - s3_object_key = ( - (self.s3_path.rstrip("/") + "/" if self.s3_path else "") - + start_time.strftime("%Y-%m-%d") - + "/" - + s3_file_name - ) # we need the s3 key to include the time, so we log cache hits too - s3_object_key += ".json" - - s3_object_download_filename = ( - "time-" - + start_time.strftime("%Y-%m-%dT%H-%M-%S-%f") - + "_" - + payload["id"] - + ".json" + verbose_logger.debug( + "\ns3 Logger - Logging payload = %s", s3_batch_logging_element ) - import json - - payload_str = json.dumps(payload) - - print_verbose(f"\ns3 Logger - Logging payload = {payload_str}") - - response = self.s3_client.put_object( - Bucket=self.bucket_name, - Key=s3_object_key, - Body=payload_str, - ContentType="application/json", - ContentLanguage="en", - ContentDisposition=f'inline; filename="{s3_object_download_filename}"', - CacheControl="private, immutable, max-age=31536000, s-maxage=0", + self.log_queue.append(s3_batch_logging_element) + verbose_logger.debug( + "s3 logging: queue length %s, batch size %s", + len(self.log_queue), + self.batch_size, ) - - print_verbose(f"Response from s3:{str(response)}") - - print_verbose(f"s3 Layer Logging - final response object: {response_obj}") - return response + if len(self.log_queue) >= self.batch_size: + await self.flush_queue() except Exception as e: verbose_logger.exception(f"s3 Layer Error - {str(e)}") pass + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + """ + Synchronous logging function to log to s3 + + Does not batch logging requests, instantly logs on s3 Bucket + """ + try: + s3_batch_logging_element = self.create_s3_batch_logging_element( + start_time=start_time, + standard_logging_payload=kwargs.get("standard_logging_object", None), + s3_path=self.s3_path, + ) + + if s3_batch_logging_element is None: + raise ValueError("s3_batch_logging_element is None") + + verbose_logger.debug( + "\ns3 Logger - Logging payload = %s", s3_batch_logging_element + ) + + # log the element sync httpx client + self.upload_data_to_s3(s3_batch_logging_element) + except Exception as e: + verbose_logger.exception(f"s3 Layer Error - {str(e)}") + pass + + async def async_upload_data_to_s3( + self, batch_logging_element: s3BatchLoggingElement + ): + try: + import hashlib + + import boto3 + import requests + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + try: + credentials: Credentials = self.get_credentials( + aws_access_key_id=self.s3_aws_access_key_id, + aws_secret_access_key=self.s3_aws_secret_access_key, + aws_session_token=self.s3_aws_session_token, + aws_region_name=self.s3_region_name, + ) + + # Prepare the URL + url = f"https://{self.bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}" + + if self.s3_endpoint_url: + url = self.s3_endpoint_url + "/" + batch_logging_element.s3_object_key + + # Convert JSON to string + json_string = json.dumps(batch_logging_element.payload) + + # Calculate SHA256 hash of the content + content_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest() + + # Prepare the request + headers = { + "Content-Type": "application/json", + "x-amz-content-sha256": content_hash, + "Content-Language": "en", + "Content-Disposition": f'inline; filename="{batch_logging_element.s3_object_download_filename}"', + "Cache-Control": "private, immutable, max-age=31536000, s-maxage=0", + } + req = requests.Request("PUT", url, data=json_string, headers=headers) + prepped = req.prepare() + + # Sign the request + aws_request = AWSRequest( + method=prepped.method, + url=prepped.url, + data=prepped.body, + headers=prepped.headers, + ) + SigV4Auth(credentials, "s3", self.s3_region_name).add_auth(aws_request) + + # Prepare the signed headers + signed_headers = dict(aws_request.headers.items()) + + # Make the request + response = await self.async_httpx_client.put( + url, data=json_string, headers=signed_headers + ) + response.raise_for_status() + except Exception as e: + verbose_logger.exception(f"Error uploading to s3: {str(e)}") + + async def async_send_batch(self): + """ + + Sends runs from self.log_queue + + Returns: None + + Raises: Does not raise an exception, will only verbose_logger.exception() + """ + if not self.log_queue: + return + + for payload in self.log_queue: + asyncio.create_task(self.async_upload_data_to_s3(payload)) + + def create_s3_batch_logging_element( + self, + start_time: datetime, + standard_logging_payload: Optional[StandardLoggingPayload], + s3_path: Optional[str], + ) -> Optional[s3BatchLoggingElement]: + """ + Helper function to create an s3BatchLoggingElement. + + Args: + start_time (datetime): The start time of the logging event. + standard_logging_payload (Optional[StandardLoggingPayload]): The payload to be logged. + s3_path (Optional[str]): The S3 path prefix. + + Returns: + Optional[s3BatchLoggingElement]: The created s3BatchLoggingElement, or None if payload is None. + """ + if standard_logging_payload is None: + return None + + s3_file_name = ( + litellm.utils.get_logging_id(start_time, standard_logging_payload) or "" + ) + s3_object_key = ( + (s3_path.rstrip("/") + "/" if s3_path else "") + + start_time.strftime("%Y-%m-%d") + + "/" + + s3_file_name + + ".json" + ) + + s3_object_download_filename = f"time-{start_time.strftime('%Y-%m-%dT%H-%M-%S-%f')}_{standard_logging_payload['id']}.json" + + return s3BatchLoggingElement( + payload=standard_logging_payload, # type: ignore + s3_object_key=s3_object_key, + s3_object_download_filename=s3_object_download_filename, + ) + + def upload_data_to_s3(self, batch_logging_element: s3BatchLoggingElement): + try: + import hashlib + + import boto3 + import requests + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + try: + credentials: Credentials = self.get_credentials( + aws_access_key_id=self.s3_aws_access_key_id, + aws_secret_access_key=self.s3_aws_secret_access_key, + aws_session_token=self.s3_aws_session_token, + aws_region_name=self.s3_region_name, + ) + + # Prepare the URL + url = f"https://{self.bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}" + + if self.s3_endpoint_url: + url = self.s3_endpoint_url + "/" + batch_logging_element.s3_object_key + + # Convert JSON to string + json_string = json.dumps(batch_logging_element.payload) + + # Calculate SHA256 hash of the content + content_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest() + + # Prepare the request + headers = { + "Content-Type": "application/json", + "x-amz-content-sha256": content_hash, + "Content-Language": "en", + "Content-Disposition": f'inline; filename="{batch_logging_element.s3_object_download_filename}"', + "Cache-Control": "private, immutable, max-age=31536000, s-maxage=0", + } + req = requests.Request("PUT", url, data=json_string, headers=headers) + prepped = req.prepare() + + # Sign the request + aws_request = AWSRequest( + method=prepped.method, + url=prepped.url, + data=prepped.body, + headers=prepped.headers, + ) + SigV4Auth(credentials, "s3", self.s3_region_name).add_auth(aws_request) + + # Prepare the signed headers + signed_headers = dict(aws_request.headers.items()) + + httpx_client = _get_httpx_client() + # Make the request + response = httpx_client.put(url, data=json_string, headers=signed_headers) + response.raise_for_status() + except Exception as e: + verbose_logger.exception(f"Error uploading to s3: {str(e)}") diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index ce97f1c6f..a641be019 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -116,7 +116,6 @@ lagoLogger = None dataDogLogger = None prometheusLogger = None dynamoLogger = None -s3Logger = None genericAPILogger = None clickHouseLogger = None greenscaleLogger = None @@ -1346,36 +1345,6 @@ class Logging: user_id=kwargs.get("user", None), print_verbose=print_verbose, ) - if callback == "s3": - global s3Logger - if s3Logger is None: - s3Logger = S3Logger() - if self.stream: - if "complete_streaming_response" in self.model_call_details: - print_verbose( - "S3Logger Logger: Got Stream Event - Completed Stream Response" - ) - s3Logger.log_event( - kwargs=self.model_call_details, - response_obj=self.model_call_details[ - "complete_streaming_response" - ], - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - ) - else: - print_verbose( - "S3Logger Logger: Got Stream Event - No complete stream response as yet" - ) - else: - s3Logger.log_event( - kwargs=self.model_call_details, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - ) if ( callback == "openmeter" and self.model_call_details.get("litellm_params", {}).get( @@ -2245,7 +2214,7 @@ def set_callbacks(callback_list, function_id=None): """ Globally sets the callback client """ - global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger + global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger try: for callback in callback_list: @@ -2319,8 +2288,6 @@ def set_callbacks(callback_list, function_id=None): dataDogLogger = DataDogLogger() elif callback == "dynamodb": dynamoLogger = DyanmoDBLogger() - elif callback == "s3": - s3Logger = S3Logger() elif callback == "wandb": weightsBiasesLogger = WeightsBiasesLogger() elif callback == "logfire": @@ -2357,7 +2324,6 @@ def _init_custom_logger_compatible_class( llm_router: Optional[ Any ], # expect litellm.Router, but typing errors due to circular import - premium_user: Optional[bool] = None, ) -> Optional[CustomLogger]: if logging_integration == "lago": for callback in _in_memory_loggers: @@ -2404,17 +2370,9 @@ def _init_custom_logger_compatible_class( if isinstance(callback, PrometheusLogger): return callback # type: ignore - if premium_user: - _prometheus_logger = PrometheusLogger() - _in_memory_loggers.append(_prometheus_logger) - return _prometheus_logger # type: ignore - elif premium_user is False: - verbose_logger.warning( - f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}" - ) - return None - else: - return None + _prometheus_logger = PrometheusLogger() + _in_memory_loggers.append(_prometheus_logger) + return _prometheus_logger # type: ignore elif logging_integration == "datadog": for callback in _in_memory_loggers: if isinstance(callback, DataDogLogger): @@ -2423,6 +2381,14 @@ def _init_custom_logger_compatible_class( _datadog_logger = DataDogLogger() _in_memory_loggers.append(_datadog_logger) return _datadog_logger # type: ignore + elif logging_integration == "s3": + for callback in _in_memory_loggers: + if isinstance(callback, S3Logger): + return callback # type: ignore + + _s3_logger = S3Logger() + _in_memory_loggers.append(_s3_logger) + return _s3_logger # type: ignore elif logging_integration == "gcs_bucket": for callback in _in_memory_loggers: if isinstance(callback, GCSBucketLogger): @@ -2589,6 +2555,10 @@ def get_custom_logger_compatible_class( for callback in _in_memory_loggers: if isinstance(callback, PrometheusLogger): return callback + elif logging_integration == "s3": + for callback in _in_memory_loggers: + if isinstance(callback, S3Logger): + return callback elif logging_integration == "datadog": for callback in _in_memory_loggers: if isinstance(callback, DataDogLogger): diff --git a/litellm/litellm_core_utils/realtime_streaming.py b/litellm/litellm_core_utils/realtime_streaming.py new file mode 100644 index 000000000..922f90e36 --- /dev/null +++ b/litellm/litellm_core_utils/realtime_streaming.py @@ -0,0 +1,112 @@ +""" +async with websockets.connect( # type: ignore + url, + extra_headers={ + "api-key": api_key, # type: ignore + }, + ) as backend_ws: + forward_task = asyncio.create_task( + forward_messages(websocket, backend_ws) + ) + + try: + while True: + message = await websocket.receive_text() + await backend_ws.send(message) + except websockets.exceptions.ConnectionClosed: # type: ignore + forward_task.cancel() + finally: + if not forward_task.done(): + forward_task.cancel() + try: + await forward_task + except asyncio.CancelledError: + pass +""" + +import asyncio +import concurrent.futures +import traceback +from asyncio import Task +from typing import Any, Dict, List, Optional, Union + +from .litellm_logging import Logging as LiteLLMLogging + +# Create a thread pool with a maximum of 10 threads +executor = concurrent.futures.ThreadPoolExecutor(max_workers=10) + + +class RealTimeStreaming: + def __init__( + self, + websocket: Any, + backend_ws: Any, + logging_obj: Optional[LiteLLMLogging] = None, + ): + self.websocket = websocket + self.backend_ws = backend_ws + self.logging_obj = logging_obj + self.messages: List = [] + self.input_message: Dict = {} + + def store_message(self, message: Union[str, bytes]): + """Store message in list""" + self.messages.append(message) + + def store_input(self, message: dict): + """Store input message""" + self.input_message = message + if self.logging_obj: + self.logging_obj.pre_call(input=message, api_key="") + + async def log_messages(self): + """Log messages in list""" + if self.logging_obj: + ## ASYNC LOGGING + # Create an event loop for the new thread + asyncio.create_task(self.logging_obj.async_success_handler(self.messages)) + ## SYNC LOGGING + executor.submit(self.logging_obj.success_handler(self.messages)) + + async def backend_to_client_send_messages(self): + import websockets + + try: + while True: + message = await self.backend_ws.recv() + await self.websocket.send_text(message) + + ## LOGGING + self.store_message(message) + except websockets.exceptions.ConnectionClosed: # type: ignore + pass + except Exception: + pass + finally: + await self.log_messages() + + async def client_ack_messages(self): + try: + while True: + message = await self.websocket.receive_text() + ## LOGGING + self.store_input(message=message) + ## FORWARD TO BACKEND + await self.backend_ws.send(message) + except self.websockets.exceptions.ConnectionClosed: # type: ignore + pass + + async def bidirectional_forward(self): + + forward_task = asyncio.create_task(self.backend_to_client_send_messages()) + try: + await self.client_ack_messages() + except self.websockets.exceptions.ConnectionClosed: # type: ignore + forward_task.cancel() + finally: + if not forward_task.done(): + forward_task.cancel() + try: + await forward_task + except asyncio.CancelledError: + pass diff --git a/litellm/llms/AzureOpenAI/realtime/handler.py b/litellm/llms/AzureOpenAI/realtime/handler.py index 7d58ee78f..bf45c53fb 100644 --- a/litellm/llms/AzureOpenAI/realtime/handler.py +++ b/litellm/llms/AzureOpenAI/realtime/handler.py @@ -7,6 +7,8 @@ This requires websockets, and is currently only supported on LiteLLM Proxy. import asyncio from typing import Any, Optional +from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging +from ....litellm_core_utils.realtime_streaming import RealTimeStreaming from ..azure import AzureChatCompletion # BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01" @@ -44,6 +46,7 @@ class AzureOpenAIRealtime(AzureChatCompletion): api_version: Optional[str] = None, azure_ad_token: Optional[str] = None, client: Optional[Any] = None, + logging_obj: Optional[LiteLLMLogging] = None, timeout: Optional[float] = None, ): import websockets @@ -62,23 +65,10 @@ class AzureOpenAIRealtime(AzureChatCompletion): "api-key": api_key, # type: ignore }, ) as backend_ws: - forward_task = asyncio.create_task( - forward_messages(websocket, backend_ws) + realtime_streaming = RealTimeStreaming( + websocket, backend_ws, logging_obj ) - - try: - while True: - message = await websocket.receive_text() - await backend_ws.send(message) - except websockets.exceptions.ConnectionClosed: # type: ignore - forward_task.cancel() - finally: - if not forward_task.done(): - forward_task.cancel() - try: - await forward_task - except asyncio.CancelledError: - pass + await realtime_streaming.bidirectional_forward() except websockets.exceptions.InvalidStatusCode as e: # type: ignore await websocket.close(code=e.status_code, reason=str(e)) diff --git a/litellm/llms/OpenAI/realtime/handler.py b/litellm/llms/OpenAI/realtime/handler.py index 08e5fa0b9..a790b1800 100644 --- a/litellm/llms/OpenAI/realtime/handler.py +++ b/litellm/llms/OpenAI/realtime/handler.py @@ -7,20 +7,11 @@ This requires websockets, and is currently only supported on LiteLLM Proxy. import asyncio from typing import Any, Optional +from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging +from ....litellm_core_utils.realtime_streaming import RealTimeStreaming from ..openai import OpenAIChatCompletion -async def forward_messages(client_ws: Any, backend_ws: Any): - import websockets - - try: - while True: - message = await backend_ws.recv() - await client_ws.send_text(message) - except websockets.exceptions.ConnectionClosed: # type: ignore - pass - - class OpenAIRealtime(OpenAIChatCompletion): def _construct_url(self, api_base: str, model: str) -> str: """ @@ -35,6 +26,7 @@ class OpenAIRealtime(OpenAIChatCompletion): self, model: str, websocket: Any, + logging_obj: LiteLLMLogging, api_base: Optional[str] = None, api_key: Optional[str] = None, client: Optional[Any] = None, @@ -57,25 +49,26 @@ class OpenAIRealtime(OpenAIChatCompletion): "OpenAI-Beta": "realtime=v1", }, ) as backend_ws: - forward_task = asyncio.create_task( - forward_messages(websocket, backend_ws) + realtime_streaming = RealTimeStreaming( + websocket, backend_ws, logging_obj ) - - try: - while True: - message = await websocket.receive_text() - await backend_ws.send(message) - except websockets.exceptions.ConnectionClosed: # type: ignore - forward_task.cancel() - finally: - if not forward_task.done(): - forward_task.cancel() - try: - await forward_task - except asyncio.CancelledError: - pass + await realtime_streaming.bidirectional_forward() except websockets.exceptions.InvalidStatusCode as e: # type: ignore await websocket.close(code=e.status_code, reason=str(e)) except Exception as e: - await websocket.close(code=1011, reason=f"Internal server error: {str(e)}") + try: + await websocket.close( + code=1011, reason=f"Internal server error: {str(e)}" + ) + except RuntimeError as close_error: + if "already completed" in str(close_error) or "websocket.close" in str( + close_error + ): + # The WebSocket is already closed or the response is completed, so we can ignore this error + pass + else: + # If it's a different RuntimeError, we might want to log it or handle it differently + raise Exception( + f"Unexpected error while closing WebSocket: {close_error}" + ) diff --git a/litellm/llms/custom_llm.py b/litellm/llms/custom_llm.py index 89798eef5..de09df19c 100644 --- a/litellm/llms/custom_llm.py +++ b/litellm/llms/custom_llm.py @@ -153,6 +153,8 @@ class CustomLLM(BaseLLM): self, model: str, prompt: str, + api_key: Optional[str], + api_base: Optional[str], model_response: ImageResponse, optional_params: dict, logging_obj: Any, @@ -166,6 +168,12 @@ class CustomLLM(BaseLLM): model: str, prompt: str, model_response: ImageResponse, + api_key: Optional[ + str + ], # dynamically set api_key - https://docs.litellm.ai/docs/set_keys#api_key + api_base: Optional[ + str + ], # dynamically set api_base - https://docs.litellm.ai/docs/set_keys#api_base optional_params: dict, logging_obj: Any, timeout: Optional[Union[float, httpx.Timeout]] = None, diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py deleted file mode 100644 index 0f98c34c2..000000000 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_anthropic.py +++ /dev/null @@ -1,464 +0,0 @@ -# What is this? -## Handler file for calling claude-3 on vertex ai -import copy -import json -import os -import time -import types -import uuid -from enum import Enum -from typing import Any, Callable, List, Optional, Tuple, Union - -import httpx # type: ignore -import requests # type: ignore - -import litellm -from litellm.litellm_core_utils.core_helpers import map_finish_reason -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler -from litellm.types.llms.anthropic import ( - AnthropicMessagesTool, - AnthropicMessagesToolChoice, -) -from litellm.types.llms.openai import ( - ChatCompletionToolParam, - ChatCompletionToolParamFunctionChunk, -) -from litellm.types.utils import ResponseFormatChunk -from litellm.utils import CustomStreamWrapper, ModelResponse, Usage - -from ..prompt_templates.factory import ( - construct_tool_use_system_prompt, - contains_tag, - custom_prompt, - extract_between_tags, - parse_xml_params, - prompt_factory, - response_schema_prompt, -) - - -class VertexAIError(Exception): - def __init__(self, status_code, message): - self.status_code = status_code - self.message = message - self.request = httpx.Request( - method="POST", url=" https://cloud.google.com/vertex-ai/" - ) - self.response = httpx.Response(status_code=status_code, request=self.request) - super().__init__( - self.message - ) # Call the base class constructor with the parameters it needs - - -class VertexAIAnthropicConfig: - """ - Reference:https://docs.anthropic.com/claude/reference/messages_post - - Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways: - - - `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL. - - `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16". - - The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters: - - - `max_tokens` Required (integer) max tokens, - - `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" - - `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py - - `temperature` Optional (float) The amount of randomness injected into the response - - `top_p` Optional (float) Use nucleus sampling. - - `top_k` Optional (int) Only sample from the top K options for each subsequent token - - `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating - - Note: Please make sure to modify the default parameters as required for your use case. - """ - - max_tokens: Optional[int] = ( - 4096 # anthropic max - setting this doesn't impact response, but is required by anthropic. - ) - system: Optional[str] = None - temperature: Optional[float] = None - top_p: Optional[float] = None - top_k: Optional[int] = None - stop_sequences: Optional[List[str]] = None - - def __init__( - self, - max_tokens: Optional[int] = None, - anthropic_version: Optional[str] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key == "max_tokens" and value is None: - value = self.max_tokens - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def get_supported_openai_params(self): - return [ - "max_tokens", - "max_completion_tokens", - "tools", - "tool_choice", - "stream", - "stop", - "temperature", - "top_p", - "response_format", - ] - - def map_openai_params(self, non_default_params: dict, optional_params: dict): - for param, value in non_default_params.items(): - if param == "max_tokens" or param == "max_completion_tokens": - optional_params["max_tokens"] = value - if param == "tools": - optional_params["tools"] = value - if param == "tool_choice": - _tool_choice: Optional[AnthropicMessagesToolChoice] = None - if value == "auto": - _tool_choice = {"type": "auto"} - elif value == "required": - _tool_choice = {"type": "any"} - elif isinstance(value, dict): - _tool_choice = {"type": "tool", "name": value["function"]["name"]} - - if _tool_choice is not None: - optional_params["tool_choice"] = _tool_choice - if param == "stream": - optional_params["stream"] = value - if param == "stop": - optional_params["stop_sequences"] = value - if param == "temperature": - optional_params["temperature"] = value - if param == "top_p": - optional_params["top_p"] = value - if param == "response_format" and isinstance(value, dict): - json_schema: Optional[dict] = None - if "response_schema" in value: - json_schema = value["response_schema"] - elif "json_schema" in value: - json_schema = value["json_schema"]["schema"] - """ - When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode - - You usually want to provide a single tool - - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool - - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. - """ - _tool_choice = None - _tool_choice = {"name": "json_tool_call", "type": "tool"} - - _tool = AnthropicMessagesTool( - name="json_tool_call", - input_schema={ - "type": "object", - "properties": {"values": json_schema}, # type: ignore - }, - ) - - optional_params["tools"] = [_tool] - optional_params["tool_choice"] = _tool_choice - optional_params["json_mode"] = True - - return optional_params - - -""" -- Run client init -- Support async completion, streaming -""" - - -def refresh_auth( - credentials, -) -> str: # used when user passes in credentials as json string - from google.auth.transport.requests import Request # type: ignore[import-untyped] - - if credentials.token is None: - credentials.refresh(Request()) - - if not credentials.token: - raise RuntimeError("Could not resolve API token from the credentials") - - return credentials.token - - -def get_vertex_client( - client: Any, - vertex_project: Optional[str], - vertex_location: Optional[str], - vertex_credentials: Optional[str], -) -> Tuple[Any, Optional[str]]: - args = locals() - from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( - VertexLLM, - ) - - try: - from anthropic import AnthropicVertex - except Exception: - raise VertexAIError( - status_code=400, - message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""", - ) - - access_token: Optional[str] = None - - if client is None: - _credentials, cred_project_id = VertexLLM().load_auth( - credentials=vertex_credentials, project_id=vertex_project - ) - - vertex_ai_client = AnthropicVertex( - project_id=vertex_project or cred_project_id, - region=vertex_location or "us-central1", - access_token=_credentials.token, - ) - access_token = _credentials.token - else: - vertex_ai_client = client - access_token = client.access_token - - return vertex_ai_client, access_token - - -def create_vertex_anthropic_url( - vertex_location: str, vertex_project: str, model: str, stream: bool -) -> str: - if stream is True: - return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict" - else: - return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict" - - -def completion( - model: str, - messages: list, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - optional_params: dict, - custom_prompt_dict: dict, - headers: Optional[dict], - timeout: Union[float, httpx.Timeout], - vertex_project=None, - vertex_location=None, - vertex_credentials=None, - litellm_params=None, - logger_fn=None, - acompletion: bool = False, - client=None, -): - try: - import vertexai - except Exception: - raise VertexAIError( - status_code=400, - message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""", - ) - - from anthropic import AnthropicVertex - - from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion - from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( - VertexLLM, - ) - - if not ( - hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") - ): - raise VertexAIError( - status_code=400, - message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", - ) - try: - - vertex_httpx_logic = VertexLLM() - - access_token, project_id = vertex_httpx_logic._ensure_access_token( - credentials=vertex_credentials, - project_id=vertex_project, - custom_llm_provider="vertex_ai", - ) - - anthropic_chat_completions = AnthropicChatCompletion() - - ## Load Config - config = litellm.VertexAIAnthropicConfig.get_config() - for k, v in config.items(): - if k not in optional_params: - optional_params[k] = v - - ## CONSTRUCT API BASE - stream = optional_params.get("stream", False) - - api_base = create_vertex_anthropic_url( - vertex_location=vertex_location or "us-central1", - vertex_project=vertex_project or project_id, - model=model, - stream=stream, - ) - - if headers is not None: - vertex_headers = headers - else: - vertex_headers = {} - - vertex_headers.update({"Authorization": "Bearer {}".format(access_token)}) - - optional_params.update( - {"anthropic_version": "vertex-2023-10-16", "is_vertex_request": True} - ) - - return anthropic_chat_completions.completion( - model=model, - messages=messages, - api_base=api_base, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=access_token, - logging_obj=logging_obj, - optional_params=optional_params, - acompletion=acompletion, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=vertex_headers, - client=client, - timeout=timeout, - ) - - except Exception as e: - raise VertexAIError(status_code=500, message=str(e)) - - -async def async_completion( - model: str, - messages: list, - data: dict, - model_response: ModelResponse, - print_verbose: Callable, - logging_obj, - vertex_project=None, - vertex_location=None, - optional_params=None, - client=None, - access_token=None, -): - from anthropic import AsyncAnthropicVertex - - if client is None: - vertex_ai_client = AsyncAnthropicVertex( - project_id=vertex_project, region=vertex_location, access_token=access_token # type: ignore - ) - else: - vertex_ai_client = client - - ## LOGGING - logging_obj.pre_call( - input=messages, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - }, - ) - message = await vertex_ai_client.messages.create(**data) # type: ignore - text_content = message.content[0].text - ## TOOL CALLING - OUTPUT PARSE - if text_content is not None and contains_tag("invoke", text_content): - function_name = extract_between_tags("tool_name", text_content)[0] - function_arguments_str = extract_between_tags("invoke", text_content)[0].strip() - function_arguments_str = f"{function_arguments_str}" - function_arguments = parse_xml_params(function_arguments_str) - _message = litellm.Message( - tool_calls=[ - { - "id": f"call_{uuid.uuid4()}", - "type": "function", - "function": { - "name": function_name, - "arguments": json.dumps(function_arguments), - }, - } - ], - content=None, - ) - model_response.choices[0].message = _message # type: ignore - else: - model_response.choices[0].message.content = text_content # type: ignore - model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason) - - ## CALCULATING USAGE - prompt_tokens = message.usage.input_tokens - completion_tokens = message.usage.output_tokens - - model_response.created = int(time.time()) - model_response.model = model - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - setattr(model_response, "usage", usage) - return model_response - - -async def async_streaming( - model: str, - messages: list, - data: dict, - model_response: ModelResponse, - print_verbose: Callable, - logging_obj, - vertex_project=None, - vertex_location=None, - optional_params=None, - client=None, - access_token=None, -): - from anthropic import AsyncAnthropicVertex - - if client is None: - vertex_ai_client = AsyncAnthropicVertex( - project_id=vertex_project, region=vertex_location, access_token=access_token # type: ignore - ) - else: - vertex_ai_client = client - - ## LOGGING - logging_obj.pre_call( - input=messages, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - }, - ) - response = await vertex_ai_client.messages.create(**data, stream=True) # type: ignore - logging_obj.post_call(input=messages, api_key=None, original_response=response) - - streamwrapper = CustomStreamWrapper( - completion_stream=response, - model=model, - custom_llm_provider="vertex_ai", - logging_obj=logging_obj, - ) - - return streamwrapper diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/anthropic/transformation.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/anthropic/transformation.py new file mode 100644 index 000000000..44b8af279 --- /dev/null +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/anthropic/transformation.py @@ -0,0 +1,179 @@ +# What is this? +## Handler file for calling claude-3 on vertex ai +import copy +import json +import os +import time +import types +import uuid +from enum import Enum +from typing import Any, Callable, List, Optional, Tuple, Union + +import httpx # type: ignore +import requests # type: ignore + +import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.anthropic import ( + AnthropicMessagesTool, + AnthropicMessagesToolChoice, +) +from litellm.types.llms.openai import ( + ChatCompletionToolParam, + ChatCompletionToolParamFunctionChunk, +) +from litellm.types.utils import ResponseFormatChunk +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage + +from ....prompt_templates.factory import ( + construct_tool_use_system_prompt, + contains_tag, + custom_prompt, + extract_between_tags, + parse_xml_params, + prompt_factory, + response_schema_prompt, +) + + +class VertexAIError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url=" https://cloud.google.com/vertex-ai/" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class VertexAIAnthropicConfig: + """ + Reference:https://docs.anthropic.com/claude/reference/messages_post + + Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways: + + - `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL. + - `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16". + + The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters: + + - `max_tokens` Required (integer) max tokens, + - `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" + - `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py + - `temperature` Optional (float) The amount of randomness injected into the response + - `top_p` Optional (float) Use nucleus sampling. + - `top_k` Optional (int) Only sample from the top K options for each subsequent token + - `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating + + Note: Please make sure to modify the default parameters as required for your use case. + """ + + max_tokens: Optional[int] = ( + 4096 # anthropic max - setting this doesn't impact response, but is required by anthropic. + ) + system: Optional[str] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + stop_sequences: Optional[List[str]] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + anthropic_version: Optional[str] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key == "max_tokens" and value is None: + value = self.max_tokens + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "max_tokens", + "max_completion_tokens", + "tools", + "tool_choice", + "stream", + "stop", + "temperature", + "top_p", + "response_format", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "max_tokens" or param == "max_completion_tokens": + optional_params["max_tokens"] = value + if param == "tools": + optional_params["tools"] = value + if param == "tool_choice": + _tool_choice: Optional[AnthropicMessagesToolChoice] = None + if value == "auto": + _tool_choice = {"type": "auto"} + elif value == "required": + _tool_choice = {"type": "any"} + elif isinstance(value, dict): + _tool_choice = {"type": "tool", "name": value["function"]["name"]} + + if _tool_choice is not None: + optional_params["tool_choice"] = _tool_choice + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop_sequences"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "response_format" and isinstance(value, dict): + json_schema: Optional[dict] = None + if "response_schema" in value: + json_schema = value["response_schema"] + elif "json_schema" in value: + json_schema = value["json_schema"]["schema"] + """ + When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode + - You usually want to provide a single tool + - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool + - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. + """ + _tool_choice = None + _tool_choice = {"name": "json_tool_call", "type": "tool"} + + _tool = AnthropicMessagesTool( + name="json_tool_call", + input_schema={ + "type": "object", + "properties": {"values": json_schema}, # type: ignore + }, + ) + + optional_params["tools"] = [_tool] + optional_params["tool_choice"] = _tool_choice + optional_params["json_mode"] = True + + return optional_params diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py index da54f6e1b..e8443e6f6 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/vertex_ai_partner_models/main.py @@ -9,13 +9,14 @@ import httpx # type: ignore import litellm from litellm.utils import ModelResponse -from ...base import BaseLLM +from ..vertex_llm_base import VertexBase class VertexPartnerProvider(str, Enum): mistralai = "mistralai" llama = "llama" ai21 = "ai21" + claude = "claude" class VertexAIError(Exception): @@ -31,31 +32,38 @@ class VertexAIError(Exception): ) # Call the base class constructor with the parameters it needs -class VertexAIPartnerModels(BaseLLM): +def create_vertex_url( + vertex_location: str, + vertex_project: str, + partner: VertexPartnerProvider, + stream: Optional[bool], + model: str, + api_base: Optional[str] = None, +) -> str: + """Return the base url for the vertex partner models""" + if partner == VertexPartnerProvider.llama: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi" + elif partner == VertexPartnerProvider.mistralai: + if stream: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict" + else: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict" + elif partner == VertexPartnerProvider.ai21: + if stream: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict" + else: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict" + elif partner == VertexPartnerProvider.claude: + if stream: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict" + else: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict" + + +class VertexAIPartnerModels(VertexBase): def __init__(self) -> None: pass - def create_vertex_url( - self, - vertex_location: str, - vertex_project: str, - partner: VertexPartnerProvider, - stream: Optional[bool], - model: str, - ) -> str: - if partner == VertexPartnerProvider.llama: - return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi" - elif partner == VertexPartnerProvider.mistralai: - if stream: - return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict" - else: - return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict" - elif partner == VertexPartnerProvider.ai21: - if stream: - return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict" - else: - return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict" - def completion( self, model: str, @@ -64,6 +72,7 @@ class VertexAIPartnerModels(BaseLLM): print_verbose: Callable, encoding, logging_obj, + api_base: Optional[str], optional_params: dict, custom_prompt_dict: dict, headers: Optional[dict], @@ -80,6 +89,7 @@ class VertexAIPartnerModels(BaseLLM): import vertexai from google.cloud import aiplatform + from litellm.llms.anthropic.chat import AnthropicChatCompletion from litellm.llms.databricks.chat import DatabricksChatCompletion from litellm.llms.OpenAI.openai import OpenAIChatCompletion from litellm.llms.text_completion_codestral import CodestralTextCompletion @@ -112,6 +122,7 @@ class VertexAIPartnerModels(BaseLLM): openai_like_chat_completions = DatabricksChatCompletion() codestral_fim_completions = CodestralTextCompletion() + anthropic_chat_completions = AnthropicChatCompletion() ## CONSTRUCT API BASE stream: bool = optional_params.get("stream", False) or False @@ -126,8 +137,10 @@ class VertexAIPartnerModels(BaseLLM): elif "jamba" in model: partner = VertexPartnerProvider.ai21 optional_params["custom_endpoint"] = True + elif "claude" in model: + partner = VertexPartnerProvider.claude - api_base = self.create_vertex_url( + default_api_base = create_vertex_url( vertex_location=vertex_location or "us-central1", vertex_project=vertex_project or project_id, partner=partner, # type: ignore @@ -135,6 +148,21 @@ class VertexAIPartnerModels(BaseLLM): model=model, ) + if len(default_api_base.split(":")) > 1: + endpoint = default_api_base.split(":")[-1] + else: + endpoint = "" + + _, api_base = self._check_custom_proxy( + api_base=api_base, + custom_llm_provider="vertex_ai", + gemini_api_key=None, + endpoint=endpoint, + stream=stream, + auth_header=None, + url=default_api_base, + ) + model = model.split("@")[0] if "codestral" in model and litellm_params.get("text_completion") is True: @@ -158,6 +186,35 @@ class VertexAIPartnerModels(BaseLLM): timeout=timeout, encoding=encoding, ) + elif "claude" in model: + if headers is None: + headers = {} + headers.update({"Authorization": "Bearer {}".format(access_token)}) + + optional_params.update( + { + "anthropic_version": "vertex-2023-10-16", + "is_vertex_request": True, + } + ) + return anthropic_chat_completions.completion( + model=model, + messages=messages, + api_base=api_base, + acompletion=acompletion, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=access_token, + logging_obj=logging_obj, + headers=headers, + timeout=timeout, + client=client, + ) return openai_like_chat_completions.completion( model=model, diff --git a/litellm/main.py b/litellm/main.py index b53db67f4..e5a22b3f5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -117,10 +117,7 @@ from .llms.sagemaker.sagemaker import SagemakerLLM from .llms.text_completion_codestral import CodestralTextCompletion from .llms.together_ai.completion.handler import TogetherAITextCompletion from .llms.triton import TritonChatCompletion -from .llms.vertex_ai_and_google_ai_studio import ( - vertex_ai_anthropic, - vertex_ai_non_gemini, -) +from .llms.vertex_ai_and_google_ai_studio import vertex_ai_non_gemini from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( VertexLLM, ) @@ -747,6 +744,11 @@ def completion( # type: ignore proxy_server_request = kwargs.get("proxy_server_request", None) fallbacks = kwargs.get("fallbacks", None) headers = kwargs.get("headers", None) or extra_headers + if headers is None: + headers = {} + + if extra_headers is not None: + headers.update(extra_headers) num_retries = kwargs.get( "num_retries", None ) ## alt. param for 'max_retries'. Use this to pass retries w/ instructor. @@ -964,7 +966,6 @@ def completion( # type: ignore max_retries=max_retries, logprobs=logprobs, top_logprobs=top_logprobs, - extra_headers=extra_headers, api_version=api_version, parallel_tool_calls=parallel_tool_calls, messages=messages, @@ -1067,6 +1068,9 @@ def completion( # type: ignore headers = headers or litellm.headers + if extra_headers is not None: + optional_params["extra_headers"] = extra_headers + if ( litellm.enable_preview_features and litellm.AzureOpenAIO1Config().is_o1_model(model=model) @@ -1166,6 +1170,9 @@ def completion( # type: ignore headers = headers or litellm.headers + if extra_headers is not None: + optional_params["extra_headers"] = extra_headers + ## LOAD CONFIG - if set config = litellm.AzureOpenAIConfig.get_config() for k, v in config.items(): @@ -1223,6 +1230,9 @@ def completion( # type: ignore headers = headers or litellm.headers + if extra_headers is not None: + optional_params["extra_headers"] = extra_headers + ## LOAD CONFIG - if set config = litellm.AzureAIStudioConfig.get_config() for k, v in config.items(): @@ -1304,6 +1314,9 @@ def completion( # type: ignore headers = headers or litellm.headers + if extra_headers is not None: + optional_params["extra_headers"] = extra_headers + ## LOAD CONFIG - if set config = litellm.OpenAITextCompletionConfig.get_config() for k, v in config.items(): @@ -1466,6 +1479,9 @@ def completion( # type: ignore headers = headers or litellm.headers + if extra_headers is not None: + optional_params["extra_headers"] = extra_headers + ## LOAD CONFIG - if set config = litellm.OpenAIConfig.get_config() for k, v in config.items(): @@ -2228,31 +2244,12 @@ def completion( # type: ignore ) new_params = deepcopy(optional_params) - if "claude-3" in model: - model_response = vertex_ai_anthropic.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - headers=headers, - custom_prompt_dict=custom_prompt_dict, - timeout=timeout, - client=client, - ) - elif ( + if ( model.startswith("meta/") or model.startswith("mistral") or model.startswith("codestral") or model.startswith("jamba") + or model.startswith("claude") ): model_response = vertex_partner_models_chat_completion.completion( model=model, @@ -2263,6 +2260,7 @@ def completion( # type: ignore litellm_params=litellm_params, # type: ignore logger_fn=logger_fn, encoding=encoding, + api_base=api_base, vertex_location=vertex_ai_location, vertex_project=vertex_ai_project, vertex_credentials=vertex_credentials, @@ -4848,6 +4846,8 @@ def image_generation( model_response = custom_handler.aimage_generation( # type: ignore model=model, prompt=prompt, + api_key=api_key, + api_base=api_base, model_response=model_response, optional_params=optional_params, logging_obj=litellm_logging_obj, @@ -4863,6 +4863,8 @@ def image_generation( model_response = custom_handler.image_generation( model=model, prompt=prompt, + api_key=api_key, + api_base=api_base, model_response=model_response, optional_params=optional_params, logging_obj=litellm_logging_obj, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 536976ce4..dda23b7b4 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,12 +1,15 @@ model_list: - - model_name: gpt-4o-mini - litellm_params: - model: azure/my-gpt-4o-mini - api_key: os.environ/AZURE_API_KEY - api_base: os.environ/AZURE_API_BASE + # - model_name: openai-gpt-4o-realtime-audio + # litellm_params: + # model: azure/gpt-4o-realtime-preview + # api_key: os.environ/AZURE_SWEDEN_API_KEY + # api_base: os.environ/AZURE_SWEDEN_API_BASE + + - model_name: openai-gpt-4o-realtime-audio + litellm_params: + model: openai/gpt-4o-realtime-preview-2024-10-01 + api_key: os.environ/OPENAI_API_KEY + api_base: http://localhost:8080 litellm_settings: - turn_off_message_logging: true - cache: True - cache_params: - type: local \ No newline at end of file + success_callback: ["langfuse"] \ No newline at end of file diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index 59a4edb90..538a1eee1 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -241,6 +241,8 @@ def initialize_callbacks_on_proxy( litellm.callbacks = imported_list # type: ignore if "prometheus" in value: + if premium_user is not True: + raise Exception(CommonProxyErrors.not_premium_user.value) from litellm.proxy.proxy_server import app verbose_proxy_logger.debug("Starting Prometheus Metrics on /metrics") diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index d611aa87b..11ccc8561 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,16 +1,17 @@ model_list: - model_name: db-openai-endpoint litellm_params: - model: openai/gpt-5 + model: openai/gpt-4 api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railwaz.app/ - - model_name: db-openai-endpoint - litellm_params: - model: openai/gpt-5 - api_key: fake-key - api_base: https://exampleopenaiendpoint-production.up.railwxaz.app/ + api_base: https://exampleopenaiendpoint-production.up.railway.app/ litellm_settings: - callbacks: ["prometheus"] + success_callback: ["s3"] + turn_off_message_logging: true + s3_callback_params: + s3_bucket_name: load-testing-oct # AWS Bucket Name for S3 + s3_region_name: us-west-2 # AWS Region Name for S3 + s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/ to pass environment variables. This is AWS Access Key ID for S3 + s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3 diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d75a087da..8407b4e86 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1692,6 +1692,10 @@ class ProxyConfig: else: litellm.success_callback.append(callback) if "prometheus" in callback: + if not premium_user: + raise Exception( + CommonProxyErrors.not_premium_user.value + ) verbose_proxy_logger.debug( "Starting Prometheus Metrics on /metrics" ) @@ -5433,12 +5437,11 @@ async def moderations( await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data ) - verbose_proxy_logger.error( + verbose_proxy_logger.exception( "litellm.proxy.proxy_server.moderations(): Exception occured - {}".format( str(e) ) ) - verbose_proxy_logger.debug(traceback.format_exc()) if isinstance(e, HTTPException): raise ProxyException( message=getattr(e, "message", str(e)), diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 666118039..25de59c8e 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -408,7 +408,6 @@ class ProxyLogging: callback, internal_usage_cache=self.internal_usage_cache.dual_cache, llm_router=llm_router, - premium_user=self.premium_user, ) if callback is None: continue diff --git a/litellm/realtime_api/main.py b/litellm/realtime_api/main.py index 5e512795f..9088a491c 100644 --- a/litellm/realtime_api/main.py +++ b/litellm/realtime_api/main.py @@ -8,13 +8,16 @@ from litellm import get_llm_provider from litellm.secret_managers.main import get_secret_str from litellm.types.router import GenericLiteLLMParams +from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging from ..llms.AzureOpenAI.realtime.handler import AzureOpenAIRealtime from ..llms.OpenAI.realtime.handler import OpenAIRealtime +from ..utils import client as wrapper_client azure_realtime = AzureOpenAIRealtime() openai_realtime = OpenAIRealtime() +@wrapper_client async def _arealtime( model: str, websocket: Any, # fastapi websocket @@ -31,6 +34,12 @@ async def _arealtime( For PROXY use only. """ + litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore + litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + model_info = kwargs.get("model_info", None) + metadata = kwargs.get("metadata", {}) + user = kwargs.get("user", None) litellm_params = GenericLiteLLMParams(**kwargs) model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider( @@ -39,6 +48,21 @@ async def _arealtime( api_key=api_key, ) + litellm_logging_obj.update_environment_variables( + model=model, + user=user, + optional_params={}, + litellm_params={ + "litellm_call_id": litellm_call_id, + "proxy_server_request": proxy_server_request, + "model_info": model_info, + "metadata": metadata, + "preset_cache_key": None, + "stream_response": {}, + }, + custom_llm_provider=_custom_llm_provider, + ) + if _custom_llm_provider == "azure": api_base = ( dynamic_api_base @@ -63,6 +87,7 @@ async def _arealtime( azure_ad_token=None, client=None, timeout=timeout, + logging_obj=litellm_logging_obj, ) elif _custom_llm_provider == "openai": api_base = ( @@ -82,6 +107,7 @@ async def _arealtime( await openai_realtime.async_realtime( model=model, websocket=websocket, + logging_obj=litellm_logging_obj, api_base=api_base, api_key=api_key, client=None, diff --git a/litellm/router_utils/pattern_match_deployments.py b/litellm/router_utils/pattern_match_deployments.py index f735c965a..6c4ec8e6b 100644 --- a/litellm/router_utils/pattern_match_deployments.py +++ b/litellm/router_utils/pattern_match_deployments.py @@ -60,7 +60,7 @@ class PatternMatchRouter: regex = re.escape(regex).replace(r"\.\*", ".*") return f"^{regex}$" - def route(self, request: str) -> Optional[List[Dict]]: + def route(self, request: Optional[str]) -> Optional[List[Dict]]: """ Route a requested model to the corresponding llm deployments based on the regex pattern @@ -69,17 +69,20 @@ class PatternMatchRouter: if no pattern is found, return None Args: - request: str + request: Optional[str] Returns: Optional[List[Deployment]]: llm deployments """ try: + if request is None: + return None for pattern, llm_deployments in self.patterns.items(): if re.match(pattern, request): return llm_deployments except Exception as e: - verbose_router_logger.error(f"Error in PatternMatchRouter.route: {str(e)}") + verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}") + return None # No matching pattern found diff --git a/litellm/types/integrations/s3.py b/litellm/types/integrations/s3.py new file mode 100644 index 000000000..d66e2c59d --- /dev/null +++ b/litellm/types/integrations/s3.py @@ -0,0 +1,14 @@ +from typing import Dict + +from pydantic import BaseModel + + +class s3BatchLoggingElement(BaseModel): + """ + Type of element stored in self.log_queue in S3Logger + + """ + + payload: Dict + s3_object_key: str + s3_object_download_filename: str diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 6019b0e6f..c3118b453 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -129,6 +129,7 @@ class CallTypes(Enum): speech = "speech" rerank = "rerank" arerank = "arerank" + arealtime = "_arealtime" class PassthroughCallTypes(Enum): diff --git a/litellm/utils.py b/litellm/utils.py index 5afeab58e..9efde1be7 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -197,7 +197,6 @@ lagoLogger = None dataDogLogger = None prometheusLogger = None dynamoLogger = None -s3Logger = None genericAPILogger = None clickHouseLogger = None greenscaleLogger = None @@ -1410,6 +1409,8 @@ def client(original_function): ) else: return result + elif call_type == CallTypes.arealtime.value: + return result # ADD HIDDEN PARAMS - additional call metadata if hasattr(result, "_hidden_params"): @@ -1799,8 +1800,9 @@ def calculate_tiles_needed( total_tiles = tiles_across * tiles_down return total_tiles + def get_image_type(image_data: bytes) -> Union[str, None]: - """ take an image (really only the first ~100 bytes max are needed) + """take an image (really only the first ~100 bytes max are needed) and return 'png' 'gif' 'jpeg' 'heic' or None. method added to allow deprecation of imghdr in 3.13""" @@ -4336,16 +4338,18 @@ def get_api_base( _optional_params.vertex_location is not None and _optional_params.vertex_project is not None ): - from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( - create_vertex_anthropic_url, + from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( + VertexPartnerProvider, + create_vertex_url, ) if "claude" in model: - _api_base = create_vertex_anthropic_url( + _api_base = create_vertex_url( vertex_location=_optional_params.vertex_location, vertex_project=_optional_params.vertex_project, model=model, stream=stream, + partner=VertexPartnerProvider.claude, ) else: @@ -4442,19 +4446,7 @@ def get_supported_openai_params( elif custom_llm_provider == "volcengine": return litellm.VolcEngineConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "groq": - return [ - "temperature", - "max_tokens", - "top_p", - "stream", - "stop", - "tools", - "tool_choice", - "response_format", - "seed", - "extra_headers", - "extra_body", - ] + return litellm.GroqChatConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "hosted_vllm": return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "deepseek": @@ -4599,6 +4591,8 @@ def get_supported_openai_params( return ( litellm.MistralTextCompletionConfig().get_supported_openai_params() ) + if model.startswith("claude"): + return litellm.VertexAIAnthropicConfig().get_supported_openai_params() return litellm.VertexAIConfig().get_supported_openai_params() elif request_type == "embeddings": return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() diff --git a/tests/llm_translation/test_max_completion_tokens.py b/tests/llm_translation/test_max_completion_tokens.py index 9ec148e47..3ce57380c 100644 --- a/tests/llm_translation/test_max_completion_tokens.py +++ b/tests/llm_translation/test_max_completion_tokens.py @@ -296,7 +296,7 @@ def test_all_model_configs(): optional_params={}, ) == {"max_tokens": 10} - from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( + from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.anthropic.transformation import ( VertexAIAnthropicConfig, ) diff --git a/tests/local_testing/test_amazing_s3_logs.py b/tests/local_testing/test_amazing_s3_logs.py index c3e8a61db..5459647c1 100644 --- a/tests/local_testing/test_amazing_s3_logs.py +++ b/tests/local_testing/test_amazing_s3_logs.py @@ -12,7 +12,70 @@ import litellm litellm.num_retries = 3 import time, random +from litellm._logging import verbose_logger +import logging import pytest +import boto3 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("sync_mode", [True, False]) +async def test_basic_s3_logging(sync_mode): + verbose_logger.setLevel(level=logging.DEBUG) + litellm.success_callback = ["s3"] + litellm.s3_callback_params = { + "s3_bucket_name": "load-testing-oct", + "s3_aws_secret_access_key": "os.environ/AWS_SECRET_ACCESS_KEY", + "s3_aws_access_key_id": "os.environ/AWS_ACCESS_KEY_ID", + "s3_region_name": "us-west-2", + } + litellm.set_verbose = True + + if sync_mode is True: + response = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "This is a test"}], + mock_response="It's simple to use and easy to get started", + ) + else: + response = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "This is a test"}], + mock_response="It's simple to use and easy to get started", + ) + print(f"response: {response}") + + await asyncio.sleep(12) + + total_objects, all_s3_keys = list_all_s3_objects("load-testing-oct") + + # assert that atlest one key has response.id in it + assert any(response.id in key for key in all_s3_keys) + s3 = boto3.client("s3") + # delete all objects + for key in all_s3_keys: + s3.delete_object(Bucket="load-testing-oct", Key=key) + + +def list_all_s3_objects(bucket_name): + s3 = boto3.client("s3") + + all_s3_keys = [] + + paginator = s3.get_paginator("list_objects_v2") + total_objects = 0 + + for page in paginator.paginate(Bucket=bucket_name): + if "Contents" in page: + total_objects += len(page["Contents"]) + all_s3_keys.extend([obj["Key"] for obj in page["Contents"]]) + + print(f"Total number of objects in {bucket_name}: {total_objects}") + print(all_s3_keys) + return total_objects, all_s3_keys + + +list_all_s3_objects("load-testing-oct") @pytest.mark.skip(reason="AWS Suspended Account") diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 604c2eed8..d8c448976 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -1616,9 +1616,11 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema( ) -@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", +@pytest.mark.parametrize( + "model", ["gemini-1.5-flash", "claude-3-sonnet@20240229"] +) # "vertex_ai", @pytest.mark.asyncio -async def test_gemini_pro_httpx_custom_api_base(provider): +async def test_gemini_pro_httpx_custom_api_base(model): load_vertex_ai_credentials() litellm.set_verbose = True messages = [ @@ -1634,7 +1636,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider): with patch.object(client, "post", new=MagicMock()) as mock_call: try: response = completion( - model="vertex_ai_beta/gemini-1.5-flash", + model="vertex_ai/{}".format(model), messages=messages, response_format={"type": "json_object"}, client=client, @@ -1647,8 +1649,17 @@ async def test_gemini_pro_httpx_custom_api_base(provider): mock_call.assert_called_once() - assert "my-custom-api-base:generateContent" == mock_call.call_args.kwargs["url"] - assert "hello" in mock_call.call_args.kwargs["headers"] + print(f"mock_call.call_args: {mock_call.call_args}") + print(f"mock_call.call_args.kwargs: {mock_call.call_args.kwargs}") + if "url" in mock_call.call_args.kwargs: + assert ( + "my-custom-api-base:generateContent" + == mock_call.call_args.kwargs["url"] + ) + else: + assert "my-custom-api-base:rawPredict" == mock_call.call_args[0][0] + if "headers" in mock_call.call_args.kwargs: + assert "hello" in mock_call.call_args.kwargs["headers"] # @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call") diff --git a/tests/local_testing/test_custom_llm.py b/tests/local_testing/test_custom_llm.py index c9edde4a8..29daef481 100644 --- a/tests/local_testing/test_custom_llm.py +++ b/tests/local_testing/test_custom_llm.py @@ -28,7 +28,6 @@ from typing import ( Union, ) from unittest.mock import AsyncMock, MagicMock, patch - import httpx from dotenv import load_dotenv @@ -226,6 +225,8 @@ class MyCustomLLM(CustomLLM): self, model: str, prompt: str, + api_key: Optional[str], + api_base: Optional[str], model_response: ImageResponse, optional_params: dict, logging_obj: Any, @@ -242,6 +243,8 @@ class MyCustomLLM(CustomLLM): self, model: str, prompt: str, + api_key: Optional[str], + api_base: Optional[str], model_response: ImageResponse, optional_params: dict, logging_obj: Any, @@ -362,3 +365,31 @@ async def test_simple_image_generation_async(): ) print(resp) + + +@pytest.mark.asyncio +async def test_image_generation_async_with_api_key_and_api_base(): + my_custom_llm = MyCustomLLM() + litellm.custom_provider_map = [ + {"provider": "custom_llm", "custom_handler": my_custom_llm} + ] + + with patch.object( + my_custom_llm, "aimage_generation", new=AsyncMock() + ) as mock_client: + try: + resp = await litellm.aimage_generation( + model="custom_llm/my-fake-model", + prompt="Hello world", + api_key="my-api-key", + api_base="my-api-base", + ) + + print(resp) + except Exception as e: + print(e) + + mock_client.assert_awaited_once() + + mock_client.call_args.kwargs["api_key"] == "my-api-key" + mock_client.call_args.kwargs["api_base"] == "my-api-base"