forked from phoenix/litellm-mirror
Merge pull request #3299 from themrzmaster/main
Allowing extra headers for bedrock
This commit is contained in:
commit
9f58583888
3 changed files with 36 additions and 1 deletions
|
@ -163,8 +163,10 @@ class AmazonAnthropicClaude3Config:
|
|||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers"
|
||||
]
|
||||
|
||||
|
||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
|
@ -530,6 +532,15 @@ class AmazonStabilityConfig:
|
|||
}
|
||||
|
||||
|
||||
def add_custom_header(headers):
|
||||
"""Closure to capture the headers and add them."""
|
||||
def callback(request, **kwargs):
|
||||
"""Actual callback function that Boto3 will call."""
|
||||
for header_name, header_value in headers.items():
|
||||
request.headers.add_header(header_name, header_value)
|
||||
return callback
|
||||
|
||||
|
||||
def init_bedrock_client(
|
||||
region_name=None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
|
@ -539,12 +550,12 @@ def init_bedrock_client(
|
|||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
):
|
||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = [
|
||||
|
@ -660,6 +671,8 @@ def init_bedrock_client(
|
|||
endpoint_url=endpoint_url,
|
||||
config=config,
|
||||
)
|
||||
if extra_headers:
|
||||
client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers))
|
||||
|
||||
return client
|
||||
|
||||
|
@ -723,6 +736,7 @@ def completion(
|
|||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
timeout=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
):
|
||||
exception_mapping_worked = False
|
||||
_is_function_call = False
|
||||
|
@ -752,6 +766,7 @@ def completion(
|
|||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
|
@ -1868,6 +1868,7 @@ def completion(
|
|||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
|
@ -207,6 +207,25 @@ def test_completion_bedrock_claude_sts_client_auth():
|
|||
# test_completion_bedrock_claude_sts_client_auth()
|
||||
|
||||
|
||||
def test_bedrock_extra_headers():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response: ModelResponse = completion(
|
||||
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.78,
|
||||
extra_headers={"x-key": "x_key_value"}
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
assert len(response.choices) > 0
|
||||
assert len(response.choices[0].message.content) > 0
|
||||
except RateLimitError:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_bedrock_claude_3():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue