Litellm stable dev (#5711)

* feat(aws_base_llm.py): prevents recreating boto3 credentials during high traffic

Leads to 100ms perf boost in local testing

* fix(base_aws_llm.py): fix credential caching check to see if token is set

* refactor(bedrock/chat): separate converse api and invoke api + isolate converse api transformation logic

Make it easier to see how requests are transformed for /converse

* fix: fix imports

* fix(bedrock/embed): fix reordering of headers

* fix(base_aws_llm.py): fix get credential logic

* fix(converse_handler.py): fix ai21 streaming response
This commit is contained in:
Krish Dholakia 2024-09-14 23:22:59 -07:00 committed by GitHub
parent 2efdd2a6a4
commit da77706c26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1073 additions and 1039 deletions

View file

@ -583,7 +583,7 @@ def init_bedrock_client(
# Iterate over parameters and update if needed
for i, param in enumerate(params_to_check):
if param and param.startswith("os.environ/"):
params_to_check[i] = get_secret(param)
params_to_check[i] = get_secret(param) # type: ignore
# Assign updated values back to parameters
(
aws_access_key_id,
@ -626,13 +626,13 @@ def init_bedrock_client(
import boto3
if isinstance(timeout, float):
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout)
config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore
elif isinstance(timeout, httpx.Timeout):
config = boto3.session.Config(
config = boto3.session.Config( # type: ignore
connect_timeout=timeout.connect, read_timeout=timeout.read
)
else:
config = boto3.session.Config()
config = boto3.session.Config() # type: ignore
### CHECK STS ###
if (
@ -733,40 +733,6 @@ def init_bedrock_client(
return client
def get_runtime_endpoint(
api_base: Optional[str],
aws_bedrock_runtime_endpoint: Optional[str],
aws_region_name: str,
) -> Tuple[str, str]:
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if api_base is not None:
endpoint_url = api_base
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
# Determine proxy_endpoint_url
if env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = aws_bedrock_runtime_endpoint
else:
proxy_endpoint_url = endpoint_url
return endpoint_url, proxy_endpoint_url
class ModelResponseIterator:
def __init__(self, model_response):
self.model_response = model_response
@ -791,3 +757,21 @@ class ModelResponseIterator:
raise StopAsyncIteration
self.is_done = True
return self.model_response
def get_bedrock_tool_name(response_tool_name: str) -> str:
"""
If litellm formatted the input tool name, we need to convert it back to the original name.
Args:
response_tool_name (str): The name of the tool as received from the response.
Returns:
str: The original name of the tool.
"""
if response_tool_name in litellm.bedrock_tool_name_mappings.cache_dict:
response_tool_name = litellm.bedrock_tool_name_mappings.cache_dict[
response_tool_name
]
return response_tool_name