diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a33473b72..d429bc6b8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: local hooks: - - id: mypy - name: mypy - entry: python3 -m mypy --ignore-missing-imports - language: system - types: [python] - files: ^litellm/ + # - id: mypy + # name: mypy + # entry: python3 -m mypy --ignore-missing-imports + # language: system + # types: [python] + # files: ^litellm/ - id: isort name: isort entry: isort diff --git a/litellm/__init__.py b/litellm/__init__.py index 2e7914fab..3f22e41b6 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -118,6 +118,8 @@ in_memory_llm_clients_cache: dict = {} safe_memory_mode: bool = False ### DEFAULT AZURE API VERSION ### AZURE_DEFAULT_API_VERSION = "2024-07-01-preview" # this is updated to the latest +### COHERE EMBEDDINGS DEFAULT TYPE ### +COHERE_DEFAULT_EMBEDDING_INPUT_TYPE = "search_document" ### GUARDRAILS ### llamaguard_model_name: Optional[str] = None openai_moderations_model_name: Optional[str] = None @@ -880,13 +882,13 @@ from .llms.sagemaker.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig -from .llms.bedrock_httpx import ( +from .llms.bedrock.chat import ( AmazonCohereChatConfig, AmazonConverseConfig, BEDROCK_CONVERSE_MODELS, bedrock_tool_name_mappings, ) -from .llms.bedrock import ( +from .llms.bedrock.common_utils import ( AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 682d1a4b5..025c0e9a3 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -608,17 +608,17 @@ class Logging: self.model_call_details["litellm_params"]["metadata"][ "hidden_params" ] = result._hidden_params - ## STANDARDIZED LOGGING PAYLOAD + ## STANDARDIZED LOGGING PAYLOAD - self.model_call_details["standard_logging_object"] = ( - get_standard_logging_object_payload( - kwargs=self.model_call_details, - init_response_obj=result, - start_time=start_time, - end_time=end_time, - logging_obj=self, + self.model_call_details["standard_logging_object"] = ( + get_standard_logging_object_payload( + kwargs=self.model_call_details, + init_response_obj=result, + start_time=start_time, + end_time=end_time, + logging_obj=self, + ) ) - ) else: # streaming chunks + image gen. self.model_call_details["response_cost"] = None diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py deleted file mode 100644 index 916826ce3..000000000 --- a/litellm/llms/bedrock.py +++ /dev/null @@ -1,1531 +0,0 @@ -#################################### -######### DEPRECATED FILE ########## -#################################### -# logic moved to `bedrock_httpx.py` # - -import copy -import json -import os -import time -import types -import uuid -from enum import Enum -from typing import Any, Callable, List, Optional, Union - -import httpx -from openai.types.image import Image - -import litellm -from litellm.litellm_core_utils.core_helpers import map_finish_reason -from litellm.types.utils import ImageResponse, ModelResponse, Usage -from litellm.utils import get_secret - -from .prompt_templates.factory import ( - construct_tool_use_system_prompt, - contains_tag, - custom_prompt, - extract_between_tags, - parse_xml_params, - prompt_factory, -) - - -class BedrockError(Exception): - def __init__(self, status_code, message): - self.status_code = status_code - self.message = message - self.request = httpx.Request( - method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock" - ) - self.response = httpx.Response(status_code=status_code, request=self.request) - super().__init__( - self.message - ) # Call the base class constructor with the parameters it needs - - -class AmazonBedrockGlobalConfig: - def __init__(self): - pass - - def get_mapped_special_auth_params(self) -> dict: - """ - Mapping of common auth params across bedrock/vertex/azure/watsonx - """ - return {"region_name": "aws_region_name"} - - def map_special_auth_params(self, non_default_params: dict, optional_params: dict): - mapped_params = self.get_mapped_special_auth_params() - for param, value in non_default_params.items(): - if param in mapped_params: - optional_params[mapped_params[param]] = value - return optional_params - - def get_eu_regions(self) -> List[str]: - """ - Source: https://www.aws-services.info/bedrock.html - """ - return [ - "eu-west-1", - "eu-west-3", - "eu-central-1", - ] - - -class AmazonTitanConfig: - """ - Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1 - - Supported Params for the Amazon Titan models: - - - `maxTokenCount` (integer) max tokens, - - `stopSequences` (string[]) list of stop sequence strings - - `temperature` (float) temperature for model, - - `topP` (int) top p for model - """ - - maxTokenCount: Optional[int] = None - stopSequences: Optional[list] = None - temperature: Optional[float] = None - topP: Optional[int] = None - - def __init__( - self, - maxTokenCount: Optional[int] = None, - stopSequences: Optional[list] = None, - temperature: Optional[float] = None, - topP: Optional[int] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - -class AmazonAnthropicClaude3Config: - """ - Reference: - https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude - https://docs.anthropic.com/claude/docs/models-overview#model-comparison - - Supported Params for the Amazon / Anthropic Claude 3 models: - - - `max_tokens` Required (integer) max tokens. Default is 4096 - - `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" - - `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py - - `temperature` Optional (float) The amount of randomness injected into the response - - `top_p` Optional (float) Use nucleus sampling. - - `top_k` Optional (int) Only sample from the top K options for each subsequent token - - `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating - """ - - max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default - anthropic_version: Optional[str] = "bedrock-2023-05-31" - system: Optional[str] = None - temperature: Optional[float] = None - top_p: Optional[float] = None - top_k: Optional[int] = None - stop_sequences: Optional[List[str]] = None - - def __init__( - self, - max_tokens: Optional[int] = None, - anthropic_version: Optional[str] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def get_supported_openai_params(self): - return [ - "max_tokens", - "tools", - "tool_choice", - "stream", - "stop", - "temperature", - "top_p", - "extra_headers", - ] - - def map_openai_params(self, non_default_params: dict, optional_params: dict): - for param, value in non_default_params.items(): - if param == "max_tokens": - optional_params["max_tokens"] = value - if param == "tools": - optional_params["tools"] = value - if param == "stream": - optional_params["stream"] = value - if param == "stop": - optional_params["stop_sequences"] = value - if param == "temperature": - optional_params["temperature"] = value - if param == "top_p": - optional_params["top_p"] = value - return optional_params - - -class AmazonAnthropicConfig: - """ - Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude - - Supported Params for the Amazon / Anthropic models: - - - `max_tokens_to_sample` (integer) max tokens, - - `temperature` (float) model temperature, - - `top_k` (integer) top k, - - `top_p` (integer) top p, - - `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"], - - `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" - """ - - max_tokens_to_sample: Optional[int] = litellm.max_tokens - stop_sequences: Optional[list] = None - temperature: Optional[float] = None - top_k: Optional[int] = None - top_p: Optional[int] = None - anthropic_version: Optional[str] = None - - def __init__( - self, - max_tokens_to_sample: Optional[int] = None, - stop_sequences: Optional[list] = None, - temperature: Optional[float] = None, - top_k: Optional[int] = None, - top_p: Optional[int] = None, - anthropic_version: Optional[str] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def get_supported_openai_params( - self, - ): - return ["max_tokens", "temperature", "stop", "top_p", "stream"] - - def map_openai_params(self, non_default_params: dict, optional_params: dict): - for param, value in non_default_params.items(): - if param == "max_tokens": - optional_params["max_tokens_to_sample"] = value - if param == "temperature": - optional_params["temperature"] = value - if param == "top_p": - optional_params["top_p"] = value - if param == "stop": - optional_params["stop_sequences"] = value - if param == "stream" and value == True: - optional_params["stream"] = value - return optional_params - - -class AmazonCohereConfig: - """ - Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command - - Supported Params for the Amazon / Cohere models: - - - `max_tokens` (integer) max tokens, - - `temperature` (float) model temperature, - - `return_likelihood` (string) n/a - """ - - max_tokens: Optional[int] = None - temperature: Optional[float] = None - return_likelihood: Optional[str] = None - - def __init__( - self, - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - return_likelihood: Optional[str] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - -class AmazonAI21Config: - """ - Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra - - Supported Params for the Amazon / AI21 models: - - - `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`. - - - `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding. - - - `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass. - - - `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional. - - - `frequencyPenalty` (object): Placeholder for frequency penalty object. - - - `presencePenalty` (object): Placeholder for presence penalty object. - - - `countPenalty` (object): Placeholder for count penalty object. - """ - - maxTokens: Optional[int] = None - temperature: Optional[float] = None - topP: Optional[float] = None - stopSequences: Optional[list] = None - frequencePenalty: Optional[dict] = None - presencePenalty: Optional[dict] = None - countPenalty: Optional[dict] = None - - def __init__( - self, - maxTokens: Optional[int] = None, - temperature: Optional[float] = None, - topP: Optional[float] = None, - stopSequences: Optional[list] = None, - frequencePenalty: Optional[dict] = None, - presencePenalty: Optional[dict] = None, - countPenalty: Optional[dict] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - -class AnthropicConstants(Enum): - HUMAN_PROMPT = "\n\nHuman: " - AI_PROMPT = "\n\nAssistant: " - - -class AmazonLlamaConfig: - """ - Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1 - - Supported Params for the Amazon / Meta Llama models: - - - `max_gen_len` (integer) max tokens, - - `temperature` (float) temperature for model, - - `top_p` (float) top p for model - """ - - max_gen_len: Optional[int] = None - temperature: Optional[float] = None - topP: Optional[float] = None - - def __init__( - self, - maxTokenCount: Optional[int] = None, - temperature: Optional[float] = None, - topP: Optional[int] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - -class AmazonMistralConfig: - """ - Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html - Supported Params for the Amazon / Mistral models: - - - `max_tokens` (integer) max tokens, - - `temperature` (float) temperature for model, - - `top_p` (float) top p for model - - `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output. - - `top_k` (float) top k for model - """ - - max_tokens: Optional[int] = None - temperature: Optional[float] = None - top_p: Optional[float] = None - top_k: Optional[float] = None - stop: Optional[List[str]] = None - - def __init__( - self, - max_tokens: Optional[int] = None, - temperature: Optional[float] = None, - top_p: Optional[int] = None, - top_k: Optional[float] = None, - stop: Optional[List[str]] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - -class AmazonStabilityConfig: - """ - Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0 - - Supported Params for the Amazon / Stable Diffusion models: - - - `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt) - - - `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed) - - - `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run. - - - `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64. - Engine-specific dimension validation: - - - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. - - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 - - SDXL v1.0: same as SDXL v0.9 - - SD v1.6: must be between 320x320 and 1536x1536 - - - `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64. - Engine-specific dimension validation: - - - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. - - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 - - SDXL v1.0: same as SDXL v0.9 - - SD v1.6: must be between 320x320 and 1536x1536 - """ - - cfg_scale: Optional[int] = None - seed: Optional[float] = None - steps: Optional[List[str]] = None - width: Optional[int] = None - height: Optional[int] = None - - def __init__( - self, - cfg_scale: Optional[int] = None, - seed: Optional[float] = None, - steps: Optional[List[str]] = None, - width: Optional[int] = None, - height: Optional[int] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - -def add_custom_header(headers): - """Closure to capture the headers and add them.""" - - def callback(request, **kwargs): - """Actual callback function that Boto3 will call.""" - for header_name, header_value in headers.items(): - request.headers.add_header(header_name, header_value) - - return callback - - -def init_bedrock_client( - region_name=None, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_region_name: Optional[str] = None, - aws_bedrock_runtime_endpoint: Optional[str] = None, - aws_session_name: Optional[str] = None, - aws_profile_name: Optional[str] = None, - aws_role_name: Optional[str] = None, - aws_web_identity_token: Optional[str] = None, - extra_headers: Optional[dict] = None, - timeout: Optional[Union[float, httpx.Timeout]] = None, -): - # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client - litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) - standard_aws_region_name = get_secret("AWS_REGION", None) - ## CHECK IS 'os.environ/' passed in - # Define the list of parameters to check - params_to_check = [ - aws_access_key_id, - aws_secret_access_key, - aws_region_name, - aws_bedrock_runtime_endpoint, - aws_session_name, - aws_profile_name, - aws_role_name, - aws_web_identity_token, - ] - - # Iterate over parameters and update if needed - for i, param in enumerate(params_to_check): - if param and param.startswith("os.environ/"): - params_to_check[i] = get_secret(param) - # Assign updated values back to parameters - ( - aws_access_key_id, - aws_secret_access_key, - aws_region_name, - aws_bedrock_runtime_endpoint, - aws_session_name, - aws_profile_name, - aws_role_name, - aws_web_identity_token, - ) = params_to_check - - # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. - ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) - - ### SET REGION NAME - if region_name: - pass - elif aws_region_name: - region_name = aws_region_name - elif litellm_aws_region_name: - region_name = litellm_aws_region_name - elif standard_aws_region_name: - region_name = standard_aws_region_name - else: - raise BedrockError( - message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", - status_code=401, - ) - - # check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client - env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") - if aws_bedrock_runtime_endpoint: - endpoint_url = aws_bedrock_runtime_endpoint - elif env_aws_bedrock_runtime_endpoint: - endpoint_url = env_aws_bedrock_runtime_endpoint - else: - endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com" - - import boto3 - - if isinstance(timeout, float): - config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) - elif isinstance(timeout, httpx.Timeout): - config = boto3.session.Config( - connect_timeout=timeout.connect, read_timeout=timeout.read - ) - else: - config = boto3.session.Config() - - ### CHECK STS ### - if ( - aws_web_identity_token is not None - and aws_role_name is not None - and aws_session_name is not None - ): - oidc_token = get_secret(aws_web_identity_token) - - if oidc_token is None: - raise BedrockError( - message="OIDC token could not be retrieved from secret manager.", - status_code=401, - ) - - sts_client = boto3.client("sts") - - # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html - # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html - sts_response = sts_client.assume_role_with_web_identity( - RoleArn=aws_role_name, - RoleSessionName=aws_session_name, - WebIdentityToken=oidc_token, - DurationSeconds=3600, - ) - - client = boto3.client( - service_name="bedrock-runtime", - aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], - aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], - aws_session_token=sts_response["Credentials"]["SessionToken"], - region_name=region_name, - endpoint_url=endpoint_url, - config=config, - verify=ssl_verify, - ) - elif aws_role_name is not None and aws_session_name is not None: - # use sts if role name passed in - sts_client = boto3.client( - "sts", - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - ) - - sts_response = sts_client.assume_role( - RoleArn=aws_role_name, RoleSessionName=aws_session_name - ) - - client = boto3.client( - service_name="bedrock-runtime", - aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], - aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], - aws_session_token=sts_response["Credentials"]["SessionToken"], - region_name=region_name, - endpoint_url=endpoint_url, - config=config, - verify=ssl_verify, - ) - elif aws_access_key_id is not None: - # uses auth params passed to completion - # aws_access_key_id is not None, assume user is trying to auth using litellm.completion - - client = boto3.client( - service_name="bedrock-runtime", - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - region_name=region_name, - endpoint_url=endpoint_url, - config=config, - verify=ssl_verify, - ) - elif aws_profile_name is not None: - # uses auth values from AWS profile usually stored in ~/.aws/credentials - - client = boto3.Session(profile_name=aws_profile_name).client( - service_name="bedrock-runtime", - region_name=region_name, - endpoint_url=endpoint_url, - config=config, - verify=ssl_verify, - ) - else: - # aws_access_key_id is None, assume user is trying to auth using env variables - # boto3 automatically reads env variables - - client = boto3.client( - service_name="bedrock-runtime", - region_name=region_name, - endpoint_url=endpoint_url, - config=config, - verify=ssl_verify, - ) - if extra_headers: - client.meta.events.register( - "before-sign.bedrock-runtime.*", add_custom_header(extra_headers) - ) - - return client - - -def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict): - # handle anthropic prompts and amazon titan prompts - chat_template_provider = ["anthropic", "amazon", "mistral", "meta"] - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details["roles"], - initial_prompt_value=model_prompt_details["initial_prompt_value"], - final_prompt_value=model_prompt_details["final_prompt_value"], - messages=messages, - ) - else: - if provider in chat_template_provider: - prompt = prompt_factory( - model=model, messages=messages, custom_llm_provider="bedrock" - ) - else: - prompt = "" - for message in messages: - if "role" in message: - if message["role"] == "user": - prompt += f"{message['content']}" - else: - prompt += f"{message['content']}" - else: - prompt += f"{message['content']}" - return prompt - - -""" -BEDROCK AUTH Keys/Vars -os.environ['AWS_ACCESS_KEY_ID'] = "" -os.environ['AWS_SECRET_ACCESS_KEY'] = "" -""" - - -# set os.environ['AWS_REGION_NAME'] = - - -def completion( - model: str, - messages: list, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - optional_params=None, - litellm_params=None, - logger_fn=None, - timeout=None, - extra_headers: Optional[dict] = None, -): - exception_mapping_worked = False - _is_function_call = False - json_schemas: dict = {} - try: - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them - aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) - aws_access_key_id = optional_params.pop("aws_access_key_id", None) - aws_region_name = optional_params.pop("aws_region_name", None) - aws_role_name = optional_params.pop("aws_role_name", None) - aws_session_name = optional_params.pop("aws_session_name", None) - aws_profile_name = optional_params.pop("aws_profile_name", None) - aws_bedrock_runtime_endpoint = optional_params.pop( - "aws_bedrock_runtime_endpoint", None - ) - aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) - - # use passed in BedrockRuntime.Client if provided, otherwise create a new one - client = optional_params.pop("aws_bedrock_client", None) - - # only init client, if user did not pass one - if client is None: - client = init_bedrock_client( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_region_name=aws_region_name, - aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, - aws_role_name=aws_role_name, - aws_session_name=aws_session_name, - aws_profile_name=aws_profile_name, - aws_web_identity_token=aws_web_identity_token, - extra_headers=extra_headers, - timeout=timeout, - ) - - model = model - modelId = ( - optional_params.pop("model_id", None) or model - ) # default to model if not passed - provider = model.split(".")[0] - prompt = convert_messages_to_prompt( - model, messages, provider, custom_prompt_dict - ) - inference_params = copy.deepcopy(optional_params) - stream = inference_params.pop("stream", False) - if provider == "anthropic": - if model.startswith("anthropic.claude-3"): - # Separate system prompt from rest of message - system_prompt_idx: list[int] = [] - system_messages: list[str] = [] - for idx, message in enumerate(messages): - if message["role"] == "system": - system_messages.append(message["content"]) - system_prompt_idx.append(idx) - if len(system_prompt_idx) > 0: - inference_params["system"] = "\n".join(system_messages) - messages = [ - i for j, i in enumerate(messages) if j not in system_prompt_idx - ] - # Format rest of message according to anthropic guidelines - messages = prompt_factory( - model=model, messages=messages, custom_llm_provider="anthropic_xml" - ) - ## LOAD CONFIG - config = litellm.AmazonAnthropicClaude3Config.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v - ## Handle Tool Calling - if "tools" in inference_params: - _is_function_call = True - for tool in inference_params["tools"]: - json_schemas[tool["function"]["name"]] = tool["function"].get( - "parameters", None - ) - tool_calling_system_prompt = construct_tool_use_system_prompt( - tools=inference_params["tools"] - ) - inference_params["system"] = ( - inference_params.get("system", "\n") - + tool_calling_system_prompt - ) # add the anthropic tool calling prompt to the system prompt - inference_params.pop("tools") - data = json.dumps({"messages": messages, **inference_params}) - else: - ## LOAD CONFIG - config = litellm.AmazonAnthropicConfig.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v - data = json.dumps({"prompt": prompt, **inference_params}) - elif provider == "ai21": - ## LOAD CONFIG - config = litellm.AmazonAI21Config.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v - - data = json.dumps({"prompt": prompt, **inference_params}) - elif provider == "cohere": - ## LOAD CONFIG - config = litellm.AmazonCohereConfig.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v - if optional_params.get("stream", False) == True: - inference_params["stream"] = ( - True # cohere requires stream = True in inference params - ) - data = json.dumps({"prompt": prompt, **inference_params}) - elif provider == "meta": - ## LOAD CONFIG - config = litellm.AmazonLlamaConfig.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v - data = json.dumps({"prompt": prompt, **inference_params}) - elif provider == "amazon": # amazon titan - ## LOAD CONFIG - config = litellm.AmazonTitanConfig.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v - - data = json.dumps( - { - "inputText": prompt, - "textGenerationConfig": inference_params, - } - ) - elif provider == "mistral": - ## LOAD CONFIG - config = litellm.AmazonMistralConfig.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v - - data = json.dumps({"prompt": prompt, **inference_params}) - else: - data = json.dumps({}) - - ## COMPLETION CALL - accept = "application/json" - contentType = "application/json" - if stream == True and _is_function_call == False: - if provider == "ai21": - ## LOGGING - request_str = f""" - response = client.invoke_model( - body={data}, - modelId={modelId}, - accept=accept, - contentType=contentType - ) - """ - logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={ - "complete_input_dict": data, - "request_str": request_str, - }, - ) - - response = client.invoke_model( - body=data, modelId=modelId, accept=accept, contentType=contentType - ) - - response = response.get("body").read() - return response - else: - ## LOGGING - request_str = f""" - response = client.invoke_model_with_response_stream( - body={data}, - modelId={modelId}, - accept=accept, - contentType=contentType - ) - """ - logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={ - "complete_input_dict": data, - "request_str": request_str, - }, - ) - - response = client.invoke_model_with_response_stream( - body=data, modelId=modelId, accept=accept, contentType=contentType - ) - response = response.get("body") - return response - try: - ## LOGGING - request_str = f""" - response = client.invoke_model( - body={data}, - modelId={modelId}, - accept=accept, - contentType=contentType - ) - """ - logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={ - "complete_input_dict": data, - "request_str": request_str, - }, - ) - response = client.invoke_model( - body=data, modelId=modelId, accept=accept, contentType=contentType - ) - except client.exceptions.ValidationException as e: - if "The provided model identifier is invalid" in str(e): - raise BedrockError(status_code=404, message=str(e)) - raise BedrockError(status_code=400, message=str(e)) - except Exception as e: - raise BedrockError(status_code=500, message=str(e)) - - response_body = json.loads(response.get("body").read()) - - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key="", - original_response=json.dumps(response_body), - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response_body}") - ## RESPONSE OBJECT - outputText = "default" - if provider == "ai21": - outputText = response_body.get("completions")[0].get("data").get("text") - elif provider == "anthropic": - if model.startswith("anthropic.claude-3"): - outputText = response_body.get("content")[0].get("text", None) - if outputText is not None and contains_tag( - "invoke", outputText - ): # OUTPUT PARSE FUNCTION CALL - function_name = extract_between_tags("tool_name", outputText)[0] - function_arguments_str = extract_between_tags("invoke", outputText)[ - 0 - ].strip() - function_arguments_str = ( - f"{function_arguments_str}" - ) - function_arguments = parse_xml_params( - function_arguments_str, - json_schema=json_schemas.get( - function_name, None - ), # check if we have a json schema for this function name) - ) - _message = litellm.Message( - tool_calls=[ - { - "id": f"call_{uuid.uuid4()}", - "type": "function", - "function": { - "name": function_name, - "arguments": json.dumps(function_arguments), - }, - } - ], - content=None, - ) - model_response.choices[0].message = _message # type: ignore - model_response._hidden_params["original_response"] = ( - outputText # allow user to access raw anthropic tool calling response - ) - if _is_function_call == True and stream is not None and stream == True: - print_verbose( - f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK" - ) - # return an iterator - streaming_model_response = ModelResponse(stream=True) - streaming_model_response.choices[0].finish_reason = ( - model_response.choices[0].finish_reason - ) - # streaming_model_response.choices = [litellm.utils.StreamingChoices()] - streaming_choice = litellm.utils.StreamingChoices() - streaming_choice.index = model_response.choices[0].index - _tool_calls = [] - print_verbose( - f"type of model_response.choices[0]: {type(model_response.choices[0])}" - ) - print_verbose(f"type of streaming_choice: {type(streaming_choice)}") - if isinstance(model_response.choices[0], litellm.Choices): - if getattr( - model_response.choices[0].message, "tool_calls", None - ) is not None and isinstance( - model_response.choices[0].message.tool_calls, list - ): - for tool_call in model_response.choices[ - 0 - ].message.tool_calls: - _tool_call = {**tool_call.dict(), "index": 0} - _tool_calls.append(_tool_call) - delta_obj = litellm.utils.Delta( - content=getattr( - model_response.choices[0].message, "content", None - ), - role=model_response.choices[0].message.role, - tool_calls=_tool_calls, - ) - streaming_choice.delta = delta_obj - streaming_model_response.choices = [streaming_choice] - completion_stream = ModelResponseIterator( - model_response=streaming_model_response - ) - print_verbose( - f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" - ) - return litellm.CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - - model_response.choices[0].finish_reason = map_finish_reason( - response_body["stop_reason"] - ) - _usage = litellm.Usage( - prompt_tokens=response_body["usage"]["input_tokens"], - completion_tokens=response_body["usage"]["output_tokens"], - total_tokens=response_body["usage"]["input_tokens"] - + response_body["usage"]["output_tokens"], - ) - setattr(model_response, "usage", _usage) - else: - outputText = response_body["completion"] - model_response.choices[0].finish_reason = response_body["stop_reason"] - elif provider == "cohere": - outputText = response_body["generations"][0]["text"] - elif provider == "meta": - outputText = response_body["generation"] - elif provider == "mistral": - outputText = response_body["outputs"][0]["text"] - model_response.choices[0].finish_reason = response_body["outputs"][0][ - "stop_reason" - ] - else: # amazon titan - outputText = response_body.get("results")[0].get("outputText") - - response_metadata = response.get("ResponseMetadata", {}) - - if response_metadata.get("HTTPStatusCode", 500) >= 400: - raise BedrockError( - message=outputText, - status_code=response_metadata.get("HTTPStatusCode", 500), - ) - else: - try: - if ( - len(outputText) > 0 - and hasattr(model_response.choices[0], "message") - and getattr(model_response.choices[0].message, "tool_calls", None) - is None - ): - model_response.choices[0].message.content = outputText - elif ( - hasattr(model_response.choices[0], "message") - and getattr(model_response.choices[0].message, "tool_calls", None) - is not None - ): - pass - else: - raise Exception() - except: - raise BedrockError( - message=json.dumps(outputText), - status_code=response_metadata.get("HTTPStatusCode", 500), - ) - - ## CALCULATING USAGE - bedrock charges on time, not tokens - have some mapping of cost here. - if not hasattr(model_response, "usage"): - setattr(model_response, "usage", Usage()) - if getattr(model_response.usage, "total_tokens", None) is None: # type: ignore - prompt_tokens = response_metadata.get( - "x-amzn-bedrock-input-token-count", len(encoding.encode(prompt)) - ) - _text_response = model_response["choices"][0]["message"].get("content", "") - completion_tokens = response_metadata.get( - "x-amzn-bedrock-output-token-count", - len( - encoding.encode( - _text_response, - disallowed_special=(), - ) - ), - ) - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - setattr(model_response, "usage", usage) - - model_response.created = int(time.time()) - model_response.model = model - - model_response._hidden_params["region_name"] = client.meta.region_name - print_verbose(f"model_response._hidden_params: {model_response._hidden_params}") - return model_response - except BedrockError as e: - exception_mapping_worked = True - raise e - except Exception as e: - if exception_mapping_worked: - raise e - else: - import traceback - - raise BedrockError(status_code=500, message=traceback.format_exc()) - - -class ModelResponseIterator: - def __init__(self, model_response): - self.model_response = model_response - self.is_done = False - - # Sync iterator - def __iter__(self): - return self - - def __next__(self): - if self.is_done: - raise StopIteration - self.is_done = True - return self.model_response - - # Async iterator - def __aiter__(self): - return self - - async def __anext__(self): - if self.is_done: - raise StopAsyncIteration - self.is_done = True - return self.model_response - - -def _embedding_func_single( - model: str, - input: str, - client: Any, - optional_params=None, - encoding=None, - logging_obj=None, -): - if isinstance(input, str) is False: - raise BedrockError( - message="Bedrock Embedding API input must be type str | List[str]", - status_code=400, - ) - # logic for parsing in - calling - parsing out model embedding calls - ## FORMAT EMBEDDING INPUT ## - provider = model.split(".")[0] - inference_params = copy.deepcopy(optional_params) - inference_params.pop( - "user", None - ) # make sure user is not passed in for bedrock call - modelId = ( - optional_params.pop("model_id", None) or model - ) # default to model if not passed - if provider == "amazon": - input = input.replace(os.linesep, " ") - data = {"inputText": input, **inference_params} - # data = json.dumps(data) - elif provider == "cohere": - inference_params["input_type"] = inference_params.get( - "input_type", "search_document" - ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3 - data = {"texts": [input], **inference_params} # type: ignore - body = json.dumps(data).encode("utf-8") # type: ignore - ## LOGGING - request_str = f""" - response = client.invoke_model( - body={body}, - modelId={modelId}, - accept="*/*", - contentType="application/json", - )""" # type: ignore - logging_obj.pre_call( - input=input, - api_key="", # boto3 is used for init. - additional_args={ - "complete_input_dict": {"model": modelId, "texts": input}, - "request_str": request_str, - }, - ) - try: - response = client.invoke_model( - body=body, - modelId=modelId, - accept="*/*", - contentType="application/json", - ) - response_body = json.loads(response.get("body").read()) - ## LOGGING - logging_obj.post_call( - input=input, - api_key="", - additional_args={"complete_input_dict": data}, - original_response=json.dumps(response_body), - ) - if provider == "cohere": - response = response_body.get("embeddings") - # flatten list - response = [item for sublist in response for item in sublist] - return response - elif provider == "amazon": - return response_body.get("embedding") - except Exception as e: - raise BedrockError( - message=f"Embedding Error with model {model}: {e}", status_code=500 - ) - - -def embedding( - model: str, - input: Union[list, str], - model_response: litellm.EmbeddingResponse, - api_key: Optional[str] = None, - logging_obj=None, - optional_params=None, - encoding=None, -): - ### BOTO3 INIT ### - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them - aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) - aws_access_key_id = optional_params.pop("aws_access_key_id", None) - aws_region_name = optional_params.pop("aws_region_name", None) - aws_role_name = optional_params.pop("aws_role_name", None) - aws_session_name = optional_params.pop("aws_session_name", None) - aws_bedrock_runtime_endpoint = optional_params.pop( - "aws_bedrock_runtime_endpoint", None - ) - aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) - - # use passed in BedrockRuntime.Client if provided, otherwise create a new one - client = init_bedrock_client( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_region_name=aws_region_name, - aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, - aws_web_identity_token=aws_web_identity_token, - aws_role_name=aws_role_name, - aws_session_name=aws_session_name, - ) - if isinstance(input, str): - ## Embedding Call - embeddings = [ - _embedding_func_single( - model, - input, - optional_params=optional_params, - client=client, - logging_obj=logging_obj, - ) - ] - elif isinstance(input, list): - ## Embedding Call - assuming this is a List[str] - embeddings = [ - _embedding_func_single( - model, - i, - optional_params=optional_params, - client=client, - logging_obj=logging_obj, - ) - for i in input - ] # [TODO]: make these parallel calls - else: - # enters this branch if input = int, ex. input=2 - raise BedrockError( - message="Bedrock Embedding API input must be type str | List[str]", - status_code=400, - ) - - ## Populate OpenAI compliant dictionary - embedding_response = [] - for idx, embedding in enumerate(embeddings): - embedding_response.append( - { - "object": "embedding", - "index": idx, - "embedding": embedding, - } - ) - model_response.object = "list" - model_response.data = embedding_response - model_response.model = model - input_tokens = 0 - - input_str = "".join(input) - - input_tokens += len(encoding.encode(input_str)) - - usage = Usage( - prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + 0 - ) - model_response.usage = usage - - return model_response - - -def image_generation( - model: str, - prompt: str, - model_response: ImageResponse, - optional_params: dict, - timeout=None, - logging_obj=None, - aimg_generation=False, -): - """ - Bedrock Image Gen endpoint support - """ - ### BOTO3 INIT ### - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them - aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) - aws_access_key_id = optional_params.pop("aws_access_key_id", None) - aws_region_name = optional_params.pop("aws_region_name", None) - aws_role_name = optional_params.pop("aws_role_name", None) - aws_session_name = optional_params.pop("aws_session_name", None) - aws_bedrock_runtime_endpoint = optional_params.pop( - "aws_bedrock_runtime_endpoint", None - ) - aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) - - # use passed in BedrockRuntime.Client if provided, otherwise create a new one - client = init_bedrock_client( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_region_name=aws_region_name, - aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, - aws_web_identity_token=aws_web_identity_token, - aws_role_name=aws_role_name, - aws_session_name=aws_session_name, - timeout=timeout, - ) - - ### FORMAT IMAGE GENERATION INPUT ### - modelId = model - provider = model.split(".")[0] - inference_params = copy.deepcopy(optional_params) - inference_params.pop( - "user", None - ) # make sure user is not passed in for bedrock call - data = {} - if provider == "stability": - prompt = prompt.replace(os.linesep, " ") - ## LOAD CONFIG - config = litellm.AmazonStabilityConfig.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v - data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params} - else: - raise BedrockError( - status_code=422, message=f"Unsupported model={model}, passed in" - ) - - body = json.dumps(data).encode("utf-8") - ## LOGGING - request_str = f""" - response = client.invoke_model( - body={body}, # type: ignore - modelId={modelId}, - accept="application/json", - contentType="application/json", - )""" # type: ignore - logging_obj.pre_call( - input=prompt, - api_key="", # boto3 is used for init. - additional_args={ - "complete_input_dict": {"model": modelId, "texts": prompt}, - "request_str": request_str, - }, - ) - try: - response = client.invoke_model( - body=body, - modelId=modelId, - accept="application/json", - contentType="application/json", - ) - response_body = json.loads(response.get("body").read()) - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key="", - additional_args={"complete_input_dict": data}, - original_response=json.dumps(response_body), - ) - except Exception as e: - raise BedrockError( - message=f"Embedding Error with model {model}: {e}", status_code=500 - ) - - ### FORMAT RESPONSE TO OPENAI FORMAT ### - if response_body is None: - raise Exception("Error in response object format") - - if model_response is None: - model_response = ImageResponse() - - image_list: List[Image] = [] - for artifact in response_body["artifacts"]: - _image = Image(b64_json=artifact["base64"]) - image_list.append(_image) - - model_response.data = image_list - return model_response diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock/chat.py similarity index 97% rename from litellm/llms/bedrock_httpx.py rename to litellm/llms/bedrock/chat.py index 5980463e5..0289b5dc3 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock/chat.py @@ -1,6 +1,7 @@ -# What is this? -## Initial implementation of calling bedrock via httpx client (allows for async calls). -## V1 - covers cohere + anthropic claude-3 support +""" +Manages calling Bedrock's `/converse` API + `/invoke` API +""" + import copy import json import os @@ -28,7 +29,7 @@ import requests # type: ignore import litellm from litellm import verbose_logger -from litellm.caching import DualCache, InMemoryCache +from litellm.caching import InMemoryCache from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.litellm_logging import Logging from litellm.llms.custom_httpx.http_handler import ( @@ -39,21 +40,16 @@ from litellm.llms.custom_httpx.http_handler import ( ) from litellm.types.llms.bedrock import * from litellm.types.llms.openai import ( - ChatCompletionDeltaChunk, ChatCompletionResponseMessage, ChatCompletionToolCallChunk, ChatCompletionToolCallFunctionChunk, ChatCompletionUsageBlock, ) -from litellm.types.utils import Choices from litellm.types.utils import GenericStreamingChunk as GChunk -from litellm.types.utils import Message from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret -from .base import BaseLLM -from .base_aws_llm import BaseAWSLLM -from .bedrock import BedrockError, ModelResponseIterator, convert_messages_to_prompt -from .prompt_templates.factory import ( +from ..base_aws_llm import BaseAWSLLM +from ..prompt_templates.factory import ( _bedrock_converse_messages_pt, _bedrock_tools_pt, cohere_message_pt, @@ -64,6 +60,7 @@ from .prompt_templates.factory import ( parse_xml_params, prompt_factory, ) +from .common_utils import BedrockError, ModelResponseIterator, get_runtime_endpoint BEDROCK_CONVERSE_MODELS = [ "anthropic.claude-3-5-sonnet-20240620-v1:0", @@ -727,22 +724,13 @@ class BedrockLLM(BaseAWSLLM): ) ### SET RUNTIME ENDPOINT ### - endpoint_url = "" - env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") - if api_base is not None: - endpoint_url = api_base - elif aws_bedrock_runtime_endpoint is not None and isinstance( - aws_bedrock_runtime_endpoint, str - ): - endpoint_url = aws_bedrock_runtime_endpoint - elif env_aws_bedrock_runtime_endpoint and isinstance( - env_aws_bedrock_runtime_endpoint, str - ): - endpoint_url = env_aws_bedrock_runtime_endpoint - else: - endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + endpoint_url = get_runtime_endpoint( + api_base=api_base, + aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + aws_region_name=aws_region_name, + ) - if (stream is not None and stream == True) and provider != "ai21": + if (stream is not None and stream is True) and provider != "ai21": endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream" else: endpoint_url = f"{endpoint_url}/model/{modelId}/invoke" @@ -1561,21 +1549,11 @@ class BedrockConverseLLM(BaseAWSLLM): ) ### SET RUNTIME ENDPOINT ### - endpoint_url = "" - env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") - if api_base is not None: - endpoint_url = api_base - elif aws_bedrock_runtime_endpoint is not None and isinstance( - aws_bedrock_runtime_endpoint, str - ): - endpoint_url = aws_bedrock_runtime_endpoint - elif env_aws_bedrock_runtime_endpoint and isinstance( - env_aws_bedrock_runtime_endpoint, str - ): - endpoint_url = env_aws_bedrock_runtime_endpoint - else: - endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" - + endpoint_url = get_runtime_endpoint( + api_base=api_base, + aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + aws_region_name=aws_region_name, + ) if (stream is not None and stream is True) and provider != "ai21": endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" else: diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py new file mode 100644 index 000000000..19a4f0986 --- /dev/null +++ b/litellm/llms/bedrock/common_utils.py @@ -0,0 +1,773 @@ +""" +Common utilities used across bedrock chat/embedding/image generation +""" + +import os +import types +from enum import Enum +from typing import List, Optional, Union + +import httpx + +import litellm +from litellm import get_secret + + +class BedrockError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class AmazonBedrockGlobalConfig: + def __init__(self): + pass + + def get_mapped_special_auth_params(self) -> dict: + """ + Mapping of common auth params across bedrock/vertex/azure/watsonx + """ + return {"region_name": "aws_region_name"} + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + + def get_eu_regions(self) -> List[str]: + """ + Source: https://www.aws-services.info/bedrock.html + """ + return [ + "eu-west-1", + "eu-west-3", + "eu-central-1", + ] + + +class AmazonTitanConfig: + """ + Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1 + + Supported Params for the Amazon Titan models: + + - `maxTokenCount` (integer) max tokens, + - `stopSequences` (string[]) list of stop sequence strings + - `temperature` (float) temperature for model, + - `topP` (int) top p for model + """ + + maxTokenCount: Optional[int] = None + stopSequences: Optional[list] = None + temperature: Optional[float] = None + topP: Optional[int] = None + + def __init__( + self, + maxTokenCount: Optional[int] = None, + stopSequences: Optional[list] = None, + temperature: Optional[float] = None, + topP: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + +class AmazonAnthropicClaude3Config: + """ + Reference: + https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude + https://docs.anthropic.com/claude/docs/models-overview#model-comparison + + Supported Params for the Amazon / Anthropic Claude 3 models: + + - `max_tokens` Required (integer) max tokens. Default is 4096 + - `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" + - `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py + - `temperature` Optional (float) The amount of randomness injected into the response + - `top_p` Optional (float) Use nucleus sampling. + - `top_k` Optional (int) Only sample from the top K options for each subsequent token + - `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating + """ + + max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default + anthropic_version: Optional[str] = "bedrock-2023-05-31" + system: Optional[str] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + stop_sequences: Optional[List[str]] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + anthropic_version: Optional[str] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "max_tokens", + "tools", + "tool_choice", + "stream", + "stop", + "temperature", + "top_p", + "extra_headers", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_tokens"] = value + if param == "tools": + optional_params["tools"] = value + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop_sequences"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + return optional_params + + +class AmazonAnthropicConfig: + """ + Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude + + Supported Params for the Amazon / Anthropic models: + + - `max_tokens_to_sample` (integer) max tokens, + - `temperature` (float) model temperature, + - `top_k` (integer) top k, + - `top_p` (integer) top p, + - `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"], + - `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" + """ + + max_tokens_to_sample: Optional[int] = litellm.max_tokens + stop_sequences: Optional[list] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[int] = None + anthropic_version: Optional[str] = None + + def __init__( + self, + max_tokens_to_sample: Optional[int] = None, + stop_sequences: Optional[list] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[int] = None, + anthropic_version: Optional[str] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params( + self, + ): + return ["max_tokens", "temperature", "stop", "top_p", "stream"] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_tokens_to_sample"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "stop": + optional_params["stop_sequences"] = value + if param == "stream" and value == True: + optional_params["stream"] = value + return optional_params + + +class AmazonCohereConfig: + """ + Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command + + Supported Params for the Amazon / Cohere models: + + - `max_tokens` (integer) max tokens, + - `temperature` (float) model temperature, + - `return_likelihood` (string) n/a + """ + + max_tokens: Optional[int] = None + temperature: Optional[float] = None + return_likelihood: Optional[str] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + return_likelihood: Optional[str] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + +class AmazonAI21Config: + """ + Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra + + Supported Params for the Amazon / AI21 models: + + - `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`. + + - `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding. + + - `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass. + + - `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional. + + - `frequencyPenalty` (object): Placeholder for frequency penalty object. + + - `presencePenalty` (object): Placeholder for presence penalty object. + + - `countPenalty` (object): Placeholder for count penalty object. + """ + + maxTokens: Optional[int] = None + temperature: Optional[float] = None + topP: Optional[float] = None + stopSequences: Optional[list] = None + frequencePenalty: Optional[dict] = None + presencePenalty: Optional[dict] = None + countPenalty: Optional[dict] = None + + def __init__( + self, + maxTokens: Optional[int] = None, + temperature: Optional[float] = None, + topP: Optional[float] = None, + stopSequences: Optional[list] = None, + frequencePenalty: Optional[dict] = None, + presencePenalty: Optional[dict] = None, + countPenalty: Optional[dict] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + +class AnthropicConstants(Enum): + HUMAN_PROMPT = "\n\nHuman: " + AI_PROMPT = "\n\nAssistant: " + + +class AmazonLlamaConfig: + """ + Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1 + + Supported Params for the Amazon / Meta Llama models: + + - `max_gen_len` (integer) max tokens, + - `temperature` (float) temperature for model, + - `top_p` (float) top p for model + """ + + max_gen_len: Optional[int] = None + temperature: Optional[float] = None + topP: Optional[float] = None + + def __init__( + self, + maxTokenCount: Optional[int] = None, + temperature: Optional[float] = None, + topP: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + +class AmazonMistralConfig: + """ + Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html + Supported Params for the Amazon / Mistral models: + + - `max_tokens` (integer) max tokens, + - `temperature` (float) temperature for model, + - `top_p` (float) top p for model + - `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output. + - `top_k` (float) top k for model + """ + + max_tokens: Optional[int] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[float] = None + stop: Optional[List[str]] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[int] = None, + top_k: Optional[float] = None, + stop: Optional[List[str]] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + +class AmazonStabilityConfig: + """ + Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0 + + Supported Params for the Amazon / Stable Diffusion models: + + - `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt) + + - `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed) + + - `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run. + + - `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64. + Engine-specific dimension validation: + + - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. + - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 + - SDXL v1.0: same as SDXL v0.9 + - SD v1.6: must be between 320x320 and 1536x1536 + + - `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64. + Engine-specific dimension validation: + + - SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512. + - SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152 + - SDXL v1.0: same as SDXL v0.9 + - SD v1.6: must be between 320x320 and 1536x1536 + """ + + cfg_scale: Optional[int] = None + seed: Optional[float] = None + steps: Optional[List[str]] = None + width: Optional[int] = None + height: Optional[int] = None + + def __init__( + self, + cfg_scale: Optional[int] = None, + seed: Optional[float] = None, + steps: Optional[List[str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + +def add_custom_header(headers): + """Closure to capture the headers and add them.""" + + def callback(request, **kwargs): + """Actual callback function that Boto3 will call.""" + for header_name, header_value in headers.items(): + request.headers.add_header(header_name, header_value) + + return callback + + +def init_bedrock_client( + region_name=None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_bedrock_runtime_endpoint: Optional[str] = None, + aws_session_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + aws_role_name: Optional[str] = None, + aws_web_identity_token: Optional[str] = None, + extra_headers: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, +): + # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + standard_aws_region_name = get_secret("AWS_REGION", None) + ## CHECK IS 'os.environ/' passed in + # Define the list of parameters to check + params_to_check = [ + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_bedrock_runtime_endpoint, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + ] + + # Iterate over parameters and update if needed + for i, param in enumerate(params_to_check): + if param and param.startswith("os.environ/"): + params_to_check[i] = get_secret(param) + # Assign updated values back to parameters + ( + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_bedrock_runtime_endpoint, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + ) = params_to_check + + # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. + ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) + + ### SET REGION NAME + if region_name: + pass + elif aws_region_name: + region_name = aws_region_name + elif litellm_aws_region_name: + region_name = litellm_aws_region_name + elif standard_aws_region_name: + region_name = standard_aws_region_name + else: + raise BedrockError( + message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", + status_code=401, + ) + + # check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client + env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") + if aws_bedrock_runtime_endpoint: + endpoint_url = aws_bedrock_runtime_endpoint + elif env_aws_bedrock_runtime_endpoint: + endpoint_url = env_aws_bedrock_runtime_endpoint + else: + endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com" + + import boto3 + + if isinstance(timeout, float): + config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) + elif isinstance(timeout, httpx.Timeout): + config = boto3.session.Config( + connect_timeout=timeout.connect, read_timeout=timeout.read + ) + else: + config = boto3.session.Config() + + ### CHECK STS ### + if ( + aws_web_identity_token is not None + and aws_role_name is not None + and aws_session_name is not None + ): + oidc_token = get_secret(aws_web_identity_token) + + if oidc_token is None: + raise BedrockError( + message="OIDC token could not be retrieved from secret manager.", + status_code=401, + ) + + sts_client = boto3.client("sts") + + # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=aws_role_name, + RoleSessionName=aws_session_name, + WebIdentityToken=oidc_token, + DurationSeconds=3600, + ) + + client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], + aws_session_token=sts_response["Credentials"]["SessionToken"], + region_name=region_name, + endpoint_url=endpoint_url, + config=config, + verify=ssl_verify, + ) + elif aws_role_name is not None and aws_session_name is not None: + # use sts if role name passed in + sts_client = boto3.client( + "sts", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) + + sts_response = sts_client.assume_role( + RoleArn=aws_role_name, RoleSessionName=aws_session_name + ) + + client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], + aws_session_token=sts_response["Credentials"]["SessionToken"], + region_name=region_name, + endpoint_url=endpoint_url, + config=config, + verify=ssl_verify, + ) + elif aws_access_key_id is not None: + # uses auth params passed to completion + # aws_access_key_id is not None, assume user is trying to auth using litellm.completion + + client = boto3.client( + service_name="bedrock-runtime", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=region_name, + endpoint_url=endpoint_url, + config=config, + verify=ssl_verify, + ) + elif aws_profile_name is not None: + # uses auth values from AWS profile usually stored in ~/.aws/credentials + + client = boto3.Session(profile_name=aws_profile_name).client( + service_name="bedrock-runtime", + region_name=region_name, + endpoint_url=endpoint_url, + config=config, + verify=ssl_verify, + ) + else: + # aws_access_key_id is None, assume user is trying to auth using env variables + # boto3 automatically reads env variables + + client = boto3.client( + service_name="bedrock-runtime", + region_name=region_name, + endpoint_url=endpoint_url, + config=config, + verify=ssl_verify, + ) + if extra_headers: + client.meta.events.register( + "before-sign.bedrock-runtime.*", add_custom_header(extra_headers) + ) + + return client + + +def get_runtime_endpoint( + api_base: Optional[str], + aws_bedrock_runtime_endpoint: Optional[str], + aws_region_name: str, +) -> str: + env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") + if api_base is not None: + endpoint_url = api_base + elif aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + endpoint_url = aws_bedrock_runtime_endpoint + elif env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + endpoint_url = env_aws_bedrock_runtime_endpoint + else: + endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + + return endpoint_url + + +class ModelResponseIterator: + def __init__(self, model_response): + self.model_response = model_response + self.is_done = False + + # Sync iterator + def __iter__(self): + return self + + def __next__(self): + if self.is_done: + raise StopIteration + self.is_done = True + return self.model_response + + # Async iterator + def __aiter__(self): + return self + + async def __anext__(self): + if self.is_done: + raise StopAsyncIteration + self.is_done = True + return self.model_response diff --git a/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py b/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py new file mode 100644 index 000000000..c86bade5d --- /dev/null +++ b/litellm/llms/bedrock/embed/amazon_titan_g1_transformation.py @@ -0,0 +1,149 @@ +""" +Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format. + +Why separate file? Make it easy to see how transformation works + +Convers +- G1 request format + +Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html +""" + +import types +from typing import List, Optional + +from litellm.types.llms.bedrock import ( + AmazonTitanG1EmbeddingRequest, + AmazonTitanG1EmbeddingResponse, + AmazonTitanV2EmbeddingRequest, + AmazonTitanV2EmbeddingResponse, +) +from litellm.types.utils import Embedding, EmbeddingResponse, Usage + + +class AmazonTitanG1Config: + """ + Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html + """ + + def __init__( + self, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def _transform_request( + self, input: str, inference_params: dict + ) -> AmazonTitanG1EmbeddingRequest: + return AmazonTitanG1EmbeddingRequest(inputText=input) + + def _transform_response( + self, response_list: List[dict], model: str + ) -> EmbeddingResponse: + total_prompt_tokens = 0 + + transformed_responses: List[Embedding] = [] + for index, response in enumerate(response_list): + _parsed_response = AmazonTitanG1EmbeddingResponse(**response) # type: ignore + transformed_responses.append( + Embedding( + embedding=_parsed_response["embedding"], + index=index, + object="embedding", + ) + ) + total_prompt_tokens += _parsed_response["inputTextTokenCount"] + + usage = Usage( + prompt_tokens=total_prompt_tokens, + completion_tokens=0, + total_tokens=total_prompt_tokens, + ) + return EmbeddingResponse(model=model, usage=usage, data=transformed_responses) + + +class AmazonTitanV2Config: + """ + Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html + + normalize: boolean - flag indicating whether or not to normalize the output embeddings. Defaults to true + dimensions: int - The number of dimensions the output embeddings should have. The following values are accepted: 1024 (default), 512, 256. + """ + + normalize: Optional[bool] = None + dimensions: Optional[int] = None + + def __init__( + self, normalize: Optional[bool] = None, dimensions: Optional[int] = None + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def _transform_request( + self, input: str, inference_params: dict + ) -> AmazonTitanV2EmbeddingRequest: + return AmazonTitanV2EmbeddingRequest(inputText=input, **inference_params) # type: ignore + + def _transform_response( + self, response_list: List[dict], model: str + ) -> EmbeddingResponse: + total_prompt_tokens = 0 + + transformed_responses: List[Embedding] = [] + for index, response in enumerate(response_list): + _parsed_response = AmazonTitanV2EmbeddingResponse(**response) # type: ignore + transformed_responses.append( + Embedding( + embedding=_parsed_response["embedding"], + index=index, + object="embedding", + ) + ) + total_prompt_tokens += _parsed_response["inputTextTokenCount"] + + usage = Usage( + prompt_tokens=total_prompt_tokens, + completion_tokens=0, + total_tokens=total_prompt_tokens, + ) + return EmbeddingResponse(model=model, usage=usage, data=transformed_responses) diff --git a/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py b/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py new file mode 100644 index 000000000..7e2b6176d --- /dev/null +++ b/litellm/llms/bedrock/embed/amazon_titan_multimodal_transformation.py @@ -0,0 +1,54 @@ +""" +Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan multimodal /invoke format. + +Why separate file? Make it easy to see how transformation works + +Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-mm.html +""" + +from typing import List + +from litellm.types.llms.bedrock import ( + AmazonTitanMultimodalEmbeddingConfig, + AmazonTitanMultimodalEmbeddingRequest, + AmazonTitanMultimodalEmbeddingResponse, +) +from litellm.types.utils import Embedding, EmbeddingResponse, Usage +from litellm.utils import is_base64_encoded + + +def _transform_request( + input: str, inference_params: dict +) -> AmazonTitanMultimodalEmbeddingRequest: + ## check if b64 encoded str or not ## + is_encoded = is_base64_encoded(input) + if is_encoded: # check if string is b64 encoded image or not + transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputImage=input) + else: + transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputText=input) + + for k, v in inference_params.items(): + transformed_request[k] = v # type: ignore + + return transformed_request + + +def _transform_response(response_list: List[dict], model: str) -> EmbeddingResponse: + + total_prompt_tokens = 0 + transformed_responses: List[Embedding] = [] + for index, response in enumerate(response_list): + _parsed_response = AmazonTitanMultimodalEmbeddingResponse(**response) # type: ignore + transformed_responses.append( + Embedding( + embedding=_parsed_response["embedding"], index=index, object="embedding" + ) + ) + total_prompt_tokens += _parsed_response["inputTextTokenCount"] + + usage = Usage( + prompt_tokens=total_prompt_tokens, + completion_tokens=0, + total_tokens=total_prompt_tokens, + ) + return EmbeddingResponse(model=model, usage=usage, data=transformed_responses) diff --git a/litellm/llms/bedrock/embed/amazon_titan_v2_transformation.py b/litellm/llms/bedrock/embed/amazon_titan_v2_transformation.py new file mode 100644 index 000000000..a9c980dbb --- /dev/null +++ b/litellm/llms/bedrock/embed/amazon_titan_v2_transformation.py @@ -0,0 +1,86 @@ +""" +Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan V2 /invoke format. + +Why separate file? Make it easy to see how transformation works + +Convers +- v2 request format + +Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html +""" + +import types +from typing import List, Optional + +from litellm.types.llms.bedrock import ( + AmazonTitanV2EmbeddingRequest, + AmazonTitanV2EmbeddingResponse, +) +from litellm.types.utils import Embedding, EmbeddingResponse, Usage + + +class AmazonTitanV2Config: + """ + Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html + + normalize: boolean - flag indicating whether or not to normalize the output embeddings. Defaults to true + dimensions: int - The number of dimensions the output embeddings should have. The following values are accepted: 1024 (default), 512, 256. + """ + + normalize: Optional[bool] = None + dimensions: Optional[int] = None + + def __init__( + self, normalize: Optional[bool] = None, dimensions: Optional[int] = None + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def _transform_request( + self, input: str, inference_params: dict + ) -> AmazonTitanV2EmbeddingRequest: + return AmazonTitanV2EmbeddingRequest(inputText=input, **inference_params) # type: ignore + + def _transform_response( + self, response_list: List[dict], model: str + ) -> EmbeddingResponse: + total_prompt_tokens = 0 + + transformed_responses: List[Embedding] = [] + for index, response in enumerate(response_list): + _parsed_response = AmazonTitanV2EmbeddingResponse(**response) # type: ignore + transformed_responses.append( + Embedding( + embedding=_parsed_response["embedding"], + index=index, + object="embedding", + ) + ) + total_prompt_tokens += _parsed_response["inputTextTokenCount"] + + usage = Usage( + prompt_tokens=total_prompt_tokens, + completion_tokens=0, + total_tokens=total_prompt_tokens, + ) + return EmbeddingResponse(model=model, usage=usage, data=transformed_responses) diff --git a/litellm/llms/bedrock/embed/cohere_transformation.py b/litellm/llms/bedrock/embed/cohere_transformation.py new file mode 100644 index 000000000..2d5fbe8c2 --- /dev/null +++ b/litellm/llms/bedrock/embed/cohere_transformation.py @@ -0,0 +1,25 @@ +""" +Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format. + +Why separate file? Make it easy to see how transformation works +""" + +from typing import List + +import litellm +from litellm.types.llms.bedrock import CohereEmbeddingRequest, CohereEmbeddingResponse +from litellm.types.utils import Embedding, EmbeddingResponse + + +def _transform_request( + input: List[str], inference_params: dict +) -> CohereEmbeddingRequest: + transformed_request = CohereEmbeddingRequest( + texts=input, + input_type=litellm.COHERE_DEFAULT_EMBEDDING_INPUT_TYPE, # type: ignore + ) + + for k, v in inference_params.items(): + transformed_request[k] = v # type: ignore + + return transformed_request diff --git a/litellm/llms/bedrock/embed/embedding.py b/litellm/llms/bedrock/embed/embedding.py new file mode 100644 index 000000000..6ad463cd0 --- /dev/null +++ b/litellm/llms/bedrock/embed/embedding.py @@ -0,0 +1,498 @@ +""" +Handles embedding calls to Bedrock's `/invoke` endpoint +""" + +import copy +import json +import os +from copy import deepcopy +from typing import Any, Callable, List, Literal, Optional, Tuple, Union + +import httpx + +import litellm +from litellm import get_secret +from litellm.llms.cohere.embed import embedding as cohere_embedding +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_httpx_client, +) +from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest +from litellm.types.utils import Embedding, EmbeddingResponse, Usage + +from ...base_aws_llm import BaseAWSLLM +from ..common_utils import BedrockError, get_runtime_endpoint +from .amazon_titan_g1_transformation import AmazonTitanG1Config +from .amazon_titan_multimodal_transformation import ( + _transform_request as amazon_multimodal_transform_request, +) +from .amazon_titan_multimodal_transformation import ( + _transform_response as amazon_multimodal_transform_response, +) +from .amazon_titan_v2_transformation import AmazonTitanV2Config +from .cohere_transformation import _transform_request as cohere_transform_request + + +class BedrockEmbedding(BaseAWSLLM): + def _load_credentials( + self, + optional_params: dict, + ) -> Tuple[Any, str]: + try: + from botocore.credentials import Credentials + except ImportError as e: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_profile_name = optional_params.pop("aws_profile_name", None) + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) + aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) + + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + aws_web_identity_token=aws_web_identity_token, + aws_sts_endpoint=aws_sts_endpoint, + ) + return credentials, aws_region_name + + async def async_embeddings(self): + pass + + def _make_sync_call( + self, + client: Optional[HTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + api_base: str, + headers: dict, + data: dict, + ) -> dict: + if client is None or not isinstance(client, HTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = _get_httpx_client(_params) # type: ignore + else: + client = client + try: + response = client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=response.text) + except httpx.TimeoutException: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + return response.json() + + def _single_func_embeddings( + self, + client: Optional[HTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + batch_data: List[dict], + credentials: Any, + extra_headers: Optional[dict], + endpoint_url: str, + aws_region_name: str, + model: str, + logging_obj: Any, + ): + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + responses: List[dict] = [] + for data in batch_data: + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( + method="POST", url=endpoint_url, data=json.dumps(data), headers=headers + ) + sigv4.add_auth(request) + if ( + extra_headers is not None and "Authorization" in extra_headers + ): # prevent sigv4 from overwriting the auth header + request.headers["Authorization"] = extra_headers["Authorization"] + prepped = request.prepare() + + ## LOGGING + logging_obj.pre_call( + input=data, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": prepped.url, + "headers": prepped.headers, + }, + ) + response = self._make_sync_call( + client=client, + timeout=timeout, + api_base=prepped.url, + headers=prepped.headers, + data=data, + ) + + ## LOGGING + logging_obj.post_call( + input=data, + api_key="", + original_response=response, + additional_args={"complete_input_dict": data}, + ) + + responses.append(response) + + returned_response: Optional[EmbeddingResponse] = None + + ## TRANSFORM RESPONSE ## + if model == "amazon.titan-embed-image-v1": + returned_response = amazon_multimodal_transform_response( + response_list=responses, model=model + ) + elif model == "amazon.titan-embed-text-v1": + returned_response = AmazonTitanG1Config()._transform_response( + response_list=responses, model=model + ) + elif model == "amazon.titan-embed-text-v2:0": + returned_response = AmazonTitanV2Config()._transform_response( + response_list=responses, model=model + ) + + if returned_response is None: + raise Exception( + "Unable to map model response to known provider format. model={}".format( + model + ) + ) + + return returned_response + + def embeddings( + self, + model: str, + input: List[str], + api_base: Optional[str], + model_response: EmbeddingResponse, + print_verbose: Callable, + encoding, + logging_obj, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]], + timeout: Optional[Union[float, httpx.Timeout]], + aembedding: Optional[bool], + extra_headers: Optional[dict], + optional_params=None, + litellm_params=None, + ) -> EmbeddingResponse: + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + credentials, aws_region_name = self._load_credentials(optional_params) + + ### TRANSFORMATION ### + provider = model.split(".")[0] + inference_params = copy.deepcopy(optional_params) + inference_params.pop( + "user", None + ) # make sure user is not passed in for bedrock call + modelId = ( + optional_params.pop("model_id", None) or model + ) # default to model if not passed + + data: Optional[CohereEmbeddingRequest] = None + batch_data: Optional[List] = None + if provider == "cohere": + data = cohere_transform_request( + input=input, inference_params=inference_params + ) + elif provider == "amazon" and model in [ + "amazon.titan-embed-image-v1", + "amazon.titan-embed-text-v1", + "amazon.titan-embed-text-v2:0", + ]: + batch_data = [] + for i in input: + if model == "amazon.titan-embed-image-v1": + transformed_request: AmazonEmbeddingRequest = ( + amazon_multimodal_transform_request( + input=i, inference_params=inference_params + ) + ) + elif model == "amazon.titan-embed-text-v1": + transformed_request = AmazonTitanG1Config()._transform_request( + input=i, inference_params=inference_params + ) + elif model == "amazon.titan-embed-text-v2:0": + transformed_request = AmazonTitanV2Config()._transform_request( + input=i, inference_params=inference_params + ) + batch_data.append(transformed_request) + + ### SET RUNTIME ENDPOINT ### + endpoint_url = get_runtime_endpoint( + api_base=api_base, + aws_bedrock_runtime_endpoint=optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ), + aws_region_name=aws_region_name, + ) + endpoint_url = f"{endpoint_url}/model/{modelId}/invoke" + + if batch_data is not None: + return self._single_func_embeddings( + client=( + client + if client is not None and isinstance(client, HTTPHandler) + else None + ), + timeout=timeout, + batch_data=batch_data, + credentials=credentials, + extra_headers=extra_headers, + endpoint_url=endpoint_url, + aws_region_name=aws_region_name, + model=model, + logging_obj=logging_obj, + ) + elif data is None: + raise Exception("Unable to map request to provider") + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( + method="POST", url=endpoint_url, data=json.dumps(data), headers=headers + ) + sigv4.add_auth(request) + if ( + extra_headers is not None and "Authorization" in extra_headers + ): # prevent sigv4 from overwriting the auth header + request.headers["Authorization"] = extra_headers["Authorization"] + prepped = request.prepare() + + ## ROUTING ## + return cohere_embedding( + model=model, + input=input, + model_response=model_response, + logging_obj=logging_obj, + optional_params=optional_params, + encoding=encoding, + data=data, # type: ignore + complete_api_base=prepped.url, + api_key=None, + aembedding=aembedding, + timeout=timeout, + client=client, + headers=prepped.headers, + ) + + # def _embedding_func_single( + # model: str, + # input: str, + # client: Any, + # optional_params=None, + # encoding=None, + # logging_obj=None, + # ): + # if isinstance(input, str) is False: + # raise BedrockError( + # message="Bedrock Embedding API input must be type str | List[str]", + # status_code=400, + # ) + # # logic for parsing in - calling - parsing out model embedding calls + # ## FORMAT EMBEDDING INPUT ## + # provider = model.split(".")[0] + # inference_params = copy.deepcopy(optional_params) + # inference_params.pop( + # "user", None + # ) # make sure user is not passed in for bedrock call + # modelId = ( + # optional_params.pop("model_id", None) or model + # ) # default to model if not passed + # if provider == "amazon": + # input = input.replace(os.linesep, " ") + # data = {"inputText": input, **inference_params} + # # data = json.dumps(data) + # elif provider == "cohere": + # inference_params["input_type"] = inference_params.get( + # "input_type", "search_document" + # ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3 + # data = {"texts": [input], **inference_params} # type: ignore + # body = json.dumps(data).encode("utf-8") # type: ignore + # ## LOGGING + # request_str = f""" + # response = client.invoke_model( + # body={body}, + # modelId={modelId}, + # accept="*/*", + # contentType="application/json", + # )""" # type: ignore + # logging_obj.pre_call( + # input=input, + # api_key="", # boto3 is used for init. + # additional_args={ + # "complete_input_dict": {"model": modelId, "texts": input}, + # "request_str": request_str, + # }, + # ) + # try: + # response = client.invoke_model( + # body=body, + # modelId=modelId, + # accept="*/*", + # contentType="application/json", + # ) + # response_body = json.loads(response.get("body").read()) + # ## LOGGING + # logging_obj.post_call( + # input=input, + # api_key="", + # additional_args={"complete_input_dict": data}, + # original_response=json.dumps(response_body), + # ) + # if provider == "cohere": + # response = response_body.get("embeddings") + # # flatten list + # response = [item for sublist in response for item in sublist] + # return response + # elif provider == "amazon": + # return response_body.get("embedding") + # except Exception as e: + # raise BedrockError( + # message=f"Embedding Error with model {model}: {e}", status_code=500 + # ) + + # def embedding( + # model: str, + # input: Union[list, str], + # model_response: litellm.EmbeddingResponse, + # api_key: Optional[str] = None, + # logging_obj=None, + # optional_params=None, + # encoding=None, + # ): + # ### BOTO3 INIT ### + # # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + # aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + # aws_access_key_id = optional_params.pop("aws_access_key_id", None) + # aws_region_name = optional_params.pop("aws_region_name", None) + # aws_role_name = optional_params.pop("aws_role_name", None) + # aws_session_name = optional_params.pop("aws_session_name", None) + # aws_bedrock_runtime_endpoint = optional_params.pop( + # "aws_bedrock_runtime_endpoint", None + # ) + # aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) + + # # use passed in BedrockRuntime.Client if provided, otherwise create a new one + # client = init_bedrock_client( + # aws_access_key_id=aws_access_key_id, + # aws_secret_access_key=aws_secret_access_key, + # aws_region_name=aws_region_name, + # aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + # aws_web_identity_token=aws_web_identity_token, + # aws_role_name=aws_role_name, + # aws_session_name=aws_session_name, + # ) + # if isinstance(input, str): + # ## Embedding Call + # embeddings = [ + # _embedding_func_single( + # model, + # input, + # optional_params=optional_params, + # client=client, + # logging_obj=logging_obj, + # ) + # ] + # elif isinstance(input, list): + # ## Embedding Call - assuming this is a List[str] + # embeddings = [ + # _embedding_func_single( + # model, + # i, + # optional_params=optional_params, + # client=client, + # logging_obj=logging_obj, + # ) + # for i in input + # ] # [TODO]: make these parallel calls + # else: + # # enters this branch if input = int, ex. input=2 + # raise BedrockError( + # message="Bedrock Embedding API input must be type str | List[str]", + # status_code=400, + # ) + + # ## Populate OpenAI compliant dictionary + # embedding_response = [] + # for idx, embedding in enumerate(embeddings): + # embedding_response.append( + # { + # "object": "embedding", + # "index": idx, + # "embedding": embedding, + # } + # ) + # model_response.object = "list" + # model_response.data = embedding_response + # model_response.model = model + # input_tokens = 0 + + # input_str = "".join(input) + + # input_tokens += len(encoding.encode(input_str)) + + # usage = Usage( + # prompt_tokens=input_tokens, + # completion_tokens=0, + # total_tokens=input_tokens + 0, + # ) + # model_response.usage = usage + + # return model_response diff --git a/litellm/llms/bedrock/image_generation.py b/litellm/llms/bedrock/image_generation.py new file mode 100644 index 000000000..a6ddd38cb --- /dev/null +++ b/litellm/llms/bedrock/image_generation.py @@ -0,0 +1,127 @@ +""" +Handles image gen calls to Bedrock's `/invoke` endpoint +""" + +import copy +import json +import os +from typing import List + +from openai.types.image import Image + +import litellm +from litellm.types.utils import ImageResponse + +from .common_utils import BedrockError, init_bedrock_client + + +def image_generation( + model: str, + prompt: str, + model_response: ImageResponse, + optional_params: dict, + timeout=None, + logging_obj=None, + aimg_generation=False, +): + """ + Bedrock Image Gen endpoint support + """ + ### BOTO3 INIT ### + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) + + # use passed in BedrockRuntime.Client if provided, otherwise create a new one + client = init_bedrock_client( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + aws_web_identity_token=aws_web_identity_token, + aws_role_name=aws_role_name, + aws_session_name=aws_session_name, + timeout=timeout, + ) + + ### FORMAT IMAGE GENERATION INPUT ### + modelId = model + provider = model.split(".")[0] + inference_params = copy.deepcopy(optional_params) + inference_params.pop( + "user", None + ) # make sure user is not passed in for bedrock call + data = {} + if provider == "stability": + prompt = prompt.replace(os.linesep, " ") + ## LOAD CONFIG + config = litellm.AmazonStabilityConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params} + else: + raise BedrockError( + status_code=422, message=f"Unsupported model={model}, passed in" + ) + + body = json.dumps(data).encode("utf-8") + ## LOGGING + request_str = f""" + response = client.invoke_model( + body={body}, # type: ignore + modelId={modelId}, + accept="application/json", + contentType="application/json", + )""" # type: ignore + logging_obj.pre_call( + input=prompt, + api_key="", # boto3 is used for init. + additional_args={ + "complete_input_dict": {"model": modelId, "texts": prompt}, + "request_str": request_str, + }, + ) + try: + response = client.invoke_model( + body=body, + modelId=modelId, + accept="application/json", + contentType="application/json", + ) + response_body = json.loads(response.get("body").read()) + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + additional_args={"complete_input_dict": data}, + original_response=json.dumps(response_body), + ) + except Exception as e: + raise BedrockError( + message=f"Embedding Error with model {model}: {e}", status_code=500 + ) + + ### FORMAT RESPONSE TO OPENAI FORMAT ### + if response_body is None: + raise Exception("Error in response object format") + + if model_response is None: + model_response = ImageResponse() + + image_list: List[Image] = [] + for artifact in response_body["artifacts"]: + _image = Image(b64_json=artifact["base64"]) + image_list.append(_image) + + model_response.data = image_list + return model_response diff --git a/litellm/llms/cohere/embed.py b/litellm/llms/cohere/embed.py index 81c84c422..5d640b506 100644 --- a/litellm/llms/cohere/embed.py +++ b/litellm/llms/cohere/embed.py @@ -76,7 +76,7 @@ async def async_embedding( data: dict, input: list, model_response: litellm.utils.EmbeddingResponse, - timeout: Union[float, httpx.Timeout], + timeout: Optional[Union[float, httpx.Timeout]], logging_obj: LiteLLMLoggingObj, optional_params: dict, api_base: str, @@ -98,16 +98,35 @@ async def async_embedding( ) ## COMPLETION CALL if client is None: - client = AsyncHTTPHandler(concurrent_limit=1) + client = AsyncHTTPHandler(concurrent_limit=1, timeout=timeout) - response = await client.post(api_base, headers=headers, data=json.dumps(data)) + try: + response = await client.post(api_base, headers=headers, data=json.dumps(data)) + except httpx.HTTPStatusError as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=e.response.text, + ) + raise e + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + raise e ## LOGGING logging_obj.post_call( input=input, api_key=api_key, additional_args={"complete_input_dict": data}, - original_response=response, + original_response=response.text, ) embeddings = response.json()["embeddings"] @@ -130,27 +149,22 @@ def embedding( optional_params: dict, headers: dict, encoding: Any, + data: Optional[dict] = None, + complete_api_base: Optional[str] = None, api_key: Optional[str] = None, aembedding: Optional[bool] = None, - timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), + timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None), client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): headers = validate_environment(api_key, headers=headers) - embed_url = "https://api.cohere.ai/v1/embed" + embed_url = complete_api_base or "https://api.cohere.ai/v1/embed" model = model - data = {"model": model, "texts": input, **optional_params} + data = data or {"model": model, "texts": input, **optional_params} if "3" in model and "input_type" not in data: # cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document" data["input_type"] = "search_document" - ## LOGGING - logging_obj.pre_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - ) - ## ROUTING if aembedding is True: return async_embedding( @@ -166,9 +180,18 @@ def embedding( headers=headers, encoding=encoding, ) + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL if client is None or not isinstance(client, HTTPHandler): client = HTTPHandler(concurrent_limit=1) + response = client.post(embed_url, headers=headers, data=json.dumps(data)) ## LOGGING logging_obj.post_call( diff --git a/litellm/main.py b/litellm/main.py index 7f1431073..70cd40f31 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -78,7 +78,6 @@ from .llms import ( ai21, aleph_alpha, baseten, - bedrock, clarifai, cloudflare, maritalk, @@ -96,7 +95,9 @@ from .llms.anthropic.chat import AnthropicChatCompletion from .llms.anthropic.completion import AnthropicTextCompletion from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params from .llms.azure_text import AzureTextCompletion -from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM +from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore +from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM +from .llms.bedrock.embed.embedding import BedrockEmbedding from .llms.cohere import chat as cohere_chat from .llms.cohere import completion as cohere_completion # type: ignore from .llms.cohere import embed as cohere_embed @@ -176,6 +177,7 @@ codestral_text_completions = CodestralTextCompletion() triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() +bedrock_embedding = BedrockEmbedding() vertex_chat_completion = VertexLLM() vertex_multimodal_embedding = VertexMultimodalEmbedding() google_batch_embeddings = GoogleBatchEmbeddings() @@ -3151,6 +3153,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: or custom_llm_provider == "watsonx" or custom_llm_provider == "cohere" or custom_llm_provider == "huggingface" + or custom_llm_provider == "bedrock" ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) @@ -3519,13 +3522,24 @@ def embedding( aembedding=aembedding, ) elif custom_llm_provider == "bedrock": - response = bedrock.embedding( + if isinstance(input, str): + transformed_input = [input] + else: + transformed_input = input + response = bedrock_embedding.embeddings( model=model, - input=input, + input=transformed_input, encoding=encoding, logging_obj=logging, optional_params=optional_params, model_response=EmbeddingResponse(), + client=client, + timeout=timeout, + aembedding=aembedding, + litellm_params=litellm_params, + api_base=api_base, + print_verbose=print_verbose, + extra_headers=extra_headers, ) elif custom_llm_provider == "triton": if api_base is None: @@ -4493,7 +4507,7 @@ def image_generation( elif custom_llm_provider == "bedrock": if model is None: raise Exception("Model needs to be set for bedrock") - model_response = bedrock.image_generation( + model_response = bedrock_image_generation.image_generation( model=model, prompt=prompt, timeout=timeout, diff --git a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py index 4efca5efd..9f2dca31e 100644 --- a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py @@ -178,7 +178,7 @@ async def bedrock_proxy_route( updated_url = base_url.copy_with(path=encoded_endpoint) # Add or update query parameters - from litellm.llms.bedrock_httpx import BedrockConverseLLM + from litellm.llms.bedrock.chat import BedrockConverseLLM credentials: Credentials = BedrockConverseLLM().get_credentials() sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 9a830d353..075c98621 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -25,7 +25,7 @@ from litellm import ( completion_cost, embedding, ) -from litellm.llms.bedrock_httpx import BedrockLLM, ToolBlock +from litellm.llms.bedrock.chat import BedrockLLM, ToolBlock from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import _bedrock_tools_pt diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index ec85a782d..4067ef047 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -311,7 +311,17 @@ async def test_cohere_embedding3(custom_llm_provider): # test_cohere_embedding3() -def test_bedrock_embedding_titan(): +@pytest.mark.parametrize( + "model", + [ + "bedrock/amazon.titan-embed-text-v1", + "bedrock/amazon.titan-embed-image-v1", + "bedrock/amazon.titan-embed-text-v2:0", + ], +) +@pytest.mark.parametrize("sync_mode", [True]) +@pytest.mark.asyncio +async def test_bedrock_embedding_titan(model, sync_mode): try: # this tests if we support str input for bedrock embedding litellm.set_verbose = True @@ -320,16 +330,23 @@ def test_bedrock_embedding_titan(): current_time = str(time.time()) # DO NOT MAKE THE INPUT A LIST in this test - response = embedding( - model="bedrock/amazon.titan-embed-text-v1", - input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test - aws_region_name="us-west-2", - ) - print(f"response:", response) + if sync_mode: + response = embedding( + model=model, + input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test + aws_region_name="us-west-2", + ) + else: + response = await litellm.aembedding( + model=model, + input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test + aws_region_name="us-west-2", + ) + print("response:", response) assert isinstance( response["data"][0]["embedding"], list ), "Expected response to be a list" - print(f"type of first embedding:", type(response["data"][0]["embedding"][0])) + print("type of first embedding:", type(response["data"][0]["embedding"][0])) assert all( isinstance(x, float) for x in response["data"][0]["embedding"] ), "Expected response to be a list of floats" @@ -339,13 +356,20 @@ def test_bedrock_embedding_titan(): start_time = time.time() - response = embedding( - model="bedrock/amazon.titan-embed-text-v1", - input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test - ) + if sync_mode: + response = embedding( + model=model, + input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test + ) + else: + response = await litellm.aembedding( + model=model, + input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test + ) print(response) end_time = time.time() + print(response._hidden_params) print(f"Embedding 2 response time: {end_time - start_time} seconds") assert end_time - start_time < 0.1 @@ -392,13 +416,13 @@ def test_demo_tokens_as_input_to_embeddings_fails_for_titan(): with pytest.raises( litellm.BadRequestError, - match="BedrockException - Bedrock Embedding API input must be type str | List[str]", + match='litellm.BadRequestError: BedrockException - {"message":"Malformed input request: expected type: String, found: JSONArray, please reformat your input and try again."}', ): litellm.embedding(model="amazon.titan-embed-text-v1", input=[[1]]) with pytest.raises( litellm.BadRequestError, - match="BedrockException - Bedrock Embedding API input must be type str | List[str]", + match='litellm.BadRequestError: BedrockException - {"message":"Malformed input request: expected type: String, found: Integer, please reformat your input and try again."}', ): litellm.embedding( model="amazon.titan-embed-text-v1", diff --git a/litellm/tests/test_secret_manager.py b/litellm/tests/test_secret_manager.py index 652e20989..1cf374148 100644 --- a/litellm/tests/test_secret_manager.py +++ b/litellm/tests/test_secret_manager.py @@ -1,21 +1,25 @@ -import sys, os, uuid +import os +import sys import time import traceback +import uuid + from dotenv import load_dotenv load_dotenv() import os -from uuid import uuid4 import tempfile +from uuid import uuid4 sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import pytest + from litellm import get_secret -from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager from litellm.llms.azure import get_azure_ad_token_from_oidc -from litellm.llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM +from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM +from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager @pytest.mark.skip(reason="AWS Suspended Account") @@ -63,9 +67,7 @@ def test_oidc_github(): reason="Cannot run without being in CircleCI Runner", ) def test_oidc_circleci(): - secret_val = get_secret( - "oidc/circleci/" - ) + secret_val = get_secret("oidc/circleci/") print(f"secret_val: {redact_oidc_signature(secret_val)}") @@ -103,9 +105,7 @@ def test_oidc_circle_v1_with_amazon(): # The purpose of this test is to get logs using the older v1 of the CircleCI OIDC token # TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually - aws_role_name = ( - "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only" - ) + aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only" aws_web_identity_token = "oidc/circleci/" bllm = BedrockLLM() @@ -116,6 +116,7 @@ def test_oidc_circle_v1_with_amazon(): aws_session_name="assume-v1-session", ) + @pytest.mark.skipif( os.environ.get("CIRCLE_OIDC_TOKEN") is None, reason="Cannot run without being in CircleCI Runner", @@ -124,9 +125,7 @@ def test_oidc_circle_v1_with_amazon_fips(): # The purpose of this test is to validate that we can assume a role in a FIPS region # TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually - aws_role_name = ( - "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only" - ) + aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only" aws_web_identity_token = "oidc/circleci/" bllm = BedrockConverseLLM() @@ -143,9 +142,7 @@ def test_oidc_env_variable(): # Create a unique environment variable name env_var_name = "OIDC_TEST_PATH_" + uuid4().hex os.environ[env_var_name] = "secret-" + uuid4().hex - secret_val = get_secret( - f"oidc/env/{env_var_name}" - ) + secret_val = get_secret(f"oidc/env/{env_var_name}") print(f"secret_val: {redact_oidc_signature(secret_val)}") @@ -157,15 +154,13 @@ def test_oidc_env_variable(): def test_oidc_file(): # Create a temporary file - with tempfile.NamedTemporaryFile(mode='w+') as temp_file: + with tempfile.NamedTemporaryFile(mode="w+") as temp_file: secret_value = "secret-" + uuid4().hex temp_file.write(secret_value) temp_file.flush() temp_file_path = temp_file.name - secret_val = get_secret( - f"oidc/file/{temp_file_path}" - ) + secret_val = get_secret(f"oidc/file/{temp_file_path}") print(f"secret_val: {redact_oidc_signature(secret_val)}") @@ -174,7 +169,7 @@ def test_oidc_file(): def test_oidc_env_path(): # Create a temporary file - with tempfile.NamedTemporaryFile(mode='w+') as temp_file: + with tempfile.NamedTemporaryFile(mode="w+") as temp_file: secret_value = "secret-" + uuid4().hex temp_file.write(secret_value) temp_file.flush() @@ -187,9 +182,7 @@ def test_oidc_env_path(): os.environ[env_var_name] = temp_file_path # Test getting the secret using the environment variable - secret_val = get_secret( - f"oidc/env_path/{env_var_name}" - ) + secret_val = get_secret(f"oidc/env_path/{env_var_name}") print(f"secret_val: {redact_oidc_signature(secret_val)}") diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index cf0a4a84b..4fa0b06bb 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -208,3 +208,62 @@ class ServerSentEvent: @override def __repr__(self) -> str: return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})" + + +class CohereEmbeddingRequest(TypedDict, total=False): + texts: Required[List[str]] + input_type: Required[ + Literal["search_document", "search_query", "classification", "clustering"] + ] + truncate: Literal["NONE", "START", "END"] + embedding_types: Literal["float", "int8", "uint8", "binary", "ubinary"] + + +class CohereEmbeddingResponse(TypedDict): + embeddings: List[List[float]] + id: str + response_type: Literal["embedding_floats"] + texts: List[str] + + +class AmazonTitanV2EmbeddingRequest(TypedDict): + inputText: str + dimensions: int + normalize: bool + + +class AmazonTitanV2EmbeddingResponse(TypedDict): + embedding: List[float] + inputTextTokenCount: int + + +class AmazonTitanG1EmbeddingRequest(TypedDict): + inputText: str + + +class AmazonTitanG1EmbeddingResponse(TypedDict): + embedding: List[float] + inputTextTokenCount: int + + +class AmazonTitanMultimodalEmbeddingConfig(TypedDict): + outputEmbeddingLength: Literal[256, 384, 1024] + + +class AmazonTitanMultimodalEmbeddingRequest(TypedDict, total=False): + inputText: str + inputImage: str + embeddingConfig: AmazonTitanMultimodalEmbeddingConfig + + +class AmazonTitanMultimodalEmbeddingResponse(TypedDict): + embedding: List[float] + inputTextTokenCount: int + message: str # Specifies any errors that occur during generation. + + +AmazonEmbeddingRequest = Union[ + AmazonTitanMultimodalEmbeddingRequest, + AmazonTitanV2EmbeddingRequest, + AmazonTitanG1EmbeddingRequest, +] diff --git a/litellm/types/utils.py b/litellm/types/utils.py index aadbdd22a..9e8c7be34 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -699,7 +699,7 @@ class ModelResponse(OpenAIObject): class Embedding(OpenAIObject): embedding: Union[list, str] = [] index: int - object: str + object: Literal["embedding"] def get(self, key, default=None): # Custom .get() method to access attributes with a default value if the attribute doesn't exist @@ -721,7 +721,7 @@ class EmbeddingResponse(OpenAIObject): data: Optional[List] = None """The actual embedding value""" - object: str + object: Literal["list"] """The object type, which is always "embedding" """ usage: Optional[Usage] = None @@ -732,11 +732,10 @@ class EmbeddingResponse(OpenAIObject): def __init__( self, - model=None, - usage=None, - stream=False, + model: Optional[str] = None, + usage: Optional[Usage] = None, response_ms=None, - data=None, + data: Optional[List] = None, hidden_params=None, _response_headers=None, **params, @@ -760,7 +759,7 @@ class EmbeddingResponse(OpenAIObject): self._response_headers = _response_headers model = model - super().__init__(model=model, object=object, data=data, usage=usage) + super().__init__(model=model, object=object, data=data, usage=usage) # type: ignore def __contains__(self, key): # Define custom behavior for the 'in' operator diff --git a/litellm/utils.py b/litellm/utils.py index facbc6a0a..5b8229d68 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -854,6 +854,7 @@ def client(original_function): ) cached_result = litellm.cache.get_cache(*args, **kwargs) if cached_result is not None: + print_verbose("Cache Hit!") if "detail" in cached_result: # implies an error occurred pass @@ -935,7 +936,10 @@ def client(original_function): args=(cached_result, start_time, end_time, cache_hit), ).start() return cached_result - + else: + print_verbose( + "Cache Miss! on key - {}".format(preset_cache_key) + ) # CHECK MAX TOKENS if ( kwargs.get("max_tokens", None) is not None @@ -1005,7 +1009,7 @@ def client(original_function): litellm.cache is not None and str(original_function.__name__) in litellm.cache.supported_call_types - ) and (kwargs.get("cache", {}).get("no-store", False) != True): + ) and (kwargs.get("cache", {}).get("no-store", False) is not True): litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated @@ -1404,10 +1408,10 @@ def client(original_function): # MODEL CALL result = await original_function(*args, **kwargs) end_time = datetime.datetime.now() - if "stream" in kwargs and kwargs["stream"] == True: + if "stream" in kwargs and kwargs["stream"] is True: if ( "complete_response" in kwargs - and kwargs["complete_response"] == True + and kwargs["complete_response"] is True ): chunks = [] for idx, chunk in enumerate(result): @@ -11734,3 +11738,13 @@ def is_cached_message(message: AllMessageValues) -> bool: return True return False + + +def is_base64_encoded(s: str) -> bool: + try: + # Try to decode the string + decoded_bytes = base64.b64decode(s, validate=True) + # Check if the original string can be re-encoded to the same string + return base64.b64encode(decoded_bytes).decode("utf-8") == s + except Exception: + return False