mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Litellm stable dev (#5711)
* feat(aws_base_llm.py): prevents recreating boto3 credentials during high traffic Leads to 100ms perf boost in local testing * fix(base_aws_llm.py): fix credential caching check to see if token is set * refactor(bedrock/chat): separate converse api and invoke api + isolate converse api transformation logic Make it easier to see how requests are transformed for /converse * fix: fix imports * fix(bedrock/embed): fix reordering of headers * fix(base_aws_llm.py): fix get credential logic * fix(converse_handler.py): fix ai21 streaming response
This commit is contained in:
parent
2efdd2a6a4
commit
da77706c26
14 changed files with 1073 additions and 1039 deletions
|
@ -1,5 +1,7 @@
|
|||
import hashlib
|
||||
import json
|
||||
from typing import List, Optional
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -28,6 +30,14 @@ class BaseAWSLLM(BaseLLM):
|
|||
self.iam_cache = DualCache()
|
||||
super().__init__()
|
||||
|
||||
def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:
|
||||
"""
|
||||
Generate a unique cache key based on the credential arguments.
|
||||
"""
|
||||
# Convert credential arguments to a JSON string and hash it to create a unique key
|
||||
credential_str = json.dumps(credential_args, sort_keys=True)
|
||||
return hashlib.sha256(credential_str.encode()).hexdigest()
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
|
@ -43,9 +53,22 @@ class BaseAWSLLM(BaseLLM):
|
|||
"""
|
||||
Return a boto3.Credentials object
|
||||
"""
|
||||
|
||||
import boto3
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
param_names = [
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
"aws_session_token",
|
||||
"aws_region_name",
|
||||
"aws_session_name",
|
||||
"aws_profile_name",
|
||||
"aws_role_name",
|
||||
"aws_web_identity_token",
|
||||
"aws_sts_endpoint",
|
||||
]
|
||||
params_to_check: List[Optional[str]] = [
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
|
@ -64,6 +87,11 @@ class BaseAWSLLM(BaseLLM):
|
|||
_v = get_secret(param)
|
||||
if _v is not None and isinstance(_v, str):
|
||||
params_to_check[i] = _v
|
||||
elif param is None: # check if uppercase value in env
|
||||
key = param_names[i]
|
||||
if key.upper() in os.environ:
|
||||
params_to_check[i] = os.getenv(key)
|
||||
|
||||
# Assign updated values back to parameters
|
||||
(
|
||||
aws_access_key_id,
|
||||
|
@ -77,6 +105,10 @@ class BaseAWSLLM(BaseLLM):
|
|||
aws_sts_endpoint,
|
||||
) = params_to_check
|
||||
|
||||
# create cache key for non-expiring auth flows
|
||||
args = {k: v for k, v in locals().items() if k.startswith("aws_")}
|
||||
cache_key = self.get_cache_key(args)
|
||||
|
||||
verbose_logger.debug(
|
||||
"in get credentials\n"
|
||||
"aws_access_key_id=%s\n"
|
||||
|
@ -186,7 +218,6 @@ class BaseAWSLLM(BaseLLM):
|
|||
|
||||
# Extract the credentials from the response and convert to Session Credentials
|
||||
sts_credentials = sts_response["Credentials"]
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
credentials = Credentials(
|
||||
access_key=sts_credentials["AccessKeyId"],
|
||||
|
@ -211,12 +242,72 @@ class BaseAWSLLM(BaseLLM):
|
|||
secret_key=aws_secret_access_key,
|
||||
token=aws_session_token,
|
||||
)
|
||||
|
||||
return credentials
|
||||
else:
|
||||
elif (
|
||||
aws_access_key_id is not None
|
||||
and aws_secret_access_key is not None
|
||||
and aws_region_name is not None
|
||||
):
|
||||
# Check if credentials are already in cache. These credentials have no expiry time.
|
||||
cached_credentials: Optional[Credentials] = self.iam_cache.get_cache(
|
||||
cache_key
|
||||
)
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
region_name=aws_region_name,
|
||||
)
|
||||
|
||||
return session.get_credentials()
|
||||
credentials = session.get_credentials()
|
||||
|
||||
if (
|
||||
credentials.token is None
|
||||
): # don't cache if session token exists. The expiry time for that is not known.
|
||||
self.iam_cache.set_cache(cache_key, credentials, ttl=3600 - 60)
|
||||
|
||||
return credentials
|
||||
else:
|
||||
# check env var. Do not cache the response from this.
|
||||
session = boto3.Session()
|
||||
|
||||
credentials = session.get_credentials()
|
||||
|
||||
return credentials
|
||||
|
||||
def get_runtime_endpoint(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
aws_bedrock_runtime_endpoint: Optional[str],
|
||||
aws_region_name: str,
|
||||
) -> Tuple[str, str]:
|
||||
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
|
||||
if api_base is not None:
|
||||
endpoint_url = api_base
|
||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = aws_bedrock_runtime_endpoint
|
||||
elif env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
# Determine proxy_endpoint_url
|
||||
if env_aws_bedrock_runtime_endpoint and isinstance(
|
||||
env_aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
|
||||
elif aws_bedrock_runtime_endpoint is not None and isinstance(
|
||||
aws_bedrock_runtime_endpoint, str
|
||||
):
|
||||
proxy_endpoint_url = aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
proxy_endpoint_url = endpoint_url
|
||||
|
||||
return endpoint_url, proxy_endpoint_url
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue