Merge pull request #3299 from themrzmaster/main

Allowing extra headers for bedrock
This commit is contained in:
Krish Dholakia 2024-05-06 07:45:53 -07:00 committed by GitHub
commit 9f58583888
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 36 additions and 1 deletions

View file

@ -163,8 +163,10 @@ class AmazonAnthropicClaude3Config:
"stop",
"temperature",
"top_p",
"extra_headers"
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
@ -530,6 +532,15 @@ class AmazonStabilityConfig:
}
def add_custom_header(headers):
"""Closure to capture the headers and add them."""
def callback(request, **kwargs):
"""Actual callback function that Boto3 will call."""
for header_name, header_value in headers.items():
request.headers.add_header(header_name, header_value)
return callback
def init_bedrock_client(
region_name=None,
aws_access_key_id: Optional[str] = None,
@ -539,12 +550,12 @@ def init_bedrock_client(
aws_session_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
aws_role_name: Optional[str] = None,
extra_headers: Optional[dict] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
):
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
standard_aws_region_name = get_secret("AWS_REGION", None)
## CHECK IS 'os.environ/' passed in
# Define the list of parameters to check
params_to_check = [
@ -660,6 +671,8 @@ def init_bedrock_client(
endpoint_url=endpoint_url,
config=config,
)
if extra_headers:
client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers))
return client
@ -723,6 +736,7 @@ def completion(
litellm_params=None,
logger_fn=None,
timeout=None,
extra_headers: Optional[dict] = None,
):
exception_mapping_worked = False
_is_function_call = False
@ -752,6 +766,7 @@ def completion(
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
extra_headers=extra_headers,
timeout=timeout,
)

View file

@ -1868,6 +1868,7 @@ def completion(
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
extra_headers=extra_headers,
timeout=timeout,
)

View file

@ -207,6 +207,25 @@ def test_completion_bedrock_claude_sts_client_auth():
# test_completion_bedrock_claude_sts_client_auth()
def test_bedrock_extra_headers():
try:
litellm.set_verbose = True
response: ModelResponse = completion(
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
messages=messages,
max_tokens=10,
temperature=0.78,
extra_headers={"x-key": "x_key_value"}
)
# Add any assertions here to check the response
assert len(response.choices) > 0
assert len(response.choices[0].message.content) > 0
except RateLimitError:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_bedrock_claude_3():
try:
litellm.set_verbose = True