extra headers

This commit is contained in:
Lucca Zenobio 2024-03-21 10:43:27 -03:00
parent 872ff6176d
commit 0c0780be83
4 changed files with 21 additions and 6 deletions

View file

@ -495,6 +495,15 @@ class AmazonStabilityConfig:
} }
def add_custom_header(headers):
"""Closure to capture the headers and add them."""
def callback(request, **kwargs):
"""Actual callback function that Boto3 will call."""
for header_name, header_value in headers.items():
request.headers.add_header(header_name, header_value)
return callback
def init_bedrock_client( def init_bedrock_client(
region_name=None, region_name=None,
aws_access_key_id: Optional[str] = None, aws_access_key_id: Optional[str] = None,
@ -504,12 +513,12 @@ def init_bedrock_client(
aws_session_name: Optional[str] = None, aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None, aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None, aws_role_name: Optional[str] = None,
extra_headers: Optional[dict] = None,
timeout: Optional[int] = None, timeout: Optional[int] = None,
): ):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None) standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in ## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check # Define the list of parameters to check
params_to_check = [ params_to_check = [
@ -618,6 +627,8 @@ def init_bedrock_client(
endpoint_url=endpoint_url, endpoint_url=endpoint_url,
config=config, config=config,
) )
if extra_headers:
client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers))
return client return client
@ -677,6 +688,7 @@ def completion(
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
timeout=None, timeout=None,
extra_headers: Optional[dict] = None,
): ):
exception_mapping_worked = False exception_mapping_worked = False
try: try:
@ -704,6 +716,7 @@ def completion(
aws_role_name=aws_role_name, aws_role_name=aws_role_name,
aws_session_name=aws_session_name, aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name, aws_profile_name=aws_profile_name,
extra_headers=extra_headers,
timeout=timeout, timeout=timeout,
) )

View file

@ -604,13 +604,14 @@ def convert_to_anthropic_tool_result(message: dict) -> str:
def convert_to_anthropic_tool_invoke(tool_calls: list) -> str: def convert_to_anthropic_tool_invoke(tool_calls: list) -> str:
invokes = "" invokes = ""
for tool in tool_calls: for tool in tool_calls:
if tool.type != "function": tool = dict(tool)
if tool["type"] != "function":
continue continue
tool_function = dict(tool["function"])
tool_name = tool.function.name tool_name = tool_function["name"]
parameters = "".join( parameters = "".join(
f"<{param}>{val}</{param}>\n" f"<{param}>{val}</{param}>\n"
for param, val in json.loads(tool.function.arguments).items() for param, val in json.loads(tool_function["arguments"]).items()
) )
invokes += ( invokes += (
"<invoke>\n" "<invoke>\n"

View file

@ -1749,6 +1749,7 @@ def completion(
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout, timeout=timeout,
) )

View file

@ -5146,7 +5146,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
""" """
if custom_llm_provider == "bedrock": if custom_llm_provider == "bedrock":
if model.startswith("anthropic.claude-3"): if model.startswith("anthropic.claude-3"):
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() + ["extra_headers"]
elif model.startswith("anthropic"): elif model.startswith("anthropic"):
return litellm.AmazonAnthropicConfig().get_supported_openai_params() return litellm.AmazonAnthropicConfig().get_supported_openai_params()
elif model.startswith("ai21"): elif model.startswith("ai21"):