mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
feat add support for aws_region_name
This commit is contained in:
parent
b4ba12e22c
commit
fa569aaf6f
2 changed files with 102 additions and 7 deletions
|
@ -104,17 +104,11 @@ os.environ['AWS_SECRET_ACCESS_KEY'] = ""
|
|||
# set os.environ['AWS_REGION_NAME'] = <your-region_name>
|
||||
class SagemakerLLM(BaseAWSLLM):
|
||||
|
||||
def _prepare_request(
|
||||
def _load_credentials(
|
||||
self,
|
||||
model: str,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
extra_headers: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
import boto3
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError as e:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
@ -163,6 +157,25 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
return credentials, aws_region_name
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
credentials,
|
||||
model: str,
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
aws_region_name: str,
|
||||
extra_headers: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
import boto3
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError as e:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
|
||||
if optional_params.get("stream") is True:
|
||||
api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream"
|
||||
|
@ -198,6 +211,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
):
|
||||
|
||||
# pop streaming if it's in the optional params as 'stream' raises an error with sagemaker
|
||||
credentials, aws_region_name = self._load_credentials(optional_params)
|
||||
inference_params = deepcopy(optional_params)
|
||||
|
||||
## Load Config
|
||||
|
@ -250,6 +264,8 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
model=model,
|
||||
data=data,
|
||||
optional_params=optional_params,
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
if model_id is not None:
|
||||
# Add model_id as InferenceComponentName header
|
||||
|
@ -313,6 +329,8 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
model=model,
|
||||
data=_data,
|
||||
optional_params=optional_params,
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
||||
# Async completion
|
||||
|
@ -357,6 +375,12 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
json=_data,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if sync_response.status_code != 200:
|
||||
raise SagemakerError(
|
||||
status_code=sync_response.status_code,
|
||||
message=sync_response.text,
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -367,6 +391,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_logger.error("Sagemaker error %s", str(e))
|
||||
status_code = (
|
||||
getattr(e, "response", {})
|
||||
.get("ResponseMetadata", {})
|
||||
|
@ -547,6 +572,11 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
json=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise SagemakerError(
|
||||
status_code=response.status_code, message=response.text
|
||||
)
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue