mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
* refactor(utils.py): migrate amazon titan config to base config * refactor(utils.py): refactor bedrock meta invoke model translation to use base config * refactor(utils.py): move bedrock ai21 to base config * refactor(utils.py): move bedrock cohere to base config * refactor(utils.py): move bedrock mistral to use base config * refactor(utils.py): move all provider optional param translations to using a config * docs(clientside_auth.md): clarify how to pass vertex region to litellm proxy * fix(utils.py): handle scenario where custom llm provider is none / empty * fix: fix get config * test(test_otel_load_tests.py): widen perf margin * fix(utils.py): fix get provider config check to handle custom llm's * fix(utils.py): fix check
936 lines
30 KiB
Python
936 lines
30 KiB
Python
"""
|
|
Common utilities used across bedrock chat/embedding/image generation
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import types
|
|
from enum import Enum
|
|
from typing import Any, List, Optional, Union
|
|
|
|
import httpx
|
|
|
|
import litellm
|
|
from litellm.llms.base_llm.chat.transformation import (
|
|
BaseConfig,
|
|
BaseLLMException,
|
|
LiteLLMLoggingObj,
|
|
)
|
|
from litellm.secret_managers.main import get_secret
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
from litellm.types.utils import ModelResponse
|
|
|
|
|
|
class BedrockError(BaseLLMException):
|
|
pass
|
|
|
|
|
|
class AmazonBedrockGlobalConfig:
|
|
def __init__(self):
|
|
pass
|
|
|
|
def get_mapped_special_auth_params(self) -> dict:
|
|
"""
|
|
Mapping of common auth params across bedrock/vertex/azure/watsonx
|
|
"""
|
|
return {"region_name": "aws_region_name"}
|
|
|
|
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
|
mapped_params = self.get_mapped_special_auth_params()
|
|
for param, value in non_default_params.items():
|
|
if param in mapped_params:
|
|
optional_params[mapped_params[param]] = value
|
|
return optional_params
|
|
|
|
def get_eu_regions(self) -> List[str]:
|
|
"""
|
|
Source: https://www.aws-services.info/bedrock.html
|
|
"""
|
|
return [
|
|
"eu-west-1",
|
|
"eu-west-3",
|
|
"eu-central-1",
|
|
]
|
|
|
|
def get_us_regions(self) -> List[str]:
|
|
"""
|
|
Source: https://www.aws-services.info/bedrock.html
|
|
"""
|
|
return [
|
|
"us-east-2",
|
|
"us-east-1",
|
|
"us-west-2",
|
|
"us-gov-west-1",
|
|
]
|
|
|
|
|
|
class AmazonInvokeMixin:
|
|
"""
|
|
Base class for bedrock models going through invoke_handler.py
|
|
"""
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
|
) -> BaseLLMException:
|
|
return BedrockError(
|
|
message=error_message,
|
|
status_code=status_code,
|
|
headers=headers,
|
|
)
|
|
|
|
def transform_request(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
headers: dict,
|
|
) -> dict:
|
|
raise NotImplementedError(
|
|
"transform_request not implemented for config. Done in invoke_handler.py"
|
|
)
|
|
|
|
def transform_response(
|
|
self,
|
|
model: str,
|
|
raw_response: httpx.Response,
|
|
model_response: ModelResponse,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
request_data: dict,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
encoding: Any,
|
|
api_key: Optional[str] = None,
|
|
json_mode: Optional[bool] = None,
|
|
) -> ModelResponse:
|
|
raise NotImplementedError(
|
|
"transform_response not implemented for config. Done in invoke_handler.py"
|
|
)
|
|
|
|
def validate_environment(
|
|
self,
|
|
headers: dict,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
api_key: Optional[str] = None,
|
|
) -> dict:
|
|
raise NotImplementedError(
|
|
"validate_environment not implemented for config. Done in invoke_handler.py"
|
|
)
|
|
|
|
|
|
class AmazonTitanConfig(AmazonInvokeMixin, BaseConfig):
|
|
"""
|
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
|
|
|
Supported Params for the Amazon Titan models:
|
|
|
|
- `maxTokenCount` (integer) max tokens,
|
|
- `stopSequences` (string[]) list of stop sequence strings
|
|
- `temperature` (float) temperature for model,
|
|
- `topP` (int) top p for model
|
|
"""
|
|
|
|
maxTokenCount: Optional[int] = None
|
|
stopSequences: Optional[list] = None
|
|
temperature: Optional[float] = None
|
|
topP: Optional[int] = None
|
|
|
|
def __init__(
|
|
self,
|
|
maxTokenCount: Optional[int] = None,
|
|
stopSequences: Optional[list] = None,
|
|
temperature: Optional[float] = None,
|
|
topP: Optional[int] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not k.startswith("_abc")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def _map_and_modify_arg(
|
|
self,
|
|
supported_params: dict,
|
|
provider: str,
|
|
model: str,
|
|
stop: Union[List[str], str],
|
|
):
|
|
"""
|
|
filter params to fit the required provider format, drop those that don't fit if user sets `litellm.drop_params = True`.
|
|
"""
|
|
filtered_stop = None
|
|
if "stop" in supported_params and litellm.drop_params:
|
|
if provider == "bedrock" and "amazon" in model:
|
|
filtered_stop = []
|
|
if isinstance(stop, list):
|
|
for s in stop:
|
|
if re.match(r"^(\|+|User:)$", s):
|
|
filtered_stop.append(s)
|
|
if filtered_stop is not None:
|
|
supported_params["stop"] = filtered_stop
|
|
|
|
return supported_params
|
|
|
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
|
return [
|
|
"max_tokens",
|
|
"max_completion_tokens",
|
|
"stop",
|
|
"temperature",
|
|
"top_p",
|
|
"stream",
|
|
]
|
|
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
for k, v in non_default_params.items():
|
|
if k == "max_tokens" or k == "max_completion_tokens":
|
|
optional_params["maxTokenCount"] = v
|
|
if k == "temperature":
|
|
optional_params["temperature"] = v
|
|
if k == "stop":
|
|
filtered_stop = self._map_and_modify_arg(
|
|
{"stop": v}, provider="bedrock", model=model, stop=v
|
|
)
|
|
optional_params["stopSequences"] = filtered_stop["stop"]
|
|
if k == "top_p":
|
|
optional_params["topP"] = v
|
|
if k == "stream":
|
|
optional_params["stream"] = v
|
|
return optional_params
|
|
|
|
|
|
class AmazonAnthropicClaude3Config:
|
|
"""
|
|
Reference:
|
|
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
|
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
|
|
|
Supported Params for the Amazon / Anthropic Claude 3 models:
|
|
|
|
- `max_tokens` Required (integer) max tokens. Default is 4096
|
|
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
|
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
|
|
- `temperature` Optional (float) The amount of randomness injected into the response
|
|
- `top_p` Optional (float) Use nucleus sampling.
|
|
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
|
|
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
|
|
"""
|
|
|
|
max_tokens: Optional[int] = 4096 # Opus, Sonnet, and Haiku default
|
|
anthropic_version: Optional[str] = "bedrock-2023-05-31"
|
|
system: Optional[str] = None
|
|
temperature: Optional[float] = None
|
|
top_p: Optional[float] = None
|
|
top_k: Optional[int] = None
|
|
stop_sequences: Optional[List[str]] = None
|
|
|
|
def __init__(
|
|
self,
|
|
max_tokens: Optional[int] = None,
|
|
anthropic_version: Optional[str] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(self):
|
|
return [
|
|
"max_tokens",
|
|
"max_completion_tokens",
|
|
"tools",
|
|
"tool_choice",
|
|
"stream",
|
|
"stop",
|
|
"temperature",
|
|
"top_p",
|
|
"extra_headers",
|
|
]
|
|
|
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
for param, value in non_default_params.items():
|
|
if param == "max_tokens" or param == "max_completion_tokens":
|
|
optional_params["max_tokens"] = value
|
|
if param == "tools":
|
|
optional_params["tools"] = value
|
|
if param == "stream":
|
|
optional_params["stream"] = value
|
|
if param == "stop":
|
|
optional_params["stop_sequences"] = value
|
|
if param == "temperature":
|
|
optional_params["temperature"] = value
|
|
if param == "top_p":
|
|
optional_params["top_p"] = value
|
|
return optional_params
|
|
|
|
|
|
class AmazonAnthropicConfig:
|
|
"""
|
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
|
|
|
Supported Params for the Amazon / Anthropic models:
|
|
|
|
- `max_tokens_to_sample` (integer) max tokens,
|
|
- `temperature` (float) model temperature,
|
|
- `top_k` (integer) top k,
|
|
- `top_p` (integer) top p,
|
|
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
|
|
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
|
"""
|
|
|
|
max_tokens_to_sample: Optional[int] = litellm.max_tokens
|
|
stop_sequences: Optional[list] = None
|
|
temperature: Optional[float] = None
|
|
top_k: Optional[int] = None
|
|
top_p: Optional[int] = None
|
|
anthropic_version: Optional[str] = None
|
|
|
|
def __init__(
|
|
self,
|
|
max_tokens_to_sample: Optional[int] = None,
|
|
stop_sequences: Optional[list] = None,
|
|
temperature: Optional[float] = None,
|
|
top_k: Optional[int] = None,
|
|
top_p: Optional[int] = None,
|
|
anthropic_version: Optional[str] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(
|
|
self,
|
|
):
|
|
return [
|
|
"max_tokens",
|
|
"max_completion_tokens",
|
|
"temperature",
|
|
"stop",
|
|
"top_p",
|
|
"stream",
|
|
]
|
|
|
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
|
for param, value in non_default_params.items():
|
|
if param == "max_tokens" or param == "max_completion_tokens":
|
|
optional_params["max_tokens_to_sample"] = value
|
|
if param == "temperature":
|
|
optional_params["temperature"] = value
|
|
if param == "top_p":
|
|
optional_params["top_p"] = value
|
|
if param == "stop":
|
|
optional_params["stop_sequences"] = value
|
|
if param == "stream" and value is True:
|
|
optional_params["stream"] = value
|
|
return optional_params
|
|
|
|
|
|
class AmazonCohereConfig(AmazonInvokeMixin, BaseConfig):
|
|
"""
|
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
|
|
|
Supported Params for the Amazon / Cohere models:
|
|
|
|
- `max_tokens` (integer) max tokens,
|
|
- `temperature` (float) model temperature,
|
|
- `return_likelihood` (string) n/a
|
|
"""
|
|
|
|
max_tokens: Optional[int] = None
|
|
temperature: Optional[float] = None
|
|
return_likelihood: Optional[str] = None
|
|
|
|
def __init__(
|
|
self,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
return_likelihood: Optional[str] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not k.startswith("_abc")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
|
return [
|
|
"max_tokens",
|
|
"temperature",
|
|
"stream",
|
|
]
|
|
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
for k, v in non_default_params.items():
|
|
if k == "stream":
|
|
optional_params["stream"] = v
|
|
if k == "temperature":
|
|
optional_params["temperature"] = v
|
|
if k == "max_tokens":
|
|
optional_params["max_tokens"] = v
|
|
return optional_params
|
|
|
|
|
|
class AmazonAI21Config(AmazonInvokeMixin, BaseConfig):
|
|
"""
|
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
|
|
|
Supported Params for the Amazon / AI21 models:
|
|
|
|
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
|
|
|
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
|
|
|
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
|
|
|
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
|
|
|
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
|
|
|
- `presencePenalty` (object): Placeholder for presence penalty object.
|
|
|
|
- `countPenalty` (object): Placeholder for count penalty object.
|
|
"""
|
|
|
|
maxTokens: Optional[int] = None
|
|
temperature: Optional[float] = None
|
|
topP: Optional[float] = None
|
|
stopSequences: Optional[list] = None
|
|
frequencePenalty: Optional[dict] = None
|
|
presencePenalty: Optional[dict] = None
|
|
countPenalty: Optional[dict] = None
|
|
|
|
def __init__(
|
|
self,
|
|
maxTokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
topP: Optional[float] = None,
|
|
stopSequences: Optional[list] = None,
|
|
frequencePenalty: Optional[dict] = None,
|
|
presencePenalty: Optional[dict] = None,
|
|
countPenalty: Optional[dict] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not k.startswith("_abc")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(self, model: str) -> List:
|
|
return [
|
|
"max_tokens",
|
|
"temperature",
|
|
"top_p",
|
|
"stream",
|
|
]
|
|
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
for k, v in non_default_params.items():
|
|
if k == "max_tokens":
|
|
optional_params["maxTokens"] = v
|
|
if k == "temperature":
|
|
optional_params["temperature"] = v
|
|
if k == "top_p":
|
|
optional_params["topP"] = v
|
|
if k == "stream":
|
|
optional_params["stream"] = v
|
|
return optional_params
|
|
|
|
|
|
class AnthropicConstants(Enum):
|
|
HUMAN_PROMPT = "\n\nHuman: "
|
|
AI_PROMPT = "\n\nAssistant: "
|
|
|
|
|
|
class AmazonLlamaConfig(AmazonInvokeMixin, BaseConfig):
|
|
"""
|
|
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
|
|
|
Supported Params for the Amazon / Meta Llama models:
|
|
|
|
- `max_gen_len` (integer) max tokens,
|
|
- `temperature` (float) temperature for model,
|
|
- `top_p` (float) top p for model
|
|
"""
|
|
|
|
max_gen_len: Optional[int] = None
|
|
temperature: Optional[float] = None
|
|
topP: Optional[float] = None
|
|
|
|
def __init__(
|
|
self,
|
|
maxTokenCount: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
topP: Optional[int] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not k.startswith("_abc")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(self, model: str) -> List:
|
|
return [
|
|
"max_tokens",
|
|
"temperature",
|
|
"top_p",
|
|
"stream",
|
|
]
|
|
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
for k, v in non_default_params.items():
|
|
if k == "max_tokens":
|
|
optional_params["max_gen_len"] = v
|
|
if k == "temperature":
|
|
optional_params["temperature"] = v
|
|
if k == "top_p":
|
|
optional_params["top_p"] = v
|
|
if k == "stream":
|
|
optional_params["stream"] = v
|
|
return optional_params
|
|
|
|
|
|
class AmazonMistralConfig(AmazonInvokeMixin, BaseConfig):
|
|
"""
|
|
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
|
Supported Params for the Amazon / Mistral models:
|
|
|
|
- `max_tokens` (integer) max tokens,
|
|
- `temperature` (float) temperature for model,
|
|
- `top_p` (float) top p for model
|
|
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
|
|
- `top_k` (float) top k for model
|
|
"""
|
|
|
|
max_tokens: Optional[int] = None
|
|
temperature: Optional[float] = None
|
|
top_p: Optional[float] = None
|
|
top_k: Optional[float] = None
|
|
stop: Optional[List[str]] = None
|
|
|
|
def __init__(
|
|
self,
|
|
max_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
top_p: Optional[int] = None,
|
|
top_k: Optional[float] = None,
|
|
stop: Optional[List[str]] = None,
|
|
) -> None:
|
|
locals_ = locals()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
|
|
@classmethod
|
|
def get_config(cls):
|
|
return {
|
|
k: v
|
|
for k, v in cls.__dict__.items()
|
|
if not k.startswith("__")
|
|
and not k.startswith("_abc")
|
|
and not isinstance(
|
|
v,
|
|
(
|
|
types.FunctionType,
|
|
types.BuiltinFunctionType,
|
|
classmethod,
|
|
staticmethod,
|
|
),
|
|
)
|
|
and v is not None
|
|
}
|
|
|
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
|
return ["max_tokens", "temperature", "top_p", "stop", "stream"]
|
|
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
for k, v in non_default_params.items():
|
|
if k == "max_tokens":
|
|
optional_params["max_tokens"] = v
|
|
if k == "temperature":
|
|
optional_params["temperature"] = v
|
|
if k == "top_p":
|
|
optional_params["top_p"] = v
|
|
if k == "stop":
|
|
optional_params["stop"] = v
|
|
if k == "stream":
|
|
optional_params["stream"] = v
|
|
return optional_params
|
|
|
|
|
|
def add_custom_header(headers):
|
|
"""Closure to capture the headers and add them."""
|
|
|
|
def callback(request, **kwargs):
|
|
"""Actual callback function that Boto3 will call."""
|
|
for header_name, header_value in headers.items():
|
|
request.headers.add_header(header_name, header_value)
|
|
|
|
return callback
|
|
|
|
|
|
def init_bedrock_client(
|
|
region_name=None,
|
|
aws_access_key_id: Optional[str] = None,
|
|
aws_secret_access_key: Optional[str] = None,
|
|
aws_region_name: Optional[str] = None,
|
|
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
|
aws_session_name: Optional[str] = None,
|
|
aws_profile_name: Optional[str] = None,
|
|
aws_role_name: Optional[str] = None,
|
|
aws_web_identity_token: Optional[str] = None,
|
|
extra_headers: Optional[dict] = None,
|
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
|
):
|
|
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
|
## CHECK IS 'os.environ/' passed in
|
|
# Define the list of parameters to check
|
|
params_to_check = [
|
|
aws_access_key_id,
|
|
aws_secret_access_key,
|
|
aws_region_name,
|
|
aws_bedrock_runtime_endpoint,
|
|
aws_session_name,
|
|
aws_profile_name,
|
|
aws_role_name,
|
|
aws_web_identity_token,
|
|
]
|
|
|
|
# Iterate over parameters and update if needed
|
|
for i, param in enumerate(params_to_check):
|
|
if param and param.startswith("os.environ/"):
|
|
params_to_check[i] = get_secret(param) # type: ignore
|
|
# Assign updated values back to parameters
|
|
(
|
|
aws_access_key_id,
|
|
aws_secret_access_key,
|
|
aws_region_name,
|
|
aws_bedrock_runtime_endpoint,
|
|
aws_session_name,
|
|
aws_profile_name,
|
|
aws_role_name,
|
|
aws_web_identity_token,
|
|
) = params_to_check
|
|
|
|
# SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
|
|
ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
|
|
|
|
### SET REGION NAME
|
|
if region_name:
|
|
pass
|
|
elif aws_region_name:
|
|
region_name = aws_region_name
|
|
elif litellm_aws_region_name:
|
|
region_name = litellm_aws_region_name
|
|
elif standard_aws_region_name:
|
|
region_name = standard_aws_region_name
|
|
else:
|
|
raise BedrockError(
|
|
message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file",
|
|
status_code=401,
|
|
)
|
|
|
|
# check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client
|
|
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
|
if aws_bedrock_runtime_endpoint:
|
|
endpoint_url = aws_bedrock_runtime_endpoint
|
|
elif env_aws_bedrock_runtime_endpoint:
|
|
endpoint_url = env_aws_bedrock_runtime_endpoint
|
|
else:
|
|
endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com"
|
|
|
|
import boto3
|
|
|
|
if isinstance(timeout, float):
|
|
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore
|
|
elif isinstance(timeout, httpx.Timeout):
|
|
config = boto3.session.Config( # type: ignore
|
|
connect_timeout=timeout.connect, read_timeout=timeout.read
|
|
)
|
|
else:
|
|
config = boto3.session.Config() # type: ignore
|
|
|
|
### CHECK STS ###
|
|
if (
|
|
aws_web_identity_token is not None
|
|
and aws_role_name is not None
|
|
and aws_session_name is not None
|
|
):
|
|
oidc_token = get_secret(aws_web_identity_token)
|
|
|
|
if oidc_token is None:
|
|
raise BedrockError(
|
|
message="OIDC token could not be retrieved from secret manager.",
|
|
status_code=401,
|
|
)
|
|
|
|
sts_client = boto3.client("sts")
|
|
|
|
# https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html
|
|
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html
|
|
sts_response = sts_client.assume_role_with_web_identity(
|
|
RoleArn=aws_role_name,
|
|
RoleSessionName=aws_session_name,
|
|
WebIdentityToken=oidc_token,
|
|
DurationSeconds=3600,
|
|
)
|
|
|
|
client = boto3.client(
|
|
service_name="bedrock-runtime",
|
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
|
region_name=region_name,
|
|
endpoint_url=endpoint_url,
|
|
config=config,
|
|
verify=ssl_verify,
|
|
)
|
|
elif aws_role_name is not None and aws_session_name is not None:
|
|
# use sts if role name passed in
|
|
sts_client = boto3.client(
|
|
"sts",
|
|
aws_access_key_id=aws_access_key_id,
|
|
aws_secret_access_key=aws_secret_access_key,
|
|
)
|
|
|
|
sts_response = sts_client.assume_role(
|
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
|
)
|
|
|
|
client = boto3.client(
|
|
service_name="bedrock-runtime",
|
|
aws_access_key_id=sts_response["Credentials"]["AccessKeyId"],
|
|
aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"],
|
|
aws_session_token=sts_response["Credentials"]["SessionToken"],
|
|
region_name=region_name,
|
|
endpoint_url=endpoint_url,
|
|
config=config,
|
|
verify=ssl_verify,
|
|
)
|
|
elif aws_access_key_id is not None:
|
|
# uses auth params passed to completion
|
|
# aws_access_key_id is not None, assume user is trying to auth using litellm.completion
|
|
|
|
client = boto3.client(
|
|
service_name="bedrock-runtime",
|
|
aws_access_key_id=aws_access_key_id,
|
|
aws_secret_access_key=aws_secret_access_key,
|
|
region_name=region_name,
|
|
endpoint_url=endpoint_url,
|
|
config=config,
|
|
verify=ssl_verify,
|
|
)
|
|
elif aws_profile_name is not None:
|
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
|
|
|
client = boto3.Session(profile_name=aws_profile_name).client(
|
|
service_name="bedrock-runtime",
|
|
region_name=region_name,
|
|
endpoint_url=endpoint_url,
|
|
config=config,
|
|
verify=ssl_verify,
|
|
)
|
|
else:
|
|
# aws_access_key_id is None, assume user is trying to auth using env variables
|
|
# boto3 automatically reads env variables
|
|
|
|
client = boto3.client(
|
|
service_name="bedrock-runtime",
|
|
region_name=region_name,
|
|
endpoint_url=endpoint_url,
|
|
config=config,
|
|
verify=ssl_verify,
|
|
)
|
|
if extra_headers:
|
|
client.meta.events.register(
|
|
"before-sign.bedrock-runtime.*", add_custom_header(extra_headers)
|
|
)
|
|
|
|
return client
|
|
|
|
|
|
class ModelResponseIterator:
|
|
def __init__(self, model_response):
|
|
self.model_response = model_response
|
|
self.is_done = False
|
|
|
|
# Sync iterator
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.is_done:
|
|
raise StopIteration
|
|
self.is_done = True
|
|
return self.model_response
|
|
|
|
# Async iterator
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
if self.is_done:
|
|
raise StopAsyncIteration
|
|
self.is_done = True
|
|
return self.model_response
|
|
|
|
|
|
def get_bedrock_tool_name(response_tool_name: str) -> str:
|
|
"""
|
|
If litellm formatted the input tool name, we need to convert it back to the original name.
|
|
|
|
Args:
|
|
response_tool_name (str): The name of the tool as received from the response.
|
|
|
|
Returns:
|
|
str: The original name of the tool.
|
|
"""
|
|
|
|
if response_tool_name in litellm.bedrock_tool_name_mappings.cache_dict:
|
|
response_tool_name = litellm.bedrock_tool_name_mappings.cache_dict[
|
|
response_tool_name
|
|
]
|
|
return response_tool_name
|