forked from phoenix/litellm-mirror
Merge pull request #3299 from themrzmaster/main
Allowing extra headers for bedrock
This commit is contained in:
commit
9f58583888
3 changed files with 36 additions and 1 deletions
|
@ -163,8 +163,10 @@ class AmazonAnthropicClaude3Config:
|
||||||
"stop",
|
"stop",
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
|
"extra_headers"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
def map_openai_params(self, non_default_params: dict, optional_params: dict):
|
||||||
for param, value in non_default_params.items():
|
for param, value in non_default_params.items():
|
||||||
if param == "max_tokens":
|
if param == "max_tokens":
|
||||||
|
@ -530,6 +532,15 @@ class AmazonStabilityConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def add_custom_header(headers):
|
||||||
|
"""Closure to capture the headers and add them."""
|
||||||
|
def callback(request, **kwargs):
|
||||||
|
"""Actual callback function that Boto3 will call."""
|
||||||
|
for header_name, header_value in headers.items():
|
||||||
|
request.headers.add_header(header_name, header_value)
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
def init_bedrock_client(
|
def init_bedrock_client(
|
||||||
region_name=None,
|
region_name=None,
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
|
@ -539,12 +550,12 @@ def init_bedrock_client(
|
||||||
aws_session_name: Optional[str] = None,
|
aws_session_name: Optional[str] = None,
|
||||||
aws_profile_name: Optional[str] = None,
|
aws_profile_name: Optional[str] = None,
|
||||||
aws_role_name: Optional[str] = None,
|
aws_role_name: Optional[str] = None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||||
):
|
):
|
||||||
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
# check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client
|
||||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||||
|
|
||||||
## CHECK IS 'os.environ/' passed in
|
## CHECK IS 'os.environ/' passed in
|
||||||
# Define the list of parameters to check
|
# Define the list of parameters to check
|
||||||
params_to_check = [
|
params_to_check = [
|
||||||
|
@ -660,6 +671,8 @@ def init_bedrock_client(
|
||||||
endpoint_url=endpoint_url,
|
endpoint_url=endpoint_url,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
if extra_headers:
|
||||||
|
client.meta.events.register('before-sign.bedrock-runtime.*', add_custom_header(extra_headers))
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
@ -723,6 +736,7 @@ def completion(
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
|
extra_headers: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
exception_mapping_worked = False
|
exception_mapping_worked = False
|
||||||
_is_function_call = False
|
_is_function_call = False
|
||||||
|
@ -752,6 +766,7 @@ def completion(
|
||||||
aws_role_name=aws_role_name,
|
aws_role_name=aws_role_name,
|
||||||
aws_session_name=aws_session_name,
|
aws_session_name=aws_session_name,
|
||||||
aws_profile_name=aws_profile_name,
|
aws_profile_name=aws_profile_name,
|
||||||
|
extra_headers=extra_headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1868,6 +1868,7 @@ def completion(
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
logging_obj=logging,
|
logging_obj=logging,
|
||||||
|
extra_headers=extra_headers,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -207,6 +207,25 @@ def test_completion_bedrock_claude_sts_client_auth():
|
||||||
# test_completion_bedrock_claude_sts_client_auth()
|
# test_completion_bedrock_claude_sts_client_auth()
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_extra_headers():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response: ModelResponse = completion(
|
||||||
|
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.78,
|
||||||
|
extra_headers={"x-key": "x_key_value"}
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
assert len(response.choices) > 0
|
||||||
|
assert len(response.choices[0].message.content) > 0
|
||||||
|
except RateLimitError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_bedrock_claude_3():
|
def test_bedrock_claude_3():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue