forked from phoenix/litellm-mirror
LiteLLM Minor Fixes & Improvements (10/10/2024) (#6158)
* refactor(vertex_ai_partner_models/anthropic): refactor anthropic to use partner model logic * fix(vertex_ai/): support passing custom api base to partner models Fixes https://github.com/BerriAI/litellm/issues/4317 * fix(proxy_server.py): Fix prometheus premium user check logic * docs(prometheus.md): update quick start docs * fix(custom_llm.py): support passing dynamic api key + api base * fix(realtime_api/main.py): Add request/response logging for realtime api endpoints Closes https://github.com/BerriAI/litellm/issues/6081 * feat(openai/realtime): add openai realtime api logging Closes https://github.com/BerriAI/litellm/issues/6081 * fix(realtime_streaming.py): fix linting errors * fix(realtime_streaming.py): fix linting errors * fix: fix linting errors * fix pattern match router * Add literalai in the sidebar observability category (#6163) * fix: add literalai in the sidebar * fix: typo * update (#6160) * Feat: Add Langtrace integration (#5341) * Feat: Add Langtrace integration * add langtrace service name * fix timestamps for traces * add tests * Discard Callback + use existing otel logger * cleanup * remove print statments * remove callback * add docs * docs * add logging docs * format logging * remove emoji and add litellm proxy example * format logging * format `logging.md` * add langtrace docs to logging.md * sync conflict * docs fix * (perf) move s3 logging to Batch logging + async [94% faster perf under 100 RPS on 1 litellm instance] (#6165) * fix move s3 to use customLogger * add basic s3 logging test * add s3 to custom logger compatible * use batch logger for s3 * s3 set flush interval and batch size * fix s3 logging * add notes on s3 logging * fix s3 logging * add basic s3 logging test * fix s3 type errors * add test for sync logging on s3 * fix: fix to debug log --------- Co-authored-by: Ishaan Jaff <ishaanjaffer0324@gmail.com> Co-authored-by: Willy Douhard <willy.douhard@gmail.com> Co-authored-by: yujonglee <yujonglee.dev@gmail.com> Co-authored-by: Ali Waleed <ali@scale3labs.com>
This commit is contained in:
parent
9db4ccca9f
commit
11f9df923a
28 changed files with 966 additions and 760 deletions
|
@ -27,8 +27,7 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
success_callback: ["prometheus"]
|
callbacks: ["prometheus"]
|
||||||
failure_callback: ["prometheus"]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Start the proxy
|
Start the proxy
|
||||||
|
|
|
@ -53,6 +53,7 @@ _custom_logger_compatible_callbacks_literal = Literal[
|
||||||
"arize",
|
"arize",
|
||||||
"langtrace",
|
"langtrace",
|
||||||
"gcs_bucket",
|
"gcs_bucket",
|
||||||
|
"s3",
|
||||||
"opik",
|
"opik",
|
||||||
]
|
]
|
||||||
_known_custom_logger_compatible_callbacks: List = list(
|
_known_custom_logger_compatible_callbacks: List = list(
|
||||||
|
@ -931,7 +932,7 @@ from .llms.vertex_ai_and_google_ai_studio.vertex_embeddings.transformation impor
|
||||||
|
|
||||||
vertexAITextEmbeddingConfig = VertexAITextEmbeddingConfig()
|
vertexAITextEmbeddingConfig = VertexAITextEmbeddingConfig()
|
||||||
|
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.anthropic.transformation import (
|
||||||
VertexAIAnthropicConfig,
|
VertexAIAnthropicConfig,
|
||||||
)
|
)
|
||||||
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
|
from .llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import (
|
||||||
|
|
|
@ -21,6 +21,7 @@ class CustomBatchLogger(CustomLogger):
|
||||||
self,
|
self,
|
||||||
flush_lock: Optional[asyncio.Lock] = None,
|
flush_lock: Optional[asyncio.Lock] = None,
|
||||||
batch_size: Optional[int] = DEFAULT_BATCH_SIZE,
|
batch_size: Optional[int] = DEFAULT_BATCH_SIZE,
|
||||||
|
flush_interval: Optional[int] = DEFAULT_FLUSH_INTERVAL_SECONDS,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -28,7 +29,7 @@ class CustomBatchLogger(CustomLogger):
|
||||||
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
|
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
|
||||||
"""
|
"""
|
||||||
self.log_queue: List = []
|
self.log_queue: List = []
|
||||||
self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds
|
self.flush_interval = flush_interval or DEFAULT_FLUSH_INTERVAL_SECONDS
|
||||||
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
|
self.batch_size: int = batch_size or DEFAULT_BATCH_SIZE
|
||||||
self.last_flush_time = time.time()
|
self.last_flush_time = time.time()
|
||||||
self.flush_lock = flush_lock
|
self.flush_lock = flush_lock
|
||||||
|
|
|
@ -235,6 +235,14 @@ class LangFuseLogger:
|
||||||
):
|
):
|
||||||
input = prompt
|
input = prompt
|
||||||
output = response_obj.results
|
output = response_obj.results
|
||||||
|
elif (
|
||||||
|
kwargs.get("call_type") is not None
|
||||||
|
and kwargs.get("call_type") == "_arealtime"
|
||||||
|
and response_obj is not None
|
||||||
|
and isinstance(response_obj, list)
|
||||||
|
):
|
||||||
|
input = kwargs.get("input")
|
||||||
|
output = response_obj
|
||||||
elif (
|
elif (
|
||||||
kwargs.get("call_type") is not None
|
kwargs.get("call_type") is not None
|
||||||
and kwargs.get("call_type") == "pass_through_endpoint"
|
and kwargs.get("call_type") == "pass_through_endpoint"
|
||||||
|
|
|
@ -1,43 +1,67 @@
|
||||||
#### What this does ####
|
"""
|
||||||
# On success + failure, log events to Supabase
|
s3 Bucket Logging Integration
|
||||||
|
|
||||||
import datetime
|
async_log_success_event: Processes the event, stores it in memory for 10 seconds or until MAX_BATCH_SIZE and then flushes to s3
|
||||||
import os
|
|
||||||
import subprocess
|
NOTE 1: S3 does not provide a BATCH PUT API endpoint, so we create tasks to upload each element individually
|
||||||
import sys
|
NOTE 2: We create a httpx client with a concurrent limit of 1 to upload to s3. Files should get uploaded BUT they should not impact latency of LLM calling logic
|
||||||
import traceback
|
"""
|
||||||
import uuid
|
|
||||||
from typing import Optional
|
import asyncio
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
|
from litellm.llms.base_aws_llm import BaseAWSLLM
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
_get_httpx_client,
|
||||||
|
get_async_httpx_client,
|
||||||
|
httpxSpecialProvider,
|
||||||
|
)
|
||||||
|
from litellm.types.integrations.s3 import s3BatchLoggingElement
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
|
from .custom_batch_logger import CustomBatchLogger
|
||||||
|
|
||||||
class S3Logger:
|
# Default Flush interval and batch size for s3
|
||||||
|
# Flush to s3 every 10 seconds OR every 1K requests in memory
|
||||||
|
DEFAULT_S3_FLUSH_INTERVAL_SECONDS = 10
|
||||||
|
DEFAULT_S3_BATCH_SIZE = 1000
|
||||||
|
|
||||||
|
|
||||||
|
class S3Logger(CustomBatchLogger, BaseAWSLLM):
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
s3_bucket_name=None,
|
s3_bucket_name: Optional[str] = None,
|
||||||
s3_path=None,
|
s3_path: Optional[str] = None,
|
||||||
s3_region_name=None,
|
s3_region_name: Optional[str] = None,
|
||||||
s3_api_version=None,
|
s3_api_version: Optional[str] = None,
|
||||||
s3_use_ssl=True,
|
s3_use_ssl: bool = True,
|
||||||
s3_verify=None,
|
s3_verify: Optional[bool] = None,
|
||||||
s3_endpoint_url=None,
|
s3_endpoint_url: Optional[str] = None,
|
||||||
s3_aws_access_key_id=None,
|
s3_aws_access_key_id: Optional[str] = None,
|
||||||
s3_aws_secret_access_key=None,
|
s3_aws_secret_access_key: Optional[str] = None,
|
||||||
s3_aws_session_token=None,
|
s3_aws_session_token: Optional[str] = None,
|
||||||
|
s3_flush_interval: Optional[int] = DEFAULT_S3_FLUSH_INTERVAL_SECONDS,
|
||||||
|
s3_batch_size: Optional[int] = DEFAULT_S3_BATCH_SIZE,
|
||||||
s3_config=None,
|
s3_config=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
import boto3
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}"
|
f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# IMPORTANT: We use a concurrent limit of 1 to upload to s3
|
||||||
|
# Files should get uploaded BUT they should not impact latency of LLM calling logic
|
||||||
|
self.async_httpx_client = get_async_httpx_client(
|
||||||
|
llm_provider=httpxSpecialProvider.LoggingCallback,
|
||||||
|
params={"concurrent_limit": 1},
|
||||||
|
)
|
||||||
|
|
||||||
if litellm.s3_callback_params is not None:
|
if litellm.s3_callback_params is not None:
|
||||||
# read in .env variables - example os.environ/AWS_BUCKET_NAME
|
# read in .env variables - example os.environ/AWS_BUCKET_NAME
|
||||||
for key, value in litellm.s3_callback_params.items():
|
for key, value in litellm.s3_callback_params.items():
|
||||||
|
@ -63,107 +87,282 @@ class S3Logger:
|
||||||
s3_path = litellm.s3_callback_params.get("s3_path")
|
s3_path = litellm.s3_callback_params.get("s3_path")
|
||||||
# done reading litellm.s3_callback_params
|
# done reading litellm.s3_callback_params
|
||||||
|
|
||||||
|
s3_flush_interval = litellm.s3_callback_params.get(
|
||||||
|
"s3_flush_interval", DEFAULT_S3_FLUSH_INTERVAL_SECONDS
|
||||||
|
)
|
||||||
|
s3_batch_size = litellm.s3_callback_params.get(
|
||||||
|
"s3_batch_size", DEFAULT_S3_BATCH_SIZE
|
||||||
|
)
|
||||||
|
|
||||||
self.bucket_name = s3_bucket_name
|
self.bucket_name = s3_bucket_name
|
||||||
self.s3_path = s3_path
|
self.s3_path = s3_path
|
||||||
verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}")
|
verbose_logger.debug(f"s3 logger using endpoint url {s3_endpoint_url}")
|
||||||
# Create an S3 client with custom endpoint URL
|
self.s3_bucket_name = s3_bucket_name
|
||||||
self.s3_client = boto3.client(
|
self.s3_region_name = s3_region_name
|
||||||
"s3",
|
self.s3_api_version = s3_api_version
|
||||||
region_name=s3_region_name,
|
self.s3_use_ssl = s3_use_ssl
|
||||||
endpoint_url=s3_endpoint_url,
|
self.s3_verify = s3_verify
|
||||||
api_version=s3_api_version,
|
self.s3_endpoint_url = s3_endpoint_url
|
||||||
use_ssl=s3_use_ssl,
|
self.s3_aws_access_key_id = s3_aws_access_key_id
|
||||||
verify=s3_verify,
|
self.s3_aws_secret_access_key = s3_aws_secret_access_key
|
||||||
aws_access_key_id=s3_aws_access_key_id,
|
self.s3_aws_session_token = s3_aws_session_token
|
||||||
aws_secret_access_key=s3_aws_secret_access_key,
|
self.s3_config = s3_config
|
||||||
aws_session_token=s3_aws_session_token,
|
self.init_kwargs = kwargs
|
||||||
config=s3_config,
|
|
||||||
**kwargs,
|
asyncio.create_task(self.periodic_flush())
|
||||||
|
self.flush_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
f"s3 flush interval: {s3_flush_interval}, s3 batch size: {s3_batch_size}"
|
||||||
)
|
)
|
||||||
|
# Call CustomLogger's __init__
|
||||||
|
CustomBatchLogger.__init__(
|
||||||
|
self,
|
||||||
|
flush_lock=self.flush_lock,
|
||||||
|
flush_interval=s3_flush_interval,
|
||||||
|
batch_size=s3_batch_size,
|
||||||
|
)
|
||||||
|
self.log_queue: List[s3BatchLoggingElement] = []
|
||||||
|
|
||||||
|
# Call BaseAWSLLM's __init__
|
||||||
|
BaseAWSLLM.__init__(self)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"Got exception on init s3 client {str(e)}")
|
print_verbose(f"Got exception on init s3 client {str(e)}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _async_log_event(
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
self, kwargs, response_obj, start_time, end_time, print_verbose
|
|
||||||
):
|
|
||||||
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
|
|
||||||
|
|
||||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
|
||||||
try:
|
try:
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
f"s3 Logging - Enters logging function for model {kwargs}"
|
f"s3 Logging - Enters logging function for model {kwargs}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# construct payload to send to s3
|
s3_batch_logging_element = self.create_s3_batch_logging_element(
|
||||||
# follows the same params as langfuse.py
|
start_time=start_time,
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
standard_logging_payload=kwargs.get("standard_logging_object", None),
|
||||||
metadata = (
|
s3_path=self.s3_path,
|
||||||
litellm_params.get("metadata", {}) or {}
|
|
||||||
) # if litellm_params['metadata'] == None
|
|
||||||
|
|
||||||
# Clean Metadata before logging - never log raw metadata
|
|
||||||
# the raw metadata can contain circular references which leads to infinite recursion
|
|
||||||
# we clean out all extra litellm metadata params before logging
|
|
||||||
clean_metadata = {}
|
|
||||||
if isinstance(metadata, dict):
|
|
||||||
for key, value in metadata.items():
|
|
||||||
# clean litellm metadata before logging
|
|
||||||
if key in [
|
|
||||||
"headers",
|
|
||||||
"endpoint",
|
|
||||||
"caching_groups",
|
|
||||||
"previous_models",
|
|
||||||
]:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
clean_metadata[key] = value
|
|
||||||
|
|
||||||
# Ensure everything in the payload is converted to str
|
|
||||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
|
||||||
"standard_logging_object", None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if payload is None:
|
if s3_batch_logging_element is None:
|
||||||
return
|
raise ValueError("s3_batch_logging_element is None")
|
||||||
|
|
||||||
s3_file_name = litellm.utils.get_logging_id(start_time, payload) or ""
|
verbose_logger.debug(
|
||||||
s3_object_key = (
|
"\ns3 Logger - Logging payload = %s", s3_batch_logging_element
|
||||||
(self.s3_path.rstrip("/") + "/" if self.s3_path else "")
|
|
||||||
+ start_time.strftime("%Y-%m-%d")
|
|
||||||
+ "/"
|
|
||||||
+ s3_file_name
|
|
||||||
) # we need the s3 key to include the time, so we log cache hits too
|
|
||||||
s3_object_key += ".json"
|
|
||||||
|
|
||||||
s3_object_download_filename = (
|
|
||||||
"time-"
|
|
||||||
+ start_time.strftime("%Y-%m-%dT%H-%M-%S-%f")
|
|
||||||
+ "_"
|
|
||||||
+ payload["id"]
|
|
||||||
+ ".json"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import json
|
self.log_queue.append(s3_batch_logging_element)
|
||||||
|
verbose_logger.debug(
|
||||||
payload_str = json.dumps(payload)
|
"s3 logging: queue length %s, batch size %s",
|
||||||
|
len(self.log_queue),
|
||||||
print_verbose(f"\ns3 Logger - Logging payload = {payload_str}")
|
self.batch_size,
|
||||||
|
|
||||||
response = self.s3_client.put_object(
|
|
||||||
Bucket=self.bucket_name,
|
|
||||||
Key=s3_object_key,
|
|
||||||
Body=payload_str,
|
|
||||||
ContentType="application/json",
|
|
||||||
ContentLanguage="en",
|
|
||||||
ContentDisposition=f'inline; filename="{s3_object_download_filename}"',
|
|
||||||
CacheControl="private, immutable, max-age=31536000, s-maxage=0",
|
|
||||||
)
|
)
|
||||||
|
if len(self.log_queue) >= self.batch_size:
|
||||||
print_verbose(f"Response from s3:{str(response)}")
|
await self.flush_queue()
|
||||||
|
|
||||||
print_verbose(f"s3 Layer Logging - final response object: {response_obj}")
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
|
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
"""
|
||||||
|
Synchronous logging function to log to s3
|
||||||
|
|
||||||
|
Does not batch logging requests, instantly logs on s3 Bucket
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
s3_batch_logging_element = self.create_s3_batch_logging_element(
|
||||||
|
start_time=start_time,
|
||||||
|
standard_logging_payload=kwargs.get("standard_logging_object", None),
|
||||||
|
s3_path=self.s3_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if s3_batch_logging_element is None:
|
||||||
|
raise ValueError("s3_batch_logging_element is None")
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
"\ns3 Logger - Logging payload = %s", s3_batch_logging_element
|
||||||
|
)
|
||||||
|
|
||||||
|
# log the element sync httpx client
|
||||||
|
self.upload_data_to_s3(s3_batch_logging_element)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.exception(f"s3 Layer Error - {str(e)}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_upload_data_to_s3(
|
||||||
|
self, batch_logging_element: s3BatchLoggingElement
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import requests
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
try:
|
||||||
|
credentials: Credentials = self.get_credentials(
|
||||||
|
aws_access_key_id=self.s3_aws_access_key_id,
|
||||||
|
aws_secret_access_key=self.s3_aws_secret_access_key,
|
||||||
|
aws_session_token=self.s3_aws_session_token,
|
||||||
|
aws_region_name=self.s3_region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare the URL
|
||||||
|
url = f"https://{self.bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}"
|
||||||
|
|
||||||
|
if self.s3_endpoint_url:
|
||||||
|
url = self.s3_endpoint_url + "/" + batch_logging_element.s3_object_key
|
||||||
|
|
||||||
|
# Convert JSON to string
|
||||||
|
json_string = json.dumps(batch_logging_element.payload)
|
||||||
|
|
||||||
|
# Calculate SHA256 hash of the content
|
||||||
|
content_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
# Prepare the request
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"x-amz-content-sha256": content_hash,
|
||||||
|
"Content-Language": "en",
|
||||||
|
"Content-Disposition": f'inline; filename="{batch_logging_element.s3_object_download_filename}"',
|
||||||
|
"Cache-Control": "private, immutable, max-age=31536000, s-maxage=0",
|
||||||
|
}
|
||||||
|
req = requests.Request("PUT", url, data=json_string, headers=headers)
|
||||||
|
prepped = req.prepare()
|
||||||
|
|
||||||
|
# Sign the request
|
||||||
|
aws_request = AWSRequest(
|
||||||
|
method=prepped.method,
|
||||||
|
url=prepped.url,
|
||||||
|
data=prepped.body,
|
||||||
|
headers=prepped.headers,
|
||||||
|
)
|
||||||
|
SigV4Auth(credentials, "s3", self.s3_region_name).add_auth(aws_request)
|
||||||
|
|
||||||
|
# Prepare the signed headers
|
||||||
|
signed_headers = dict(aws_request.headers.items())
|
||||||
|
|
||||||
|
# Make the request
|
||||||
|
response = await self.async_httpx_client.put(
|
||||||
|
url, data=json_string, headers=signed_headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.exception(f"Error uploading to s3: {str(e)}")
|
||||||
|
|
||||||
|
async def async_send_batch(self):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Sends runs from self.log_queue
|
||||||
|
|
||||||
|
Returns: None
|
||||||
|
|
||||||
|
Raises: Does not raise an exception, will only verbose_logger.exception()
|
||||||
|
"""
|
||||||
|
if not self.log_queue:
|
||||||
|
return
|
||||||
|
|
||||||
|
for payload in self.log_queue:
|
||||||
|
asyncio.create_task(self.async_upload_data_to_s3(payload))
|
||||||
|
|
||||||
|
def create_s3_batch_logging_element(
|
||||||
|
self,
|
||||||
|
start_time: datetime,
|
||||||
|
standard_logging_payload: Optional[StandardLoggingPayload],
|
||||||
|
s3_path: Optional[str],
|
||||||
|
) -> Optional[s3BatchLoggingElement]:
|
||||||
|
"""
|
||||||
|
Helper function to create an s3BatchLoggingElement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_time (datetime): The start time of the logging event.
|
||||||
|
standard_logging_payload (Optional[StandardLoggingPayload]): The payload to be logged.
|
||||||
|
s3_path (Optional[str]): The S3 path prefix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[s3BatchLoggingElement]: The created s3BatchLoggingElement, or None if payload is None.
|
||||||
|
"""
|
||||||
|
if standard_logging_payload is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
s3_file_name = (
|
||||||
|
litellm.utils.get_logging_id(start_time, standard_logging_payload) or ""
|
||||||
|
)
|
||||||
|
s3_object_key = (
|
||||||
|
(s3_path.rstrip("/") + "/" if s3_path else "")
|
||||||
|
+ start_time.strftime("%Y-%m-%d")
|
||||||
|
+ "/"
|
||||||
|
+ s3_file_name
|
||||||
|
+ ".json"
|
||||||
|
)
|
||||||
|
|
||||||
|
s3_object_download_filename = f"time-{start_time.strftime('%Y-%m-%dT%H-%M-%S-%f')}_{standard_logging_payload['id']}.json"
|
||||||
|
|
||||||
|
return s3BatchLoggingElement(
|
||||||
|
payload=standard_logging_payload, # type: ignore
|
||||||
|
s3_object_key=s3_object_key,
|
||||||
|
s3_object_download_filename=s3_object_download_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
def upload_data_to_s3(self, batch_logging_element: s3BatchLoggingElement):
|
||||||
|
try:
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import requests
|
||||||
|
from botocore.auth import SigV4Auth
|
||||||
|
from botocore.awsrequest import AWSRequest
|
||||||
|
from botocore.credentials import Credentials
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||||
|
try:
|
||||||
|
credentials: Credentials = self.get_credentials(
|
||||||
|
aws_access_key_id=self.s3_aws_access_key_id,
|
||||||
|
aws_secret_access_key=self.s3_aws_secret_access_key,
|
||||||
|
aws_session_token=self.s3_aws_session_token,
|
||||||
|
aws_region_name=self.s3_region_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare the URL
|
||||||
|
url = f"https://{self.bucket_name}.s3.{self.s3_region_name}.amazonaws.com/{batch_logging_element.s3_object_key}"
|
||||||
|
|
||||||
|
if self.s3_endpoint_url:
|
||||||
|
url = self.s3_endpoint_url + "/" + batch_logging_element.s3_object_key
|
||||||
|
|
||||||
|
# Convert JSON to string
|
||||||
|
json_string = json.dumps(batch_logging_element.payload)
|
||||||
|
|
||||||
|
# Calculate SHA256 hash of the content
|
||||||
|
content_hash = hashlib.sha256(json_string.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
# Prepare the request
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"x-amz-content-sha256": content_hash,
|
||||||
|
"Content-Language": "en",
|
||||||
|
"Content-Disposition": f'inline; filename="{batch_logging_element.s3_object_download_filename}"',
|
||||||
|
"Cache-Control": "private, immutable, max-age=31536000, s-maxage=0",
|
||||||
|
}
|
||||||
|
req = requests.Request("PUT", url, data=json_string, headers=headers)
|
||||||
|
prepped = req.prepare()
|
||||||
|
|
||||||
|
# Sign the request
|
||||||
|
aws_request = AWSRequest(
|
||||||
|
method=prepped.method,
|
||||||
|
url=prepped.url,
|
||||||
|
data=prepped.body,
|
||||||
|
headers=prepped.headers,
|
||||||
|
)
|
||||||
|
SigV4Auth(credentials, "s3", self.s3_region_name).add_auth(aws_request)
|
||||||
|
|
||||||
|
# Prepare the signed headers
|
||||||
|
signed_headers = dict(aws_request.headers.items())
|
||||||
|
|
||||||
|
httpx_client = _get_httpx_client()
|
||||||
|
# Make the request
|
||||||
|
response = httpx_client.put(url, data=json_string, headers=signed_headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.exception(f"Error uploading to s3: {str(e)}")
|
||||||
|
|
|
@ -116,7 +116,6 @@ lagoLogger = None
|
||||||
dataDogLogger = None
|
dataDogLogger = None
|
||||||
prometheusLogger = None
|
prometheusLogger = None
|
||||||
dynamoLogger = None
|
dynamoLogger = None
|
||||||
s3Logger = None
|
|
||||||
genericAPILogger = None
|
genericAPILogger = None
|
||||||
clickHouseLogger = None
|
clickHouseLogger = None
|
||||||
greenscaleLogger = None
|
greenscaleLogger = None
|
||||||
|
@ -1346,36 +1345,6 @@ class Logging:
|
||||||
user_id=kwargs.get("user", None),
|
user_id=kwargs.get("user", None),
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
if callback == "s3":
|
|
||||||
global s3Logger
|
|
||||||
if s3Logger is None:
|
|
||||||
s3Logger = S3Logger()
|
|
||||||
if self.stream:
|
|
||||||
if "complete_streaming_response" in self.model_call_details:
|
|
||||||
print_verbose(
|
|
||||||
"S3Logger Logger: Got Stream Event - Completed Stream Response"
|
|
||||||
)
|
|
||||||
s3Logger.log_event(
|
|
||||||
kwargs=self.model_call_details,
|
|
||||||
response_obj=self.model_call_details[
|
|
||||||
"complete_streaming_response"
|
|
||||||
],
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print_verbose(
|
|
||||||
"S3Logger Logger: Got Stream Event - No complete stream response as yet"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
s3Logger.log_event(
|
|
||||||
kwargs=self.model_call_details,
|
|
||||||
response_obj=result,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
callback == "openmeter"
|
callback == "openmeter"
|
||||||
and self.model_call_details.get("litellm_params", {}).get(
|
and self.model_call_details.get("litellm_params", {}).get(
|
||||||
|
@ -2245,7 +2214,7 @@ def set_callbacks(callback_list, function_id=None):
|
||||||
"""
|
"""
|
||||||
Globally sets the callback client
|
Globally sets the callback client
|
||||||
"""
|
"""
|
||||||
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
|
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, logfireLogger, dynamoLogger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for callback in callback_list:
|
for callback in callback_list:
|
||||||
|
@ -2319,8 +2288,6 @@ def set_callbacks(callback_list, function_id=None):
|
||||||
dataDogLogger = DataDogLogger()
|
dataDogLogger = DataDogLogger()
|
||||||
elif callback == "dynamodb":
|
elif callback == "dynamodb":
|
||||||
dynamoLogger = DyanmoDBLogger()
|
dynamoLogger = DyanmoDBLogger()
|
||||||
elif callback == "s3":
|
|
||||||
s3Logger = S3Logger()
|
|
||||||
elif callback == "wandb":
|
elif callback == "wandb":
|
||||||
weightsBiasesLogger = WeightsBiasesLogger()
|
weightsBiasesLogger = WeightsBiasesLogger()
|
||||||
elif callback == "logfire":
|
elif callback == "logfire":
|
||||||
|
@ -2357,7 +2324,6 @@ def _init_custom_logger_compatible_class(
|
||||||
llm_router: Optional[
|
llm_router: Optional[
|
||||||
Any
|
Any
|
||||||
], # expect litellm.Router, but typing errors due to circular import
|
], # expect litellm.Router, but typing errors due to circular import
|
||||||
premium_user: Optional[bool] = None,
|
|
||||||
) -> Optional[CustomLogger]:
|
) -> Optional[CustomLogger]:
|
||||||
if logging_integration == "lago":
|
if logging_integration == "lago":
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
|
@ -2404,17 +2370,9 @@ def _init_custom_logger_compatible_class(
|
||||||
if isinstance(callback, PrometheusLogger):
|
if isinstance(callback, PrometheusLogger):
|
||||||
return callback # type: ignore
|
return callback # type: ignore
|
||||||
|
|
||||||
if premium_user:
|
_prometheus_logger = PrometheusLogger()
|
||||||
_prometheus_logger = PrometheusLogger()
|
_in_memory_loggers.append(_prometheus_logger)
|
||||||
_in_memory_loggers.append(_prometheus_logger)
|
return _prometheus_logger # type: ignore
|
||||||
return _prometheus_logger # type: ignore
|
|
||||||
elif premium_user is False:
|
|
||||||
verbose_logger.warning(
|
|
||||||
f"🚨🚨🚨 Prometheus Metrics is on LiteLLM Enterprise\n🚨 {CommonProxyErrors.not_premium_user.value}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
elif logging_integration == "datadog":
|
elif logging_integration == "datadog":
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if isinstance(callback, DataDogLogger):
|
if isinstance(callback, DataDogLogger):
|
||||||
|
@ -2423,6 +2381,14 @@ def _init_custom_logger_compatible_class(
|
||||||
_datadog_logger = DataDogLogger()
|
_datadog_logger = DataDogLogger()
|
||||||
_in_memory_loggers.append(_datadog_logger)
|
_in_memory_loggers.append(_datadog_logger)
|
||||||
return _datadog_logger # type: ignore
|
return _datadog_logger # type: ignore
|
||||||
|
elif logging_integration == "s3":
|
||||||
|
for callback in _in_memory_loggers:
|
||||||
|
if isinstance(callback, S3Logger):
|
||||||
|
return callback # type: ignore
|
||||||
|
|
||||||
|
_s3_logger = S3Logger()
|
||||||
|
_in_memory_loggers.append(_s3_logger)
|
||||||
|
return _s3_logger # type: ignore
|
||||||
elif logging_integration == "gcs_bucket":
|
elif logging_integration == "gcs_bucket":
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if isinstance(callback, GCSBucketLogger):
|
if isinstance(callback, GCSBucketLogger):
|
||||||
|
@ -2589,6 +2555,10 @@ def get_custom_logger_compatible_class(
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if isinstance(callback, PrometheusLogger):
|
if isinstance(callback, PrometheusLogger):
|
||||||
return callback
|
return callback
|
||||||
|
elif logging_integration == "s3":
|
||||||
|
for callback in _in_memory_loggers:
|
||||||
|
if isinstance(callback, S3Logger):
|
||||||
|
return callback
|
||||||
elif logging_integration == "datadog":
|
elif logging_integration == "datadog":
|
||||||
for callback in _in_memory_loggers:
|
for callback in _in_memory_loggers:
|
||||||
if isinstance(callback, DataDogLogger):
|
if isinstance(callback, DataDogLogger):
|
||||||
|
|
112
litellm/litellm_core_utils/realtime_streaming.py
Normal file
112
litellm/litellm_core_utils/realtime_streaming.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
"""
|
||||||
|
async with websockets.connect( # type: ignore
|
||||||
|
url,
|
||||||
|
extra_headers={
|
||||||
|
"api-key": api_key, # type: ignore
|
||||||
|
},
|
||||||
|
) as backend_ws:
|
||||||
|
forward_task = asyncio.create_task(
|
||||||
|
forward_messages(websocket, backend_ws)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await websocket.receive_text()
|
||||||
|
await backend_ws.send(message)
|
||||||
|
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||||
|
forward_task.cancel()
|
||||||
|
finally:
|
||||||
|
if not forward_task.done():
|
||||||
|
forward_task.cancel()
|
||||||
|
try:
|
||||||
|
await forward_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
import traceback
|
||||||
|
from asyncio import Task
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from .litellm_logging import Logging as LiteLLMLogging
|
||||||
|
|
||||||
|
# Create a thread pool with a maximum of 10 threads
|
||||||
|
executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
|
||||||
|
|
||||||
|
|
||||||
|
class RealTimeStreaming:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
websocket: Any,
|
||||||
|
backend_ws: Any,
|
||||||
|
logging_obj: Optional[LiteLLMLogging] = None,
|
||||||
|
):
|
||||||
|
self.websocket = websocket
|
||||||
|
self.backend_ws = backend_ws
|
||||||
|
self.logging_obj = logging_obj
|
||||||
|
self.messages: List = []
|
||||||
|
self.input_message: Dict = {}
|
||||||
|
|
||||||
|
def store_message(self, message: Union[str, bytes]):
|
||||||
|
"""Store message in list"""
|
||||||
|
self.messages.append(message)
|
||||||
|
|
||||||
|
def store_input(self, message: dict):
|
||||||
|
"""Store input message"""
|
||||||
|
self.input_message = message
|
||||||
|
if self.logging_obj:
|
||||||
|
self.logging_obj.pre_call(input=message, api_key="")
|
||||||
|
|
||||||
|
async def log_messages(self):
|
||||||
|
"""Log messages in list"""
|
||||||
|
if self.logging_obj:
|
||||||
|
## ASYNC LOGGING
|
||||||
|
# Create an event loop for the new thread
|
||||||
|
asyncio.create_task(self.logging_obj.async_success_handler(self.messages))
|
||||||
|
## SYNC LOGGING
|
||||||
|
executor.submit(self.logging_obj.success_handler(self.messages))
|
||||||
|
|
||||||
|
async def backend_to_client_send_messages(self):
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await self.backend_ws.recv()
|
||||||
|
await self.websocket.send_text(message)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
self.store_message(message)
|
||||||
|
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
await self.log_messages()
|
||||||
|
|
||||||
|
async def client_ack_messages(self):
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await self.websocket.receive_text()
|
||||||
|
## LOGGING
|
||||||
|
self.store_input(message=message)
|
||||||
|
## FORWARD TO BACKEND
|
||||||
|
await self.backend_ws.send(message)
|
||||||
|
except self.websockets.exceptions.ConnectionClosed: # type: ignore
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def bidirectional_forward(self):
|
||||||
|
|
||||||
|
forward_task = asyncio.create_task(self.backend_to_client_send_messages())
|
||||||
|
try:
|
||||||
|
await self.client_ack_messages()
|
||||||
|
except self.websockets.exceptions.ConnectionClosed: # type: ignore
|
||||||
|
forward_task.cancel()
|
||||||
|
finally:
|
||||||
|
if not forward_task.done():
|
||||||
|
forward_task.cancel()
|
||||||
|
try:
|
||||||
|
await forward_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
|
@ -7,6 +7,8 @@ This requires websockets, and is currently only supported on LiteLLM Proxy.
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
|
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
|
||||||
from ..azure import AzureChatCompletion
|
from ..azure import AzureChatCompletion
|
||||||
|
|
||||||
# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
|
# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
|
||||||
|
@ -44,6 +46,7 @@ class AzureOpenAIRealtime(AzureChatCompletion):
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
azure_ad_token: Optional[str] = None,
|
azure_ad_token: Optional[str] = None,
|
||||||
client: Optional[Any] = None,
|
client: Optional[Any] = None,
|
||||||
|
logging_obj: Optional[LiteLLMLogging] = None,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
):
|
):
|
||||||
import websockets
|
import websockets
|
||||||
|
@ -62,23 +65,10 @@ class AzureOpenAIRealtime(AzureChatCompletion):
|
||||||
"api-key": api_key, # type: ignore
|
"api-key": api_key, # type: ignore
|
||||||
},
|
},
|
||||||
) as backend_ws:
|
) as backend_ws:
|
||||||
forward_task = asyncio.create_task(
|
realtime_streaming = RealTimeStreaming(
|
||||||
forward_messages(websocket, backend_ws)
|
websocket, backend_ws, logging_obj
|
||||||
)
|
)
|
||||||
|
await realtime_streaming.bidirectional_forward()
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
message = await websocket.receive_text()
|
|
||||||
await backend_ws.send(message)
|
|
||||||
except websockets.exceptions.ConnectionClosed: # type: ignore
|
|
||||||
forward_task.cancel()
|
|
||||||
finally:
|
|
||||||
if not forward_task.done():
|
|
||||||
forward_task.cancel()
|
|
||||||
try:
|
|
||||||
await forward_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||||
await websocket.close(code=e.status_code, reason=str(e))
|
await websocket.close(code=e.status_code, reason=str(e))
|
||||||
|
|
|
@ -7,20 +7,11 @@ This requires websockets, and is currently only supported on LiteLLM Proxy.
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
|
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
|
||||||
from ..openai import OpenAIChatCompletion
|
from ..openai import OpenAIChatCompletion
|
||||||
|
|
||||||
|
|
||||||
async def forward_messages(client_ws: Any, backend_ws: Any):
|
|
||||||
import websockets
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
message = await backend_ws.recv()
|
|
||||||
await client_ws.send_text(message)
|
|
||||||
except websockets.exceptions.ConnectionClosed: # type: ignore
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIRealtime(OpenAIChatCompletion):
|
class OpenAIRealtime(OpenAIChatCompletion):
|
||||||
def _construct_url(self, api_base: str, model: str) -> str:
|
def _construct_url(self, api_base: str, model: str) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -35,6 +26,7 @@ class OpenAIRealtime(OpenAIChatCompletion):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
websocket: Any,
|
websocket: Any,
|
||||||
|
logging_obj: LiteLLMLogging,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
client: Optional[Any] = None,
|
client: Optional[Any] = None,
|
||||||
|
@ -57,25 +49,26 @@ class OpenAIRealtime(OpenAIChatCompletion):
|
||||||
"OpenAI-Beta": "realtime=v1",
|
"OpenAI-Beta": "realtime=v1",
|
||||||
},
|
},
|
||||||
) as backend_ws:
|
) as backend_ws:
|
||||||
forward_task = asyncio.create_task(
|
realtime_streaming = RealTimeStreaming(
|
||||||
forward_messages(websocket, backend_ws)
|
websocket, backend_ws, logging_obj
|
||||||
)
|
)
|
||||||
|
await realtime_streaming.bidirectional_forward()
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
message = await websocket.receive_text()
|
|
||||||
await backend_ws.send(message)
|
|
||||||
except websockets.exceptions.ConnectionClosed: # type: ignore
|
|
||||||
forward_task.cancel()
|
|
||||||
finally:
|
|
||||||
if not forward_task.done():
|
|
||||||
forward_task.cancel()
|
|
||||||
try:
|
|
||||||
await forward_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||||
await websocket.close(code=e.status_code, reason=str(e))
|
await websocket.close(code=e.status_code, reason=str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await websocket.close(code=1011, reason=f"Internal server error: {str(e)}")
|
try:
|
||||||
|
await websocket.close(
|
||||||
|
code=1011, reason=f"Internal server error: {str(e)}"
|
||||||
|
)
|
||||||
|
except RuntimeError as close_error:
|
||||||
|
if "already completed" in str(close_error) or "websocket.close" in str(
|
||||||
|
close_error
|
||||||
|
):
|
||||||
|
# The WebSocket is already closed or the response is completed, so we can ignore this error
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# If it's a different RuntimeError, we might want to log it or handle it differently
|
||||||
|
raise Exception(
|
||||||
|
f"Unexpected error while closing WebSocket: {close_error}"
|
||||||
|
)
|
||||||
|
|
|
@ -153,6 +153,8 @@ class CustomLLM(BaseLLM):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
model_response: ImageResponse,
|
model_response: ImageResponse,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
|
@ -166,6 +168,12 @@ class CustomLLM(BaseLLM):
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model_response: ImageResponse,
|
model_response: ImageResponse,
|
||||||
|
api_key: Optional[
|
||||||
|
str
|
||||||
|
], # dynamically set api_key - https://docs.litellm.ai/docs/set_keys#api_key
|
||||||
|
api_base: Optional[
|
||||||
|
str
|
||||||
|
], # dynamically set api_base - https://docs.litellm.ai/docs/set_keys#api_base
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
|
|
|
@ -1,464 +0,0 @@
|
||||||
# What is this?
|
|
||||||
## Handler file for calling claude-3 on vertex ai
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import types
|
|
||||||
import uuid
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import httpx # type: ignore
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
|
||||||
from litellm.types.llms.anthropic import (
|
|
||||||
AnthropicMessagesTool,
|
|
||||||
AnthropicMessagesToolChoice,
|
|
||||||
)
|
|
||||||
from litellm.types.llms.openai import (
|
|
||||||
ChatCompletionToolParam,
|
|
||||||
ChatCompletionToolParamFunctionChunk,
|
|
||||||
)
|
|
||||||
from litellm.types.utils import ResponseFormatChunk
|
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
|
||||||
|
|
||||||
from ..prompt_templates.factory import (
|
|
||||||
construct_tool_use_system_prompt,
|
|
||||||
contains_tag,
|
|
||||||
custom_prompt,
|
|
||||||
extract_between_tags,
|
|
||||||
parse_xml_params,
|
|
||||||
prompt_factory,
|
|
||||||
response_schema_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
|
||||||
def __init__(self, status_code, message):
|
|
||||||
self.status_code = status_code
|
|
||||||
self.message = message
|
|
||||||
self.request = httpx.Request(
|
|
||||||
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
|
||||||
)
|
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
||||||
super().__init__(
|
|
||||||
self.message
|
|
||||||
) # Call the base class constructor with the parameters it needs
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIAnthropicConfig:
|
|
||||||
"""
|
|
||||||
Reference:https://docs.anthropic.com/claude/reference/messages_post
|
|
||||||
|
|
||||||
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
|
|
||||||
|
|
||||||
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
|
|
||||||
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
|
|
||||||
|
|
||||||
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
|
|
||||||
|
|
||||||
- `max_tokens` Required (integer) max tokens,
|
|
||||||
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
|
||||||
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
|
||||||
- `temperature` Optional (float) The amount of randomness injected into the response
|
|
||||||
- `top_p` Optional (float) Use nucleus sampling.
|
|
||||||
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
|
||||||
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
|
||||||
|
|
||||||
Note: Please make sure to modify the default parameters as required for your use case.
|
|
||||||
"""
|
|
||||||
|
|
||||||
max_tokens: Optional[int] = (
|
|
||||||
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
|
|
||||||
)
|
|
||||||
system: Optional[str] = None
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
top_k: Optional[int] = None
|
|
||||||
stop_sequences: Optional[List[str]] = None
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
anthropic_version: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
locals_ = locals()
|
|
||||||
for key, value in locals_.items():
|
|
||||||
if key == "max_tokens" and value is None:
|
|
||||||
value = self.max_tokens
|
|
||||||
if key != "self" and value is not None:
|
|
||||||
setattr(self.__class__, key, value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls):
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in cls.__dict__.items()
|
|
||||||
if not k.startswith("__")
|
|
||||||
and not isinstance(
|
|
||||||
v,
|
|
||||||
(
|
|
||||||
types.FunctionType,
|
|
||||||
types.BuiltinFunctionType,
|
|
||||||
classmethod,
|
|
||||||
staticmethod,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
and v is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_supported_openai_params(self):
|
|
||||||
return [
|
|
||||||
"max_tokens",
|
|
||||||
"max_completion_tokens",
|
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
"stream",
|
|
||||||
"stop",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"response_format",
|
|
||||||
]
|
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
||||||
for param, value in non_default_params.items():
|
|
||||||
if param == "max_tokens" or param == "max_completion_tokens":
|
|
||||||
optional_params["max_tokens"] = value
|
|
||||||
if param == "tools":
|
|
||||||
optional_params["tools"] = value
|
|
||||||
if param == "tool_choice":
|
|
||||||
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
|
|
||||||
if value == "auto":
|
|
||||||
_tool_choice = {"type": "auto"}
|
|
||||||
elif value == "required":
|
|
||||||
_tool_choice = {"type": "any"}
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
|
|
||||||
|
|
||||||
if _tool_choice is not None:
|
|
||||||
optional_params["tool_choice"] = _tool_choice
|
|
||||||
if param == "stream":
|
|
||||||
optional_params["stream"] = value
|
|
||||||
if param == "stop":
|
|
||||||
optional_params["stop_sequences"] = value
|
|
||||||
if param == "temperature":
|
|
||||||
optional_params["temperature"] = value
|
|
||||||
if param == "top_p":
|
|
||||||
optional_params["top_p"] = value
|
|
||||||
if param == "response_format" and isinstance(value, dict):
|
|
||||||
json_schema: Optional[dict] = None
|
|
||||||
if "response_schema" in value:
|
|
||||||
json_schema = value["response_schema"]
|
|
||||||
elif "json_schema" in value:
|
|
||||||
json_schema = value["json_schema"]["schema"]
|
|
||||||
"""
|
|
||||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
|
||||||
- You usually want to provide a single tool
|
|
||||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
|
||||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
|
||||||
"""
|
|
||||||
_tool_choice = None
|
|
||||||
_tool_choice = {"name": "json_tool_call", "type": "tool"}
|
|
||||||
|
|
||||||
_tool = AnthropicMessagesTool(
|
|
||||||
name="json_tool_call",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"values": json_schema}, # type: ignore
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
optional_params["tools"] = [_tool]
|
|
||||||
optional_params["tool_choice"] = _tool_choice
|
|
||||||
optional_params["json_mode"] = True
|
|
||||||
|
|
||||||
return optional_params
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
- Run client init
|
|
||||||
- Support async completion, streaming
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def refresh_auth(
|
|
||||||
credentials,
|
|
||||||
) -> str: # used when user passes in credentials as json string
|
|
||||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
|
||||||
|
|
||||||
if credentials.token is None:
|
|
||||||
credentials.refresh(Request())
|
|
||||||
|
|
||||||
if not credentials.token:
|
|
||||||
raise RuntimeError("Could not resolve API token from the credentials")
|
|
||||||
|
|
||||||
return credentials.token
|
|
||||||
|
|
||||||
|
|
||||||
def get_vertex_client(
|
|
||||||
client: Any,
|
|
||||||
vertex_project: Optional[str],
|
|
||||||
vertex_location: Optional[str],
|
|
||||||
vertex_credentials: Optional[str],
|
|
||||||
) -> Tuple[Any, Optional[str]]:
|
|
||||||
args = locals()
|
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
|
||||||
VertexLLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from anthropic import AnthropicVertex
|
|
||||||
except Exception:
|
|
||||||
raise VertexAIError(
|
|
||||||
status_code=400,
|
|
||||||
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
|
|
||||||
)
|
|
||||||
|
|
||||||
access_token: Optional[str] = None
|
|
||||||
|
|
||||||
if client is None:
|
|
||||||
_credentials, cred_project_id = VertexLLM().load_auth(
|
|
||||||
credentials=vertex_credentials, project_id=vertex_project
|
|
||||||
)
|
|
||||||
|
|
||||||
vertex_ai_client = AnthropicVertex(
|
|
||||||
project_id=vertex_project or cred_project_id,
|
|
||||||
region=vertex_location or "us-central1",
|
|
||||||
access_token=_credentials.token,
|
|
||||||
)
|
|
||||||
access_token = _credentials.token
|
|
||||||
else:
|
|
||||||
vertex_ai_client = client
|
|
||||||
access_token = client.access_token
|
|
||||||
|
|
||||||
return vertex_ai_client, access_token
|
|
||||||
|
|
||||||
|
|
||||||
def create_vertex_anthropic_url(
|
|
||||||
vertex_location: str, vertex_project: str, model: str, stream: bool
|
|
||||||
) -> str:
|
|
||||||
if stream is True:
|
|
||||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
|
|
||||||
else:
|
|
||||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
|
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
encoding,
|
|
||||||
logging_obj,
|
|
||||||
optional_params: dict,
|
|
||||||
custom_prompt_dict: dict,
|
|
||||||
headers: Optional[dict],
|
|
||||||
timeout: Union[float, httpx.Timeout],
|
|
||||||
vertex_project=None,
|
|
||||||
vertex_location=None,
|
|
||||||
vertex_credentials=None,
|
|
||||||
litellm_params=None,
|
|
||||||
logger_fn=None,
|
|
||||||
acompletion: bool = False,
|
|
||||||
client=None,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
import vertexai
|
|
||||||
except Exception:
|
|
||||||
raise VertexAIError(
|
|
||||||
status_code=400,
|
|
||||||
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
|
|
||||||
)
|
|
||||||
|
|
||||||
from anthropic import AnthropicVertex
|
|
||||||
|
|
||||||
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
|
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
|
||||||
VertexLLM,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not (
|
|
||||||
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
|
|
||||||
):
|
|
||||||
raise VertexAIError(
|
|
||||||
status_code=400,
|
|
||||||
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
|
|
||||||
vertex_httpx_logic = VertexLLM()
|
|
||||||
|
|
||||||
access_token, project_id = vertex_httpx_logic._ensure_access_token(
|
|
||||||
credentials=vertex_credentials,
|
|
||||||
project_id=vertex_project,
|
|
||||||
custom_llm_provider="vertex_ai",
|
|
||||||
)
|
|
||||||
|
|
||||||
anthropic_chat_completions = AnthropicChatCompletion()
|
|
||||||
|
|
||||||
## Load Config
|
|
||||||
config = litellm.VertexAIAnthropicConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if k not in optional_params:
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
## CONSTRUCT API BASE
|
|
||||||
stream = optional_params.get("stream", False)
|
|
||||||
|
|
||||||
api_base = create_vertex_anthropic_url(
|
|
||||||
vertex_location=vertex_location or "us-central1",
|
|
||||||
vertex_project=vertex_project or project_id,
|
|
||||||
model=model,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
if headers is not None:
|
|
||||||
vertex_headers = headers
|
|
||||||
else:
|
|
||||||
vertex_headers = {}
|
|
||||||
|
|
||||||
vertex_headers.update({"Authorization": "Bearer {}".format(access_token)})
|
|
||||||
|
|
||||||
optional_params.update(
|
|
||||||
{"anthropic_version": "vertex-2023-10-16", "is_vertex_request": True}
|
|
||||||
)
|
|
||||||
|
|
||||||
return anthropic_chat_completions.completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
api_base=api_base,
|
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
encoding=encoding,
|
|
||||||
api_key=access_token,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
optional_params=optional_params,
|
|
||||||
acompletion=acompletion,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
headers=vertex_headers,
|
|
||||||
client=client,
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise VertexAIError(status_code=500, message=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
async def async_completion(
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
data: dict,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
logging_obj,
|
|
||||||
vertex_project=None,
|
|
||||||
vertex_location=None,
|
|
||||||
optional_params=None,
|
|
||||||
client=None,
|
|
||||||
access_token=None,
|
|
||||||
):
|
|
||||||
from anthropic import AsyncAnthropicVertex
|
|
||||||
|
|
||||||
if client is None:
|
|
||||||
vertex_ai_client = AsyncAnthropicVertex(
|
|
||||||
project_id=vertex_project, region=vertex_location, access_token=access_token # type: ignore
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
vertex_ai_client = client
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
message = await vertex_ai_client.messages.create(**data) # type: ignore
|
|
||||||
text_content = message.content[0].text
|
|
||||||
## TOOL CALLING - OUTPUT PARSE
|
|
||||||
if text_content is not None and contains_tag("invoke", text_content):
|
|
||||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
|
||||||
function_arguments_str = extract_between_tags("invoke", text_content)[0].strip()
|
|
||||||
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
|
||||||
function_arguments = parse_xml_params(function_arguments_str)
|
|
||||||
_message = litellm.Message(
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": f"call_{uuid.uuid4()}",
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": function_name,
|
|
||||||
"arguments": json.dumps(function_arguments),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
content=None,
|
|
||||||
)
|
|
||||||
model_response.choices[0].message = _message # type: ignore
|
|
||||||
else:
|
|
||||||
model_response.choices[0].message.content = text_content # type: ignore
|
|
||||||
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
|
|
||||||
|
|
||||||
## CALCULATING USAGE
|
|
||||||
prompt_tokens = message.usage.input_tokens
|
|
||||||
completion_tokens = message.usage.output_tokens
|
|
||||||
|
|
||||||
model_response.created = int(time.time())
|
|
||||||
model_response.model = model
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
)
|
|
||||||
setattr(model_response, "usage", usage)
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
|
|
||||||
async def async_streaming(
|
|
||||||
model: str,
|
|
||||||
messages: list,
|
|
||||||
data: dict,
|
|
||||||
model_response: ModelResponse,
|
|
||||||
print_verbose: Callable,
|
|
||||||
logging_obj,
|
|
||||||
vertex_project=None,
|
|
||||||
vertex_location=None,
|
|
||||||
optional_params=None,
|
|
||||||
client=None,
|
|
||||||
access_token=None,
|
|
||||||
):
|
|
||||||
from anthropic import AsyncAnthropicVertex
|
|
||||||
|
|
||||||
if client is None:
|
|
||||||
vertex_ai_client = AsyncAnthropicVertex(
|
|
||||||
project_id=vertex_project, region=vertex_location, access_token=access_token # type: ignore
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
vertex_ai_client = client
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response = await vertex_ai_client.messages.create(**data, stream=True) # type: ignore
|
|
||||||
logging_obj.post_call(input=messages, api_key=None, original_response=response)
|
|
||||||
|
|
||||||
streamwrapper = CustomStreamWrapper(
|
|
||||||
completion_stream=response,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="vertex_ai",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
return streamwrapper
|
|
|
@ -0,0 +1,179 @@
|
||||||
|
# What is this?
|
||||||
|
## Handler file for calling claude-3 on vertex ai
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import types
|
||||||
|
import uuid
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx # type: ignore
|
||||||
|
import requests # type: ignore
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.types.llms.anthropic import (
|
||||||
|
AnthropicMessagesTool,
|
||||||
|
AnthropicMessagesToolChoice,
|
||||||
|
)
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
ChatCompletionToolParam,
|
||||||
|
ChatCompletionToolParamFunctionChunk,
|
||||||
|
)
|
||||||
|
from litellm.types.utils import ResponseFormatChunk
|
||||||
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
|
from ....prompt_templates.factory import (
|
||||||
|
construct_tool_use_system_prompt,
|
||||||
|
contains_tag,
|
||||||
|
custom_prompt,
|
||||||
|
extract_between_tags,
|
||||||
|
parse_xml_params,
|
||||||
|
prompt_factory,
|
||||||
|
response_schema_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIError(Exception):
|
||||||
|
def __init__(self, status_code, message):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.message = message
|
||||||
|
self.request = httpx.Request(
|
||||||
|
method="POST", url=" https://cloud.google.com/vertex-ai/"
|
||||||
|
)
|
||||||
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
|
super().__init__(
|
||||||
|
self.message
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIAnthropicConfig:
|
||||||
|
"""
|
||||||
|
Reference:https://docs.anthropic.com/claude/reference/messages_post
|
||||||
|
|
||||||
|
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
|
||||||
|
|
||||||
|
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
|
||||||
|
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
|
||||||
|
|
||||||
|
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
|
||||||
|
|
||||||
|
- `max_tokens` Required (integer) max tokens,
|
||||||
|
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||||
|
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
||||||
|
- `temperature` Optional (float) The amount of randomness injected into the response
|
||||||
|
- `top_p` Optional (float) Use nucleus sampling.
|
||||||
|
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
||||||
|
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
||||||
|
|
||||||
|
Note: Please make sure to modify the default parameters as required for your use case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_tokens: Optional[int] = (
|
||||||
|
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
|
||||||
|
)
|
||||||
|
system: Optional[str] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
stop_sequences: Optional[List[str]] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
anthropic_version: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
locals_ = locals()
|
||||||
|
for key, value in locals_.items():
|
||||||
|
if key == "max_tokens" and value is None:
|
||||||
|
value = self.max_tokens
|
||||||
|
if key != "self" and value is not None:
|
||||||
|
setattr(self.__class__, key, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config(cls):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in cls.__dict__.items()
|
||||||
|
if not k.startswith("__")
|
||||||
|
and not isinstance(
|
||||||
|
v,
|
||||||
|
(
|
||||||
|
types.FunctionType,
|
||||||
|
types.BuiltinFunctionType,
|
||||||
|
classmethod,
|
||||||
|
staticmethod,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
and v is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_supported_openai_params(self):
|
||||||
|
return [
|
||||||
|
"max_tokens",
|
||||||
|
"max_completion_tokens",
|
||||||
|
"tools",
|
||||||
|
"tool_choice",
|
||||||
|
"stream",
|
||||||
|
"stop",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"response_format",
|
||||||
|
]
|
||||||
|
|
||||||
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
|
for param, value in non_default_params.items():
|
||||||
|
if param == "max_tokens" or param == "max_completion_tokens":
|
||||||
|
optional_params["max_tokens"] = value
|
||||||
|
if param == "tools":
|
||||||
|
optional_params["tools"] = value
|
||||||
|
if param == "tool_choice":
|
||||||
|
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
|
||||||
|
if value == "auto":
|
||||||
|
_tool_choice = {"type": "auto"}
|
||||||
|
elif value == "required":
|
||||||
|
_tool_choice = {"type": "any"}
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
|
||||||
|
|
||||||
|
if _tool_choice is not None:
|
||||||
|
optional_params["tool_choice"] = _tool_choice
|
||||||
|
if param == "stream":
|
||||||
|
optional_params["stream"] = value
|
||||||
|
if param == "stop":
|
||||||
|
optional_params["stop_sequences"] = value
|
||||||
|
if param == "temperature":
|
||||||
|
optional_params["temperature"] = value
|
||||||
|
if param == "top_p":
|
||||||
|
optional_params["top_p"] = value
|
||||||
|
if param == "response_format" and isinstance(value, dict):
|
||||||
|
json_schema: Optional[dict] = None
|
||||||
|
if "response_schema" in value:
|
||||||
|
json_schema = value["response_schema"]
|
||||||
|
elif "json_schema" in value:
|
||||||
|
json_schema = value["json_schema"]["schema"]
|
||||||
|
"""
|
||||||
|
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||||
|
- You usually want to provide a single tool
|
||||||
|
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||||
|
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
|
||||||
|
"""
|
||||||
|
_tool_choice = None
|
||||||
|
_tool_choice = {"name": "json_tool_call", "type": "tool"}
|
||||||
|
|
||||||
|
_tool = AnthropicMessagesTool(
|
||||||
|
name="json_tool_call",
|
||||||
|
input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"values": json_schema}, # type: ignore
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
optional_params["tools"] = [_tool]
|
||||||
|
optional_params["tool_choice"] = _tool_choice
|
||||||
|
optional_params["json_mode"] = True
|
||||||
|
|
||||||
|
return optional_params
|
|
@ -9,13 +9,14 @@ import httpx # type: ignore
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
|
|
||||||
from ...base import BaseLLM
|
from ..vertex_llm_base import VertexBase
|
||||||
|
|
||||||
|
|
||||||
class VertexPartnerProvider(str, Enum):
|
class VertexPartnerProvider(str, Enum):
|
||||||
mistralai = "mistralai"
|
mistralai = "mistralai"
|
||||||
llama = "llama"
|
llama = "llama"
|
||||||
ai21 = "ai21"
|
ai21 = "ai21"
|
||||||
|
claude = "claude"
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
@ -31,31 +32,38 @@ class VertexAIError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class VertexAIPartnerModels(BaseLLM):
|
def create_vertex_url(
|
||||||
|
vertex_location: str,
|
||||||
|
vertex_project: str,
|
||||||
|
partner: VertexPartnerProvider,
|
||||||
|
stream: Optional[bool],
|
||||||
|
model: str,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Return the base url for the vertex partner models"""
|
||||||
|
if partner == VertexPartnerProvider.llama:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
|
||||||
|
elif partner == VertexPartnerProvider.mistralai:
|
||||||
|
if stream:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
|
||||||
|
else:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
|
||||||
|
elif partner == VertexPartnerProvider.ai21:
|
||||||
|
if stream:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
|
||||||
|
else:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
|
||||||
|
elif partner == VertexPartnerProvider.claude:
|
||||||
|
if stream:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
|
||||||
|
else:
|
||||||
|
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
|
||||||
|
|
||||||
|
|
||||||
|
class VertexAIPartnerModels(VertexBase):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_vertex_url(
|
|
||||||
self,
|
|
||||||
vertex_location: str,
|
|
||||||
vertex_project: str,
|
|
||||||
partner: VertexPartnerProvider,
|
|
||||||
stream: Optional[bool],
|
|
||||||
model: str,
|
|
||||||
) -> str:
|
|
||||||
if partner == VertexPartnerProvider.llama:
|
|
||||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
|
|
||||||
elif partner == VertexPartnerProvider.mistralai:
|
|
||||||
if stream:
|
|
||||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
|
|
||||||
else:
|
|
||||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
|
|
||||||
elif partner == VertexPartnerProvider.ai21:
|
|
||||||
if stream:
|
|
||||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
|
|
||||||
else:
|
|
||||||
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
|
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -64,6 +72,7 @@ class VertexAIPartnerModels(BaseLLM):
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
api_base: Optional[str],
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
custom_prompt_dict: dict,
|
custom_prompt_dict: dict,
|
||||||
headers: Optional[dict],
|
headers: Optional[dict],
|
||||||
|
@ -80,6 +89,7 @@ class VertexAIPartnerModels(BaseLLM):
|
||||||
import vertexai
|
import vertexai
|
||||||
from google.cloud import aiplatform
|
from google.cloud import aiplatform
|
||||||
|
|
||||||
|
from litellm.llms.anthropic.chat import AnthropicChatCompletion
|
||||||
from litellm.llms.databricks.chat import DatabricksChatCompletion
|
from litellm.llms.databricks.chat import DatabricksChatCompletion
|
||||||
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
from litellm.llms.OpenAI.openai import OpenAIChatCompletion
|
||||||
from litellm.llms.text_completion_codestral import CodestralTextCompletion
|
from litellm.llms.text_completion_codestral import CodestralTextCompletion
|
||||||
|
@ -112,6 +122,7 @@ class VertexAIPartnerModels(BaseLLM):
|
||||||
|
|
||||||
openai_like_chat_completions = DatabricksChatCompletion()
|
openai_like_chat_completions = DatabricksChatCompletion()
|
||||||
codestral_fim_completions = CodestralTextCompletion()
|
codestral_fim_completions = CodestralTextCompletion()
|
||||||
|
anthropic_chat_completions = AnthropicChatCompletion()
|
||||||
|
|
||||||
## CONSTRUCT API BASE
|
## CONSTRUCT API BASE
|
||||||
stream: bool = optional_params.get("stream", False) or False
|
stream: bool = optional_params.get("stream", False) or False
|
||||||
|
@ -126,8 +137,10 @@ class VertexAIPartnerModels(BaseLLM):
|
||||||
elif "jamba" in model:
|
elif "jamba" in model:
|
||||||
partner = VertexPartnerProvider.ai21
|
partner = VertexPartnerProvider.ai21
|
||||||
optional_params["custom_endpoint"] = True
|
optional_params["custom_endpoint"] = True
|
||||||
|
elif "claude" in model:
|
||||||
|
partner = VertexPartnerProvider.claude
|
||||||
|
|
||||||
api_base = self.create_vertex_url(
|
default_api_base = create_vertex_url(
|
||||||
vertex_location=vertex_location or "us-central1",
|
vertex_location=vertex_location or "us-central1",
|
||||||
vertex_project=vertex_project or project_id,
|
vertex_project=vertex_project or project_id,
|
||||||
partner=partner, # type: ignore
|
partner=partner, # type: ignore
|
||||||
|
@ -135,6 +148,21 @@ class VertexAIPartnerModels(BaseLLM):
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(default_api_base.split(":")) > 1:
|
||||||
|
endpoint = default_api_base.split(":")[-1]
|
||||||
|
else:
|
||||||
|
endpoint = ""
|
||||||
|
|
||||||
|
_, api_base = self._check_custom_proxy(
|
||||||
|
api_base=api_base,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
gemini_api_key=None,
|
||||||
|
endpoint=endpoint,
|
||||||
|
stream=stream,
|
||||||
|
auth_header=None,
|
||||||
|
url=default_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
model = model.split("@")[0]
|
model = model.split("@")[0]
|
||||||
|
|
||||||
if "codestral" in model and litellm_params.get("text_completion") is True:
|
if "codestral" in model and litellm_params.get("text_completion") is True:
|
||||||
|
@ -158,6 +186,35 @@ class VertexAIPartnerModels(BaseLLM):
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
elif "claude" in model:
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
headers.update({"Authorization": "Bearer {}".format(access_token)})
|
||||||
|
|
||||||
|
optional_params.update(
|
||||||
|
{
|
||||||
|
"anthropic_version": "vertex-2023-10-16",
|
||||||
|
"is_vertex_request": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return anthropic_chat_completions.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
api_base=api_base,
|
||||||
|
acompletion=acompletion,
|
||||||
|
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||||
|
model_response=model_response,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
logger_fn=logger_fn,
|
||||||
|
encoding=encoding, # for calculating input/output tokens
|
||||||
|
api_key=access_token,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
return openai_like_chat_completions.completion(
|
return openai_like_chat_completions.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -117,10 +117,7 @@ from .llms.sagemaker.sagemaker import SagemakerLLM
|
||||||
from .llms.text_completion_codestral import CodestralTextCompletion
|
from .llms.text_completion_codestral import CodestralTextCompletion
|
||||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||||
from .llms.triton import TritonChatCompletion
|
from .llms.triton import TritonChatCompletion
|
||||||
from .llms.vertex_ai_and_google_ai_studio import (
|
from .llms.vertex_ai_and_google_ai_studio import vertex_ai_non_gemini
|
||||||
vertex_ai_anthropic,
|
|
||||||
vertex_ai_non_gemini,
|
|
||||||
)
|
|
||||||
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
from .llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import (
|
||||||
VertexLLM,
|
VertexLLM,
|
||||||
)
|
)
|
||||||
|
@ -747,6 +744,11 @@ def completion( # type: ignore
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
fallbacks = kwargs.get("fallbacks", None)
|
fallbacks = kwargs.get("fallbacks", None)
|
||||||
headers = kwargs.get("headers", None) or extra_headers
|
headers = kwargs.get("headers", None) or extra_headers
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
headers.update(extra_headers)
|
||||||
num_retries = kwargs.get(
|
num_retries = kwargs.get(
|
||||||
"num_retries", None
|
"num_retries", None
|
||||||
) ## alt. param for 'max_retries'. Use this to pass retries w/ instructor.
|
) ## alt. param for 'max_retries'. Use this to pass retries w/ instructor.
|
||||||
|
@ -964,7 +966,6 @@ def completion( # type: ignore
|
||||||
max_retries=max_retries,
|
max_retries=max_retries,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=top_logprobs,
|
||||||
extra_headers=extra_headers,
|
|
||||||
api_version=api_version,
|
api_version=api_version,
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1067,6 +1068,9 @@ def completion( # type: ignore
|
||||||
|
|
||||||
headers = headers or litellm.headers
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
optional_params["extra_headers"] = extra_headers
|
||||||
|
|
||||||
if (
|
if (
|
||||||
litellm.enable_preview_features
|
litellm.enable_preview_features
|
||||||
and litellm.AzureOpenAIO1Config().is_o1_model(model=model)
|
and litellm.AzureOpenAIO1Config().is_o1_model(model=model)
|
||||||
|
@ -1166,6 +1170,9 @@ def completion( # type: ignore
|
||||||
|
|
||||||
headers = headers or litellm.headers
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
optional_params["extra_headers"] = extra_headers
|
||||||
|
|
||||||
## LOAD CONFIG - if set
|
## LOAD CONFIG - if set
|
||||||
config = litellm.AzureOpenAIConfig.get_config()
|
config = litellm.AzureOpenAIConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
|
@ -1223,6 +1230,9 @@ def completion( # type: ignore
|
||||||
|
|
||||||
headers = headers or litellm.headers
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
optional_params["extra_headers"] = extra_headers
|
||||||
|
|
||||||
## LOAD CONFIG - if set
|
## LOAD CONFIG - if set
|
||||||
config = litellm.AzureAIStudioConfig.get_config()
|
config = litellm.AzureAIStudioConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
|
@ -1304,6 +1314,9 @@ def completion( # type: ignore
|
||||||
|
|
||||||
headers = headers or litellm.headers
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
optional_params["extra_headers"] = extra_headers
|
||||||
|
|
||||||
## LOAD CONFIG - if set
|
## LOAD CONFIG - if set
|
||||||
config = litellm.OpenAITextCompletionConfig.get_config()
|
config = litellm.OpenAITextCompletionConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
|
@ -1466,6 +1479,9 @@ def completion( # type: ignore
|
||||||
|
|
||||||
headers = headers or litellm.headers
|
headers = headers or litellm.headers
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
optional_params["extra_headers"] = extra_headers
|
||||||
|
|
||||||
## LOAD CONFIG - if set
|
## LOAD CONFIG - if set
|
||||||
config = litellm.OpenAIConfig.get_config()
|
config = litellm.OpenAIConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
|
@ -2228,31 +2244,12 @@ def completion( # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
new_params = deepcopy(optional_params)
|
new_params = deepcopy(optional_params)
|
||||||
if "claude-3" in model:
|
if (
|
||||||
model_response = vertex_ai_anthropic.completion(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
model_response=model_response,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
optional_params=new_params,
|
|
||||||
litellm_params=litellm_params,
|
|
||||||
logger_fn=logger_fn,
|
|
||||||
encoding=encoding,
|
|
||||||
vertex_location=vertex_ai_location,
|
|
||||||
vertex_project=vertex_ai_project,
|
|
||||||
vertex_credentials=vertex_credentials,
|
|
||||||
logging_obj=logging,
|
|
||||||
acompletion=acompletion,
|
|
||||||
headers=headers,
|
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
|
||||||
timeout=timeout,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
model.startswith("meta/")
|
model.startswith("meta/")
|
||||||
or model.startswith("mistral")
|
or model.startswith("mistral")
|
||||||
or model.startswith("codestral")
|
or model.startswith("codestral")
|
||||||
or model.startswith("jamba")
|
or model.startswith("jamba")
|
||||||
|
or model.startswith("claude")
|
||||||
):
|
):
|
||||||
model_response = vertex_partner_models_chat_completion.completion(
|
model_response = vertex_partner_models_chat_completion.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -2263,6 +2260,7 @@ def completion( # type: ignore
|
||||||
litellm_params=litellm_params, # type: ignore
|
litellm_params=litellm_params, # type: ignore
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
api_base=api_base,
|
||||||
vertex_location=vertex_ai_location,
|
vertex_location=vertex_ai_location,
|
||||||
vertex_project=vertex_ai_project,
|
vertex_project=vertex_ai_project,
|
||||||
vertex_credentials=vertex_credentials,
|
vertex_credentials=vertex_credentials,
|
||||||
|
@ -4848,6 +4846,8 @@ def image_generation(
|
||||||
model_response = custom_handler.aimage_generation( # type: ignore
|
model_response = custom_handler.aimage_generation( # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
logging_obj=litellm_logging_obj,
|
logging_obj=litellm_logging_obj,
|
||||||
|
@ -4863,6 +4863,8 @@ def image_generation(
|
||||||
model_response = custom_handler.image_generation(
|
model_response = custom_handler.image_generation(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
logging_obj=litellm_logging_obj,
|
logging_obj=litellm_logging_obj,
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-4o-mini
|
# - model_name: openai-gpt-4o-realtime-audio
|
||||||
litellm_params:
|
# litellm_params:
|
||||||
model: azure/my-gpt-4o-mini
|
# model: azure/gpt-4o-realtime-preview
|
||||||
api_key: os.environ/AZURE_API_KEY
|
# api_key: os.environ/AZURE_SWEDEN_API_KEY
|
||||||
api_base: os.environ/AZURE_API_BASE
|
# api_base: os.environ/AZURE_SWEDEN_API_BASE
|
||||||
|
|
||||||
|
- model_name: openai-gpt-4o-realtime-audio
|
||||||
|
litellm_params:
|
||||||
|
model: openai/gpt-4o-realtime-preview-2024-10-01
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
api_base: http://localhost:8080
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
turn_off_message_logging: true
|
success_callback: ["langfuse"]
|
||||||
cache: True
|
|
||||||
cache_params:
|
|
||||||
type: local
|
|
|
@ -241,6 +241,8 @@ def initialize_callbacks_on_proxy(
|
||||||
litellm.callbacks = imported_list # type: ignore
|
litellm.callbacks = imported_list # type: ignore
|
||||||
|
|
||||||
if "prometheus" in value:
|
if "prometheus" in value:
|
||||||
|
if premium_user is not True:
|
||||||
|
raise Exception(CommonProxyErrors.not_premium_user.value)
|
||||||
from litellm.proxy.proxy_server import app
|
from litellm.proxy.proxy_server import app
|
||||||
|
|
||||||
verbose_proxy_logger.debug("Starting Prometheus Metrics on /metrics")
|
verbose_proxy_logger.debug("Starting Prometheus Metrics on /metrics")
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: db-openai-endpoint
|
- model_name: db-openai-endpoint
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/gpt-5
|
model: openai/gpt-4
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railwaz.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
- model_name: db-openai-endpoint
|
|
||||||
litellm_params:
|
|
||||||
model: openai/gpt-5
|
|
||||||
api_key: fake-key
|
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railwxaz.app/
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
callbacks: ["prometheus"]
|
success_callback: ["s3"]
|
||||||
|
turn_off_message_logging: true
|
||||||
|
s3_callback_params:
|
||||||
|
s3_bucket_name: load-testing-oct # AWS Bucket Name for S3
|
||||||
|
s3_region_name: us-west-2 # AWS Region Name for S3
|
||||||
|
s3_aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID # us os.environ/<variable name> to pass environment variables. This is AWS Access Key ID for S3
|
||||||
|
s3_aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY # AWS Secret Access Key for S3
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1692,6 +1692,10 @@ class ProxyConfig:
|
||||||
else:
|
else:
|
||||||
litellm.success_callback.append(callback)
|
litellm.success_callback.append(callback)
|
||||||
if "prometheus" in callback:
|
if "prometheus" in callback:
|
||||||
|
if not premium_user:
|
||||||
|
raise Exception(
|
||||||
|
CommonProxyErrors.not_premium_user.value
|
||||||
|
)
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
"Starting Prometheus Metrics on /metrics"
|
"Starting Prometheus Metrics on /metrics"
|
||||||
)
|
)
|
||||||
|
@ -5433,12 +5437,11 @@ async def moderations(
|
||||||
await proxy_logging_obj.post_call_failure_hook(
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.exception(
|
||||||
"litellm.proxy.proxy_server.moderations(): Exception occured - {}".format(
|
"litellm.proxy.proxy_server.moderations(): Exception occured - {}".format(
|
||||||
str(e)
|
str(e)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
verbose_proxy_logger.debug(traceback.format_exc())
|
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
raise ProxyException(
|
raise ProxyException(
|
||||||
message=getattr(e, "message", str(e)),
|
message=getattr(e, "message", str(e)),
|
||||||
|
|
|
@ -408,7 +408,6 @@ class ProxyLogging:
|
||||||
callback,
|
callback,
|
||||||
internal_usage_cache=self.internal_usage_cache.dual_cache,
|
internal_usage_cache=self.internal_usage_cache.dual_cache,
|
||||||
llm_router=llm_router,
|
llm_router=llm_router,
|
||||||
premium_user=self.premium_user,
|
|
||||||
)
|
)
|
||||||
if callback is None:
|
if callback is None:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -8,13 +8,16 @@ from litellm import get_llm_provider
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.router import GenericLiteLLMParams
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
|
|
||||||
|
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||||
from ..llms.AzureOpenAI.realtime.handler import AzureOpenAIRealtime
|
from ..llms.AzureOpenAI.realtime.handler import AzureOpenAIRealtime
|
||||||
from ..llms.OpenAI.realtime.handler import OpenAIRealtime
|
from ..llms.OpenAI.realtime.handler import OpenAIRealtime
|
||||||
|
from ..utils import client as wrapper_client
|
||||||
|
|
||||||
azure_realtime = AzureOpenAIRealtime()
|
azure_realtime = AzureOpenAIRealtime()
|
||||||
openai_realtime = OpenAIRealtime()
|
openai_realtime = OpenAIRealtime()
|
||||||
|
|
||||||
|
|
||||||
|
@wrapper_client
|
||||||
async def _arealtime(
|
async def _arealtime(
|
||||||
model: str,
|
model: str,
|
||||||
websocket: Any, # fastapi websocket
|
websocket: Any, # fastapi websocket
|
||||||
|
@ -31,6 +34,12 @@ async def _arealtime(
|
||||||
|
|
||||||
For PROXY use only.
|
For PROXY use only.
|
||||||
"""
|
"""
|
||||||
|
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
|
||||||
|
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id", None)
|
||||||
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
|
model_info = kwargs.get("model_info", None)
|
||||||
|
metadata = kwargs.get("metadata", {})
|
||||||
|
user = kwargs.get("user", None)
|
||||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
|
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
|
||||||
|
@ -39,6 +48,21 @@ async def _arealtime(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
litellm_logging_obj.update_environment_variables(
|
||||||
|
model=model,
|
||||||
|
user=user,
|
||||||
|
optional_params={},
|
||||||
|
litellm_params={
|
||||||
|
"litellm_call_id": litellm_call_id,
|
||||||
|
"proxy_server_request": proxy_server_request,
|
||||||
|
"model_info": model_info,
|
||||||
|
"metadata": metadata,
|
||||||
|
"preset_cache_key": None,
|
||||||
|
"stream_response": {},
|
||||||
|
},
|
||||||
|
custom_llm_provider=_custom_llm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
if _custom_llm_provider == "azure":
|
if _custom_llm_provider == "azure":
|
||||||
api_base = (
|
api_base = (
|
||||||
dynamic_api_base
|
dynamic_api_base
|
||||||
|
@ -63,6 +87,7 @@ async def _arealtime(
|
||||||
azure_ad_token=None,
|
azure_ad_token=None,
|
||||||
client=None,
|
client=None,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
)
|
)
|
||||||
elif _custom_llm_provider == "openai":
|
elif _custom_llm_provider == "openai":
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -82,6 +107,7 @@ async def _arealtime(
|
||||||
await openai_realtime.async_realtime(
|
await openai_realtime.async_realtime(
|
||||||
model=model,
|
model=model,
|
||||||
websocket=websocket,
|
websocket=websocket,
|
||||||
|
logging_obj=litellm_logging_obj,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
client=None,
|
client=None,
|
||||||
|
|
|
@ -60,7 +60,7 @@ class PatternMatchRouter:
|
||||||
regex = re.escape(regex).replace(r"\.\*", ".*")
|
regex = re.escape(regex).replace(r"\.\*", ".*")
|
||||||
return f"^{regex}$"
|
return f"^{regex}$"
|
||||||
|
|
||||||
def route(self, request: str) -> Optional[List[Dict]]:
|
def route(self, request: Optional[str]) -> Optional[List[Dict]]:
|
||||||
"""
|
"""
|
||||||
Route a requested model to the corresponding llm deployments based on the regex pattern
|
Route a requested model to the corresponding llm deployments based on the regex pattern
|
||||||
|
|
||||||
|
@ -69,17 +69,20 @@ class PatternMatchRouter:
|
||||||
if no pattern is found, return None
|
if no pattern is found, return None
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: str
|
request: Optional[str]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[List[Deployment]]: llm deployments
|
Optional[List[Deployment]]: llm deployments
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
if request is None:
|
||||||
|
return None
|
||||||
for pattern, llm_deployments in self.patterns.items():
|
for pattern, llm_deployments in self.patterns.items():
|
||||||
if re.match(pattern, request):
|
if re.match(pattern, request):
|
||||||
return llm_deployments
|
return llm_deployments
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_router_logger.error(f"Error in PatternMatchRouter.route: {str(e)}")
|
verbose_router_logger.debug(f"Error in PatternMatchRouter.route: {str(e)}")
|
||||||
|
|
||||||
return None # No matching pattern found
|
return None # No matching pattern found
|
||||||
|
|
||||||
|
|
||||||
|
|
14
litellm/types/integrations/s3.py
Normal file
14
litellm/types/integrations/s3.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class s3BatchLoggingElement(BaseModel):
|
||||||
|
"""
|
||||||
|
Type of element stored in self.log_queue in S3Logger
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
payload: Dict
|
||||||
|
s3_object_key: str
|
||||||
|
s3_object_download_filename: str
|
|
@ -129,6 +129,7 @@ class CallTypes(Enum):
|
||||||
speech = "speech"
|
speech = "speech"
|
||||||
rerank = "rerank"
|
rerank = "rerank"
|
||||||
arerank = "arerank"
|
arerank = "arerank"
|
||||||
|
arealtime = "_arealtime"
|
||||||
|
|
||||||
|
|
||||||
class PassthroughCallTypes(Enum):
|
class PassthroughCallTypes(Enum):
|
||||||
|
|
|
@ -197,7 +197,6 @@ lagoLogger = None
|
||||||
dataDogLogger = None
|
dataDogLogger = None
|
||||||
prometheusLogger = None
|
prometheusLogger = None
|
||||||
dynamoLogger = None
|
dynamoLogger = None
|
||||||
s3Logger = None
|
|
||||||
genericAPILogger = None
|
genericAPILogger = None
|
||||||
clickHouseLogger = None
|
clickHouseLogger = None
|
||||||
greenscaleLogger = None
|
greenscaleLogger = None
|
||||||
|
@ -1410,6 +1409,8 @@ def client(original_function):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
elif call_type == CallTypes.arealtime.value:
|
||||||
|
return result
|
||||||
|
|
||||||
# ADD HIDDEN PARAMS - additional call metadata
|
# ADD HIDDEN PARAMS - additional call metadata
|
||||||
if hasattr(result, "_hidden_params"):
|
if hasattr(result, "_hidden_params"):
|
||||||
|
@ -1799,8 +1800,9 @@ def calculate_tiles_needed(
|
||||||
total_tiles = tiles_across * tiles_down
|
total_tiles = tiles_across * tiles_down
|
||||||
return total_tiles
|
return total_tiles
|
||||||
|
|
||||||
|
|
||||||
def get_image_type(image_data: bytes) -> Union[str, None]:
|
def get_image_type(image_data: bytes) -> Union[str, None]:
|
||||||
""" take an image (really only the first ~100 bytes max are needed)
|
"""take an image (really only the first ~100 bytes max are needed)
|
||||||
and return 'png' 'gif' 'jpeg' 'heic' or None. method added to
|
and return 'png' 'gif' 'jpeg' 'heic' or None. method added to
|
||||||
allow deprecation of imghdr in 3.13"""
|
allow deprecation of imghdr in 3.13"""
|
||||||
|
|
||||||
|
@ -4336,16 +4338,18 @@ def get_api_base(
|
||||||
_optional_params.vertex_location is not None
|
_optional_params.vertex_location is not None
|
||||||
and _optional_params.vertex_project is not None
|
and _optional_params.vertex_project is not None
|
||||||
):
|
):
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import (
|
||||||
create_vertex_anthropic_url,
|
VertexPartnerProvider,
|
||||||
|
create_vertex_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "claude" in model:
|
if "claude" in model:
|
||||||
_api_base = create_vertex_anthropic_url(
|
_api_base = create_vertex_url(
|
||||||
vertex_location=_optional_params.vertex_location,
|
vertex_location=_optional_params.vertex_location,
|
||||||
vertex_project=_optional_params.vertex_project,
|
vertex_project=_optional_params.vertex_project,
|
||||||
model=model,
|
model=model,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
partner=VertexPartnerProvider.claude,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
@ -4442,19 +4446,7 @@ def get_supported_openai_params(
|
||||||
elif custom_llm_provider == "volcengine":
|
elif custom_llm_provider == "volcengine":
|
||||||
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
return litellm.VolcEngineConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "groq":
|
elif custom_llm_provider == "groq":
|
||||||
return [
|
return litellm.GroqChatConfig().get_supported_openai_params(model=model)
|
||||||
"temperature",
|
|
||||||
"max_tokens",
|
|
||||||
"top_p",
|
|
||||||
"stream",
|
|
||||||
"stop",
|
|
||||||
"tools",
|
|
||||||
"tool_choice",
|
|
||||||
"response_format",
|
|
||||||
"seed",
|
|
||||||
"extra_headers",
|
|
||||||
"extra_body",
|
|
||||||
]
|
|
||||||
elif custom_llm_provider == "hosted_vllm":
|
elif custom_llm_provider == "hosted_vllm":
|
||||||
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
|
return litellm.HostedVLLMChatConfig().get_supported_openai_params(model=model)
|
||||||
elif custom_llm_provider == "deepseek":
|
elif custom_llm_provider == "deepseek":
|
||||||
|
@ -4599,6 +4591,8 @@ def get_supported_openai_params(
|
||||||
return (
|
return (
|
||||||
litellm.MistralTextCompletionConfig().get_supported_openai_params()
|
litellm.MistralTextCompletionConfig().get_supported_openai_params()
|
||||||
)
|
)
|
||||||
|
if model.startswith("claude"):
|
||||||
|
return litellm.VertexAIAnthropicConfig().get_supported_openai_params()
|
||||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
return litellm.VertexAIConfig().get_supported_openai_params()
|
||||||
elif request_type == "embeddings":
|
elif request_type == "embeddings":
|
||||||
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
|
||||||
|
|
|
@ -296,7 +296,7 @@ def test_all_model_configs():
|
||||||
optional_params={},
|
optional_params={},
|
||||||
) == {"max_tokens": 10}
|
) == {"max_tokens": 10}
|
||||||
|
|
||||||
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import (
|
from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.anthropic.transformation import (
|
||||||
VertexAIAnthropicConfig,
|
VertexAIAnthropicConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,70 @@ import litellm
|
||||||
litellm.num_retries = 3
|
litellm.num_retries = 3
|
||||||
|
|
||||||
import time, random
|
import time, random
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
import logging
|
||||||
import pytest
|
import pytest
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
async def test_basic_s3_logging(sync_mode):
|
||||||
|
verbose_logger.setLevel(level=logging.DEBUG)
|
||||||
|
litellm.success_callback = ["s3"]
|
||||||
|
litellm.s3_callback_params = {
|
||||||
|
"s3_bucket_name": "load-testing-oct",
|
||||||
|
"s3_aws_secret_access_key": "os.environ/AWS_SECRET_ACCESS_KEY",
|
||||||
|
"s3_aws_access_key_id": "os.environ/AWS_ACCESS_KEY_ID",
|
||||||
|
"s3_region_name": "us-west-2",
|
||||||
|
}
|
||||||
|
litellm.set_verbose = True
|
||||||
|
|
||||||
|
if sync_mode is True:
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "This is a test"}],
|
||||||
|
mock_response="It's simple to use and easy to get started",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "This is a test"}],
|
||||||
|
mock_response="It's simple to use and easy to get started",
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
await asyncio.sleep(12)
|
||||||
|
|
||||||
|
total_objects, all_s3_keys = list_all_s3_objects("load-testing-oct")
|
||||||
|
|
||||||
|
# assert that atlest one key has response.id in it
|
||||||
|
assert any(response.id in key for key in all_s3_keys)
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
# delete all objects
|
||||||
|
for key in all_s3_keys:
|
||||||
|
s3.delete_object(Bucket="load-testing-oct", Key=key)
|
||||||
|
|
||||||
|
|
||||||
|
def list_all_s3_objects(bucket_name):
|
||||||
|
s3 = boto3.client("s3")
|
||||||
|
|
||||||
|
all_s3_keys = []
|
||||||
|
|
||||||
|
paginator = s3.get_paginator("list_objects_v2")
|
||||||
|
total_objects = 0
|
||||||
|
|
||||||
|
for page in paginator.paginate(Bucket=bucket_name):
|
||||||
|
if "Contents" in page:
|
||||||
|
total_objects += len(page["Contents"])
|
||||||
|
all_s3_keys.extend([obj["Key"] for obj in page["Contents"]])
|
||||||
|
|
||||||
|
print(f"Total number of objects in {bucket_name}: {total_objects}")
|
||||||
|
print(all_s3_keys)
|
||||||
|
return total_objects, all_s3_keys
|
||||||
|
|
||||||
|
|
||||||
|
list_all_s3_objects("load-testing-oct")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="AWS Suspended Account")
|
@pytest.mark.skip(reason="AWS Suspended Account")
|
||||||
|
|
|
@ -1616,9 +1616,11 @@ async def test_gemini_pro_json_schema_args_sent_httpx_openai_schema(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai",
|
@pytest.mark.parametrize(
|
||||||
|
"model", ["gemini-1.5-flash", "claude-3-sonnet@20240229"]
|
||||||
|
) # "vertex_ai",
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_gemini_pro_httpx_custom_api_base(provider):
|
async def test_gemini_pro_httpx_custom_api_base(model):
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
messages = [
|
messages = [
|
||||||
|
@ -1634,7 +1636,7 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
|
||||||
with patch.object(client, "post", new=MagicMock()) as mock_call:
|
with patch.object(client, "post", new=MagicMock()) as mock_call:
|
||||||
try:
|
try:
|
||||||
response = completion(
|
response = completion(
|
||||||
model="vertex_ai_beta/gemini-1.5-flash",
|
model="vertex_ai/{}".format(model),
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_format={"type": "json_object"},
|
response_format={"type": "json_object"},
|
||||||
client=client,
|
client=client,
|
||||||
|
@ -1647,8 +1649,17 @@ async def test_gemini_pro_httpx_custom_api_base(provider):
|
||||||
|
|
||||||
mock_call.assert_called_once()
|
mock_call.assert_called_once()
|
||||||
|
|
||||||
assert "my-custom-api-base:generateContent" == mock_call.call_args.kwargs["url"]
|
print(f"mock_call.call_args: {mock_call.call_args}")
|
||||||
assert "hello" in mock_call.call_args.kwargs["headers"]
|
print(f"mock_call.call_args.kwargs: {mock_call.call_args.kwargs}")
|
||||||
|
if "url" in mock_call.call_args.kwargs:
|
||||||
|
assert (
|
||||||
|
"my-custom-api-base:generateContent"
|
||||||
|
== mock_call.call_args.kwargs["url"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert "my-custom-api-base:rawPredict" == mock_call.call_args[0][0]
|
||||||
|
if "headers" in mock_call.call_args.kwargs:
|
||||||
|
assert "hello" in mock_call.call_args.kwargs["headers"]
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
# @pytest.mark.skip(reason="exhausted vertex quota. need to refactor to mock the call")
|
||||||
|
|
|
@ -28,7 +28,6 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
@ -226,6 +225,8 @@ class MyCustomLLM(CustomLLM):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
model_response: ImageResponse,
|
model_response: ImageResponse,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
|
@ -242,6 +243,8 @@ class MyCustomLLM(CustomLLM):
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
model_response: ImageResponse,
|
model_response: ImageResponse,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
|
@ -362,3 +365,31 @@ async def test_simple_image_generation_async():
|
||||||
)
|
)
|
||||||
|
|
||||||
print(resp)
|
print(resp)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_image_generation_async_with_api_key_and_api_base():
|
||||||
|
my_custom_llm = MyCustomLLM()
|
||||||
|
litellm.custom_provider_map = [
|
||||||
|
{"provider": "custom_llm", "custom_handler": my_custom_llm}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
my_custom_llm, "aimage_generation", new=AsyncMock()
|
||||||
|
) as mock_client:
|
||||||
|
try:
|
||||||
|
resp = await litellm.aimage_generation(
|
||||||
|
model="custom_llm/my-fake-model",
|
||||||
|
prompt="Hello world",
|
||||||
|
api_key="my-api-key",
|
||||||
|
api_base="my-api-base",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(resp)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_client.assert_awaited_once()
|
||||||
|
|
||||||
|
mock_client.call_args.kwargs["api_key"] == "my-api-key"
|
||||||
|
mock_client.call_args.kwargs["api_base"] == "my-api-base"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue