LiteLLM Minor Fixes & Improvements (10/10/2024) (#6158)

* refactor(vertex_ai_partner_models/anthropic): refactor anthropic to use partner model logic

* fix(vertex_ai/): support passing custom api base to partner models

Fixes https://github.com/BerriAI/litellm/issues/4317

* fix(proxy_server.py): Fix prometheus premium user check logic

* docs(prometheus.md): update quick start docs

* fix(custom_llm.py): support passing dynamic api key + api base

* fix(realtime_api/main.py): Add request/response logging for realtime api endpoints

Closes https://github.com/BerriAI/litellm/issues/6081

* feat(openai/realtime): add openai realtime api logging

Closes https://github.com/BerriAI/litellm/issues/6081

* fix(realtime_streaming.py): fix linting errors

* fix(realtime_streaming.py): fix linting errors

* fix: fix linting errors

* fix pattern match router

* Add literalai in the sidebar observability category (#6163)

* fix: add literalai in the sidebar

* fix: typo

* update (#6160)

* Feat: Add Langtrace integration (#5341)

* Feat: Add Langtrace integration

* add langtrace service name

* fix timestamps for traces

* add tests

* Discard Callback + use existing otel logger

* cleanup

* remove print statments

* remove callback

* add docs

* docs

* add logging docs

* format logging

* remove emoji and add litellm proxy example

* format logging

* format `logging.md`

* add langtrace docs to logging.md

* sync conflict

* docs fix

* (perf) move s3 logging to Batch logging + async [94% faster perf under 100 RPS on 1 litellm instance] (#6165)

* fix move s3 to use customLogger

* add basic s3 logging test

* add s3 to custom logger compatible

* use batch logger for s3

* s3 set flush interval and batch size

* fix s3 logging

* add notes on s3 logging

* fix s3 logging

* add basic s3 logging test

* fix s3 type errors

* add test for sync logging on s3

* fix: fix to debug log

---------

Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com>
Co-authored-by: Willy Douhard <willy.douhard@gmail.com>
Co-authored-by: yujonglee <yujonglee.dev@gmail.com>
Co-authored-by: Ali Waleed <ali@scale3labs.com>
This commit is contained in:
Krish Dholakia 2024-10-11 23:04:36 -07:00 committed by GitHub
parent 9db4ccca9f
commit 11f9df923a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 966 additions and 760 deletions

View file

@ -27,8 +27,7 @@ model_list:
litellm_params:
model: gpt-3.5-turbo
litellm_settings:
success_callback: ["prometheus"]
failure_callback: ["prometheus"]
callbacks: ["prometheus"]
```
Start the proxy

View file

@ -53,6 +53,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
"arize",
"langtrace",
"gcs_bucket",
"s3",
"opik",
]
_known_custom_logger_compatible_callbacks: List = list(
@ -931,7 +932,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation impor
vertexAITextEmbeddingConfig = VertexAITextEmbeddingConfig()
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.anthropic.transformation import (
VertexAIAnthropicConfig,
)
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (

View file

@ -21,6 +21,7 @@ class CustomBatchLogger(CustomLogger):
self,
flush_lock: Optional[asyncio.Lock] = None,
batch_size: Optional[int] = DEFAULT_BATCH_SIZE,
flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS,
**kwargs,
) -> None:
"""
@ -28,7 +29,7 @@ class CustomBatchLogger(CustomLogger):
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
"""
self.log_queue: List = []
self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds
self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
self.last_flush_time = time.time()
self.flush_lock = flush_lock

View file

@ -235,6 +235,14 @@ class LangFuseLogger:
):
input = prompt
output = response_obj.results
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "_arealtime"
and response_obj is not None
and isinstance(response_obj, list)
):
input = kwargs.get("input")
output = response_obj
elif (
kwargs.get("call_type") is not None
and kwargs.get("call_type") == "pass_through_endpoint"

View file

@ -1,43 +1,67 @@
#### What this does ####
# On success + failure, log events to Supabase
"""
s3 Bucket Logging Integration
import datetime
import os
import subprocess
import sys
import traceback
import uuid
from typing import Optional
async_log_success_event: Processes the event, stores it in memory for 10 seconds or until MAX_BATCH_SIZE and then flushes to s3
NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually
NOTE 2: We create a httpx client with a concurrent limit of 1 to upload to s3. Files should get uploaded BUT they should not impact latency of LLM calling logic
"""
import asyncio
import json
from datetime import datetime
from typing import Dict, List, Optional
import litellm
from litellm._logging import print_verbose, verbose_logger
from litellm.llms.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.integrations.s3 import s3BatchLoggingElement
from litellm.types.utils import StandardLoggingPayload
from .custom_batch_logger import CustomBatchLogger
class S3Logger:
# Default Flush interval and batch size for s3
# Flush to s3 every 10 seconds OR every 1K requests in memory
DEFAULT_S3_FLUSH_INTERVAL_SECONDS = 10
DEFAULT_S3_BATCH_SIZE = 1000
class S3Logger(CustomBatchLogger, BaseAWSLLM):
# Class variables or attributes
def __init__(
self,
s3_bucket_name=None,
s3_path=None,
s3_region_name=None,
s3_api_version=None,
s3_use_ssl=True,
s3_verify=None,
s3_endpoint_url=None,
s3_aws_access_key_id=None,
s3_aws_secret_access_key=None,
s3_aws_session_token=None,
s3_bucket_name: Optional[str] = None,
s3_path: Optional[str] = None,
s3_region_name: Optional[str] = None,
s3_api_version: Optional[str] = None,
s3_use_ssl: bool = True,
s3_verify: Optional[bool] = None,
s3_endpoint_url: Optional[str] = None,
s3_aws_access_key_id: Optional[str] = None,
s3_aws_secret_access_key: Optional[str] = None,
s3_aws_session_token: Optional[str] = None,
s3_flush_interval: Optional[int] = DEFAULT_S3_FLUSH_INTERVAL_SECONDS,
s3_batch_size: Optional[int] = DEFAULT_S3_BATCH_SIZE,
s3_config=None,
**kwargs,
):
import boto3
try:
verbose_logger.debug(
f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}"
)
# IMPORTANT: We use a concurrent limit of 1 to upload to s3
# Files should get uploaded BUT they should not impact latency of LLM calling logic
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback,
params={"concurrent_limit": 1},
)
if litellm.s3_callback_params is not None:
# read in .env variables - example os.environ/AWS_BUCKET_NAME
for key, value in litellm.s3_callback_params.items():
@ -63,107 +87,282 @@ class S3Logger:
s3_path = litellm.s3_callback_params.get("s3_path")
# done reading litellm.s3_callback_params
s3_flush_interval = litellm.s3_callback_params.get(
"s3_flush_interval", DEFAULT_S3_FLUSH_INTERVAL_SECONDS
)
s3_batch_size = litellm.s3_callback_params.get(
"s3_batch_size", DEFAULT_S3_BATCH_SIZE
)
self.bucket_name = s3_bucket_name
self.s3_path = s3_path
verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}")
# Create an S3 client with custom endpoint URL
self.s3_client = boto3.client(
"s3",
region_name=s3_region_name,
endpoint_url=s3_endpoint_url,
api_version=s3_api_version,
use_ssl=s3_use_ssl,
verify=s3_verify,
aws_access_key_id=s3_aws_access_key_id,
aws_secret_access_key=s3_aws_secret_access_key,
aws_session_token=s3_aws_session_token,
config=s3_config,
**kwargs,
self.s3_bucket_name = s3_bucket_name
self.s3_region_name = s3_region_name
self.s3_api_version = s3_api_version
self.s3_use_ssl = s3_use_ssl
self.s3_verify = s3_verify
self.s3_endpoint_url = s3_endpoint_url
self.s3_aws_access_key_id = s3_aws_access_key_id
self.s3_aws_secret_access_key = s3_aws_secret_access_key
self.s3_aws_session_token = s3_aws_session_token
self.s3_config = s3_config
self.init_kwargs = kwargs
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
verbose_logger.debug(
f"s3 flush interval: {s3_flush_interval}, s3 batch size: {s3_batch_size}"
)
# Call CustomLogger's __init__
CustomBatchLogger.__init__(
self,
flush_lock=self.flush_lock,
flush_interval=s3_flush_interval,
batch_size=s3_batch_size,
)
self.log_queue: List[s3BatchLoggingElement] = []
# Call BaseAWSLLM's __init__
BaseAWSLLM.__init__(self)
except Exception as e:
print_verbose(f"Got exception on init s3 client {str(e)}")
raise e
async def _async_log_event(
self, kwargs, response_obj, start_time, end_time, print_verbose
):
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
verbose_logger.debug(
f"s3 Logging - Enters logging function for model {kwargs}"
)
# construct payload to send to s3
# follows the same params as langfuse.py
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
# Clean Metadata before logging - never log raw metadata
# the raw metadata can contain circular references which leads to infinite recursion
# we clean out all extra litellm metadata params before logging
clean_metadata = {}
if isinstance(metadata, dict):
for key, value in metadata.items():
# clean litellm metadata before logging
if key in [
"headers",
"endpoint",
"caching_groups",
"previous_models",
]:
continue
else:
clean_metadata[key] = value
# Ensure everything in the payload is converted to str
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
s3_batch_logging_element = self.create_s3_batch_logging_element(
start_time=start_time,
standard_logging_payload=kwargs.get("standard_logging_object", None),
s3_path=self.s3_path,
)
if payload is None:
return
if s3_batch_logging_element is None:
raise ValueError("s3_batch_logging_element is None")
s3_file_name = litellm.utils.get_logging_id(start_time, payload) or ""
s3_object_key = (
(self.s3_path.rstrip("/") + "/" if self.s3_path else "")
+ start_time.strftime("%Y-%m-%d")
+ "/"
+ s3_file_name
) # we need the s3 key to include the time, so we log cache hits too
s3_object_key += ".json"
s3_object_download_filename = (
"time-"
+ start_time.strftime("%Y-%m-%dT%H-%M-%S-%f")
+ "_"
+ payload["id"]
+ ".json"
verbose_logger.debug(
"\ns3 Logger - Logging payload = %s", s3_batch_logging_element
)
import json
payload_str = json.dumps(payload)
print_verbose(f"\ns3 Logger - Logging payload = {payload_str}")
response = self.s3_client.put_object(
Bucket=self.bucket_name,
Key=s3_object_key,
Body=payload_str,
ContentType="application/json",
ContentLanguage="en",
ContentDisposition=f'inline; filename="{s3_object_download_filename}"',
CacheControl="private, immutable, max-age=31536000, s-maxage=0",
self.log_queue.append(s3_batch_logging_element)
verbose_logger.debug(
"s3 logging: queue length %s, batch size %s",
len(self.log_queue),
self.batch_size,
)
print_verbose(f"Response from s3:{str(response)}")
print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
return response
if len(self.log_queue) >= self.batch_size:
await self.flush_queue()
except Exception as e:
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
pass
def log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Synchronous logging function to log to s3
Does not batch logging requests, instantly logs on s3 Bucket
"""
try:
s3_batch_logging_element = self.create_s3_batch_logging_element(
start_time=start_time,
standard_logging_payload=kwargs.get("standard_logging_object", None),
s3_path=self.s3_path,
)
if s3_batch_logging_element is None:
raise ValueError("s3_batch_logging_element is None")
verbose_logger.debug(
"\ns3 Logger - Logging payload = %s", s3_batch_logging_element
)
# log the element sync httpx client
self.upload_data_to_s3(s3_batch_logging_element)
except Exception as e:
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
pass
async def async_upload_data_to_s3(
self, batch_logging_element: s3BatchLoggingElement
):
try:
import hashlib
import boto3
import requests
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
try:
credentials: Credentials = self.get_credentials(
aws_access_key_id=self.s3_aws_access_key_id,
aws_secret_access_key=self.s3_aws_secret_access_key,
aws_session_token=self.s3_aws_session_token,
aws_region_name=self.s3_region_name,
)
# Prepare the URL
url = f"https://{self.bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}"
if self.s3_endpoint_url:
url = self.s3_endpoint_url + "/" + batch_logging_element.s3_object_key
# Convert JSON to string
json_string = json.dumps(batch_logging_element.payload)
# Calculate SHA256 hash of the content
content_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest()
# Prepare the request
headers = {
"Content-Type": "application/json",
"x-amz-content-sha256": content_hash,
"Content-Language": "en",
"Content-Disposition": f'inline; filename="{batch_logging_element.s3_object_download_filename}"',
"Cache-Control": "private, immutable, max-age=31536000, s-maxage=0",
}
req = requests.Request("PUT", url, data=json_string, headers=headers)
prepped = req.prepare()
# Sign the request
aws_request = AWSRequest(
method=prepped.method,
url=prepped.url,
data=prepped.body,
headers=prepped.headers,
)
SigV4Auth(credentials, "s3", self.s3_region_name).add_auth(aws_request)
# Prepare the signed headers
signed_headers = dict(aws_request.headers.items())
# Make the request
response = await self.async_httpx_client.put(
url, data=json_string, headers=signed_headers
)
response.raise_for_status()
except Exception as e:
verbose_logger.exception(f"Error uploading to s3: {str(e)}")
async def async_send_batch(self):
"""
Sends runs from self.log_queue
Returns: None
Raises: Does not raise an exception, will only verbose_logger.exception()
"""
if not self.log_queue:
return
for payload in self.log_queue:
asyncio.create_task(self.async_upload_data_to_s3(payload))
def create_s3_batch_logging_element(
self,
start_time: datetime,
standard_logging_payload: Optional[StandardLoggingPayload],
s3_path: Optional[str],
) -> Optional[s3BatchLoggingElement]:
"""
Helper function to create an s3BatchLoggingElement.
Args:
start_time (datetime): The start time of the logging event.
standard_logging_payload (Optional[StandardLoggingPayload]): The payload to be logged.
s3_path (Optional[str]): The S3 path prefix.
Returns:
Optional[s3BatchLoggingElement]: The created s3BatchLoggingElement, or None if payload is None.
"""
if standard_logging_payload is None:
return None
s3_file_name = (
litellm.utils.get_logging_id(start_time, standard_logging_payload) or ""
)
s3_object_key = (
(s3_path.rstrip("/") + "/" if s3_path else "")
+ start_time.strftime("%Y-%m-%d")
+ "/"
+ s3_file_name
+ ".json"
)
s3_object_download_filename = f"time-{start_time.strftime('%Y-%m-%dT%H-%M-%S-%f')}_{standard_logging_payload['id']}.json"
return s3BatchLoggingElement(
payload=standard_logging_payload, # type: ignore
s3_object_key=s3_object_key,
s3_object_download_filename=s3_object_download_filename,
)
def upload_data_to_s3(self, batch_logging_element: s3BatchLoggingElement):
try:
import hashlib
import boto3
import requests
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
try:
credentials: Credentials = self.get_credentials(
aws_access_key_id=self.s3_aws_access_key_id,
aws_secret_access_key=self.s3_aws_secret_access_key,
aws_session_token=self.s3_aws_session_token,
aws_region_name=self.s3_region_name,
)
# Prepare the URL
url = f"https://{self.bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}"
if self.s3_endpoint_url:
url = self.s3_endpoint_url + "/" + batch_logging_element.s3_object_key
# Convert JSON to string
json_string = json.dumps(batch_logging_element.payload)
# Calculate SHA256 hash of the content
content_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest()
# Prepare the request
headers = {
"Content-Type": "application/json",
"x-amz-content-sha256": content_hash,
"Content-Language": "en",
"Content-Disposition": f'inline; filename="{batch_logging_element.s3_object_download_filename}"',
"Cache-Control": "private, immutable, max-age=31536000, s-maxage=0",
}
req = requests.Request("PUT", url, data=json_string, headers=headers)
prepped = req.prepare()
# Sign the request
aws_request = AWSRequest(
method=prepped.method,
url=prepped.url,
data=prepped.body,
headers=prepped.headers,
)
SigV4Auth(credentials, "s3", self.s3_region_name).add_auth(aws_request)
# Prepare the signed headers
signed_headers = dict(aws_request.headers.items())
httpx_client = _get_httpx_client()
# Make the request
response = httpx_client.put(url, data=json_string, headers=signed_headers)
response.raise_for_status()
except Exception as e:
verbose_logger.exception(f"Error uploading to s3: {str(e)}")

View file

@ -116,7 +116,6 @@ lagoLogger = None
dataDogLogger = None
prometheusLogger = None
dynamoLogger = None
s3Logger = None
genericAPILogger = None
clickHouseLogger = None
greenscaleLogger = None
@ -1346,36 +1345,6 @@ class Logging:
user_id=kwargs.get("user", None),
print_verbose=print_verbose,
)
if callback == "s3":
global s3Logger
if s3Logger is None:
s3Logger = S3Logger()
if self.stream:
if "complete_streaming_response" in self.model_call_details:
print_verbose(
"S3Logger Logger: Got Stream Event - Completed Stream Response"
)
s3Logger.log_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
"complete_streaming_response"
],
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
)
else:
print_verbose(
"S3Logger Logger: Got Stream Event - No complete stream response as yet"
)
else:
s3Logger.log_event(
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
)
if (
callback == "openmeter"
and self.model_call_details.get("litellm_params", {}).get(
@ -2245,7 +2214,7 @@ def set_callbacks(callback_list, function_id=None):
"""
Globally sets the callback client
"""
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
try:
for callback in callback_list:
@ -2319,8 +2288,6 @@ def set_callbacks(callback_list, function_id=None):
dataDogLogger = DataDogLogger()
elif callback == "dynamodb":
dynamoLogger = DyanmoDBLogger()
elif callback == "s3":
s3Logger = S3Logger()
elif callback == "wandb":
weightsBiasesLogger = WeightsBiasesLogger()
elif callback == "logfire":
@ -2357,7 +2324,6 @@ def _init_custom_logger_compatible_class(
llm_router: Optional[
Any
], # expect litellm.Router, but typing errors due to circular import
premium_user: Optional[bool] = None,
) -> Optional[CustomLogger]:
if logging_integration == "lago":
for callback in _in_memory_loggers:
@ -2404,17 +2370,9 @@ def _init_custom_logger_compatible_class(
if isinstance(callback, PrometheusLogger):
return callback # type: ignore
if premium_user:
_prometheus_logger = PrometheusLogger()
_in_memory_loggers.append(_prometheus_logger)
return _prometheus_logger # type: ignore
elif premium_user is False:
verbose_logger.warning(
f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}"
)
return None
else:
return None
_prometheus_logger = PrometheusLogger()
_in_memory_loggers.append(_prometheus_logger)
return _prometheus_logger # type: ignore
elif logging_integration == "datadog":
for callback in _in_memory_loggers:
if isinstance(callback, DataDogLogger):
@ -2423,6 +2381,14 @@ def _init_custom_logger_compatible_class(
_datadog_logger = DataDogLogger()
_in_memory_loggers.append(_datadog_logger)
return _datadog_logger # type: ignore
elif logging_integration == "s3":
for callback in _in_memory_loggers:
if isinstance(callback, S3Logger):
return callback # type: ignore
_s3_logger = S3Logger()
_in_memory_loggers.append(_s3_logger)
return _s3_logger # type: ignore
elif logging_integration == "gcs_bucket":
for callback in _in_memory_loggers:
if isinstance(callback, GCSBucketLogger):
@ -2589,6 +2555,10 @@ def get_custom_logger_compatible_class(
for callback in _in_memory_loggers:
if isinstance(callback, PrometheusLogger):
return callback
elif logging_integration == "s3":
for callback in _in_memory_loggers:
if isinstance(callback, S3Logger):
return callback
elif logging_integration == "datadog":
for callback in _in_memory_loggers:
if isinstance(callback, DataDogLogger):

View file

@ -0,0 +1,112 @@
"""
async with websockets.connect( # type: ignore
url,
extra_headers={
"api-key": api_key, # type: ignore
},
) as backend_ws:
forward_task = asyncio.create_task(
forward_messages(websocket, backend_ws)
)
try:
while True:
message = await websocket.receive_text()
await backend_ws.send(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
forward_task.cancel()
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass
"""
import asyncio
import concurrent.futures
import traceback
from asyncio import Task
from typing import Any, Dict, List, Optional, Union
from .litellm_logging import Logging as LiteLLMLogging
# Create a thread pool with a maximum of 10 threads
executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
class RealTimeStreaming:
def __init__(
self,
websocket: Any,
backend_ws: Any,
logging_obj: Optional[LiteLLMLogging] = None,
):
self.websocket = websocket
self.backend_ws = backend_ws
self.logging_obj = logging_obj
self.messages: List = []
self.input_message: Dict = {}
def store_message(self, message: Union[str, bytes]):
"""Store message in list"""
self.messages.append(message)
def store_input(self, message: dict):
"""Store input message"""
self.input_message = message
if self.logging_obj:
self.logging_obj.pre_call(input=message, api_key="")
async def log_messages(self):
"""Log messages in list"""
if self.logging_obj:
## ASYNC LOGGING
# Create an event loop for the new thread
asyncio.create_task(self.logging_obj.async_success_handler(self.messages))
## SYNC LOGGING
executor.submit(self.logging_obj.success_handler(self.messages))
async def backend_to_client_send_messages(self):
import websockets
try:
while True:
message = await self.backend_ws.recv()
await self.websocket.send_text(message)
## LOGGING
self.store_message(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
pass
except Exception:
pass
finally:
await self.log_messages()
async def client_ack_messages(self):
try:
while True:
message = await self.websocket.receive_text()
## LOGGING
self.store_input(message=message)
## FORWARD TO BACKEND
await self.backend_ws.send(message)
except self.websockets.exceptions.ConnectionClosed: # type: ignore
pass
async def bidirectional_forward(self):
forward_task = asyncio.create_task(self.backend_to_client_send_messages())
try:
await self.client_ack_messages()
except self.websockets.exceptions.ConnectionClosed: # type: ignore
forward_task.cancel()
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass

View file

@ -7,6 +7,8 @@ This requires websockets, and is currently only supported on LiteLLM Proxy.
import asyncio
from typing import Any, Optional
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
from ..azure import AzureChatCompletion
# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
@ -44,6 +46,7 @@ class AzureOpenAIRealtime(AzureChatCompletion):
api_version: Optional[str] = None,
azure_ad_token: Optional[str] = None,
client: Optional[Any] = None,
logging_obj: Optional[LiteLLMLogging] = None,
timeout: Optional[float] = None,
):
import websockets
@ -62,23 +65,10 @@ class AzureOpenAIRealtime(AzureChatCompletion):
"api-key": api_key, # type: ignore
},
) as backend_ws:
forward_task = asyncio.create_task(
forward_messages(websocket, backend_ws)
realtime_streaming = RealTimeStreaming(
websocket, backend_ws, logging_obj
)
try:
while True:
message = await websocket.receive_text()
await backend_ws.send(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
forward_task.cancel()
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass
await realtime_streaming.bidirectional_forward()
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
await websocket.close(code=e.status_code, reason=str(e))

View file

@ -7,20 +7,11 @@ This requires websockets, and is currently only supported on LiteLLM Proxy.
import asyncio
from typing import Any, Optional
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
from ..openai import OpenAIChatCompletion
async def forward_messages(client_ws: Any, backend_ws: Any):
import websockets
try:
while True:
message = await backend_ws.recv()
await client_ws.send_text(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
pass
class OpenAIRealtime(OpenAIChatCompletion):
def _construct_url(self, api_base: str, model: str) -> str:
"""
@ -35,6 +26,7 @@ class OpenAIRealtime(OpenAIChatCompletion):
self,
model: str,
websocket: Any,
logging_obj: LiteLLMLogging,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
client: Optional[Any] = None,
@ -57,25 +49,26 @@ class OpenAIRealtime(OpenAIChatCompletion):
"OpenAI-Beta": "realtime=v1",
},
) as backend_ws:
forward_task = asyncio.create_task(
forward_messages(websocket, backend_ws)
realtime_streaming = RealTimeStreaming(
websocket, backend_ws, logging_obj
)
try:
while True:
message = await websocket.receive_text()
await backend_ws.send(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
forward_task.cancel()
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass
await realtime_streaming.bidirectional_forward()
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
await websocket.close(code=e.status_code, reason=str(e))
except Exception as e:
await websocket.close(code=1011, reason=f"Internal server error: {str(e)}")
try:
await websocket.close(
code=1011, reason=f"Internal server error: {str(e)}"
)
except RuntimeError as close_error:
if "already completed" in str(close_error) or "websocket.close" in str(
close_error
):
# The WebSocket is already closed or the response is completed, so we can ignore this error
pass
else:
# If it's a different RuntimeError, we might want to log it or handle it differently
raise Exception(
f"Unexpected error while closing WebSocket: {close_error}"
)

View file

@ -153,6 +153,8 @@ class CustomLLM(BaseLLM):
self,
model: str,
prompt: str,
api_key: Optional[str],
api_base: Optional[str],
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
@ -166,6 +168,12 @@ class CustomLLM(BaseLLM):
model: str,
prompt: str,
model_response: ImageResponse,
api_key: Optional[
str
], # dynamically set api_key - https://docs.litellm.ai/docs/set_keys#api_key
api_base: Optional[
str
], # dynamically set api_base - https://docs.litellm.ai/docs/set_keys#api_base
optional_params: dict,
logging_obj: Any,
timeout: Optional[Union[float, httpx.Timeout]] = None,

View file

@ -1,464 +0,0 @@
# What is this?
## Handler file for calling claude-3 on vertex ai
import copy
import json
import os
import time
import types
import uuid
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import (
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
)
from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ..prompt_templates.factory import (
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
response_schema_prompt,
)
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexAIAnthropicConfig:
"""
Reference:https://docs.anthropic.com/claude/reference/messages_post
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = (
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
)
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"max_tokens",
"max_completion_tokens",
"tools",
"tool_choice",
"stream",
"stop",
"temperature",
"top_p",
"response_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if value == "auto":
_tool_choice = {"type": "auto"}
elif value == "required":
_tool_choice = {"type": "any"}
elif isinstance(value, dict):
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "response_format" and isinstance(value, dict):
json_schema: Optional[dict] = None
if "response_schema" in value:
json_schema = value["response_schema"]
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
"""
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
_tool_choice = None
_tool_choice = {"name": "json_tool_call", "type": "tool"}
_tool = AnthropicMessagesTool(
name="json_tool_call",
input_schema={
"type": "object",
"properties": {"values": json_schema}, # type: ignore
},
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
return optional_params
"""
- Run client init
- Support async completion, streaming
"""
def refresh_auth(
credentials,
) -> str: # used when user passes in credentials as json string
from google.auth.transport.requests import Request # type: ignore[import-untyped]
if credentials.token is None:
credentials.refresh(Request())
if not credentials.token:
raise RuntimeError("Could not resolve API token from the credentials")
return credentials.token
def get_vertex_client(
client: Any,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
) -> Tuple[Any, Optional[str]]:
args = locals()
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
try:
from anthropic import AnthropicVertex
except Exception:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
)
access_token: Optional[str] = None
if client is None:
_credentials, cred_project_id = VertexLLM().load_auth(
credentials=vertex_credentials, project_id=vertex_project
)
vertex_ai_client = AnthropicVertex(
project_id=vertex_project or cred_project_id,
region=vertex_location or "us-central1",
access_token=_credentials.token,
)
access_token = _credentials.token
else:
vertex_ai_client = client
access_token = client.access_token
return vertex_ai_client, access_token
def create_vertex_anthropic_url(
vertex_location: str, vertex_project: str, model: str, stream: bool
) -> str:
if stream is True:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
def completion(
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
client=None,
):
try:
import vertexai
except Exception:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
)
from anthropic import AnthropicVertex
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
anthropic_chat_completions = AnthropicChatCompletion()
## Load Config
config = litellm.VertexAIAnthropicConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
## CONSTRUCT API BASE
stream = optional_params.get("stream", False)
api_base = create_vertex_anthropic_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
model=model,
stream=stream,
)
if headers is not None:
vertex_headers = headers
else:
vertex_headers = {}
vertex_headers.update({"Authorization": "Bearer {}".format(access_token)})
optional_params.update(
{"anthropic_version": "vertex-2023-10-16", "is_vertex_request": True}
)
return anthropic_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=access_token,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=vertex_headers,
client=client,
timeout=timeout,
)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))
async def async_completion(
model: str,
messages: list,
data: dict,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj,
vertex_project=None,
vertex_location=None,
optional_params=None,
client=None,
access_token=None,
):
from anthropic import AsyncAnthropicVertex
if client is None:
vertex_ai_client = AsyncAnthropicVertex(
project_id=vertex_project, region=vertex_location, access_token=access_token # type: ignore
)
else:
vertex_ai_client = client
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
message = await vertex_ai_client.messages.create(**data) # type: ignore
text_content = message.content[0].text
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[0].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
## CALCULATING USAGE
prompt_tokens = message.usage.input_tokens
completion_tokens = message.usage.output_tokens
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
async def async_streaming(
model: str,
messages: list,
data: dict,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj,
vertex_project=None,
vertex_location=None,
optional_params=None,
client=None,
access_token=None,
):
from anthropic import AsyncAnthropicVertex
if client is None:
vertex_ai_client = AsyncAnthropicVertex(
project_id=vertex_project, region=vertex_location, access_token=access_token # type: ignore
)
else:
vertex_ai_client = client
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
response = await vertex_ai_client.messages.create(**data, stream=True) # type: ignore
logging_obj.post_call(input=messages, api_key=None, original_response=response)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="vertex_ai",
logging_obj=logging_obj,
)
return streamwrapper

View file

@ -0,0 +1,179 @@
# What is this?
## Handler file for calling claude-3 on vertex ai
import copy
import json
import os
import time
import types
import uuid
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import (
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
)
from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from ....prompt_templates.factory import (
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
response_schema_prompt,
)
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexAIAnthropicConfig:
"""
Reference:https://docs.anthropic.com/claude/reference/messages_post
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = (
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
)
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"max_tokens",
"max_completion_tokens",
"tools",
"tool_choice",
"stream",
"stop",
"temperature",
"top_p",
"response_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if value == "auto":
_tool_choice = {"type": "auto"}
elif value == "required":
_tool_choice = {"type": "any"}
elif isinstance(value, dict):
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "response_format" and isinstance(value, dict):
json_schema: Optional[dict] = None
if "response_schema" in value:
json_schema = value["response_schema"]
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
"""
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
_tool_choice = None
_tool_choice = {"name": "json_tool_call", "type": "tool"}
_tool = AnthropicMessagesTool(
name="json_tool_call",
input_schema={
"type": "object",
"properties": {"values": json_schema}, # type: ignore
},
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
return optional_params

View file

@ -9,13 +9,14 @@ import httpx # type: ignore
import litellm
from litellm.utils import ModelResponse
from ...base import BaseLLM
from ..vertex_llm_base import VertexBase
class VertexPartnerProvider(str, Enum):
mistralai = "mistralai"
llama = "llama"
ai21 = "ai21"
claude = "claude"
class VertexAIError(Exception):
@ -31,31 +32,38 @@ class VertexAIError(Exception):
) # Call the base class constructor with the parameters it needs
class VertexAIPartnerModels(BaseLLM):
def create_vertex_url(
vertex_location: str,
vertex_project: str,
partner: VertexPartnerProvider,
stream: Optional[bool],
model: str,
api_base: Optional[str] = None,
) -> str:
"""Return the base url for the vertex partner models"""
if partner == VertexPartnerProvider.llama:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
elif partner == VertexPartnerProvider.mistralai:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.ai21:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.claude:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
class VertexAIPartnerModels(VertexBase):
def __init__(self) -> None:
pass
def create_vertex_url(
self,
vertex_location: str,
vertex_project: str,
partner: VertexPartnerProvider,
stream: Optional[bool],
model: str,
) -> str:
if partner == VertexPartnerProvider.llama:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
elif partner == VertexPartnerProvider.mistralai:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
elif partner == VertexPartnerProvider.ai21:
if stream:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
else:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
def completion(
self,
model: str,
@ -64,6 +72,7 @@ class VertexAIPartnerModels(BaseLLM):
print_verbose: Callable,
encoding,
logging_obj,
api_base: Optional[str],
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
@ -80,6 +89,7 @@ class VertexAIPartnerModels(BaseLLM):
import vertexai
from google.cloud import aiplatform
from litellm.llms.anthropic.chat import AnthropicChatCompletion
from litellm.llms.databricks.chat import DatabricksChatCompletion
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
from litellm.llms.text_completion_codestral import CodestralTextCompletion
@ -112,6 +122,7 @@ class VertexAIPartnerModels(BaseLLM):
openai_like_chat_completions = DatabricksChatCompletion()
codestral_fim_completions = CodestralTextCompletion()
anthropic_chat_completions = AnthropicChatCompletion()
## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False
@ -126,8 +137,10 @@ class VertexAIPartnerModels(BaseLLM):
elif "jamba" in model:
partner = VertexPartnerProvider.ai21
optional_params["custom_endpoint"] = True
elif "claude" in model:
partner = VertexPartnerProvider.claude
api_base = self.create_vertex_url(
default_api_base = create_vertex_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
partner=partner, # type: ignore
@ -135,6 +148,21 @@ class VertexAIPartnerModels(BaseLLM):
model=model,
)
if len(default_api_base.split(":")) > 1:
endpoint = default_api_base.split(":")[-1]
else:
endpoint = ""
_, api_base = self._check_custom_proxy(
api_base=api_base,
custom_llm_provider="vertex_ai",
gemini_api_key=None,
endpoint=endpoint,
stream=stream,
auth_header=None,
url=default_api_base,
)
model = model.split("@")[0]
if "codestral" in model and litellm_params.get("text_completion") is True:
@ -158,6 +186,35 @@ class VertexAIPartnerModels(BaseLLM):
timeout=timeout,
encoding=encoding,
)
elif "claude" in model:
if headers is None:
headers = {}
headers.update({"Authorization": "Bearer {}".format(access_token)})
optional_params.update(
{
"anthropic_version": "vertex-2023-10-16",
"is_vertex_request": True,
}
)
return anthropic_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
acompletion=acompletion,
custom_prompt_dict=litellm.custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding, # for calculating input/output tokens
api_key=access_token,
logging_obj=logging_obj,
headers=headers,
timeout=timeout,
client=client,
)
return openai_like_chat_completions.completion(
model=model,

View file

@ -117,10 +117,7 @@ from .llms.sagemaker.sagemaker import SagemakerLLM
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_and_google_ai_studio import (
vertex_ai_anthropic,
vertex_ai_non_gemini,
)
from .llms.vertex_ai_and_google_ai_studio import vertex_ai_non_gemini
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
VertexLLM,
)
@ -747,6 +744,11 @@ def completion( # type: ignore
proxy_server_request = kwargs.get("proxy_server_request", None)
fallbacks = kwargs.get("fallbacks", None)
headers = kwargs.get("headers", None) or extra_headers
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
num_retries = kwargs.get(
"num_retries", None
) ## alt. param for 'max_retries'. Use this to pass retries w/ instructor.
@ -964,7 +966,6 @@ def completion( # type: ignore
max_retries=max_retries,
logprobs=logprobs,
top_logprobs=top_logprobs,
extra_headers=extra_headers,
api_version=api_version,
parallel_tool_calls=parallel_tool_calls,
messages=messages,
@ -1067,6 +1068,9 @@ def completion( # type: ignore
headers = headers or litellm.headers
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
if (
litellm.enable_preview_features
and litellm.AzureOpenAIO1Config().is_o1_model(model=model)
@ -1166,6 +1170,9 @@ def completion( # type: ignore
headers = headers or litellm.headers
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
## LOAD CONFIG - if set
config = litellm.AzureOpenAIConfig.get_config()
for k, v in config.items():
@ -1223,6 +1230,9 @@ def completion( # type: ignore
headers = headers or litellm.headers
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
## LOAD CONFIG - if set
config = litellm.AzureAIStudioConfig.get_config()
for k, v in config.items():
@ -1304,6 +1314,9 @@ def completion( # type: ignore
headers = headers or litellm.headers
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
## LOAD CONFIG - if set
config = litellm.OpenAITextCompletionConfig.get_config()
for k, v in config.items():
@ -1466,6 +1479,9 @@ def completion( # type: ignore
headers = headers or litellm.headers
if extra_headers is not None:
optional_params["extra_headers"] = extra_headers
## LOAD CONFIG - if set
config = litellm.OpenAIConfig.get_config()
for k, v in config.items():
@ -2228,31 +2244,12 @@ def completion( # type: ignore
)
new_params = deepcopy(optional_params)
if "claude-3" in model:
model_response = vertex_ai_anthropic.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
client=client,
)
elif (
if (
model.startswith("meta/")
or model.startswith("mistral")
or model.startswith("codestral")
or model.startswith("jamba")
or model.startswith("claude")
):
model_response = vertex_partner_models_chat_completion.completion(
model=model,
@ -2263,6 +2260,7 @@ def completion( # type: ignore
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
api_base=api_base,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
@ -4848,6 +4846,8 @@ def image_generation(
model_response = custom_handler.aimage_generation( # type: ignore
model=model,
prompt=prompt,
api_key=api_key,
api_base=api_base,
model_response=model_response,
optional_params=optional_params,
logging_obj=litellm_logging_obj,
@ -4863,6 +4863,8 @@ def image_generation(
model_response = custom_handler.image_generation(
model=model,
prompt=prompt,
api_key=api_key,
api_base=api_base,
model_response=model_response,
optional_params=optional_params,
logging_obj=litellm_logging_obj,

View file

@ -1,12 +1,15 @@
model_list:
- model_name: gpt-4o-mini
litellm_params:
model: azure/my-gpt-4o-mini
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
# - model_name: openai-gpt-4o-realtime-audio
# litellm_params:
# model: azure/gpt-4o-realtime-preview
# api_key: os.environ/AZURE_SWEDEN_API_KEY
# api_base: os.environ/AZURE_SWEDEN_API_BASE
- model_name: openai-gpt-4o-realtime-audio
litellm_params:
model: openai/gpt-4o-realtime-preview-2024-10-01
api_key: os.environ/OPENAI_API_KEY
api_base: http://localhost:8080
litellm_settings:
turn_off_message_logging: true
cache: True
cache_params:
type: local
success_callback: ["langfuse"]

View file

@ -241,6 +241,8 @@ def initialize_callbacks_on_proxy(
litellm.callbacks = imported_list # type: ignore
if "prometheus" in value:
if premium_user is not True:
raise Exception(CommonProxyErrors.not_premium_user.value)
from litellm.proxy.proxy_server import app
verbose_proxy_logger.debug("Starting Prometheus Metrics on /metrics")

View file

@ -1,16 +1,17 @@
model_list:
- model_name: db-openai-endpoint
litellm_params:
model: openai/gpt-5
model: openai/gpt-4
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railwaz.app/
- model_name: db-openai-endpoint
litellm_params:
model: openai/gpt-5
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railwxaz.app/
api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings:
callbacks: ["prometheus"]
success_callback: ["s3"]
turn_off_message_logging: true
s3_callback_params:
s3_bucket_name: load-testing-oct # AWS Bucket Name for S3
s3_region_name: us-west-2 # AWS Region Name for S3
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3

View file

@ -1692,6 +1692,10 @@ class ProxyConfig:
else:
litellm.success_callback.append(callback)
if "prometheus" in callback:
if not premium_user:
raise Exception(
CommonProxyErrors.not_premium_user.value
)
verbose_proxy_logger.debug(
"Starting Prometheus Metrics on /metrics"
)
@ -5433,12 +5437,11 @@ async def moderations(
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.moderations(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e)),

View file

@ -408,7 +408,6 @@ class ProxyLogging:
callback,
internal_usage_cache=self.internal_usage_cache.dual_cache,
llm_router=llm_router,
premium_user=self.premium_user,
)
if callback is None:
continue

View file

@ -8,13 +8,16 @@ from litellm import get_llm_provider
from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ..llms.AzureOpenAI.realtime.handler import AzureOpenAIRealtime
from ..llms.OpenAI.realtime.handler import OpenAIRealtime
from ..utils import client as wrapper_client
azure_realtime = AzureOpenAIRealtime()
openai_realtime = OpenAIRealtime()
@wrapper_client
async def _arealtime(
model: str,
websocket: Any, # fastapi websocket
@ -31,6 +34,12 @@ async def _arealtime(
For PROXY use only.
"""
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
user = kwargs.get("user", None)
litellm_params = GenericLiteLLMParams(**kwargs)
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
@ -39,6 +48,21 @@ async def _arealtime(
api_key=api_key,
)
litellm_logging_obj.update_environment_variables(
model=model,
user=user,
optional_params={},
litellm_params={
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
},
custom_llm_provider=_custom_llm_provider,
)
if _custom_llm_provider == "azure":
api_base = (
dynamic_api_base
@ -63,6 +87,7 @@ async def _arealtime(
azure_ad_token=None,
client=None,
timeout=timeout,
logging_obj=litellm_logging_obj,
)
elif _custom_llm_provider == "openai":
api_base = (
@ -82,6 +107,7 @@ async def _arealtime(
await openai_realtime.async_realtime(
model=model,
websocket=websocket,
logging_obj=litellm_logging_obj,
api_base=api_base,
api_key=api_key,
client=None,

View file

@ -60,7 +60,7 @@ class PatternMatchRouter:
regex = re.escape(regex).replace(r"\.\*", ".*")
return f"^{regex}$"
def route(self, request: str) -> Optional[List[Dict]]:
def route(self, request: Optional[str]) -> Optional[List[Dict]]:
"""
Route a requested model to the corresponding llm deployments based on the regex pattern
@ -69,17 +69,20 @@ class PatternMatchRouter:
if no pattern is found, return None
Args:
request: str
request: Optional[str]
Returns:
Optional[List[Deployment]]: llm deployments
"""
try:
if request is None:
return None
for pattern, llm_deployments in self.patterns.items():
if re.match(pattern, request):
return llm_deployments
except Exception as e:
verbose_router_logger.error(f"Error in PatternMatchRouter.route: {str(e)}")
verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}")
return None # No matching pattern found

View file

@ -0,0 +1,14 @@
from typing import Dict
from pydantic import BaseModel
class s3BatchLoggingElement(BaseModel):
"""
Type of element stored in self.log_queue in S3Logger
"""
payload: Dict
s3_object_key: str
s3_object_download_filename: str

View file

@ -129,6 +129,7 @@ class CallTypes(Enum):
speech = "speech"
rerank = "rerank"
arerank = "arerank"
arealtime = "_arealtime"
class PassthroughCallTypes(Enum):

View file

@ -197,7 +197,6 @@ lagoLogger = None
dataDogLogger = None
prometheusLogger = None
dynamoLogger = None
s3Logger = None
genericAPILogger = None
clickHouseLogger = None
greenscaleLogger = None
@ -1410,6 +1409,8 @@ def client(original_function):
)
else:
return result
elif call_type == CallTypes.arealtime.value:
return result
# ADD HIDDEN PARAMS - additional call metadata
if hasattr(result, "_hidden_params"):
@ -1799,8 +1800,9 @@ def calculate_tiles_needed(
total_tiles = tiles_across * tiles_down
return total_tiles
def get_image_type(image_data: bytes) -> Union[str, None]:
""" take an image (really only the first ~100 bytes max are needed)
"""take an image (really only the first ~100 bytes max are needed)
and return 'png' 'gif' 'jpeg' 'heic' or None. method added to
allow deprecation of imghdr in 3.13"""
@ -4336,16 +4338,18 @@ def get_api_base(
_optional_params.vertex_location is not None
and _optional_params.vertex_project is not None
):
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
create_vertex_anthropic_url,
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
VertexPartnerProvider,
create_vertex_url,
)
if "claude" in model:
_api_base = create_vertex_anthropic_url(
_api_base = create_vertex_url(
vertex_location=_optional_params.vertex_location,
vertex_project=_optional_params.vertex_project,
model=model,
stream=stream,
partner=VertexPartnerProvider.claude,
)
else:
@ -4442,19 +4446,7 @@ def get_supported_openai_params(
elif custom_llm_provider == "volcengine":
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "groq":
return [
"temperature",
"max_tokens",
"top_p",
"stream",
"stop",
"tools",
"tool_choice",
"response_format",
"seed",
"extra_headers",
"extra_body",
]
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "hosted_vllm":
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "deepseek":
@ -4599,6 +4591,8 @@ def get_supported_openai_params(
return (
litellm.MistralTextCompletionConfig().get_supported_openai_params()
)
if model.startswith("claude"):
return litellm.VertexAIAnthropicConfig().get_supported_openai_params()
return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()

View file

@ -296,7 +296,7 @@ def test_all_model_configs():
optional_params={},
) == {"max_tokens": 10}
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.anthropic.transformation import (
VertexAIAnthropicConfig,
)

View file

@ -12,7 +12,70 @@ import litellm
litellm.num_retries = 3
import time, random
from litellm._logging import verbose_logger
import logging
import pytest
import boto3
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_basic_s3_logging(sync_mode):
verbose_logger.setLevel(level=logging.DEBUG)
litellm.success_callback = ["s3"]
litellm.s3_callback_params = {
"s3_bucket_name": "load-testing-oct",
"s3_aws_secret_access_key": "os.environ/AWS_SECRET_ACCESS_KEY",
"s3_aws_access_key_id": "os.environ/AWS_ACCESS_KEY_ID",
"s3_region_name": "us-west-2",
}
litellm.set_verbose = True
if sync_mode is True:
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "This is a test"}],
mock_response="It's simple to use and easy to get started",
)
else:
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "This is a test"}],
mock_response="It's simple to use and easy to get started",
)
print(f"response: {response}")
await asyncio.sleep(12)
total_objects, all_s3_keys = list_all_s3_objects("load-testing-oct")
# assert that atlest one key has response.id in it
assert any(response.id in key for key in all_s3_keys)
s3 = boto3.client("s3")
# delete all objects
for key in all_s3_keys:
s3.delete_object(Bucket="load-testing-oct", Key=key)
def list_all_s3_objects(bucket_name):
s3 = boto3.client("s3")
all_s3_keys = []
paginator = s3.get_paginator("list_objects_v2")
total_objects = 0
for page in paginator.paginate(Bucket=bucket_name):
if "Contents" in page:
total_objects += len(page["Contents"])
all_s3_keys.extend([obj["Key"] for obj in page["Contents"]])
print(f"Total number of objects in {bucket_name}: {total_objects}")
print(all_s3_keys)
return total_objects, all_s3_keys
list_all_s3_objects("load-testing-oct")
@pytest.mark.skip(reason="AWS Suspended Account")

View file

@ -1616,9 +1616,11 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema(
)
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
@pytest.mark.parametrize(
"model", ["gemini-1.5-flash", "claude-3-sonnet@20240229"]
) # "vertex_ai",
@pytest.mark.asyncio
async def test_gemini_pro_httpx_custom_api_base(provider):
async def test_gemini_pro_httpx_custom_api_base(model):
load_vertex_ai_credentials()
litellm.set_verbose = True
messages = [
@ -1634,7 +1636,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
with patch.object(client, "post", new=MagicMock()) as mock_call:
try:
response = completion(
model="vertex_ai_beta/gemini-1.5-flash",
model="vertex_ai/{}".format(model),
messages=messages,
response_format={"type": "json_object"},
client=client,
@ -1647,8 +1649,17 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
mock_call.assert_called_once()
assert "my-custom-api-base:generateContent" == mock_call.call_args.kwargs["url"]
assert "hello" in mock_call.call_args.kwargs["headers"]
print(f"mock_call.call_args: {mock_call.call_args}")
print(f"mock_call.call_args.kwargs: {mock_call.call_args.kwargs}")
if "url" in mock_call.call_args.kwargs:
assert (
"my-custom-api-base:generateContent"
== mock_call.call_args.kwargs["url"]
)
else:
assert "my-custom-api-base:rawPredict" == mock_call.call_args[0][0]
if "headers" in mock_call.call_args.kwargs:
assert "hello" in mock_call.call_args.kwargs["headers"]
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")

View file

@ -28,7 +28,6 @@ from typing import (
Union,
)
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
from dotenv import load_dotenv
@ -226,6 +225,8 @@ class MyCustomLLM(CustomLLM):
self,
model: str,
prompt: str,
api_key: Optional[str],
api_base: Optional[str],
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
@ -242,6 +243,8 @@ class MyCustomLLM(CustomLLM):
self,
model: str,
prompt: str,
api_key: Optional[str],
api_base: Optional[str],
model_response: ImageResponse,
optional_params: dict,
logging_obj: Any,
@ -362,3 +365,31 @@ async def test_simple_image_generation_async():
)
print(resp)
@pytest.mark.asyncio
async def test_image_generation_async_with_api_key_and_api_base():
my_custom_llm = MyCustomLLM()
litellm.custom_provider_map = [
{"provider": "custom_llm", "custom_handler": my_custom_llm}
]
with patch.object(
my_custom_llm, "aimage_generation", new=AsyncMock()
) as mock_client:
try:
resp = await litellm.aimage_generation(
model="custom_llm/my-fake-model",
prompt="Hello world",
api_key="my-api-key",
api_base="my-api-base",
)
print(resp)
except Exception as e:
print(e)
mock_client.assert_awaited_once()
mock_client.call_args.kwargs["api_key"] == "my-api-key"
mock_client.call_args.kwargs["api_base"] == "my-api-base"