mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
extra headers
This commit is contained in:
parent
872ff6176d
commit
0c0780be83
4 changed files with 21 additions and 6 deletions
|
@ -495,6 +495,15 @@ class AmazonStabilityConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def add_custom_header(headers):
|
||||||
|
"""Closure to capture the headers and add them."""
|
||||||
|
def callback(request, **kwargs):
|
||||||
|
"""Actual callback function that Boto3 will call."""
|
||||||
|
for header_name, header_value in headers.items():
|
||||||
|
request.headers.add_header(header_name, header_value)
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
def init_bedrock_client(
|
def init_bedrock_client(
|
||||||
region_name=None,
|
region_name=None,
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
@ -504,12 +513,12 @@ def init_bedrock_client(
|
||||||
aws_session_name: Optional[str] = None,
|
aws_session_name: Optional[str] = None,
|
||||||
aws_profile_name: Optional[str] = None,
|
aws_profile_name: Optional[str] = None,
|
||||||
aws_role_name: Optional[str] = None,
|
aws_role_name: Optional[str] = None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
):
|
):
|
||||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
|
||||||
## CHECK IS 'os.environ/' passed in
|
## CHECK IS 'os.environ/' passed in
|
||||||
# Define the list of parameters to check
|
# Define the list of parameters to check
|
||||||
params_to_check = [
|
params_to_check = [
|
||||||
|
@ -618,6 +627,8 @@ def init_bedrock_client(
|
||||||
endpoint_url=endpoint_url,
|
endpoint_url=endpoint_url,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
if extra_headers:
|
||||||
|
client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers))
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
@ -677,6 +688,7 @@ def completion(
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
try:
|
try:
|
||||||
|
@ -704,6 +716,7 @@ def completion(
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
aws_profile_name=aws_profile_name,
|
aws_profile_name=aws_profile_name,
|
||||||
|
extra_headers=extra_headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -604,13 +604,14 @@ def convert_to_anthropic_tool_result(message: dict) -> str:
|
||||||
def convert_to_anthropic_tool_invoke(tool_calls: list) -> str:
|
def convert_to_anthropic_tool_invoke(tool_calls: list) -> str:
|
||||||
invokes = ""
|
invokes = ""
|
||||||
for tool in tool_calls:
|
for tool in tool_calls:
|
||||||
if tool.type != "function":
|
tool = dict(tool)
|
||||||
|
if tool["type"] != "function":
|
||||||
continue
|
continue
|
||||||
|
tool_function = dict(tool["function"])
|
||||||
tool_name = tool.function.name
|
tool_name = tool_function["name"]
|
||||||
parameters = "".join(
|
parameters = "".join(
|
||||||
f"<{param}>{val}</{param}>\n"
|
f"<{param}>{val}</{param}>\n"
|
||||||
for param, val in json.loads(tool.function.arguments).items()
|
for param, val in json.loads(tool_function["arguments"]).items()
|
||||||
)
|
)
|
||||||
invokes += (
|
invokes += (
|
||||||
"<invoke>\n"
|
"<invoke>\n"
|
||||||
|
|
|
@ -1749,6 +1749,7 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
extra_headers=extra_headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5146,7 +5146,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
"""
|
"""
|
||||||
if custom_llm_provider == "bedrock":
|
if custom_llm_provider == "bedrock":
|
||||||
if model.startswith("anthropic.claude-3"):
|
if model.startswith("anthropic.claude-3"):
|
||||||
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params()
|
return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() + ["extra_headers"]
|
||||||
elif model.startswith("anthropic"):
|
elif model.startswith("anthropic"):
|
||||||
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
|
return litellm.AmazonAnthropicConfig().get_supported_openai_params()
|
||||||
elif model.startswith("ai21"):
|
elif model.startswith("ai21"):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue