Bedrock Embeddings refactor + model support (#5462)

* refactor(bedrock): initial commit to refactor bedrock to a folder

Improve code readability + maintainability

* refactor: more refactor work

* fix: fix imports

* feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats

* fix: fix linting errors

* test: skip test on end of life model

* fix(cohere/embed.py): fix linting error

* fix(cohere/embed.py): fix typing

* fix(cohere/embed.py): fix post-call logging for cohere embedding call

* test(test_embeddings.py): fix error message assertion in test
This commit is contained in:
Krish Dholakia 2024-09-01 13:29:58 -07:00 committed by GitHub
parent 6fb82aaf75
commit 37f9705d6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1946 additions and 1659 deletions

View file

@ -1,12 +1,12 @@
repos:
- repo: local
hooks:
- id: mypy
name: mypy
entry: python3 -m mypy --ignore-missing-imports
language: system
types: [python]
files: ^litellm/
# - id: mypy
# name: mypy
# entry: python3 -m mypy --ignore-missing-imports
# language: system
# types: [python]
# files: ^litellm/
- id: isort
name: isort
entry: isort

View file

@ -118,6 +118,8 @@ in_memory_llm_clients_cache: dict = {}
safe_memory_mode: bool = False
### DEFAULT AZURE API VERSION ###
AZURE_DEFAULT_API_VERSION = "2024-07-01-preview" # this is updated to the latest
### COHERE EMBEDDINGS DEFAULT TYPE ###
COHERE_DEFAULT_EMBEDDING_INPUT_TYPE = "search_document"
### GUARDRAILS ###
llamaguard_model_name: Optional[str] = None
openai_moderations_model_name: Optional[str] = None
@ -880,13 +882,13 @@ from .llms.sagemaker.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig
from .llms.bedrock_httpx import (
from .llms.bedrock.chat import (
AmazonCohereChatConfig,
AmazonConverseConfig,
BEDROCK_CONVERSE_MODELS,
bedrock_tool_name_mappings,
)
from .llms.bedrock import (
from .llms.bedrock.common_utils import (
AmazonTitanConfig,
AmazonAI21Config,
AmazonAnthropicConfig,

View file

@ -608,17 +608,17 @@ class Logging:
self.model_call_details["litellm_params"]["metadata"][
"hidden_params"
] = result._hidden_params
## STANDARDIZED LOGGING PAYLOAD
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj=result,
start_time=start_time,
end_time=end_time,
logging_obj=self,
)
)
)
else: # streaming chunks + image gen.
self.model_call_details["response_cost"] = None

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,7 @@
# What is this?
## Initial implementation of calling bedrock via httpx client (allows for async calls).
## V1 - covers cohere + anthropic claude-3 support
"""
Manages calling Bedrock's `/converse` API + `/invoke` API
"""
import copy
import json
import os
@ -28,7 +29,7 @@ import requests # type: ignore
import litellm
from litellm import verbose_logger
from litellm.caching import DualCache, InMemoryCache
from litellm.caching import InMemoryCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import (
@ -39,21 +40,16 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.types.llms.bedrock import *
from litellm.types.llms.openai import (
ChatCompletionDeltaChunk,
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.utils import Choices
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import Message
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret
from .base import BaseLLM
from .base_aws_llm import BaseAWSLLM
from .bedrock import BedrockError, ModelResponseIterator, convert_messages_to_prompt
from .prompt_templates.factory import (
from ..base_aws_llm import BaseAWSLLM
from ..prompt_templates.factory import (
_bedrock_converse_messages_pt,
_bedrock_tools_pt,
cohere_message_pt,
@ -64,6 +60,7 @@ from .prompt_templates.factory import (
parse_xml_params,
prompt_factory,
)
from .common_utils import BedrockError, ModelResponseIterator, get_runtime_endpoint
BEDROCK_CONVERSE_MODELS = [
"anthropic.claude-3-5-sonnet-20240620-v1:0",
@ -727,22 +724,13 @@ class BedrockLLM(BaseAWSLLM):
)
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if api_base is not None:
endpoint_url = api_base
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
endpoint_url = get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name,
)
if (stream is not None and stream == True) and provider != "ai21":
if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
@ -1561,21 +1549,11 @@ class BedrockConverseLLM(BaseAWSLLM):
)
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if api_base is not None:
endpoint_url = api_base
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
endpoint_url = get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name,
)
if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
else:

View file

@ -0,0 +1,773 @@
"""
Common utilities used across bedrock chat/embedding/image generation
"""
import os
import types
from enum import Enum
from typing import List, Optional, Union
import httpx
import litellm
from litellm import get_secret
class BedrockError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class AmazonBedrockGlobalConfig:
def __init__(self):
pass
def get_mapped_special_auth_params(self) -> dict:
"""
Mapping of common auth params across bedrock/vertex/azure/watsonx
"""
return {"region_name": "aws_region_name"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
mapped_params = self.get_mapped_special_auth_params()
for param, value in non_default_params.items():
if param in mapped_params:
optional_params[mapped_params[param]] = value
return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://www.aws-services.info/bedrock.html
"""
return [
"eu-west-1",
"eu-west-3",
"eu-central-1",
]
class AmazonTitanConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
Supported Params for the Amazon Titan models:
- `maxTokenCount` (integer) max tokens,
- `stopSequences` (string[]) list of stop sequence strings
- `temperature` (float) temperature for model,
- `topP` (int) top p for model
"""
maxTokenCount: Optional[int] = None
stopSequences: Optional[list] = None
temperature: Optional[float] = None
topP: Optional[int] = None
def __init__(
self,
maxTokenCount: Optional[int] = None,
stopSequences: Optional[list] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonAnthropicClaude3Config:
"""
Reference:
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
Supported Params for the Amazon / Anthropic Claude 3 models:
- `max_tokens` Required (integer) max tokens. Default is 4096
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
"""
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
anthropic_version: Optional[str] = "bedrock-2023-05-31"
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"max_tokens",
"tools",
"tool_choice",
"stream",
"stop",
"temperature",
"top_p",
"extra_headers",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
return optional_params
class AmazonAnthropicConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
Supported Params for the Amazon / Anthropic models:
- `max_tokens_to_sample` (integer) max tokens,
- `temperature` (float) model temperature,
- `top_k` (integer) top k,
- `top_p` (integer) top p,
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
"""
max_tokens_to_sample: Optional[int] = litellm.max_tokens
stop_sequences: Optional[list] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
anthropic_version: Optional[str] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = None,
stop_sequences: Optional[list] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(
self,
):
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens_to_sample"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "stream" and value == True:
optional_params["stream"] = value
return optional_params
class AmazonCohereConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
Supported Params for the Amazon / Cohere models:
- `max_tokens` (integer) max tokens,
- `temperature` (float) model temperature,
- `return_likelihood` (string) n/a
"""
max_tokens: Optional[int] = None
temperature: Optional[float] = None
return_likelihood: Optional[str] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
return_likelihood: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonAI21Config:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
Supported Params for the Amazon / AI21 models:
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
- `presencePenalty` (object): Placeholder for presence penalty object.
- `countPenalty` (object): Placeholder for count penalty object.
"""
maxTokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
stopSequences: Optional[list] = None
frequencePenalty: Optional[dict] = None
presencePenalty: Optional[dict] = None
countPenalty: Optional[dict] = None
def __init__(
self,
maxTokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[float] = None,
stopSequences: Optional[list] = None,
frequencePenalty: Optional[dict] = None,
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman: "
AI_PROMPT = "\n\nAssistant: "
class AmazonLlamaConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
Supported Params for the Amazon / Meta Llama models:
- `max_gen_len` (integer) max tokens,
- `temperature` (float) temperature for model,
- `top_p` (float) top p for model
"""
max_gen_len: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
def __init__(
self,
maxTokenCount: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonMistralConfig:
"""
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
Supported Params for the Amazon / Mistral models:
- `max_tokens` (integer) max tokens,
- `temperature` (float) temperature for model,
- `top_p` (float) top p for model
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
- `top_k` (float) top k for model
"""
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[float] = None
stop: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[int] = None,
top_k: Optional[float] = None,
stop: Optional[List[str]] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonStabilityConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Supported Params for the Amazon / Stable Diffusion models:
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
Engine-specific dimension validation:
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
- SDXL v1.0: same as SDXL v0.9
- SD v1.6: must be between 320x320 and 1536x1536
"""
cfg_scale: Optional[int] = None
seed: Optional[float] = None
steps: Optional[List[str]] = None
width: Optional[int] = None
height: Optional[int] = None
def __init__(
self,
cfg_scale: Optional[int] = None,
seed: Optional[float] = None,
steps: Optional[List[str]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def add_custom_header(headers):
"""Closure to capture the headers and add them."""
def callback(request, **kwargs):
"""Actual callback function that Boto3 will call."""
for header_name, header_value in headers.items():
request.headers.add_header(header_name, header_value)
return callback
def init_bedrock_client(
region_name=None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_bedrock_runtime_endpoint: Optional[str] = None,
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
aws_web_identity_token: Optional[str] = None,
extra_headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check
params_to_check = [
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_bedrock_runtime_endpoint,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
]
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
params_to_check[i] = get_secret(param)
# Assign updated values back to parameters
(
aws_access_key_id,
aws_secret_access_key,
aws_region_name,
aws_bedrock_runtime_endpoint,
aws_session_name,
aws_profile_name,
aws_role_name,
aws_web_identity_token,
) = params_to_check
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
### SET REGION NAME
if region_name:
pass
elif aws_region_name:
region_name = aws_region_name
elif litellm_aws_region_name:
region_name = litellm_aws_region_name
elif standard_aws_region_name:
region_name = standard_aws_region_name
else:
raise BedrockError(
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
status_code=401,
)
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint:
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint:
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
import boto3
if isinstance(timeout, float):
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
elif isinstance(timeout, httpx.Timeout):
config = boto3.session.Config(
connect_timeout=timeout.connect, read_timeout=timeout.read
)
else:
config = boto3.session.Config()
### CHECK STS ###
if (
aws_web_identity_token is not None
and aws_role_name is not None
and aws_session_name is not None
):
oidc_token = get_secret(aws_web_identity_token)
if oidc_token is None:
raise BedrockError(
message="OIDC token could not be retrieved from secret manager.",
status_code=401,
)
sts_client = boto3.client("sts")
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
sts_response = sts_client.assume_role_with_web_identity(
RoleArn=aws_role_name,
RoleSessionName=aws_session_name,
WebIdentityToken=oidc_token,
DurationSeconds=3600,
)
client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
verify=ssl_verify,
)
elif aws_role_name is not None and aws_session_name is not None:
# use sts if role name passed in
sts_client = boto3.client(
"sts",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
sts_response = sts_client.assume_role(
RoleArn=aws_role_name, RoleSessionName=aws_session_name
)
client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
aws_session_token=sts_response["Credentials"]["SessionToken"],
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
verify=ssl_verify,
)
elif aws_access_key_id is not None:
# uses auth params passed to completion
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
verify=ssl_verify,
)
elif aws_profile_name is not None:
# uses auth values from AWS profile usually stored in ~/.aws/credentials
client = boto3.Session(profile_name=aws_profile_name).client(
service_name="bedrock-runtime",
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
verify=ssl_verify,
)
else:
# aws_access_key_id is None, assume user is trying to auth using env variables
# boto3 automatically reads env variables
client = boto3.client(
service_name="bedrock-runtime",
region_name=region_name,
endpoint_url=endpoint_url,
config=config,
verify=ssl_verify,
)
if extra_headers:
client.meta.events.register(
"before-sign.bedrock-runtime.*", add_custom_header(extra_headers)
)
return client
def get_runtime_endpoint(
api_base: Optional[str],
aws_bedrock_runtime_endpoint: Optional[str],
aws_region_name: str,
) -> str:
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if api_base is not None:
endpoint_url = api_base
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
return endpoint_url
class ModelResponseIterator:
def __init__(self, model_response):
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.model_response

View file

@ -0,0 +1,149 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
Why separate file? Make it easy to see how transformation works
Convers
- G1 request format
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
"""
import types
from typing import List, Optional
from litellm.types.llms.bedrock import (
AmazonTitanG1EmbeddingRequest,
AmazonTitanG1EmbeddingResponse,
AmazonTitanV2EmbeddingRequest,
AmazonTitanV2EmbeddingResponse,
)
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
class AmazonTitanG1Config:
"""
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
"""
def __init__(
self,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def _transform_request(
self, input: str, inference_params: dict
) -> AmazonTitanG1EmbeddingRequest:
return AmazonTitanG1EmbeddingRequest(inputText=input)
def _transform_response(
self, response_list: List[dict], model: str
) -> EmbeddingResponse:
total_prompt_tokens = 0
transformed_responses: List[Embedding] = []
for index, response in enumerate(response_list):
_parsed_response = AmazonTitanG1EmbeddingResponse(**response) # type: ignore
transformed_responses.append(
Embedding(
embedding=_parsed_response["embedding"],
index=index,
object="embedding",
)
)
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
usage = Usage(
prompt_tokens=total_prompt_tokens,
completion_tokens=0,
total_tokens=total_prompt_tokens,
)
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
class AmazonTitanV2Config:
"""
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
normalize: boolean - flag indicating whether or not to normalize the output embeddings. Defaults to true
dimensions: int - The number of dimensions the output embeddings should have. The following values are accepted: 1024 (default), 512, 256.
"""
normalize: Optional[bool] = None
dimensions: Optional[int] = None
def __init__(
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def _transform_request(
self, input: str, inference_params: dict
) -> AmazonTitanV2EmbeddingRequest:
return AmazonTitanV2EmbeddingRequest(inputText=input, **inference_params) # type: ignore
def _transform_response(
self, response_list: List[dict], model: str
) -> EmbeddingResponse:
total_prompt_tokens = 0
transformed_responses: List[Embedding] = []
for index, response in enumerate(response_list):
_parsed_response = AmazonTitanV2EmbeddingResponse(**response) # type: ignore
transformed_responses.append(
Embedding(
embedding=_parsed_response["embedding"],
index=index,
object="embedding",
)
)
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
usage = Usage(
prompt_tokens=total_prompt_tokens,
completion_tokens=0,
total_tokens=total_prompt_tokens,
)
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)

View file

@ -0,0 +1,54 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan multimodal /invoke format.
Why separate file? Make it easy to see how transformation works
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-mm.html
"""
from typing import List
from litellm.types.llms.bedrock import (
AmazonTitanMultimodalEmbeddingConfig,
AmazonTitanMultimodalEmbeddingRequest,
AmazonTitanMultimodalEmbeddingResponse,
)
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
from litellm.utils import is_base64_encoded
def _transform_request(
input: str, inference_params: dict
) -> AmazonTitanMultimodalEmbeddingRequest:
## check if b64 encoded str or not ##
is_encoded = is_base64_encoded(input)
if is_encoded: # check if string is b64 encoded image or not
transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputImage=input)
else:
transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputText=input)
for k, v in inference_params.items():
transformed_request[k] = v # type: ignore
return transformed_request
def _transform_response(response_list: List[dict], model: str) -> EmbeddingResponse:
total_prompt_tokens = 0
transformed_responses: List[Embedding] = []
for index, response in enumerate(response_list):
_parsed_response = AmazonTitanMultimodalEmbeddingResponse(**response) # type: ignore
transformed_responses.append(
Embedding(
embedding=_parsed_response["embedding"], index=index, object="embedding"
)
)
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
usage = Usage(
prompt_tokens=total_prompt_tokens,
completion_tokens=0,
total_tokens=total_prompt_tokens,
)
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)

View file

@ -0,0 +1,86 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan V2 /invoke format.
Why separate file? Make it easy to see how transformation works
Convers
- v2 request format
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
"""
import types
from typing import List, Optional
from litellm.types.llms.bedrock import (
AmazonTitanV2EmbeddingRequest,
AmazonTitanV2EmbeddingResponse,
)
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
class AmazonTitanV2Config:
"""
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
normalize: boolean - flag indicating whether or not to normalize the output embeddings. Defaults to true
dimensions: int - The number of dimensions the output embeddings should have. The following values are accepted: 1024 (default), 512, 256.
"""
normalize: Optional[bool] = None
dimensions: Optional[int] = None
def __init__(
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def _transform_request(
self, input: str, inference_params: dict
) -> AmazonTitanV2EmbeddingRequest:
return AmazonTitanV2EmbeddingRequest(inputText=input, **inference_params) # type: ignore
def _transform_response(
self, response_list: List[dict], model: str
) -> EmbeddingResponse:
total_prompt_tokens = 0
transformed_responses: List[Embedding] = []
for index, response in enumerate(response_list):
_parsed_response = AmazonTitanV2EmbeddingResponse(**response) # type: ignore
transformed_responses.append(
Embedding(
embedding=_parsed_response["embedding"],
index=index,
object="embedding",
)
)
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
usage = Usage(
prompt_tokens=total_prompt_tokens,
completion_tokens=0,
total_tokens=total_prompt_tokens,
)
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)

View file

@ -0,0 +1,25 @@
"""
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
Why separate file? Make it easy to see how transformation works
"""
from typing import List
import litellm
from litellm.types.llms.bedrock import CohereEmbeddingRequest, CohereEmbeddingResponse
from litellm.types.utils import Embedding, EmbeddingResponse
def _transform_request(
input: List[str], inference_params: dict
) -> CohereEmbeddingRequest:
transformed_request = CohereEmbeddingRequest(
texts=input,
input_type=litellm.COHERE_DEFAULT_EMBEDDING_INPUT_TYPE, # type: ignore
)
for k, v in inference_params.items():
transformed_request[k] = v # type: ignore
return transformed_request

View file

@ -0,0 +1,498 @@
"""
Handles embedding calls to Bedrock's `/invoke` endpoint
"""
import copy
import json
import os
from copy import deepcopy
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
import httpx
import litellm
from litellm import get_secret
from litellm.llms.cohere.embed import embedding as cohere_embedding
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
)
from litellm.types.llms.bedrock import AmazonEmbeddingRequest, CohereEmbeddingRequest
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
from ...base_aws_llm import BaseAWSLLM
from ..common_utils import BedrockError, get_runtime_endpoint
from .amazon_titan_g1_transformation import AmazonTitanG1Config
from .amazon_titan_multimodal_transformation import (
_transform_request as amazon_multimodal_transform_request,
)
from .amazon_titan_multimodal_transformation import (
_transform_response as amazon_multimodal_transform_response,
)
from .amazon_titan_v2_transformation import AmazonTitanV2Config
from .cohere_transformation import _transform_request as cohere_transform_request
class BedrockEmbedding(BaseAWSLLM):
def _load_credentials(
self,
optional_params: dict,
) -> Tuple[Any, str]:
try:
from botocore.credentials import Credentials
except ImportError as e:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
### SET REGION NAME ###
if aws_region_name is None:
# check env #
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
if litellm_aws_region_name is not None and isinstance(
litellm_aws_region_name, str
):
aws_region_name = litellm_aws_region_name
standard_aws_region_name = get_secret("AWS_REGION", None)
if standard_aws_region_name is not None and isinstance(
standard_aws_region_name, str
):
aws_region_name = standard_aws_region_name
if aws_region_name is None:
aws_region_name = "us-west-2"
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
)
return credentials, aws_region_name
async def async_embeddings(self):
pass
def _make_sync_call(
self,
client: Optional[HTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
api_base: str,
headers: dict,
data: dict,
) -> dict:
if client is None or not isinstance(client, HTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = _get_httpx_client(_params) # type: ignore
else:
client = client
try:
response = client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return response.json()
def _single_func_embeddings(
self,
client: Optional[HTTPHandler],
timeout: Optional[Union[float, httpx.Timeout]],
batch_data: List[dict],
credentials: Any,
extra_headers: Optional[dict],
endpoint_url: str,
aws_region_name: str,
model: str,
logging_obj: Any,
):
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
responses: List[dict] = []
for data in batch_data:
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
)
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
## LOGGING
logging_obj.pre_call(
input=data,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": prepped.url,
"headers": prepped.headers,
},
)
response = self._make_sync_call(
client=client,
timeout=timeout,
api_base=prepped.url,
headers=prepped.headers,
data=data,
)
## LOGGING
logging_obj.post_call(
input=data,
api_key="",
original_response=response,
additional_args={"complete_input_dict": data},
)
responses.append(response)
returned_response: Optional[EmbeddingResponse] = None
## TRANSFORM RESPONSE ##
if model == "amazon.titan-embed-image-v1":
returned_response = amazon_multimodal_transform_response(
response_list=responses, model=model
)
elif model == "amazon.titan-embed-text-v1":
returned_response = AmazonTitanG1Config()._transform_response(
response_list=responses, model=model
)
elif model == "amazon.titan-embed-text-v2:0":
returned_response = AmazonTitanV2Config()._transform_response(
response_list=responses, model=model
)
if returned_response is None:
raise Exception(
"Unable to map model response to known provider format. model={}".format(
model
)
)
return returned_response
def embeddings(
self,
model: str,
input: List[str],
api_base: Optional[str],
model_response: EmbeddingResponse,
print_verbose: Callable,
encoding,
logging_obj,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]],
timeout: Optional[Union[float, httpx.Timeout]],
aembedding: Optional[bool],
extra_headers: Optional[dict],
optional_params=None,
litellm_params=None,
) -> EmbeddingResponse:
try:
import boto3
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
credentials, aws_region_name = self._load_credentials(optional_params)
### TRANSFORMATION ###
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
modelId = (
optional_params.pop("model_id", None) or model
) # default to model if not passed
data: Optional[CohereEmbeddingRequest] = None
batch_data: Optional[List] = None
if provider == "cohere":
data = cohere_transform_request(
input=input, inference_params=inference_params
)
elif provider == "amazon" and model in [
"amazon.titan-embed-image-v1",
"amazon.titan-embed-text-v1",
"amazon.titan-embed-text-v2:0",
]:
batch_data = []
for i in input:
if model == "amazon.titan-embed-image-v1":
transformed_request: AmazonEmbeddingRequest = (
amazon_multimodal_transform_request(
input=i, inference_params=inference_params
)
)
elif model == "amazon.titan-embed-text-v1":
transformed_request = AmazonTitanG1Config()._transform_request(
input=i, inference_params=inference_params
)
elif model == "amazon.titan-embed-text-v2:0":
transformed_request = AmazonTitanV2Config()._transform_request(
input=i, inference_params=inference_params
)
batch_data.append(transformed_request)
### SET RUNTIME ENDPOINT ###
endpoint_url = get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=optional_params.pop(
"aws_bedrock_runtime_endpoint", None
),
aws_region_name=aws_region_name,
)
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
if batch_data is not None:
return self._single_func_embeddings(
client=(
client
if client is not None and isinstance(client, HTTPHandler)
else None
),
timeout=timeout,
batch_data=batch_data,
credentials=credentials,
extra_headers=extra_headers,
endpoint_url=endpoint_url,
aws_region_name=aws_region_name,
model=model,
logging_obj=logging_obj,
)
elif data is None:
raise Exception("Unable to map request to provider")
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
request = AWSRequest(
method="POST", url=endpoint_url, data=json.dumps(data), headers=headers
)
sigv4.add_auth(request)
if (
extra_headers is not None and "Authorization" in extra_headers
): # prevent sigv4 from overwriting the auth header
request.headers["Authorization"] = extra_headers["Authorization"]
prepped = request.prepare()
## ROUTING ##
return cohere_embedding(
model=model,
input=input,
model_response=model_response,
logging_obj=logging_obj,
optional_params=optional_params,
encoding=encoding,
data=data, # type: ignore
complete_api_base=prepped.url,
api_key=None,
aembedding=aembedding,
timeout=timeout,
client=client,
headers=prepped.headers,
)
# def _embedding_func_single(
# model: str,
# input: str,
# client: Any,
# optional_params=None,
# encoding=None,
# logging_obj=None,
# ):
# if isinstance(input, str) is False:
# raise BedrockError(
# message="Bedrock Embedding API input must be type str | List[str]",
# status_code=400,
# )
# # logic for parsing in - calling - parsing out model embedding calls
# ## FORMAT EMBEDDING INPUT ##
# provider = model.split(".")[0]
# inference_params = copy.deepcopy(optional_params)
# inference_params.pop(
# "user", None
# ) # make sure user is not passed in for bedrock call
# modelId = (
# optional_params.pop("model_id", None) or model
# ) # default to model if not passed
# if provider == "amazon":
# input = input.replace(os.linesep, " ")
# data = {"inputText": input, **inference_params}
# # data = json.dumps(data)
# elif provider == "cohere":
# inference_params["input_type"] = inference_params.get(
# "input_type", "search_document"
# ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3
# data = {"texts": [input], **inference_params} # type: ignore
# body = json.dumps(data).encode("utf-8") # type: ignore
# ## LOGGING
# request_str = f"""
# response = client.invoke_model(
# body={body},
# modelId={modelId},
# accept="*/*",
# contentType="application/json",
# )""" # type: ignore
# logging_obj.pre_call(
# input=input,
# api_key="", # boto3 is used for init.
# additional_args={
# "complete_input_dict": {"model": modelId, "texts": input},
# "request_str": request_str,
# },
# )
# try:
# response = client.invoke_model(
# body=body,
# modelId=modelId,
# accept="*/*",
# contentType="application/json",
# )
# response_body = json.loads(response.get("body").read())
# ## LOGGING
# logging_obj.post_call(
# input=input,
# api_key="",
# additional_args={"complete_input_dict": data},
# original_response=json.dumps(response_body),
# )
# if provider == "cohere":
# response = response_body.get("embeddings")
# # flatten list
# response = [item for sublist in response for item in sublist]
# return response
# elif provider == "amazon":
# return response_body.get("embedding")
# except Exception as e:
# raise BedrockError(
# message=f"Embedding Error with model {model}: {e}", status_code=500
# )
# def embedding(
# model: str,
# input: Union[list, str],
# model_response: litellm.EmbeddingResponse,
# api_key: Optional[str] = None,
# logging_obj=None,
# optional_params=None,
# encoding=None,
# ):
# ### BOTO3 INIT ###
# # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
# aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
# aws_access_key_id = optional_params.pop("aws_access_key_id", None)
# aws_region_name = optional_params.pop("aws_region_name", None)
# aws_role_name = optional_params.pop("aws_role_name", None)
# aws_session_name = optional_params.pop("aws_session_name", None)
# aws_bedrock_runtime_endpoint = optional_params.pop(
# "aws_bedrock_runtime_endpoint", None
# )
# aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# # use passed in BedrockRuntime.Client if provided, otherwise create a new one
# client = init_bedrock_client(
# aws_access_key_id=aws_access_key_id,
# aws_secret_access_key=aws_secret_access_key,
# aws_region_name=aws_region_name,
# aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
# aws_web_identity_token=aws_web_identity_token,
# aws_role_name=aws_role_name,
# aws_session_name=aws_session_name,
# )
# if isinstance(input, str):
# ## Embedding Call
# embeddings = [
# _embedding_func_single(
# model,
# input,
# optional_params=optional_params,
# client=client,
# logging_obj=logging_obj,
# )
# ]
# elif isinstance(input, list):
# ## Embedding Call - assuming this is a List[str]
# embeddings = [
# _embedding_func_single(
# model,
# i,
# optional_params=optional_params,
# client=client,
# logging_obj=logging_obj,
# )
# for i in input
# ] # [TODO]: make these parallel calls
# else:
# # enters this branch if input = int, ex. input=2
# raise BedrockError(
# message="Bedrock Embedding API input must be type str | List[str]",
# status_code=400,
# )
# ## Populate OpenAI compliant dictionary
# embedding_response = []
# for idx, embedding in enumerate(embeddings):
# embedding_response.append(
# {
# "object": "embedding",
# "index": idx,
# "embedding": embedding,
# }
# )
# model_response.object = "list"
# model_response.data = embedding_response
# model_response.model = model
# input_tokens = 0
# input_str = "".join(input)
# input_tokens += len(encoding.encode(input_str))
# usage = Usage(
# prompt_tokens=input_tokens,
# completion_tokens=0,
# total_tokens=input_tokens + 0,
# )
# model_response.usage = usage
# return model_response

View file

@ -0,0 +1,127 @@
"""
Handles image gen calls to Bedrock's `/invoke` endpoint
"""
import copy
import json
import os
from typing import List
from openai.types.image import Image
import litellm
from litellm.types.utils import ImageResponse
from .common_utils import BedrockError, init_bedrock_client
def image_generation(
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: dict,
timeout=None,
logging_obj=None,
aimg_generation=False,
):
"""
Bedrock Image Gen endpoint support
"""
### BOTO3 INIT ###
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_region_name = optional_params.pop("aws_region_name", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
)
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
# use passed in BedrockRuntime.Client if provided, otherwise create a new one
client = init_bedrock_client(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_region_name=aws_region_name,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_web_identity_token=aws_web_identity_token,
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
timeout=timeout,
)
### FORMAT IMAGE GENERATION INPUT ###
modelId = model
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
data = {}
if provider == "stability":
prompt = prompt.replace(os.linesep, " ")
## LOAD CONFIG
config = litellm.AmazonStabilityConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
)
body = json.dumps(data).encode("utf-8")
## LOGGING
request_str = f"""
response = client.invoke_model(
body={body}, # type: ignore
modelId={modelId},
accept="application/json",
contentType="application/json",
)""" # type: ignore
logging_obj.pre_call(
input=prompt,
api_key="", # boto3 is used for init.
additional_args={
"complete_input_dict": {"model": modelId, "texts": prompt},
"request_str": request_str,
},
)
try:
response = client.invoke_model(
body=body,
modelId=modelId,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
additional_args={"complete_input_dict": data},
original_response=json.dumps(response_body),
)
except Exception as e:
raise BedrockError(
message=f"Embedding Error with model {model}: {e}", status_code=500
)
### FORMAT RESPONSE TO OPENAI FORMAT ###
if response_body is None:
raise Exception("Error in response object format")
if model_response is None:
model_response = ImageResponse()
image_list: List[Image] = []
for artifact in response_body["artifacts"]:
_image = Image(b64_json=artifact["base64"])
image_list.append(_image)
model_response.data = image_list
return model_response

View file

@ -76,7 +76,7 @@ async def async_embedding(
data: dict,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
timeout: Optional[Union[float, httpx.Timeout]],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
@ -98,16 +98,35 @@ async def async_embedding(
)
## COMPLETION CALL
if client is None:
client = AsyncHTTPHandler(concurrent_limit=1)
client = AsyncHTTPHandler(concurrent_limit=1, timeout=timeout)
response = await client.post(api_base, headers=headers, data=json.dumps(data))
try:
response = await client.post(api_base, headers=headers, data=json.dumps(data))
except httpx.HTTPStatusError as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=e.response.text,
)
raise e
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
original_response=response.text,
)
embeddings = response.json()["embeddings"]
@ -130,27 +149,22 @@ def embedding(
optional_params: dict,
headers: dict,
encoding: Any,
data: Optional[dict] = None,
complete_api_base: Optional[str] = None,
api_key: Optional[str] = None,
aembedding: Optional[bool] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None),
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
headers = validate_environment(api_key, headers=headers)
embed_url = "https://api.cohere.ai/v1/embed"
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
model = model
data = {"model": model, "texts": input, **optional_params}
data = data or {"model": model, "texts": input, **optional_params}
if "3" in model and "input_type" not in data:
# cohere v3 embedding models require input_type, if no input_type is provided, default to "search_document"
data["input_type"] = "search_document"
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## ROUTING
if aembedding is True:
return async_embedding(
@ -166,9 +180,18 @@ def embedding(
headers=headers,
encoding=encoding,
)
## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(

View file

@ -78,7 +78,6 @@ from .llms import (
ai21,
aleph_alpha,
baseten,
bedrock,
clarifai,
cloudflare,
maritalk,
@ -96,7 +95,9 @@ from .llms.anthropic.chat import AnthropicChatCompletion
from .llms.anthropic.completion import AnthropicTextCompletion
from .llms.azure import AzureChatCompletion, _check_dynamic_azure_params
from .llms.azure_text import AzureTextCompletion
from .llms.bedrock_httpx import BedrockConverseLLM, BedrockLLM
from .llms.bedrock import image_generation as bedrock_image_generation # type: ignore
from .llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from .llms.bedrock.embed.embedding import BedrockEmbedding
from .llms.cohere import chat as cohere_chat
from .llms.cohere import completion as cohere_completion # type: ignore
from .llms.cohere import embed as cohere_embed
@ -176,6 +177,7 @@ codestral_text_completions = CodestralTextCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
bedrock_embedding = BedrockEmbedding()
vertex_chat_completion = VertexLLM()
vertex_multimodal_embedding = VertexMultimodalEmbedding()
google_batch_embeddings = GoogleBatchEmbeddings()
@ -3151,6 +3153,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "watsonx"
or custom_llm_provider == "cohere"
or custom_llm_provider == "huggingface"
or custom_llm_provider == "bedrock"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
@ -3519,13 +3522,24 @@ def embedding(
aembedding=aembedding,
)
elif custom_llm_provider == "bedrock":
response = bedrock.embedding(
if isinstance(input, str):
transformed_input = [input]
else:
transformed_input = input
response = bedrock_embedding.embeddings(
model=model,
input=input,
input=transformed_input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
client=client,
timeout=timeout,
aembedding=aembedding,
litellm_params=litellm_params,
api_base=api_base,
print_verbose=print_verbose,
extra_headers=extra_headers,
)
elif custom_llm_provider == "triton":
if api_base is None:
@ -4493,7 +4507,7 @@ def image_generation(
elif custom_llm_provider == "bedrock":
if model is None:
raise Exception("Model needs to be set for bedrock")
model_response = bedrock.image_generation(
model_response = bedrock_image_generation.image_generation(
model=model,
prompt=prompt,
timeout=timeout,

View file

@ -178,7 +178,7 @@ async def bedrock_proxy_route(
updated_url = base_url.copy_with(path=encoded_endpoint)
# Add or update query parameters
from litellm.llms.bedrock_httpx import BedrockConverseLLM
from litellm.llms.bedrock.chat import BedrockConverseLLM
credentials: Credentials = BedrockConverseLLM().get_credentials()
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)

View file

@ -25,7 +25,7 @@ from litellm import (
completion_cost,
embedding,
)
from litellm.llms.bedrock_httpx import BedrockLLM, ToolBlock
from litellm.llms.bedrock.chat import BedrockLLM, ToolBlock
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.prompt_templates.factory import _bedrock_tools_pt

View file

@ -311,7 +311,17 @@ async def test_cohere_embedding3(custom_llm_provider):
# test_cohere_embedding3()
def test_bedrock_embedding_titan():
@pytest.mark.parametrize(
"model",
[
"bedrock/amazon.titan-embed-text-v1",
"bedrock/amazon.titan-embed-image-v1",
"bedrock/amazon.titan-embed-text-v2:0",
],
)
@pytest.mark.parametrize("sync_mode", [True])
@pytest.mark.asyncio
async def test_bedrock_embedding_titan(model, sync_mode):
try:
# this tests if we support str input for bedrock embedding
litellm.set_verbose = True
@ -320,16 +330,23 @@ def test_bedrock_embedding_titan():
current_time = str(time.time())
# DO NOT MAKE THE INPUT A LIST in this test
response = embedding(
model="bedrock/amazon.titan-embed-text-v1",
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
aws_region_name="us-west-2",
)
print(f"response:", response)
if sync_mode:
response = embedding(
model=model,
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
aws_region_name="us-west-2",
)
else:
response = await litellm.aembedding(
model=model,
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
aws_region_name="us-west-2",
)
print("response:", response)
assert isinstance(
response["data"][0]["embedding"], list
), "Expected response to be a list"
print(f"type of first embedding:", type(response["data"][0]["embedding"][0]))
print("type of first embedding:", type(response["data"][0]["embedding"][0]))
assert all(
isinstance(x, float) for x in response["data"][0]["embedding"]
), "Expected response to be a list of floats"
@ -339,13 +356,20 @@ def test_bedrock_embedding_titan():
start_time = time.time()
response = embedding(
model="bedrock/amazon.titan-embed-text-v1",
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
)
if sync_mode:
response = embedding(
model=model,
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
)
else:
response = await litellm.aembedding(
model=model,
input=f"good morning from litellm, attempting to embed data {current_time}", # input should always be a string in this test
)
print(response)
end_time = time.time()
print(response._hidden_params)
print(f"Embedding 2 response time: {end_time - start_time} seconds")
assert end_time - start_time < 0.1
@ -392,13 +416,13 @@ def test_demo_tokens_as_input_to_embeddings_fails_for_titan():
with pytest.raises(
litellm.BadRequestError,
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
match='litellm.BadRequestError: BedrockException - {"message":"Malformed input request: expected type: String, found: JSONArray, please reformat your input and try again."}',
):
litellm.embedding(model="amazon.titan-embed-text-v1", input=[[1]])
with pytest.raises(
litellm.BadRequestError,
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
match='litellm.BadRequestError: BedrockException - {"message":"Malformed input request: expected type: String, found: Integer, please reformat your input and try again."}',
):
litellm.embedding(
model="amazon.titan-embed-text-v1",

View file

@ -1,21 +1,25 @@
import sys, os, uuid
import os
import sys
import time
import traceback
import uuid
from dotenv import load_dotenv
load_dotenv()
import os
from uuid import uuid4
import tempfile
from uuid import uuid4
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm import get_secret
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM
from litellm.proxy.secret_managers.aws_secret_manager import load_aws_secret_manager
@pytest.mark.skip(reason="AWS Suspended Account")
@ -63,9 +67,7 @@ def test_oidc_github():
reason="Cannot run without being in CircleCI Runner",
)
def test_oidc_circleci():
secret_val = get_secret(
"oidc/circleci/"
)
secret_val = get_secret("oidc/circleci/")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@ -103,9 +105,7 @@ def test_oidc_circle_v1_with_amazon():
# The purpose of this test is to get logs using the older v1 of the CircleCI OIDC token
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
aws_role_name = (
"arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
)
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
aws_web_identity_token = "oidc/circleci/"
bllm = BedrockLLM()
@ -116,6 +116,7 @@ def test_oidc_circle_v1_with_amazon():
aws_session_name="assume-v1-session",
)
@pytest.mark.skipif(
os.environ.get("CIRCLE_OIDC_TOKEN") is None,
reason="Cannot run without being in CircleCI Runner",
@ -124,9 +125,7 @@ def test_oidc_circle_v1_with_amazon_fips():
# The purpose of this test is to validate that we can assume a role in a FIPS region
# TODO: This is using ai.moda's IAM role, we should use LiteLLM's IAM role eventually
aws_role_name = (
"arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
)
aws_role_name = "arn:aws:iam::335785316107:role/litellm-github-unit-tests-circleci-v1-assume-only"
aws_web_identity_token = "oidc/circleci/"
bllm = BedrockConverseLLM()
@ -143,9 +142,7 @@ def test_oidc_env_variable():
# Create a unique environment variable name
env_var_name = "OIDC_TEST_PATH_" + uuid4().hex
os.environ[env_var_name] = "secret-" + uuid4().hex
secret_val = get_secret(
f"oidc/env/{env_var_name}"
)
secret_val = get_secret(f"oidc/env/{env_var_name}")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@ -157,15 +154,13 @@ def test_oidc_env_variable():
def test_oidc_file():
# Create a temporary file
with tempfile.NamedTemporaryFile(mode='w+') as temp_file:
with tempfile.NamedTemporaryFile(mode="w+") as temp_file:
secret_value = "secret-" + uuid4().hex
temp_file.write(secret_value)
temp_file.flush()
temp_file_path = temp_file.name
secret_val = get_secret(
f"oidc/file/{temp_file_path}"
)
secret_val = get_secret(f"oidc/file/{temp_file_path}")
print(f"secret_val: {redact_oidc_signature(secret_val)}")
@ -174,7 +169,7 @@ def test_oidc_file():
def test_oidc_env_path():
# Create a temporary file
with tempfile.NamedTemporaryFile(mode='w+') as temp_file:
with tempfile.NamedTemporaryFile(mode="w+") as temp_file:
secret_value = "secret-" + uuid4().hex
temp_file.write(secret_value)
temp_file.flush()
@ -187,9 +182,7 @@ def test_oidc_env_path():
os.environ[env_var_name] = temp_file_path
# Test getting the secret using the environment variable
secret_val = get_secret(
f"oidc/env_path/{env_var_name}"
)
secret_val = get_secret(f"oidc/env_path/{env_var_name}")
print(f"secret_val: {redact_oidc_signature(secret_val)}")

View file

@ -208,3 +208,62 @@ class ServerSentEvent:
@override
def __repr__(self) -> str:
return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})"
class CohereEmbeddingRequest(TypedDict, total=False):
texts: Required[List[str]]
input_type: Required[
Literal["search_document", "search_query", "classification", "clustering"]
]
truncate: Literal["NONE", "START", "END"]
embedding_types: Literal["float", "int8", "uint8", "binary", "ubinary"]
class CohereEmbeddingResponse(TypedDict):
embeddings: List[List[float]]
id: str
response_type: Literal["embedding_floats"]
texts: List[str]
class AmazonTitanV2EmbeddingRequest(TypedDict):
inputText: str
dimensions: int
normalize: bool
class AmazonTitanV2EmbeddingResponse(TypedDict):
embedding: List[float]
inputTextTokenCount: int
class AmazonTitanG1EmbeddingRequest(TypedDict):
inputText: str
class AmazonTitanG1EmbeddingResponse(TypedDict):
embedding: List[float]
inputTextTokenCount: int
class AmazonTitanMultimodalEmbeddingConfig(TypedDict):
outputEmbeddingLength: Literal[256, 384, 1024]
class AmazonTitanMultimodalEmbeddingRequest(TypedDict, total=False):
inputText: str
inputImage: str
embeddingConfig: AmazonTitanMultimodalEmbeddingConfig
class AmazonTitanMultimodalEmbeddingResponse(TypedDict):
embedding: List[float]
inputTextTokenCount: int
message: str # Specifies any errors that occur during generation.
AmazonEmbeddingRequest = Union[
AmazonTitanMultimodalEmbeddingRequest,
AmazonTitanV2EmbeddingRequest,
AmazonTitanG1EmbeddingRequest,
]

View file

@ -699,7 +699,7 @@ class ModelResponse(OpenAIObject):
class Embedding(OpenAIObject):
embedding: Union[list, str] = []
index: int
object: str
object: Literal["embedding"]
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
@ -721,7 +721,7 @@ class EmbeddingResponse(OpenAIObject):
data: Optional[List] = None
"""The actual embedding value"""
object: str
object: Literal["list"]
"""The object type, which is always "embedding" """
usage: Optional[Usage] = None
@ -732,11 +732,10 @@ class EmbeddingResponse(OpenAIObject):
def __init__(
self,
model=None,
usage=None,
stream=False,
model: Optional[str] = None,
usage: Optional[Usage] = None,
response_ms=None,
data=None,
data: Optional[List] = None,
hidden_params=None,
_response_headers=None,
**params,
@ -760,7 +759,7 @@ class EmbeddingResponse(OpenAIObject):
self._response_headers = _response_headers
model = model
super().__init__(model=model, object=object, data=data, usage=usage)
super().__init__(model=model, object=object, data=data, usage=usage) # type: ignore
def __contains__(self, key):
# Define custom behavior for the 'in' operator

View file

@ -854,6 +854,7 @@ def client(original_function):
)
cached_result = litellm.cache.get_cache(*args, **kwargs)
if cached_result is not None:
print_verbose("Cache Hit!")
if "detail" in cached_result:
# implies an error occurred
pass
@ -935,7 +936,10 @@ def client(original_function):
args=(cached_result, start_time, end_time, cache_hit),
).start()
return cached_result
else:
print_verbose(
"Cache Miss! on key - {}".format(preset_cache_key)
)
# CHECK MAX TOKENS
if (
kwargs.get("max_tokens", None) is not None
@ -1005,7 +1009,7 @@ def client(original_function):
litellm.cache is not None
and str(original_function.__name__)
in litellm.cache.supported_call_types
) and (kwargs.get("cache", {}).get("no-store", False) != True):
) and (kwargs.get("cache", {}).get("no-store", False) is not True):
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object, remove `handle_success` once it's deprecated
@ -1404,10 +1408,10 @@ def client(original_function):
# MODEL CALL
result = await original_function(*args, **kwargs)
end_time = datetime.datetime.now()
if "stream" in kwargs and kwargs["stream"] == True:
if "stream" in kwargs and kwargs["stream"] is True:
if (
"complete_response" in kwargs
and kwargs["complete_response"] == True
and kwargs["complete_response"] is True
):
chunks = []
for idx, chunk in enumerate(result):
@ -11734,3 +11738,13 @@ def is_cached_message(message: AllMessageValues) -> bool:
return True
return False
def is_base64_encoded(s: str) -> bool:
try:
# Try to decode the string
decoded_bytes = base64.b64decode(s, validate=True)
# Check if the original string can be re-encoded to the same string
return base64.b64encode(decoded_bytes).decode("utf-8") == s
except Exception:
return False