mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Bedrock Embeddings refactor + model support (#5462)
* refactor(bedrock): initial commit to refactor bedrock to a folder Improve code readability + maintainability * refactor: more refactor work * fix: fix imports * feat(bedrock/embeddings.py): support translating embedding into amazon embedding formats * fix: fix linting errors * test: skip test on end of life model * fix(cohere/embed.py): fix linting error * fix(cohere/embed.py): fix typing * fix(cohere/embed.py): fix post-call logging for cohere embedding call * test(test_embeddings.py): fix error message assertion in test
This commit is contained in:
parent
6fb82aaf75
commit
37f9705d6e
21 changed files with 1946 additions and 1659 deletions
773
litellm/llms/bedrock/common_utils.py
Normal file
773
litellm/llms/bedrock/common_utils.py
Normal file
|
@ -0,0 +1,773 @@
|
|||
"""
|
||||
Common utilities used across bedrock chat/embedding/image generation
|
||||
"""
|
||||
|
||||
import os
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class AmazonBedrockGlobalConfig:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Mapping of common auth params across bedrock/vertex/azure/watsonx
|
||||
"""
|
||||
return {"region_name": "aws_region_name"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
mapped_params = self.get_mapped_special_auth_params()
|
||||
for param, value in non_default_params.items():
|
||||
if param in mapped_params:
|
||||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://www.aws-services.info/bedrock.html
|
||||
"""
|
||||
return [
|
||||
"eu-west-1",
|
||||
"eu-west-3",
|
||||
"eu-central-1",
|
||||
]
|
||||
|
||||
|
||||
class AmazonTitanConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
||||
|
||||
Supported Params for the Amazon Titan models:
|
||||
|
||||
- `maxTokenCount` (integer) max tokens,
|
||||
- `stopSequences` (string[]) list of stop sequence strings
|
||||
- `temperature` (float) temperature for model,
|
||||
- `topP` (int) top p for model
|
||||
"""
|
||||
|
||||
maxTokenCount: Optional[int] = None
|
||||
stopSequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
class AmazonAnthropicClaude3Config:
|
||||
"""
|
||||
Reference:
|
||||
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
||||
|
||||
Supported Params for the Amazon / Anthropic Claude 3 models:
|
||||
|
||||
- `max_tokens` Required (integer) max tokens. Default is 4096
|
||||
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
||||
- `temperature` Optional (float) The amount of randomness injected into the response
|
||||
- `top_p` Optional (float) Use nucleus sampling.
|
||||
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
||||
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
|
||||
anthropic_version: Optional[str] = "bedrock-2023-05-31"
|
||||
system: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"max_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"stream",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class AmazonAnthropicConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
|
||||
Supported Params for the Amazon / Anthropic models:
|
||||
|
||||
- `max_tokens_to_sample` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `top_k` (integer) top k,
|
||||
- `top_p` (integer) top p,
|
||||
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
|
||||
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
"""
|
||||
|
||||
max_tokens_to_sample: Optional[int] = litellm.max_tokens
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
anthropic_version: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(
|
||||
self,
|
||||
):
|
||||
return ["max_tokens", "temperature", "stop", "top_p", "stream"]
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "stream" and value == True:
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
|
||||
|
||||
class AmazonCohereConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||
|
||||
Supported Params for the Amazon / Cohere models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `return_likelihood` (string) n/a
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
return_likelihood: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
return_likelihood: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
class AmazonAI21Config:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||
|
||||
Supported Params for the Amazon / AI21 models:
|
||||
|
||||
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
||||
|
||||
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
||||
|
||||
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
||||
|
||||
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
||||
|
||||
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
||||
|
||||
- `presencePenalty` (object): Placeholder for presence penalty object.
|
||||
|
||||
- `countPenalty` (object): Placeholder for count penalty object.
|
||||
"""
|
||||
|
||||
maxTokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
stopSequences: Optional[list] = None
|
||||
frequencePenalty: Optional[dict] = None
|
||||
presencePenalty: Optional[dict] = None
|
||||
countPenalty: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[float] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
frequencePenalty: Optional[dict] = None,
|
||||
presencePenalty: Optional[dict] = None,
|
||||
countPenalty: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman: "
|
||||
AI_PROMPT = "\n\nAssistant: "
|
||||
|
||||
|
||||
class AmazonLlamaConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
||||
|
||||
Supported Params for the Amazon / Meta Llama models:
|
||||
|
||||
- `max_gen_len` (integer) max tokens,
|
||||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
"""
|
||||
|
||||
max_gen_len: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
class AmazonMistralConfig:
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
||||
Supported Params for the Amazon / Mistral models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
|
||||
- `top_k` (float) top k for model
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[float] = None
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
class AmazonStabilityConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Supported Params for the Amazon / Stable Diffusion models:
|
||||
|
||||
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||
|
||||
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||
|
||||
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||
|
||||
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
|
||||
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
"""
|
||||
|
||||
cfg_scale: Optional[int] = None
|
||||
seed: Optional[float] = None
|
||||
steps: Optional[List[str]] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_scale: Optional[int] = None,
|
||||
seed: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
|
||||
def add_custom_header(headers):
|
||||
"""Closure to capture the headers and add them."""
|
||||
|
||||
def callback(request, **kwargs):
|
||||
"""Actual callback function that Boto3 will call."""
|
||||
for header_name, header_value in headers.items():
|
||||
request.headers.add_header(header_name, header_value)
|
||||
|
||||
return callback
|
||||
|
||||
|
||||
def init_bedrock_client(
|
||||
region_name=None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_bedrock_runtime_endpoint,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
]
|
||||
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
params_to_check[i] = get_secret(param)
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
aws_region_name,
|
||||
aws_bedrock_runtime_endpoint,
|
||||
aws_session_name,
|
||||
aws_profile_name,
|
||||
aws_role_name,
|
||||
aws_web_identity_token,
|
||||
) = params_to_check
|
||||
|
||||
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
||||
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
||||
|
||||
### SET REGION NAME
|
||||
if region_name:
|
||||
pass
|
||||
elif aws_region_name:
|
||||
region_name = aws_region_name
|
||||
elif litellm_aws_region_name:
|
||||
region_name = litellm_aws_region_name
|
||||
elif standard_aws_region_name:
|
||||
region_name = standard_aws_region_name
|
||||
else:
|
||||
raise BedrockError(
|
||||
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if aws_bedrock_runtime_endpoint:
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint:
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
|
||||
|
||||
import boto3
|
||||
|
||||
if isinstance(timeout, float):
|
||||
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
|
||||
elif isinstance(timeout, httpx.Timeout):
|
||||
config = boto3.session.Config(
|
||||
connect_timeout=timeout.connect, read_timeout=timeout.read
|
||||
)
|
||||
else:
|
||||
config = boto3.session.Config()
|
||||
|
||||
### CHECK STS ###
|
||||
if (
|
||||
aws_web_identity_token is not None
|
||||
and aws_role_name is not None
|
||||
and aws_session_name is not None
|
||||
):
|
||||
oidc_token = get_secret(aws_web_identity_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise BedrockError(
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
status_code=401,
|
||||
)
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
|
||||
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
||||
sts_response = sts_client.assume_role_with_web_identity(
|
||||
RoleArn=aws_role_name,
|
||||
RoleSessionName=aws_session_name,
|
||||
WebIdentityToken=oidc_token,
|
||||
DurationSeconds=3600,
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
elif aws_role_name is not None and aws_session_name is not None:
|
||||
# use sts if role name passed in
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
|
||||
sts_response = sts_client.assume_role(
|
||||
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
||||
)
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
||||
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
||||
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
elif aws_access_key_id is not None:
|
||||
# uses auth params passed to completion
|
||||
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
elif aws_profile_name is not None:
|
||||
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
||||
|
||||
client = boto3.Session(profile_name=aws_profile_name).client(
|
||||
service_name="bedrock-runtime",
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
else:
|
||||
# aws_access_key_id is None, assume user is trying to auth using env variables
|
||||
# boto3 automatically reads env variables
|
||||
|
||||
client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
region_name=region_name,
|
||||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
verify=ssl_verify,
|
||||
)
|
||||
if extra_headers:
|
||||
client.meta.events.register(
|
||||
"before-sign.bedrock-runtime.*", add_custom_header(extra_headers)
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
def get_runtime_endpoint(
|
||||
api_base: Optional[str],
|
||||
aws_bedrock_runtime_endpoint: Optional[str],
|
||||
aws_region_name: str,
|
||||
) -> str:
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if api_base is not None:
|
||||
endpoint_url = api_base
|
||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
return endpoint_url
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(self, model_response):
|
||||
self.model_response = model_response
|
||||
self.is_done = False
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.is_done:
|
||||
raise StopIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.is_done:
|
||||
raise StopAsyncIteration
|
||||
self.is_done = True
|
||||
return self.model_response
|
Loading…
Add table
Add a link
Reference in a new issue