forked from phoenix/litellm-mirror
773 lines
26 KiB
Python
773 lines
26 KiB
Python
"""
|
|
Common utilities used across bedrock chat/embedding/image generation
|
|
"""
|
|
|
|
import os
|
|
import types
|
|
from enum import Enum
|
|
from typing import List, Optional, Union
|
|
|
|
import httpx
|
|
|
|
import litellm
|
|
from litellm.secret_managers.main import get_secret
|
|
|
|
|
|
class BedrockError(Exception):
|
|
def __init__(self, status_code, message):
|
|
self.status_code = status_code
|
|
self.message = message
|
|
self.request = httpx.Request(
|
|
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
|
|
)
|
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
|
super().__init__(
|
|
self.message
|
|
) # Call the base class constructor with the parameters it needs
|
|
|
|
|
|
class AmazonBedrockGlobalConfig:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_mapped_special_auth_params(self) -> dict:
|
|
"""
|
|
Mapping of common auth params across bedrock/vertex/azure/watsonx
|
|
"""
|
|
return {"region_name": "aws_region_name"}
|
|
|
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
|
mapped_params = self.get_mapped_special_auth_params()
|
|
for param, value in non_default_params.items():
|
|
if param in mapped_params:
|
|
optional_params[mapped_params[param]] = value
|
|
return optional_params
|
|
|
|
def get_eu_regions(self) -> List[str]:
|
|
"""
|
|
Source: https://www.aws-services.info/bedrock.html
|
|
"""
|
|
return [
|
|
"eu-west-1",
|
|
"eu-west-3",
|
|
"eu-central-1",
|
|
]
|
|
|
|
|
|
class AmazonTitanConfig:
|
|
"""
|
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
|
|
|
Supported Params for the Amazon Titan models:
|
|
|
|
- `maxTokenCount` (integer) max tokens,
|
|
- `stopSequences` (string[]) list of stop sequence strings
|
|
- `temperature` (float) temperature for model,
|
|
- `topP` (int) top p for model
|
|
"""
|
|
|
|
maxTokenCount: Optional[int] = None
|
|
stopSequences: Optional[list] = None
|
|
temperature: Optional[float] = None
|
|
topP: Optional[int] = None
|
|
|
|
def __init__(
|
|
self,
|
|
maxTokenCount: Optional[int] = None,
|
|
stopSequences: Optional[list] = None,
|
|
temperature: Optional[float] = None,
|
|
topP: Optional[int] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
|
|
class AmazonAnthropicClaude3Config:
|
|
"""
|
|
Reference:
|
|
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
|
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
|
|
|
Supported Params for the Amazon / Anthropic Claude 3 models:
|
|
|
|
- `max_tokens` Required (integer) max tokens. Default is 4096
|
|
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
|
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
|
- `temperature` Optional (float) The amount of randomness injected into the response
|
|
- `top_p` Optional (float) Use nucleus sampling.
|
|
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
|
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
|
"""
|
|
|
|
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
|
|
anthropic_version: Optional[str] = "bedrock-2023-05-31"
|
|
system: Optional[str] = None
|
|
temperature: Optional[float] = None
|
|
top_p: Optional[float] = None
|
|
top_k: Optional[int] = None
|
|
stop_sequences: Optional[List[str]] = None
|
|
|
|
def __init__(
|
|
self,
|
|
max_tokens: Optional[int] = None,
|
|
anthropic_version: Optional[str] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(self):
|
|
return [
|
|
"max_tokens",
|
|
"tools",
|
|
"tool_choice",
|
|
"stream",
|
|
"stop",
|
|
"temperature",
|
|
"top_p",
|
|
"extra_headers",
|
|
]
|
|
|
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
for param, value in non_default_params.items():
|
|
if param == "max_tokens":
|
|
optional_params["max_tokens"] = value
|
|
if param == "tools":
|
|
optional_params["tools"] = value
|
|
if param == "stream":
|
|
optional_params["stream"] = value
|
|
if param == "stop":
|
|
optional_params["stop_sequences"] = value
|
|
if param == "temperature":
|
|
optional_params["temperature"] = value
|
|
if param == "top_p":
|
|
optional_params["top_p"] = value
|
|
return optional_params
|
|
|
|
|
|
class AmazonAnthropicConfig:
|
|
"""
|
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
|
|
|
Supported Params for the Amazon / Anthropic models:
|
|
|
|
- `max_tokens_to_sample` (integer) max tokens,
|
|
- `temperature` (float) model temperature,
|
|
- `top_k` (integer) top k,
|
|
- `top_p` (integer) top p,
|
|
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
|
|
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
|
"""
|
|
|
|
max_tokens_to_sample: Optional[int] = litellm.max_tokens
|
|
stop_sequences: Optional[list] = None
|
|
temperature: Optional[float] = None
|
|
top_k: Optional[int] = None
|
|
top_p: Optional[int] = None
|
|
anthropic_version: Optional[str] = None
|
|
|
|
def __init__(
|
|
self,
|
|
max_tokens_to_sample: Optional[int] = None,
|
|
stop_sequences: Optional[list] = None,
|
|
temperature: Optional[float] = None,
|
|
top_k: Optional[int] = None,
|
|
top_p: Optional[int] = None,
|
|
anthropic_version: Optional[str] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(
|
|
self,
|
|
):
|
|
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
|
|
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
for param, value in non_default_params.items():
|
|
if param == "max_tokens":
|
|
optional_params["max_tokens_to_sample"] = value
|
|
if param == "temperature":
|
|
optional_params["temperature"] = value
|
|
if param == "top_p":
|
|
optional_params["top_p"] = value
|
|
if param == "stop":
|
|
optional_params["stop_sequences"] = value
|
|
if param == "stream" and value == True:
|
|
optional_params["stream"] = value
|
|
return optional_params
|
|
|
|
|
|
class AmazonCohereConfig:
|
|
"""
|
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
|
|
|
Supported Params for the Amazon / Cohere models:
|
|
|
|
- `max_tokens` (integer) max tokens,
|
|
- `temperature` (float) model temperature,
|
|
- `return_likelihood` (string) n/a
|
|
"""
|
|
|
|
max_tokens: Optional[int] = None
|
|
temperature: Optional[float] = None
|
|
return_likelihood: Optional[str] = None
|
|
|
|
def __init__(
|
|
self,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
return_likelihood: Optional[str] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
|
|
class AmazonAI21Config:
|
|
"""
|
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
|
|
|
Supported Params for the Amazon / AI21 models:
|
|
|
|
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
|
|
|
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
|
|
|
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
|
|
|
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
|
|
|
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
|
|
|
- `presencePenalty` (object): Placeholder for presence penalty object.
|
|
|
|
- `countPenalty` (object): Placeholder for count penalty object.
|
|
"""
|
|
|
|
maxTokens: Optional[int] = None
|
|
temperature: Optional[float] = None
|
|
topP: Optional[float] = None
|
|
stopSequences: Optional[list] = None
|
|
frequencePenalty: Optional[dict] = None
|
|
presencePenalty: Optional[dict] = None
|
|
countPenalty: Optional[dict] = None
|
|
|
|
def __init__(
|
|
self,
|
|
maxTokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
topP: Optional[float] = None,
|
|
stopSequences: Optional[list] = None,
|
|
frequencePenalty: Optional[dict] = None,
|
|
presencePenalty: Optional[dict] = None,
|
|
countPenalty: Optional[dict] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
|
|
class AnthropicConstants(Enum):
|
|
HUMAN_PROMPT = "\n\nHuman: "
|
|
AI_PROMPT = "\n\nAssistant: "
|
|
|
|
|
|
class AmazonLlamaConfig:
|
|
"""
|
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
|
|
|
Supported Params for the Amazon / Meta Llama models:
|
|
|
|
- `max_gen_len` (integer) max tokens,
|
|
- `temperature` (float) temperature for model,
|
|
- `top_p` (float) top p for model
|
|
"""
|
|
|
|
max_gen_len: Optional[int] = None
|
|
temperature: Optional[float] = None
|
|
topP: Optional[float] = None
|
|
|
|
def __init__(
|
|
self,
|
|
maxTokenCount: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
topP: Optional[int] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
|
|
class AmazonMistralConfig:
|
|
"""
|
|
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
|
Supported Params for the Amazon / Mistral models:
|
|
|
|
- `max_tokens` (integer) max tokens,
|
|
- `temperature` (float) temperature for model,
|
|
- `top_p` (float) top p for model
|
|
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
|
|
- `top_k` (float) top k for model
|
|
"""
|
|
|
|
max_tokens: Optional[int] = None
|
|
temperature: Optional[float] = None
|
|
top_p: Optional[float] = None
|
|
top_k: Optional[float] = None
|
|
stop: Optional[List[str]] = None
|
|
|
|
def __init__(
|
|
self,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
top_p: Optional[int] = None,
|
|
top_k: Optional[float] = None,
|
|
stop: Optional[List[str]] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
|
|
class AmazonStabilityConfig:
|
|
"""
|
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
|
|
|
Supported Params for the Amazon / Stable Diffusion models:
|
|
|
|
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
|
|
|
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
|
|
|
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
|
|
|
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
|
Engine-specific dimension validation:
|
|
|
|
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
|
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
|
- SDXL v1.0: same as SDXL v0.9
|
|
- SD v1.6: must be between 320x320 and 1536x1536
|
|
|
|
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
|
Engine-specific dimension validation:
|
|
|
|
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
|
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
|
- SDXL v1.0: same as SDXL v0.9
|
|
- SD v1.6: must be between 320x320 and 1536x1536
|
|
"""
|
|
|
|
cfg_scale: Optional[int] = None
|
|
seed: Optional[float] = None
|
|
steps: Optional[List[str]] = None
|
|
width: Optional[int] = None
|
|
height: Optional[int] = None
|
|
|
|
def __init__(
|
|
self,
|
|
cfg_scale: Optional[int] = None,
|
|
seed: Optional[float] = None,
|
|
steps: Optional[List[str]] = None,
|
|
width: Optional[int] = None,
|
|
height: Optional[int] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
|
|
def add_custom_header(headers):
|
|
"""Closure to capture the headers and add them."""
|
|
|
|
def callback(request, **kwargs):
|
|
"""Actual callback function that Boto3 will call."""
|
|
for header_name, header_value in headers.items():
|
|
request.headers.add_header(header_name, header_value)
|
|
|
|
return callback
|
|
|
|
|
|
def init_bedrock_client(
|
|
region_name=None,
|
|
aws_access_key_id: Optional[str] = None,
|
|
aws_secret_access_key: Optional[str] = None,
|
|
aws_region_name: Optional[str] = None,
|
|
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
|
aws_session_name: Optional[str] = None,
|
|
aws_profile_name: Optional[str] = None,
|
|
aws_role_name: Optional[str] = None,
|
|
aws_web_identity_token: Optional[str] = None,
|
|
extra_headers: Optional[dict] = None,
|
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
|
):
|
|
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
|
## CHECK IS 'os.environ/' passed in
|
|
# Define the list of parameters to check
|
|
params_to_check = [
|
|
aws_access_key_id,
|
|
aws_secret_access_key,
|
|
aws_region_name,
|
|
aws_bedrock_runtime_endpoint,
|
|
aws_session_name,
|
|
aws_profile_name,
|
|
aws_role_name,
|
|
aws_web_identity_token,
|
|
]
|
|
|
|
# Iterate over parameters and update if needed
|
|
for i, param in enumerate(params_to_check):
|
|
if param and param.startswith("os.environ/"):
|
|
params_to_check[i] = get_secret(param)
|
|
# Assign updated values back to parameters
|
|
(
|
|
aws_access_key_id,
|
|
aws_secret_access_key,
|
|
aws_region_name,
|
|
aws_bedrock_runtime_endpoint,
|
|
aws_session_name,
|
|
aws_profile_name,
|
|
aws_role_name,
|
|
aws_web_identity_token,
|
|
) = params_to_check
|
|
|
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
|
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
|
|
|
### SET REGION NAME
|
|
if region_name:
|
|
pass
|
|
elif aws_region_name:
|
|
region_name = aws_region_name
|
|
elif litellm_aws_region_name:
|
|
region_name = litellm_aws_region_name
|
|
elif standard_aws_region_name:
|
|
region_name = standard_aws_region_name
|
|
else:
|
|
raise BedrockError(
|
|
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
|
|
status_code=401,
|
|
)
|
|
|
|
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
|
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
|
if aws_bedrock_runtime_endpoint:
|
|
endpoint_url = aws_bedrock_runtime_endpoint
|
|
elif env_aws_bedrock_runtime_endpoint:
|
|
endpoint_url = env_aws_bedrock_runtime_endpoint
|
|
else:
|
|
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
|
|
|
|
import boto3
|
|
|
|
if isinstance(timeout, float):
|
|
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
|
elif isinstance(timeout, httpx.Timeout):
|
|
config = boto3.session.Config(
|
|
connect_timeout=timeout.connect, read_timeout=timeout.read
|
|
)
|
|
else:
|
|
config = boto3.session.Config()
|
|
|
|
### CHECK STS ###
|
|
if (
|
|
aws_web_identity_token is not None
|
|
and aws_role_name is not None
|
|
and aws_session_name is not None
|
|
):
|
|
oidc_token = get_secret(aws_web_identity_token)
|
|
|
|
if oidc_token is None:
|
|
raise BedrockError(
|
|
message="OIDC token could not be retrieved from secret manager.",
|
|
status_code=401,
|
|
)
|
|
|
|
sts_client = boto3.client("sts")
|
|
|
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
|
sts_response = sts_client.assume_role_with_web_identity(
|
|
RoleArn=aws_role_name,
|
|
RoleSessionName=aws_session_name,
|
|
WebIdentityToken=oidc_token,
|
|
DurationSeconds=3600,
|
|
)
|
|
|
|
client = boto3.client(
|
|
service_name="bedrock-runtime",
|
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
|
region_name=region_name,
|
|
endpoint_url=endpoint_url,
|
|
config=config,
|
|
verify=ssl_verify,
|
|
)
|
|
elif aws_role_name is not None and aws_session_name is not None:
|
|
# use sts if role name passed in
|
|
sts_client = boto3.client(
|
|
"sts",
|
|
aws_access_key_id=aws_access_key_id,
|
|
aws_secret_access_key=aws_secret_access_key,
|
|
)
|
|
|
|
sts_response = sts_client.assume_role(
|
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
|
)
|
|
|
|
client = boto3.client(
|
|
service_name="bedrock-runtime",
|
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
|
region_name=region_name,
|
|
endpoint_url=endpoint_url,
|
|
config=config,
|
|
verify=ssl_verify,
|
|
)
|
|
elif aws_access_key_id is not None:
|
|
# uses auth params passed to completion
|
|
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
|
|
|
client = boto3.client(
|
|
service_name="bedrock-runtime",
|
|
aws_access_key_id=aws_access_key_id,
|
|
aws_secret_access_key=aws_secret_access_key,
|
|
region_name=region_name,
|
|
endpoint_url=endpoint_url,
|
|
config=config,
|
|
verify=ssl_verify,
|
|
)
|
|
elif aws_profile_name is not None:
|
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
|
|
|
client = boto3.Session(profile_name=aws_profile_name).client(
|
|
service_name="bedrock-runtime",
|
|
region_name=region_name,
|
|
endpoint_url=endpoint_url,
|
|
config=config,
|
|
verify=ssl_verify,
|
|
)
|
|
else:
|
|
# aws_access_key_id is None, assume user is trying to auth using env variables
|
|
# boto3 automatically reads env variables
|
|
|
|
client = boto3.client(
|
|
service_name="bedrock-runtime",
|
|
region_name=region_name,
|
|
endpoint_url=endpoint_url,
|
|
config=config,
|
|
verify=ssl_verify,
|
|
)
|
|
if extra_headers:
|
|
client.meta.events.register(
|
|
"before-sign.bedrock-runtime.*", add_custom_header(extra_headers)
|
|
)
|
|
|
|
return client
|
|
|
|
|
|
def get_runtime_endpoint(
|
|
api_base: Optional[str],
|
|
aws_bedrock_runtime_endpoint: Optional[str],
|
|
aws_region_name: str,
|
|
) -> str:
|
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
|
if api_base is not None:
|
|
endpoint_url = api_base
|
|
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
|
aws_bedrock_runtime_endpoint, str
|
|
):
|
|
endpoint_url = aws_bedrock_runtime_endpoint
|
|
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
|
env_aws_bedrock_runtime_endpoint, str
|
|
):
|
|
endpoint_url = env_aws_bedrock_runtime_endpoint
|
|
else:
|
|
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
|
|
|
return endpoint_url
|
|
|
|
|
|
class ModelResponseIterator:
|
|
def __init__(self, model_response):
|
|
self.model_response = model_response
|
|
self.is_done = False
|
|
|
|
# Sync iterator
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.is_done:
|
|
raise StopIteration
|
|
self.is_done = True
|
|
return self.model_response
|
|
|
|
# Async iterator
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
if self.is_done:
|
|
raise StopAsyncIteration
|
|
self.is_done = True
|
|
return self.model_response
|