extra headers

This commit is contained in:
Lucca Zenobio 2024-03-21 10:43:27 -03:00
parent 872ff6176d
commit 0c0780be83
4 changed files with 21 additions and 6 deletions

View file

@ -495,6 +495,15 @@ class AmazonStabilityConfig:
}
def add_custom_header(headers):
"""Closure to capture the headers and add them."""
def callback(request, **kwargs):
"""Actual callback function that Boto3 will call."""
for header_name, header_value in headers.items():
request.headers.add_header(header_name, header_value)
return callback
def init_bedrock_client(
region_name=None,
aws_access_key_id: Optional[str] = None,
@ -504,12 +513,12 @@ def init_bedrock_client(
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
extra_headers: Optional[dict] = None,
timeout: Optional[int] = None,
):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check
params_to_check = [
@ -618,6 +627,8 @@ def init_bedrock_client(
endpoint_url=endpoint_url,
config=config,
)
if extra_headers:
client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers))
return client
@ -677,6 +688,7 @@ def completion(
litellm_params=None,
logger_fn=None,
timeout=None,
extra_headers: Optional[dict] = None,
):
exception_mapping_worked = False
try:
@ -704,6 +716,7 @@ def completion(
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
extra_headers=extra_headers,
timeout=timeout,
)

View file

@ -604,13 +604,14 @@ def convert_to_anthropic_tool_result(message: dict) -> str:
def convert_to_anthropic_tool_invoke(tool_calls: list) -> str:
invokes = ""
for tool in tool_calls:
if tool.type != "function":
tool = dict(tool)
if tool["type"] != "function":
continue
tool_name = tool.function.name
tool_function = dict(tool["function"])
tool_name = tool_function["name"]
parameters = "".join(
f"<{param}>{val}</{param}>\n"
for param, val in json.loads(tool.function.arguments).items()
for param, val in json.loads(tool_function["arguments"]).items()
)
invokes += (
"<invoke>\n"

View file

@ -1749,6 +1749,7 @@ def completion(
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
)

View file

@ -5146,7 +5146,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"""
if custom_llm_provider == "bedrock":
if model.startswith("anthropic.claude-3"):
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() + ["extra_headers"]
elif model.startswith("anthropic"):
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
elif model.startswith("ai21"):