forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (10/10/2024) (#6158)
* refactor(vertex_ai_partner_models/anthropic): refactor anthropic to use partner model logic * fix(vertex_ai/): support passing custom api base to partner models Fixes https://github.com/BerriAI/litellm/issues/4317 * fix(proxy_server.py): Fix prometheus premium user check logic * docs(prometheus.md): update quick start docs * fix(custom_llm.py): support passing dynamic api key + api base * fix(realtime_api/main.py): Add request/response logging for realtime api endpoints Closes https://github.com/BerriAI/litellm/issues/6081 * feat(openai/realtime): add openai realtime api logging Closes https://github.com/BerriAI/litellm/issues/6081 * fix(realtime_streaming.py): fix linting errors * fix(realtime_streaming.py): fix linting errors * fix: fix linting errors * fix pattern match router * Add literalai in the sidebar observability category (#6163) * fix: add literalai in the sidebar * fix: typo * update (#6160) * Feat: Add Langtrace integration (#5341) * Feat: Add Langtrace integration * add langtrace service name * fix timestamps for traces * add tests * Discard Callback + use existing otel logger * cleanup * remove print statments * remove callback * add docs * docs * add logging docs * format logging * remove emoji and add litellm proxy example * format logging * format `logging.md` * add langtrace docs to logging.md * sync conflict * docs fix * (perf) move s3 logging to Batch logging + async [94% faster perf under 100 RPS on 1 litellm instance] (#6165) * fix move s3 to use customLogger * add basic s3 logging test * add s3 to custom logger compatible * use batch logger for s3 * s3 set flush interval and batch size * fix s3 logging * add notes on s3 logging * fix s3 logging * add basic s3 logging test * fix s3 type errors * add test for sync logging on s3 * fix: fix to debug log --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: Willy Douhard <willy.douhard@gmail.com> Co-authored-by: yujonglee <yujonglee.dev@gmail.com> Co-authored-by: Ali Waleed <ali@scale3labs.com>
This commit is contained in:
parent
9db4ccca9f
commit
11f9df923a
28 changed files with 966 additions and 760 deletions
|
@ -27,8 +27,7 @@ model_list:
|
|||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
litellm_settings:
|
||||
success_callback: ["prometheus"]
|
||||
failure_callback: ["prometheus"]
|
||||
callbacks: ["prometheus"]
|
||||
```
|
||||
|
||||
Start the proxy
|
||||
|
|
|
@ -53,6 +53,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
|
|||
"arize",
|
||||
"langtrace",
|
||||
"gcs_bucket",
|
||||
"s3",
|
||||
"opik",
|
||||
]
|
||||
_known_custom_logger_compatible_callbacks: List = list(
|
||||
|
@ -931,7 +932,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation impor
|
|||
|
||||
vertexAITextEmbeddingConfig = VertexAITextEmbeddingConfig()
|
||||
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.anthropic.transformation import (
|
||||
VertexAIAnthropicConfig,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
|
||||
|
|
|
@ -21,6 +21,7 @@ class CustomBatchLogger(CustomLogger):
|
|||
self,
|
||||
flush_lock: Optional[asyncio.Lock] = None,
|
||||
batch_size: Optional[int] = DEFAULT_BATCH_SIZE,
|
||||
flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -28,7 +29,7 @@ class CustomBatchLogger(CustomLogger):
|
|||
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
|
||||
"""
|
||||
self.log_queue: List = []
|
||||
self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds
|
||||
self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS
|
||||
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
|
||||
self.last_flush_time = time.time()
|
||||
self.flush_lock = flush_lock
|
||||
|
|
|
@ -235,6 +235,14 @@ class LangFuseLogger:
|
|||
):
|
||||
input = prompt
|
||||
output = response_obj.results
|
||||
elif (
|
||||
kwargs.get("call_type") is not None
|
||||
and kwargs.get("call_type") == "_arealtime"
|
||||
and response_obj is not None
|
||||
and isinstance(response_obj, list)
|
||||
):
|
||||
input = kwargs.get("input")
|
||||
output = response_obj
|
||||
elif (
|
||||
kwargs.get("call_type") is not None
|
||||
and kwargs.get("call_type") == "pass_through_endpoint"
|
||||
|
|
|
@ -1,43 +1,67 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
"""
|
||||
s3 Bucket Logging Integration
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Optional
|
||||
async_log_success_event: Processes the event, stores it in memory for 10 seconds or until MAX_BATCH_SIZE and then flushes to s3
|
||||
|
||||
NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually
|
||||
NOTE 2: We create a httpx client with a concurrent limit of 1 to upload to s3. Files should get uploaded BUT they should not impact latency of LLM calling logic
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.llms.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.integrations.s3 import s3BatchLoggingElement
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
from .custom_batch_logger import CustomBatchLogger
|
||||
|
||||
class S3Logger:
|
||||
# Default Flush interval and batch size for s3
|
||||
# Flush to s3 every 10 seconds OR every 1K requests in memory
|
||||
DEFAULT_S3_FLUSH_INTERVAL_SECONDS = 10
|
||||
DEFAULT_S3_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
class S3Logger(CustomBatchLogger, BaseAWSLLM):
|
||||
# Class variables or attributes
|
||||
def __init__(
|
||||
self,
|
||||
s3_bucket_name=None,
|
||||
s3_path=None,
|
||||
s3_region_name=None,
|
||||
s3_api_version=None,
|
||||
s3_use_ssl=True,
|
||||
s3_verify=None,
|
||||
s3_endpoint_url=None,
|
||||
s3_aws_access_key_id=None,
|
||||
s3_aws_secret_access_key=None,
|
||||
s3_aws_session_token=None,
|
||||
s3_bucket_name: Optional[str] = None,
|
||||
s3_path: Optional[str] = None,
|
||||
s3_region_name: Optional[str] = None,
|
||||
s3_api_version: Optional[str] = None,
|
||||
s3_use_ssl: bool = True,
|
||||
s3_verify: Optional[bool] = None,
|
||||
s3_endpoint_url: Optional[str] = None,
|
||||
s3_aws_access_key_id: Optional[str] = None,
|
||||
s3_aws_secret_access_key: Optional[str] = None,
|
||||
s3_aws_session_token: Optional[str] = None,
|
||||
s3_flush_interval: Optional[int] = DEFAULT_S3_FLUSH_INTERVAL_SECONDS,
|
||||
s3_batch_size: Optional[int] = DEFAULT_S3_BATCH_SIZE,
|
||||
s3_config=None,
|
||||
**kwargs,
|
||||
):
|
||||
import boto3
|
||||
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}"
|
||||
)
|
||||
|
||||
# IMPORTANT: We use a concurrent limit of 1 to upload to s3
|
||||
# Files should get uploaded BUT they should not impact latency of LLM calling logic
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback,
|
||||
params={"concurrent_limit": 1},
|
||||
)
|
||||
|
||||
if litellm.s3_callback_params is not None:
|
||||
# read in .env variables - example os.environ/AWS_BUCKET_NAME
|
||||
for key, value in litellm.s3_callback_params.items():
|
||||
|
@ -63,107 +87,282 @@ class S3Logger:
|
|||
s3_path = litellm.s3_callback_params.get("s3_path")
|
||||
# done reading litellm.s3_callback_params
|
||||
|
||||
s3_flush_interval = litellm.s3_callback_params.get(
|
||||
"s3_flush_interval", DEFAULT_S3_FLUSH_INTERVAL_SECONDS
|
||||
)
|
||||
s3_batch_size = litellm.s3_callback_params.get(
|
||||
"s3_batch_size", DEFAULT_S3_BATCH_SIZE
|
||||
)
|
||||
|
||||
self.bucket_name = s3_bucket_name
|
||||
self.s3_path = s3_path
|
||||
verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}")
|
||||
# Create an S3 client with custom endpoint URL
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
region_name=s3_region_name,
|
||||
endpoint_url=s3_endpoint_url,
|
||||
api_version=s3_api_version,
|
||||
use_ssl=s3_use_ssl,
|
||||
verify=s3_verify,
|
||||
aws_access_key_id=s3_aws_access_key_id,
|
||||
aws_secret_access_key=s3_aws_secret_access_key,
|
||||
aws_session_token=s3_aws_session_token,
|
||||
config=s3_config,
|
||||
**kwargs,
|
||||
self.s3_bucket_name = s3_bucket_name
|
||||
self.s3_region_name = s3_region_name
|
||||
self.s3_api_version = s3_api_version
|
||||
self.s3_use_ssl = s3_use_ssl
|
||||
self.s3_verify = s3_verify
|
||||
self.s3_endpoint_url = s3_endpoint_url
|
||||
self.s3_aws_access_key_id = s3_aws_access_key_id
|
||||
self.s3_aws_secret_access_key = s3_aws_secret_access_key
|
||||
self.s3_aws_session_token = s3_aws_session_token
|
||||
self.s3_config = s3_config
|
||||
self.init_kwargs = kwargs
|
||||
|
||||
asyncio.create_task(self.periodic_flush())
|
||||
self.flush_lock = asyncio.Lock()
|
||||
|
||||
verbose_logger.debug(
|
||||
f"s3 flush interval: {s3_flush_interval}, s3 batch size: {s3_batch_size}"
|
||||
)
|
||||
# Call CustomLogger's __init__
|
||||
CustomBatchLogger.__init__(
|
||||
self,
|
||||
flush_lock=self.flush_lock,
|
||||
flush_interval=s3_flush_interval,
|
||||
batch_size=s3_batch_size,
|
||||
)
|
||||
self.log_queue: List[s3BatchLoggingElement] = []
|
||||
|
||||
# Call BaseAWSLLM's __init__
|
||||
BaseAWSLLM.__init__(self)
|
||||
|
||||
except Exception as e:
|
||||
print_verbose(f"Got exception on init s3 client {str(e)}")
|
||||
raise e
|
||||
|
||||
async def _async_log_event(
|
||||
self, kwargs, response_obj, start_time, end_time, print_verbose
|
||||
):
|
||||
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
f"s3 Logging - Enters logging function for model {kwargs}"
|
||||
)
|
||||
|
||||
# construct payload to send to s3
|
||||
# follows the same params as langfuse.py
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
|
||||
# Clean Metadata before logging - never log raw metadata
|
||||
# the raw metadata can contain circular references which leads to infinite recursion
|
||||
# we clean out all extra litellm metadata params before logging
|
||||
clean_metadata = {}
|
||||
if isinstance(metadata, dict):
|
||||
for key, value in metadata.items():
|
||||
# clean litellm metadata before logging
|
||||
if key in [
|
||||
"headers",
|
||||
"endpoint",
|
||||
"caching_groups",
|
||||
"previous_models",
|
||||
]:
|
||||
continue
|
||||
else:
|
||||
clean_metadata[key] = value
|
||||
|
||||
# Ensure everything in the payload is converted to str
|
||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
s3_batch_logging_element = self.create_s3_batch_logging_element(
|
||||
start_time=start_time,
|
||||
standard_logging_payload=kwargs.get("standard_logging_object", None),
|
||||
s3_path=self.s3_path,
|
||||
)
|
||||
|
||||
if payload is None:
|
||||
return
|
||||
if s3_batch_logging_element is None:
|
||||
raise ValueError("s3_batch_logging_element is None")
|
||||
|
||||
s3_file_name = litellm.utils.get_logging_id(start_time, payload) or ""
|
||||
s3_object_key = (
|
||||
(self.s3_path.rstrip("/") + "/" if self.s3_path else "")
|
||||
+ start_time.strftime("%Y-%m-%d")
|
||||
+ "/"
|
||||
+ s3_file_name
|
||||
) # we need the s3 key to include the time, so we log cache hits too
|
||||
s3_object_key += ".json"
|
||||
|
||||
s3_object_download_filename = (
|
||||
"time-"
|
||||
+ start_time.strftime("%Y-%m-%dT%H-%M-%S-%f")
|
||||
+ "_"
|
||||
+ payload["id"]
|
||||
+ ".json"
|
||||
verbose_logger.debug(
|
||||
"\ns3 Logger - Logging payload = %s", s3_batch_logging_element
|
||||
)
|
||||
|
||||
import json
|
||||
|
||||
payload_str = json.dumps(payload)
|
||||
|
||||
print_verbose(f"\ns3 Logger - Logging payload = {payload_str}")
|
||||
|
||||
response = self.s3_client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=s3_object_key,
|
||||
Body=payload_str,
|
||||
ContentType="application/json",
|
||||
ContentLanguage="en",
|
||||
ContentDisposition=f'inline; filename="{s3_object_download_filename}"',
|
||||
CacheControl="private, immutable, max-age=31536000, s-maxage=0",
|
||||
self.log_queue.append(s3_batch_logging_element)
|
||||
verbose_logger.debug(
|
||||
"s3 logging: queue length %s, batch size %s",
|
||||
len(self.log_queue),
|
||||
self.batch_size,
|
||||
)
|
||||
|
||||
print_verbose(f"Response from s3:{str(response)}")
|
||||
|
||||
print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
|
||||
return response
|
||||
if len(self.log_queue) >= self.batch_size:
|
||||
await self.flush_queue()
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
|
||||
pass
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Synchronous logging function to log to s3
|
||||
|
||||
Does not batch logging requests, instantly logs on s3 Bucket
|
||||
"""
|
||||
try:
|
||||
s3_batch_logging_element = self.create_s3_batch_logging_element(
|
||||
start_time=start_time,
|
||||
standard_logging_payload=kwargs.get("standard_logging_object", None),
|
||||
s3_path=self.s3_path,
|
||||
)
|
||||
|
||||
if s3_batch_logging_element is None:
|
||||
raise ValueError("s3_batch_logging_element is None")
|
||||
|
||||
verbose_logger.debug(
|
||||
"\ns3 Logger - Logging payload = %s", s3_batch_logging_element
|
||||
)
|
||||
|
||||
# log the element sync httpx client
|
||||
self.upload_data_to_s3(s3_batch_logging_element)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
|
||||
pass
|
||||
|
||||
async def async_upload_data_to_s3(
|
||||
self, batch_logging_element: s3BatchLoggingElement
|
||||
):
|
||||
try:
|
||||
import hashlib
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
try:
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=self.s3_aws_access_key_id,
|
||||
aws_secret_access_key=self.s3_aws_secret_access_key,
|
||||
aws_session_token=self.s3_aws_session_token,
|
||||
aws_region_name=self.s3_region_name,
|
||||
)
|
||||
|
||||
# Prepare the URL
|
||||
url = f"https://{self.bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}"
|
||||
|
||||
if self.s3_endpoint_url:
|
||||
url = self.s3_endpoint_url + "/" + batch_logging_element.s3_object_key
|
||||
|
||||
# Convert JSON to string
|
||||
json_string = json.dumps(batch_logging_element.payload)
|
||||
|
||||
# Calculate SHA256 hash of the content
|
||||
content_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest()
|
||||
|
||||
# Prepare the request
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-amz-content-sha256": content_hash,
|
||||
"Content-Language": "en",
|
||||
"Content-Disposition": f'inline; filename="{batch_logging_element.s3_object_download_filename}"',
|
||||
"Cache-Control": "private, immutable, max-age=31536000, s-maxage=0",
|
||||
}
|
||||
req = requests.Request("PUT", url, data=json_string, headers=headers)
|
||||
prepped = req.prepare()
|
||||
|
||||
# Sign the request
|
||||
aws_request = AWSRequest(
|
||||
method=prepped.method,
|
||||
url=prepped.url,
|
||||
data=prepped.body,
|
||||
headers=prepped.headers,
|
||||
)
|
||||
SigV4Auth(credentials, "s3", self.s3_region_name).add_auth(aws_request)
|
||||
|
||||
# Prepare the signed headers
|
||||
signed_headers = dict(aws_request.headers.items())
|
||||
|
||||
# Make the request
|
||||
response = await self.async_httpx_client.put(
|
||||
url, data=json_string, headers=signed_headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error uploading to s3: {str(e)}")
|
||||
|
||||
async def async_send_batch(self):
|
||||
"""
|
||||
|
||||
Sends runs from self.log_queue
|
||||
|
||||
Returns: None
|
||||
|
||||
Raises: Does not raise an exception, will only verbose_logger.exception()
|
||||
"""
|
||||
if not self.log_queue:
|
||||
return
|
||||
|
||||
for payload in self.log_queue:
|
||||
asyncio.create_task(self.async_upload_data_to_s3(payload))
|
||||
|
||||
def create_s3_batch_logging_element(
|
||||
self,
|
||||
start_time: datetime,
|
||||
standard_logging_payload: Optional[StandardLoggingPayload],
|
||||
s3_path: Optional[str],
|
||||
) -> Optional[s3BatchLoggingElement]:
|
||||
"""
|
||||
Helper function to create an s3BatchLoggingElement.
|
||||
|
||||
Args:
|
||||
start_time (datetime): The start time of the logging event.
|
||||
standard_logging_payload (Optional[StandardLoggingPayload]): The payload to be logged.
|
||||
s3_path (Optional[str]): The S3 path prefix.
|
||||
|
||||
Returns:
|
||||
Optional[s3BatchLoggingElement]: The created s3BatchLoggingElement, or None if payload is None.
|
||||
"""
|
||||
if standard_logging_payload is None:
|
||||
return None
|
||||
|
||||
s3_file_name = (
|
||||
litellm.utils.get_logging_id(start_time, standard_logging_payload) or ""
|
||||
)
|
||||
s3_object_key = (
|
||||
(s3_path.rstrip("/") + "/" if s3_path else "")
|
||||
+ start_time.strftime("%Y-%m-%d")
|
||||
+ "/"
|
||||
+ s3_file_name
|
||||
+ ".json"
|
||||
)
|
||||
|
||||
s3_object_download_filename = f"time-{start_time.strftime('%Y-%m-%dT%H-%M-%S-%f')}_{standard_logging_payload['id']}.json"
|
||||
|
||||
return s3BatchLoggingElement(
|
||||
payload=standard_logging_payload, # type: ignore
|
||||
s3_object_key=s3_object_key,
|
||||
s3_object_download_filename=s3_object_download_filename,
|
||||
)
|
||||
|
||||
def upload_data_to_s3(self, batch_logging_element: s3BatchLoggingElement):
|
||||
try:
|
||||
import hashlib
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
try:
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=self.s3_aws_access_key_id,
|
||||
aws_secret_access_key=self.s3_aws_secret_access_key,
|
||||
aws_session_token=self.s3_aws_session_token,
|
||||
aws_region_name=self.s3_region_name,
|
||||
)
|
||||
|
||||
# Prepare the URL
|
||||
url = f"https://{self.bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}"
|
||||
|
||||
if self.s3_endpoint_url:
|
||||
url = self.s3_endpoint_url + "/" + batch_logging_element.s3_object_key
|
||||
|
||||
# Convert JSON to string
|
||||
json_string = json.dumps(batch_logging_element.payload)
|
||||
|
||||
# Calculate SHA256 hash of the content
|
||||
content_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest()
|
||||
|
||||
# Prepare the request
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-amz-content-sha256": content_hash,
|
||||
"Content-Language": "en",
|
||||
"Content-Disposition": f'inline; filename="{batch_logging_element.s3_object_download_filename}"',
|
||||
"Cache-Control": "private, immutable, max-age=31536000, s-maxage=0",
|
||||
}
|
||||
req = requests.Request("PUT", url, data=json_string, headers=headers)
|
||||
prepped = req.prepare()
|
||||
|
||||
# Sign the request
|
||||
aws_request = AWSRequest(
|
||||
method=prepped.method,
|
||||
url=prepped.url,
|
||||
data=prepped.body,
|
||||
headers=prepped.headers,
|
||||
)
|
||||
SigV4Auth(credentials, "s3", self.s3_region_name).add_auth(aws_request)
|
||||
|
||||
# Prepare the signed headers
|
||||
signed_headers = dict(aws_request.headers.items())
|
||||
|
||||
httpx_client = _get_httpx_client()
|
||||
# Make the request
|
||||
response = httpx_client.put(url, data=json_string, headers=signed_headers)
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error uploading to s3: {str(e)}")
|
||||
|
|
|
@ -116,7 +116,6 @@ lagoLogger = None
|
|||
dataDogLogger = None
|
||||
prometheusLogger = None
|
||||
dynamoLogger = None
|
||||
s3Logger = None
|
||||
genericAPILogger = None
|
||||
clickHouseLogger = None
|
||||
greenscaleLogger = None
|
||||
|
@ -1346,36 +1345,6 @@ class Logging:
|
|||
user_id=kwargs.get("user", None),
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if callback == "s3":
|
||||
global s3Logger
|
||||
if s3Logger is None:
|
||||
s3Logger = S3Logger()
|
||||
if self.stream:
|
||||
if "complete_streaming_response" in self.model_call_details:
|
||||
print_verbose(
|
||||
"S3Logger Logger: Got Stream Event - Completed Stream Response"
|
||||
)
|
||||
s3Logger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
],
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
else:
|
||||
print_verbose(
|
||||
"S3Logger Logger: Got Stream Event - No complete stream response as yet"
|
||||
)
|
||||
else:
|
||||
s3Logger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if (
|
||||
callback == "openmeter"
|
||||
and self.model_call_details.get("litellm_params", {}).get(
|
||||
|
@ -2245,7 +2214,7 @@ def set_callbacks(callback_list, function_id=None):
|
|||
"""
|
||||
Globally sets the callback client
|
||||
"""
|
||||
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
|
||||
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
|
||||
|
||||
try:
|
||||
for callback in callback_list:
|
||||
|
@ -2319,8 +2288,6 @@ def set_callbacks(callback_list, function_id=None):
|
|||
dataDogLogger = DataDogLogger()
|
||||
elif callback == "dynamodb":
|
||||
dynamoLogger = DyanmoDBLogger()
|
||||
elif callback == "s3":
|
||||
s3Logger = S3Logger()
|
||||
elif callback == "wandb":
|
||||
weightsBiasesLogger = WeightsBiasesLogger()
|
||||
elif callback == "logfire":
|
||||
|
@ -2357,7 +2324,6 @@ def _init_custom_logger_compatible_class(
|
|||
llm_router: Optional[
|
||||
Any
|
||||
], # expect litellm.Router, but typing errors due to circular import
|
||||
premium_user: Optional[bool] = None,
|
||||
) -> Optional[CustomLogger]:
|
||||
if logging_integration == "lago":
|
||||
for callback in _in_memory_loggers:
|
||||
|
@ -2404,17 +2370,9 @@ def _init_custom_logger_compatible_class(
|
|||
if isinstance(callback, PrometheusLogger):
|
||||
return callback # type: ignore
|
||||
|
||||
if premium_user:
|
||||
_prometheus_logger = PrometheusLogger()
|
||||
_in_memory_loggers.append(_prometheus_logger)
|
||||
return _prometheus_logger # type: ignore
|
||||
elif premium_user is False:
|
||||
verbose_logger.warning(
|
||||
f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
_prometheus_logger = PrometheusLogger()
|
||||
_in_memory_loggers.append(_prometheus_logger)
|
||||
return _prometheus_logger # type: ignore
|
||||
elif logging_integration == "datadog":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, DataDogLogger):
|
||||
|
@ -2423,6 +2381,14 @@ def _init_custom_logger_compatible_class(
|
|||
_datadog_logger = DataDogLogger()
|
||||
_in_memory_loggers.append(_datadog_logger)
|
||||
return _datadog_logger # type: ignore
|
||||
elif logging_integration == "s3":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, S3Logger):
|
||||
return callback # type: ignore
|
||||
|
||||
_s3_logger = S3Logger()
|
||||
_in_memory_loggers.append(_s3_logger)
|
||||
return _s3_logger # type: ignore
|
||||
elif logging_integration == "gcs_bucket":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, GCSBucketLogger):
|
||||
|
@ -2589,6 +2555,10 @@ def get_custom_logger_compatible_class(
|
|||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, PrometheusLogger):
|
||||
return callback
|
||||
elif logging_integration == "s3":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, S3Logger):
|
||||
return callback
|
||||
elif logging_integration == "datadog":
|
||||
for callback in _in_memory_loggers:
|
||||
if isinstance(callback, DataDogLogger):
|
||||
|
|
112
litellm/litellm_core_utils/realtime_streaming.py
Normal file
112
litellm/litellm_core_utils/realtime_streaming.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
"""
|
||||
async with websockets.connect( # type: ignore
|
||||
url,
|
||||
extra_headers={
|
||||
"api-key": api_key, # type: ignore
|
||||
},
|
||||
) as backend_ws:
|
||||
forward_task = asyncio.create_task(
|
||||
forward_messages(websocket, backend_ws)
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_text()
|
||||
await backend_ws.send(message)
|
||||
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||
forward_task.cancel()
|
||||
finally:
|
||||
if not forward_task.done():
|
||||
forward_task.cancel()
|
||||
try:
|
||||
await forward_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import traceback
|
||||
from asyncio import Task
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .litellm_logging import Logging as LiteLLMLogging
|
||||
|
||||
# Create a thread pool with a maximum of 10 threads
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
|
||||
class RealTimeStreaming:
|
||||
def __init__(
|
||||
self,
|
||||
websocket: Any,
|
||||
backend_ws: Any,
|
||||
logging_obj: Optional[LiteLLMLogging] = None,
|
||||
):
|
||||
self.websocket = websocket
|
||||
self.backend_ws = backend_ws
|
||||
self.logging_obj = logging_obj
|
||||
self.messages: List = []
|
||||
self.input_message: Dict = {}
|
||||
|
||||
def store_message(self, message: Union[str, bytes]):
|
||||
"""Store message in list"""
|
||||
self.messages.append(message)
|
||||
|
||||
def store_input(self, message: dict):
|
||||
"""Store input message"""
|
||||
self.input_message = message
|
||||
if self.logging_obj:
|
||||
self.logging_obj.pre_call(input=message, api_key="")
|
||||
|
||||
async def log_messages(self):
|
||||
"""Log messages in list"""
|
||||
if self.logging_obj:
|
||||
## ASYNC LOGGING
|
||||
# Create an event loop for the new thread
|
||||
asyncio.create_task(self.logging_obj.async_success_handler(self.messages))
|
||||
## SYNC LOGGING
|
||||
executor.submit(self.logging_obj.success_handler(self.messages))
|
||||
|
||||
async def backend_to_client_send_messages(self):
|
||||
import websockets
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await self.backend_ws.recv()
|
||||
await self.websocket.send_text(message)
|
||||
|
||||
## LOGGING
|
||||
self.store_message(message)
|
||||
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
await self.log_messages()
|
||||
|
||||
async def client_ack_messages(self):
|
||||
try:
|
||||
while True:
|
||||
message = await self.websocket.receive_text()
|
||||
## LOGGING
|
||||
self.store_input(message=message)
|
||||
## FORWARD TO BACKEND
|
||||
await self.backend_ws.send(message)
|
||||
except self.websockets.exceptions.ConnectionClosed: # type: ignore
|
||||
pass
|
||||
|
||||
async def bidirectional_forward(self):
|
||||
|
||||
forward_task = asyncio.create_task(self.backend_to_client_send_messages())
|
||||
try:
|
||||
await self.client_ack_messages()
|
||||
except self.websockets.exceptions.ConnectionClosed: # type: ignore
|
||||
forward_task.cancel()
|
||||
finally:
|
||||
if not forward_task.done():
|
||||
forward_task.cancel()
|
||||
try:
|
||||
await forward_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
|
@ -7,6 +7,8 @@ This requires websockets, and is currently only supported on LiteLLM Proxy.
|
|||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
|
||||
from ..azure import AzureChatCompletion
|
||||
|
||||
# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
|
||||
|
@ -44,6 +46,7 @@ class AzureOpenAIRealtime(AzureChatCompletion):
|
|||
api_version: Optional[str] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client: Optional[Any] = None,
|
||||
logging_obj: Optional[LiteLLMLogging] = None,
|
||||
timeout: Optional[float] = None,
|
||||
):
|
||||
import websockets
|
||||
|
@ -62,23 +65,10 @@ class AzureOpenAIRealtime(AzureChatCompletion):
|
|||
"api-key": api_key, # type: ignore
|
||||
},
|
||||
) as backend_ws:
|
||||
forward_task = asyncio.create_task(
|
||||
forward_messages(websocket, backend_ws)
|
||||
realtime_streaming = RealTimeStreaming(
|
||||
websocket, backend_ws, logging_obj
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_text()
|
||||
await backend_ws.send(message)
|
||||
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||
forward_task.cancel()
|
||||
finally:
|
||||
if not forward_task.done():
|
||||
forward_task.cancel()
|
||||
try:
|
||||
await forward_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await realtime_streaming.bidirectional_forward()
|
||||
|
||||
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||
await websocket.close(code=e.status_code, reason=str(e))
|
||||
|
|
|
@ -7,20 +7,11 @@ This requires websockets, and is currently only supported on LiteLLM Proxy.
|
|||
import asyncio
|
||||
from typing import Any, Optional
|
||||
|
||||
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
|
||||
from ..openai import OpenAIChatCompletion
|
||||
|
||||
|
||||
async def forward_messages(client_ws: Any, backend_ws: Any):
|
||||
import websockets
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await backend_ws.recv()
|
||||
await client_ws.send_text(message)
|
||||
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
class OpenAIRealtime(OpenAIChatCompletion):
|
||||
def _construct_url(self, api_base: str, model: str) -> str:
|
||||
"""
|
||||
|
@ -35,6 +26,7 @@ class OpenAIRealtime(OpenAIChatCompletion):
|
|||
self,
|
||||
model: str,
|
||||
websocket: Any,
|
||||
logging_obj: LiteLLMLogging,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
client: Optional[Any] = None,
|
||||
|
@ -57,25 +49,26 @@ class OpenAIRealtime(OpenAIChatCompletion):
|
|||
"OpenAI-Beta": "realtime=v1",
|
||||
},
|
||||
) as backend_ws:
|
||||
forward_task = asyncio.create_task(
|
||||
forward_messages(websocket, backend_ws)
|
||||
realtime_streaming = RealTimeStreaming(
|
||||
websocket, backend_ws, logging_obj
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_text()
|
||||
await backend_ws.send(message)
|
||||
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||
forward_task.cancel()
|
||||
finally:
|
||||
if not forward_task.done():
|
||||
forward_task.cancel()
|
||||
try:
|
||||
await forward_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await realtime_streaming.bidirectional_forward()
|
||||
|
||||
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||
await websocket.close(code=e.status_code, reason=str(e))
|
||||
except Exception as e:
|
||||
await websocket.close(code=1011, reason=f"Internal server error: {str(e)}")
|
||||
try:
|
||||
await websocket.close(
|
||||
code=1011, reason=f"Internal server error: {str(e)}"
|
||||
)
|
||||
except RuntimeError as close_error:
|
||||
if "already completed" in str(close_error) or "websocket.close" in str(
|
||||
close_error
|
||||
):
|
||||
# The WebSocket is already closed or the response is completed, so we can ignore this error
|
||||
pass
|
||||
else:
|
||||
# If it's a different RuntimeError, we might want to log it or handle it differently
|
||||
raise Exception(
|
||||
f"Unexpected error while closing WebSocket: {close_error}"
|
||||
)
|
||||
|
|
|
@ -153,6 +153,8 @@ class CustomLLM(BaseLLM):
|
|||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
|
@ -166,6 +168,12 @@ class CustomLLM(BaseLLM):
|
|||
model: str,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
api_key: Optional[
|
||||
str
|
||||
], # dynamically set api_key - https://docs.litellm.ai/docs/set_keys#api_key
|
||||
api_base: Optional[
|
||||
str
|
||||
], # dynamically set api_base - https://docs.litellm.ai/docs/set_keys#api_base
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
|
|
|
@ -1,464 +0,0 @@
|
|||
# What is this?
|
||||
## Handler file for calling claude-3 on vertex ai
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.anthropic import (
|
||||
AnthropicMessagesTool,
|
||||
AnthropicMessagesToolChoice,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
from litellm.types.utils import ResponseFormatChunk
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
from ..prompt_templates.factory import (
|
||||
construct_tool_use_system_prompt,
|
||||
contains_tag,
|
||||
custom_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
prompt_factory,
|
||||
response_schema_prompt,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class VertexAIAnthropicConfig:
|
||||
"""
|
||||
Reference:https://docs.anthropic.com/claude/reference/messages_post
|
||||
|
||||
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
|
||||
|
||||
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
|
||||
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
|
||||
|
||||
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens,
|
||||
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
||||
- `temperature` Optional (float) The amount of randomness injected into the response
|
||||
- `top_p` Optional (float) Use nucleus sampling.
|
||||
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
||||
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = (
|
||||
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
|
||||
)
|
||||
system: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key == "max_tokens" and value is None:
|
||||
value = self.max_tokens
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "tool_choice":
|
||||
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
|
||||
if value == "auto":
|
||||
_tool_choice = {"type": "auto"}
|
||||
elif value == "required":
|
||||
_tool_choice = {"type": "any"}
|
||||
elif isinstance(value, dict):
|
||||
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
|
||||
|
||||
if _tool_choice is not None:
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "response_format" and isinstance(value, dict):
|
||||
json_schema: Optional[dict] = None
|
||||
if "response_schema" in value:
|
||||
json_schema = value["response_schema"]
|
||||
elif "json_schema" in value:
|
||||
json_schema = value["json_schema"]["schema"]
|
||||
"""
|
||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||
- You usually want to provide a single tool
|
||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
"""
|
||||
_tool_choice = None
|
||||
_tool_choice = {"name": "json_tool_call", "type": "tool"}
|
||||
|
||||
_tool = AnthropicMessagesTool(
|
||||
name="json_tool_call",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"values": json_schema}, # type: ignore
|
||||
},
|
||||
)
|
||||
|
||||
optional_params["tools"] = [_tool]
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
optional_params["json_mode"] = True
|
||||
|
||||
return optional_params
|
||||
|
||||
|
||||
"""
|
||||
- Run client init
|
||||
- Support async completion, streaming
|
||||
"""
|
||||
|
||||
|
||||
def refresh_auth(
|
||||
credentials,
|
||||
) -> str: # used when user passes in credentials as json string
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
|
||||
if credentials.token is None:
|
||||
credentials.refresh(Request())
|
||||
|
||||
if not credentials.token:
|
||||
raise RuntimeError("Could not resolve API token from the credentials")
|
||||
|
||||
return credentials.token
|
||||
|
||||
|
||||
def get_vertex_client(
|
||||
client: Any,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
) -> Tuple[Any, Optional[str]]:
|
||||
args = locals()
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
|
||||
try:
|
||||
from anthropic import AnthropicVertex
|
||||
except Exception:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
|
||||
)
|
||||
|
||||
access_token: Optional[str] = None
|
||||
|
||||
if client is None:
|
||||
_credentials, cred_project_id = VertexLLM().load_auth(
|
||||
credentials=vertex_credentials, project_id=vertex_project
|
||||
)
|
||||
|
||||
vertex_ai_client = AnthropicVertex(
|
||||
project_id=vertex_project or cred_project_id,
|
||||
region=vertex_location or "us-central1",
|
||||
access_token=_credentials.token,
|
||||
)
|
||||
access_token = _credentials.token
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
access_token = client.access_token
|
||||
|
||||
return vertex_ai_client, access_token
|
||||
|
||||
|
||||
def create_vertex_anthropic_url(
|
||||
vertex_location: str, vertex_project: str, model: str, stream: bool
|
||||
) -> str:
|
||||
if stream is True:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
|
||||
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
custom_prompt_dict: dict,
|
||||
headers: Optional[dict],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
client=None,
|
||||
):
|
||||
try:
|
||||
import vertexai
|
||||
except Exception:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
|
||||
)
|
||||
|
||||
from anthropic import AnthropicVertex
|
||||
|
||||
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
|
||||
if not (
|
||||
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
|
||||
):
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
||||
)
|
||||
try:
|
||||
|
||||
vertex_httpx_logic = VertexLLM()
|
||||
|
||||
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
|
||||
## Load Config
|
||||
config = litellm.VertexAIAnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
## CONSTRUCT API BASE
|
||||
stream = optional_params.get("stream", False)
|
||||
|
||||
api_base = create_vertex_anthropic_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
model=model,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
if headers is not None:
|
||||
vertex_headers = headers
|
||||
else:
|
||||
vertex_headers = {}
|
||||
|
||||
vertex_headers.update({"Authorization": "Bearer {}".format(access_token)})
|
||||
|
||||
optional_params.update(
|
||||
{"anthropic_version": "vertex-2023-10-16", "is_vertex_request": True}
|
||||
)
|
||||
|
||||
return anthropic_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=access_token,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=vertex_headers,
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
|
||||
async def async_completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
access_token=None,
|
||||
):
|
||||
from anthropic import AsyncAnthropicVertex
|
||||
|
||||
if client is None:
|
||||
vertex_ai_client = AsyncAnthropicVertex(
|
||||
project_id=vertex_project, region=vertex_location, access_token=access_token # type: ignore
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
},
|
||||
)
|
||||
message = await vertex_ai_client.messages.create(**data) # type: ignore
|
||||
text_content = message.content[0].text
|
||||
## TOOL CALLING - OUTPUT PARSE
|
||||
if text_content is not None and contains_tag("invoke", text_content):
|
||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||
function_arguments_str = extract_between_tags("invoke", text_content)[0].strip()
|
||||
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
||||
function_arguments = parse_xml_params(function_arguments_str)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{uuid.uuid4()}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_arguments),
|
||||
},
|
||||
}
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
else:
|
||||
model_response.choices[0].message.content = text_content # type: ignore
|
||||
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = message.usage.input_tokens
|
||||
completion_tokens = message.usage.output_tokens
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
|
||||
async def async_streaming(
|
||||
model: str,
|
||||
messages: list,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
access_token=None,
|
||||
):
|
||||
from anthropic import AsyncAnthropicVertex
|
||||
|
||||
if client is None:
|
||||
vertex_ai_client = AsyncAnthropicVertex(
|
||||
project_id=vertex_project, region=vertex_location, access_token=access_token # type: ignore
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
},
|
||||
)
|
||||
response = await vertex_ai_client.messages.create(**data, stream=True) # type: ignore
|
||||
logging_obj.post_call(input=messages, api_key=None, original_response=response)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return streamwrapper
|
|
@ -0,0 +1,179 @@
|
|||
# What is this?
|
||||
## Handler file for calling claude-3 on vertex ai
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.anthropic import (
|
||||
AnthropicMessagesTool,
|
||||
AnthropicMessagesToolChoice,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
from litellm.types.utils import ResponseFormatChunk
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
from ....prompt_templates.factory import (
|
||||
construct_tool_use_system_prompt,
|
||||
contains_tag,
|
||||
custom_prompt,
|
||||
extract_between_tags,
|
||||
parse_xml_params,
|
||||
prompt_factory,
|
||||
response_schema_prompt,
|
||||
)
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class VertexAIAnthropicConfig:
|
||||
"""
|
||||
Reference:https://docs.anthropic.com/claude/reference/messages_post
|
||||
|
||||
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
|
||||
|
||||
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
|
||||
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
|
||||
|
||||
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens,
|
||||
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
||||
- `temperature` Optional (float) The amount of randomness injected into the response
|
||||
- `top_p` Optional (float) Use nucleus sampling.
|
||||
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
||||
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = (
|
||||
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
|
||||
)
|
||||
system: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key == "max_tokens" and value is None:
|
||||
value = self.max_tokens
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "tool_choice":
|
||||
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
|
||||
if value == "auto":
|
||||
_tool_choice = {"type": "auto"}
|
||||
elif value == "required":
|
||||
_tool_choice = {"type": "any"}
|
||||
elif isinstance(value, dict):
|
||||
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
|
||||
|
||||
if _tool_choice is not None:
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "response_format" and isinstance(value, dict):
|
||||
json_schema: Optional[dict] = None
|
||||
if "response_schema" in value:
|
||||
json_schema = value["response_schema"]
|
||||
elif "json_schema" in value:
|
||||
json_schema = value["json_schema"]["schema"]
|
||||
"""
|
||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||
- You usually want to provide a single tool
|
||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||
"""
|
||||
_tool_choice = None
|
||||
_tool_choice = {"name": "json_tool_call", "type": "tool"}
|
||||
|
||||
_tool = AnthropicMessagesTool(
|
||||
name="json_tool_call",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": {"values": json_schema}, # type: ignore
|
||||
},
|
||||
)
|
||||
|
||||
optional_params["tools"] = [_tool]
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
optional_params["json_mode"] = True
|
||||
|
||||
return optional_params
|
|
@ -9,13 +9,14 @@ import httpx # type: ignore
|
|||
import litellm
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
from ...base import BaseLLM
|
||||
from ..vertex_llm_base import VertexBase
|
||||
|
||||
|
||||
class VertexPartnerProvider(str, Enum):
|
||||
mistralai = "mistralai"
|
||||
llama = "llama"
|
||||
ai21 = "ai21"
|
||||
claude = "claude"
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -31,31 +32,38 @@ class VertexAIError(Exception):
|
|||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class VertexAIPartnerModels(BaseLLM):
|
||||
def create_vertex_url(
|
||||
vertex_location: str,
|
||||
vertex_project: str,
|
||||
partner: VertexPartnerProvider,
|
||||
stream: Optional[bool],
|
||||
model: str,
|
||||
api_base: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Return the base url for the vertex partner models"""
|
||||
if partner == VertexPartnerProvider.llama:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
|
||||
elif partner == VertexPartnerProvider.mistralai:
|
||||
if stream:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
|
||||
elif partner == VertexPartnerProvider.ai21:
|
||||
if stream:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
|
||||
elif partner == VertexPartnerProvider.claude:
|
||||
if stream:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
|
||||
|
||||
|
||||
class VertexAIPartnerModels(VertexBase):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def create_vertex_url(
|
||||
self,
|
||||
vertex_location: str,
|
||||
vertex_project: str,
|
||||
partner: VertexPartnerProvider,
|
||||
stream: Optional[bool],
|
||||
model: str,
|
||||
) -> str:
|
||||
if partner == VertexPartnerProvider.llama:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
|
||||
elif partner == VertexPartnerProvider.mistralai:
|
||||
if stream:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
|
||||
elif partner == VertexPartnerProvider.ai21:
|
||||
if stream:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
|
||||
else:
|
||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -64,6 +72,7 @@ class VertexAIPartnerModels(BaseLLM):
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
custom_prompt_dict: dict,
|
||||
headers: Optional[dict],
|
||||
|
@ -80,6 +89,7 @@ class VertexAIPartnerModels(BaseLLM):
|
|||
import vertexai
|
||||
from google.cloud import aiplatform
|
||||
|
||||
from litellm.llms.anthropic.chat import AnthropicChatCompletion
|
||||
from litellm.llms.databricks.chat import DatabricksChatCompletion
|
||||
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
||||
from litellm.llms.text_completion_codestral import CodestralTextCompletion
|
||||
|
@ -112,6 +122,7 @@ class VertexAIPartnerModels(BaseLLM):
|
|||
|
||||
openai_like_chat_completions = DatabricksChatCompletion()
|
||||
codestral_fim_completions = CodestralTextCompletion()
|
||||
anthropic_chat_completions = AnthropicChatCompletion()
|
||||
|
||||
## CONSTRUCT API BASE
|
||||
stream: bool = optional_params.get("stream", False) or False
|
||||
|
@ -126,8 +137,10 @@ class VertexAIPartnerModels(BaseLLM):
|
|||
elif "jamba" in model:
|
||||
partner = VertexPartnerProvider.ai21
|
||||
optional_params["custom_endpoint"] = True
|
||||
elif "claude" in model:
|
||||
partner = VertexPartnerProvider.claude
|
||||
|
||||
api_base = self.create_vertex_url(
|
||||
default_api_base = create_vertex_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
partner=partner, # type: ignore
|
||||
|
@ -135,6 +148,21 @@ class VertexAIPartnerModels(BaseLLM):
|
|||
model=model,
|
||||
)
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=stream,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
)
|
||||
|
||||
model = model.split("@")[0]
|
||||
|
||||
if "codestral" in model and litellm_params.get("text_completion") is True:
|
||||
|
@ -158,6 +186,35 @@ class VertexAIPartnerModels(BaseLLM):
|
|||
timeout=timeout,
|
||||
encoding=encoding,
|
||||
)
|
||||
elif "claude" in model:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers.update({"Authorization": "Bearer {}".format(access_token)})
|
||||
|
||||
optional_params.update(
|
||||
{
|
||||
"anthropic_version": "vertex-2023-10-16",
|
||||
"is_vertex_request": True,
|
||||
}
|
||||
)
|
||||
return anthropic_chat_completions.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
acompletion=acompletion,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding, # for calculating input/output tokens
|
||||
api_key=access_token,
|
||||
logging_obj=logging_obj,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
|
||||
return openai_like_chat_completions.completion(
|
||||
model=model,
|
||||
|
|
|
@ -117,10 +117,7 @@ from .llms.sagemaker.sagemaker import SagemakerLLM
|
|||
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.vertex_ai_and_google_ai_studio import (
|
||||
vertex_ai_anthropic,
|
||||
vertex_ai_non_gemini,
|
||||
)
|
||||
from .llms.vertex_ai_and_google_ai_studio import vertex_ai_non_gemini
|
||||
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexLLM,
|
||||
)
|
||||
|
@ -747,6 +744,11 @@ def completion( # type: ignore
|
|||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
fallbacks = kwargs.get("fallbacks", None)
|
||||
headers = kwargs.get("headers", None) or extra_headers
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
num_retries = kwargs.get(
|
||||
"num_retries", None
|
||||
) ## alt. param for 'max_retries'. Use this to pass retries w/ instructor.
|
||||
|
@ -964,7 +966,6 @@ def completion( # type: ignore
|
|||
max_retries=max_retries,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
extra_headers=extra_headers,
|
||||
api_version=api_version,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
messages=messages,
|
||||
|
@ -1067,6 +1068,9 @@ def completion( # type: ignore
|
|||
|
||||
headers = headers or litellm.headers
|
||||
|
||||
if extra_headers is not None:
|
||||
optional_params["extra_headers"] = extra_headers
|
||||
|
||||
if (
|
||||
litellm.enable_preview_features
|
||||
and litellm.AzureOpenAIO1Config().is_o1_model(model=model)
|
||||
|
@ -1166,6 +1170,9 @@ def completion( # type: ignore
|
|||
|
||||
headers = headers or litellm.headers
|
||||
|
||||
if extra_headers is not None:
|
||||
optional_params["extra_headers"] = extra_headers
|
||||
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.AzureOpenAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
|
@ -1223,6 +1230,9 @@ def completion( # type: ignore
|
|||
|
||||
headers = headers or litellm.headers
|
||||
|
||||
if extra_headers is not None:
|
||||
optional_params["extra_headers"] = extra_headers
|
||||
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.AzureAIStudioConfig.get_config()
|
||||
for k, v in config.items():
|
||||
|
@ -1304,6 +1314,9 @@ def completion( # type: ignore
|
|||
|
||||
headers = headers or litellm.headers
|
||||
|
||||
if extra_headers is not None:
|
||||
optional_params["extra_headers"] = extra_headers
|
||||
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.OpenAITextCompletionConfig.get_config()
|
||||
for k, v in config.items():
|
||||
|
@ -1466,6 +1479,9 @@ def completion( # type: ignore
|
|||
|
||||
headers = headers or litellm.headers
|
||||
|
||||
if extra_headers is not None:
|
||||
optional_params["extra_headers"] = extra_headers
|
||||
|
||||
## LOAD CONFIG - if set
|
||||
config = litellm.OpenAIConfig.get_config()
|
||||
for k, v in config.items():
|
||||
|
@ -2228,31 +2244,12 @@ def completion( # type: ignore
|
|||
)
|
||||
|
||||
new_params = deepcopy(optional_params)
|
||||
if "claude-3" in model:
|
||||
model_response = vertex_ai_anthropic.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=new_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_credentials=vertex_credentials,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
headers=headers,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
elif (
|
||||
if (
|
||||
model.startswith("meta/")
|
||||
or model.startswith("mistral")
|
||||
or model.startswith("codestral")
|
||||
or model.startswith("jamba")
|
||||
or model.startswith("claude")
|
||||
):
|
||||
model_response = vertex_partner_models_chat_completion.completion(
|
||||
model=model,
|
||||
|
@ -2263,6 +2260,7 @@ def completion( # type: ignore
|
|||
litellm_params=litellm_params, # type: ignore
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
api_base=api_base,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_credentials=vertex_credentials,
|
||||
|
@ -4848,6 +4846,8 @@ def image_generation(
|
|||
model_response = custom_handler.aimage_generation( # type: ignore
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
|
@ -4863,6 +4863,8 @@ def image_generation(
|
|||
model_response = custom_handler.image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
model_list:
|
||||
- model_name: gpt-4o-mini
|
||||
litellm_params:
|
||||
model: azure/my-gpt-4o-mini
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
# - model_name: openai-gpt-4o-realtime-audio
|
||||
# litellm_params:
|
||||
# model: azure/gpt-4o-realtime-preview
|
||||
# api_key: os.environ/AZURE_SWEDEN_API_KEY
|
||||
# api_base: os.environ/AZURE_SWEDEN_API_BASE
|
||||
|
||||
- model_name: openai-gpt-4o-realtime-audio
|
||||
litellm_params:
|
||||
model: openai/gpt-4o-realtime-preview-2024-10-01
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
api_base: http://localhost:8080
|
||||
|
||||
litellm_settings:
|
||||
turn_off_message_logging: true
|
||||
cache: True
|
||||
cache_params:
|
||||
type: local
|
||||
success_callback: ["langfuse"]
|
|
@ -241,6 +241,8 @@ def initialize_callbacks_on_proxy(
|
|||
litellm.callbacks = imported_list # type: ignore
|
||||
|
||||
if "prometheus" in value:
|
||||
if premium_user is not True:
|
||||
raise Exception(CommonProxyErrors.not_premium_user.value)
|
||||
from litellm.proxy.proxy_server import app
|
||||
|
||||
verbose_proxy_logger.debug("Starting Prometheus Metrics on /metrics")
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
model_list:
|
||||
- model_name: db-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/gpt-5
|
||||
model: openai/gpt-4
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railwaz.app/
|
||||
- model_name: db-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/gpt-5
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railwxaz.app/
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["prometheus"]
|
||||
success_callback: ["s3"]
|
||||
turn_off_message_logging: true
|
||||
s3_callback_params:
|
||||
s3_bucket_name: load-testing-oct # AWS Bucket Name for S3
|
||||
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
||||
|
||||
|
||||
|
|
|
@ -1692,6 +1692,10 @@ class ProxyConfig:
|
|||
else:
|
||||
litellm.success_callback.append(callback)
|
||||
if "prometheus" in callback:
|
||||
if not premium_user:
|
||||
raise Exception(
|
||||
CommonProxyErrors.not_premium_user.value
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Starting Prometheus Metrics on /metrics"
|
||||
)
|
||||
|
@ -5433,12 +5437,11 @@ async def moderations(
|
|||
await proxy_logging_obj.post_call_failure_hook(
|
||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||
)
|
||||
verbose_proxy_logger.error(
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.moderations(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "message", str(e)),
|
||||
|
|
|
@ -408,7 +408,6 @@ class ProxyLogging:
|
|||
callback,
|
||||
internal_usage_cache=self.internal_usage_cache.dual_cache,
|
||||
llm_router=llm_router,
|
||||
premium_user=self.premium_user,
|
||||
)
|
||||
if callback is None:
|
||||
continue
|
||||
|
|
|
@ -8,13 +8,16 @@ from litellm import get_llm_provider
|
|||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from ..llms.AzureOpenAI.realtime.handler import AzureOpenAIRealtime
|
||||
from ..llms.OpenAI.realtime.handler import OpenAIRealtime
|
||||
from ..utils import client as wrapper_client
|
||||
|
||||
azure_realtime = AzureOpenAIRealtime()
|
||||
openai_realtime = OpenAIRealtime()
|
||||
|
||||
|
||||
@wrapper_client
|
||||
async def _arealtime(
|
||||
model: str,
|
||||
websocket: Any, # fastapi websocket
|
||||
|
@ -31,6 +34,12 @@ async def _arealtime(
|
|||
|
||||
For PROXY use only.
|
||||
"""
|
||||
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
user = kwargs.get("user", None)
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
|
||||
|
@ -39,6 +48,21 @@ async def _arealtime(
|
|||
api_key=api_key,
|
||||
)
|
||||
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user=user,
|
||||
optional_params={},
|
||||
litellm_params={
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"proxy_server_request": proxy_server_request,
|
||||
"model_info": model_info,
|
||||
"metadata": metadata,
|
||||
"preset_cache_key": None,
|
||||
"stream_response": {},
|
||||
},
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
)
|
||||
|
||||
if _custom_llm_provider == "azure":
|
||||
api_base = (
|
||||
dynamic_api_base
|
||||
|
@ -63,6 +87,7 @@ async def _arealtime(
|
|||
azure_ad_token=None,
|
||||
client=None,
|
||||
timeout=timeout,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
elif _custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
|
@ -82,6 +107,7 @@ async def _arealtime(
|
|||
await openai_realtime.async_realtime(
|
||||
model=model,
|
||||
websocket=websocket,
|
||||
logging_obj=litellm_logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
client=None,
|
||||
|
|
|
@ -60,7 +60,7 @@ class PatternMatchRouter:
|
|||
regex = re.escape(regex).replace(r"\.\*", ".*")
|
||||
return f"^{regex}$"
|
||||
|
||||
def route(self, request: str) -> Optional[List[Dict]]:
|
||||
def route(self, request: Optional[str]) -> Optional[List[Dict]]:
|
||||
"""
|
||||
Route a requested model to the corresponding llm deployments based on the regex pattern
|
||||
|
||||
|
@ -69,17 +69,20 @@ class PatternMatchRouter:
|
|||
if no pattern is found, return None
|
||||
|
||||
Args:
|
||||
request: str
|
||||
request: Optional[str]
|
||||
|
||||
Returns:
|
||||
Optional[List[Deployment]]: llm deployments
|
||||
"""
|
||||
try:
|
||||
if request is None:
|
||||
return None
|
||||
for pattern, llm_deployments in self.patterns.items():
|
||||
if re.match(pattern, request):
|
||||
return llm_deployments
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(f"Error in PatternMatchRouter.route: {str(e)}")
|
||||
verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}")
|
||||
|
||||
return None # No matching pattern found
|
||||
|
||||
|
||||
|
|
14
litellm/types/integrations/s3.py
Normal file
14
litellm/types/integrations/s3.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class s3BatchLoggingElement(BaseModel):
|
||||
"""
|
||||
Type of element stored in self.log_queue in S3Logger
|
||||
|
||||
"""
|
||||
|
||||
payload: Dict
|
||||
s3_object_key: str
|
||||
s3_object_download_filename: str
|
|
@ -129,6 +129,7 @@ class CallTypes(Enum):
|
|||
speech = "speech"
|
||||
rerank = "rerank"
|
||||
arerank = "arerank"
|
||||
arealtime = "_arealtime"
|
||||
|
||||
|
||||
class PassthroughCallTypes(Enum):
|
||||
|
|
|
@ -197,7 +197,6 @@ lagoLogger = None
|
|||
dataDogLogger = None
|
||||
prometheusLogger = None
|
||||
dynamoLogger = None
|
||||
s3Logger = None
|
||||
genericAPILogger = None
|
||||
clickHouseLogger = None
|
||||
greenscaleLogger = None
|
||||
|
@ -1410,6 +1409,8 @@ def client(original_function):
|
|||
)
|
||||
else:
|
||||
return result
|
||||
elif call_type == CallTypes.arealtime.value:
|
||||
return result
|
||||
|
||||
# ADD HIDDEN PARAMS - additional call metadata
|
||||
if hasattr(result, "_hidden_params"):
|
||||
|
@ -1799,8 +1800,9 @@ def calculate_tiles_needed(
|
|||
total_tiles = tiles_across * tiles_down
|
||||
return total_tiles
|
||||
|
||||
|
||||
def get_image_type(image_data: bytes) -> Union[str, None]:
|
||||
""" take an image (really only the first ~100 bytes max are needed)
|
||||
"""take an image (really only the first ~100 bytes max are needed)
|
||||
and return 'png' 'gif' 'jpeg' 'heic' or None. method added to
|
||||
allow deprecation of imghdr in 3.13"""
|
||||
|
||||
|
@ -4336,16 +4338,18 @@ def get_api_base(
|
|||
_optional_params.vertex_location is not None
|
||||
and _optional_params.vertex_project is not None
|
||||
):
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
||||
create_vertex_anthropic_url,
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
||||
VertexPartnerProvider,
|
||||
create_vertex_url,
|
||||
)
|
||||
|
||||
if "claude" in model:
|
||||
_api_base = create_vertex_anthropic_url(
|
||||
_api_base = create_vertex_url(
|
||||
vertex_location=_optional_params.vertex_location,
|
||||
vertex_project=_optional_params.vertex_project,
|
||||
model=model,
|
||||
stream=stream,
|
||||
partner=VertexPartnerProvider.claude,
|
||||
)
|
||||
else:
|
||||
|
||||
|
@ -4442,19 +4446,7 @@ def get_supported_openai_params(
|
|||
elif custom_llm_provider == "volcengine":
|
||||
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "groq":
|
||||
return [
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"stream",
|
||||
"stop",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
"extra_body",
|
||||
]
|
||||
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "hosted_vllm":
|
||||
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
|
||||
elif custom_llm_provider == "deepseek":
|
||||
|
@ -4599,6 +4591,8 @@ def get_supported_openai_params(
|
|||
return (
|
||||
litellm.MistralTextCompletionConfig().get_supported_openai_params()
|
||||
)
|
||||
if model.startswith("claude"):
|
||||
return litellm.VertexAIAnthropicConfig().get_supported_openai_params()
|
||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
||||
elif request_type == "embeddings":
|
||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||
|
|
|
@ -296,7 +296,7 @@ def test_all_model_configs():
|
|||
optional_params={},
|
||||
) == {"max_tokens": 10}
|
||||
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.anthropic.transformation import (
|
||||
VertexAIAnthropicConfig,
|
||||
)
|
||||
|
||||
|
|
|
@ -12,7 +12,70 @@ import litellm
|
|||
litellm.num_retries = 3
|
||||
|
||||
import time, random
|
||||
from litellm._logging import verbose_logger
|
||||
import logging
|
||||
import pytest
|
||||
import boto3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_basic_s3_logging(sync_mode):
|
||||
verbose_logger.setLevel(level=logging.DEBUG)
|
||||
litellm.success_callback = ["s3"]
|
||||
litellm.s3_callback_params = {
|
||||
"s3_bucket_name": "load-testing-oct",
|
||||
"s3_aws_secret_access_key": "os.environ/AWS_SECRET_ACCESS_KEY",
|
||||
"s3_aws_access_key_id": "os.environ/AWS_ACCESS_KEY_ID",
|
||||
"s3_region_name": "us-west-2",
|
||||
}
|
||||
litellm.set_verbose = True
|
||||
|
||||
if sync_mode is True:
|
||||
response = litellm.completion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "This is a test"}],
|
||||
mock_response="It's simple to use and easy to get started",
|
||||
)
|
||||
else:
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "This is a test"}],
|
||||
mock_response="It's simple to use and easy to get started",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
await asyncio.sleep(12)
|
||||
|
||||
total_objects, all_s3_keys = list_all_s3_objects("load-testing-oct")
|
||||
|
||||
# assert that atlest one key has response.id in it
|
||||
assert any(response.id in key for key in all_s3_keys)
|
||||
s3 = boto3.client("s3")
|
||||
# delete all objects
|
||||
for key in all_s3_keys:
|
||||
s3.delete_object(Bucket="load-testing-oct", Key=key)
|
||||
|
||||
|
||||
def list_all_s3_objects(bucket_name):
|
||||
s3 = boto3.client("s3")
|
||||
|
||||
all_s3_keys = []
|
||||
|
||||
paginator = s3.get_paginator("list_objects_v2")
|
||||
total_objects = 0
|
||||
|
||||
for page in paginator.paginate(Bucket=bucket_name):
|
||||
if "Contents" in page:
|
||||
total_objects += len(page["Contents"])
|
||||
all_s3_keys.extend([obj["Key"] for obj in page["Contents"]])
|
||||
|
||||
print(f"Total number of objects in {bucket_name}: {total_objects}")
|
||||
print(all_s3_keys)
|
||||
return total_objects, all_s3_keys
|
||||
|
||||
|
||||
list_all_s3_objects("load-testing-oct")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
||||
|
|
|
@ -1616,9 +1616,11 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema(
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
||||
@pytest.mark.parametrize(
|
||||
"model", ["gemini-1.5-flash", "claude-3-sonnet@20240229"]
|
||||
) # "vertex_ai",
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_pro_httpx_custom_api_base(provider):
|
||||
async def test_gemini_pro_httpx_custom_api_base(model):
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
messages = [
|
||||
|
@ -1634,7 +1636,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
|
|||
with patch.object(client, "post", new=MagicMock()) as mock_call:
|
||||
try:
|
||||
response = completion(
|
||||
model="vertex_ai_beta/gemini-1.5-flash",
|
||||
model="vertex_ai/{}".format(model),
|
||||
messages=messages,
|
||||
response_format={"type": "json_object"},
|
||||
client=client,
|
||||
|
@ -1647,8 +1649,17 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
|
|||
|
||||
mock_call.assert_called_once()
|
||||
|
||||
assert "my-custom-api-base:generateContent" == mock_call.call_args.kwargs["url"]
|
||||
assert "hello" in mock_call.call_args.kwargs["headers"]
|
||||
print(f"mock_call.call_args: {mock_call.call_args}")
|
||||
print(f"mock_call.call_args.kwargs: {mock_call.call_args.kwargs}")
|
||||
if "url" in mock_call.call_args.kwargs:
|
||||
assert (
|
||||
"my-custom-api-base:generateContent"
|
||||
== mock_call.call_args.kwargs["url"]
|
||||
)
|
||||
else:
|
||||
assert "my-custom-api-base:rawPredict" == mock_call.call_args[0][0]
|
||||
if "headers" in mock_call.call_args.kwargs:
|
||||
assert "hello" in mock_call.call_args.kwargs["headers"]
|
||||
|
||||
|
||||
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
||||
|
|
|
@ -28,7 +28,6 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
@ -226,6 +225,8 @@ class MyCustomLLM(CustomLLM):
|
|||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
|
@ -242,6 +243,8 @@ class MyCustomLLM(CustomLLM):
|
|||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
|
@ -362,3 +365,31 @@ async def test_simple_image_generation_async():
|
|||
)
|
||||
|
||||
print(resp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_generation_async_with_api_key_and_api_base():
|
||||
my_custom_llm = MyCustomLLM()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
my_custom_llm, "aimage_generation", new=AsyncMock()
|
||||
) as mock_client:
|
||||
try:
|
||||
resp = await litellm.aimage_generation(
|
||||
model="custom_llm/my-fake-model",
|
||||
prompt="Hello world",
|
||||
api_key="my-api-key",
|
||||
api_base="my-api-base",
|
||||
)
|
||||
|
||||
print(resp)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_client.assert_awaited_once()
|
||||
|
||||
mock_client.call_args.kwargs["api_key"] == "my-api-key"
|
||||
mock_client.call_args.kwargs["api_base"] == "my-api-base"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue