mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
extra headers
This commit is contained in:
parent
872ff6176d
commit
0c0780be83
4 changed files with 21 additions and 6 deletions
|
@ -495,6 +495,15 @@ class AmazonStabilityConfig:
|
|||
}
|
||||
|
||||
|
||||
def add_custom_header(headers):
|
||||
"""Closure to capture the headers and add them."""
|
||||
def callback(request, **kwargs):
|
||||
"""Actual callback function that Boto3 will call."""
|
||||
for header_name, header_value in headers.items():
|
||||
request.headers.add_header(header_name, header_value)
|
||||
return callback
|
||||
|
||||
|
||||
def init_bedrock_client(
|
||||
region_name=None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
|
@ -504,12 +513,12 @@ def init_bedrock_client(
|
|||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
):
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [
|
||||
|
@ -618,6 +627,8 @@ def init_bedrock_client(
|
|||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
)
|
||||
if extra_headers:
|
||||
client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers))
|
||||
|
||||
return client
|
||||
|
||||
|
@ -677,6 +688,7 @@ def completion(
|
|||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
timeout=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
):
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
|
@ -704,6 +716,7 @@ def completion(
|
|||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
|
@ -604,13 +604,14 @@ def convert_to_anthropic_tool_result(message: dict) -> str:
|
|||
def convert_to_anthropic_tool_invoke(tool_calls: list) -> str:
|
||||
invokes = ""
|
||||
for tool in tool_calls:
|
||||
if tool.type != "function":
|
||||
tool = dict(tool)
|
||||
if tool["type"] != "function":
|
||||
continue
|
||||
|
||||
tool_name = tool.function.name
|
||||
tool_function = dict(tool["function"])
|
||||
tool_name = tool_function["name"]
|
||||
parameters = "".join(
|
||||
f"<{param}>{val}</{param}>\n"
|
||||
for param, val in json.loads(tool.function.arguments).items()
|
||||
for param, val in json.loads(tool_function["arguments"]).items()
|
||||
)
|
||||
invokes += (
|
||||
"<invoke>\n"
|
||||
|
|
|
@ -1749,6 +1749,7 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
|
@ -5146,7 +5146,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
"""
|
||||
if custom_llm_provider == "bedrock":
|
||||
if model.startswith("anthropic.claude-3"):
|
||||
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
|
||||
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() + ["extra_headers"]
|
||||
elif model.startswith("anthropic"):
|
||||
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
|
||||
elif model.startswith("ai21"):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue