mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
124 lines
3.9 KiB
Python
124 lines
3.9 KiB
Python
# What is this?
|
|
## Initial implementation of calling bedrock via httpx client (allows for async calls).
|
|
## V0 - just covers cohere command-r support
|
|
|
|
import os, types
|
|
import json
|
|
from enum import Enum
|
|
import requests, copy # type: ignore
|
|
import time
|
|
from typing import Callable, Optional, List, Literal, Union
|
|
from litellm.utils import (
|
|
ModelResponse,
|
|
Usage,
|
|
map_finish_reason,
|
|
CustomStreamWrapper,
|
|
Message,
|
|
Choices,
|
|
get_secret,
|
|
)
|
|
import litellm
|
|
from .prompt_templates.factory import prompt_factory, custom_prompt
|
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
|
from .base import BaseLLM
|
|
import httpx # type: ignore
|
|
from .bedrock import BedrockError
|
|
|
|
|
|
class BedrockLLM(BaseLLM):
|
|
"""
|
|
Example call
|
|
|
|
```
|
|
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \
|
|
--header 'Content-Type: application/json' \
|
|
--header 'Accept: application/json' \
|
|
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \
|
|
--aws-sigv4 "aws:amz:us-east-1:bedrock" \
|
|
--data-raw '{
|
|
"prompt": "Hi",
|
|
"temperature": 0,
|
|
"p": 0.9,
|
|
"max_tokens": 4096
|
|
}'
|
|
```
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def get_credentials(
|
|
self,
|
|
aws_access_key_id: Optional[str] = None,
|
|
aws_secret_access_key: Optional[str] = None,
|
|
aws_region_name: Optional[str] = None,
|
|
aws_session_name: Optional[str] = None,
|
|
aws_profile_name: Optional[str] = None,
|
|
aws_role_name: Optional[str] = None,
|
|
):
|
|
"""
|
|
Return a boto3.Credentials object
|
|
"""
|
|
import boto3
|
|
|
|
## CHECK IS 'os.environ/' passed in
|
|
params_to_check: List[Optional[str]] = [
|
|
aws_access_key_id,
|
|
aws_secret_access_key,
|
|
aws_region_name,
|
|
aws_session_name,
|
|
aws_profile_name,
|
|
aws_role_name,
|
|
]
|
|
|
|
# Iterate over parameters and update if needed
|
|
for i, param in enumerate(params_to_check):
|
|
if param and param.startswith("os.environ/"):
|
|
_v = get_secret(param)
|
|
if _v is not None and isinstance(_v, str):
|
|
params_to_check[i] = _v
|
|
# Assign updated values back to parameters
|
|
(
|
|
aws_access_key_id,
|
|
aws_secret_access_key,
|
|
aws_region_name,
|
|
aws_session_name,
|
|
aws_profile_name,
|
|
aws_role_name,
|
|
) = params_to_check
|
|
|
|
### CHECK STS ###
|
|
if aws_role_name is not None and aws_session_name is not None:
|
|
sts_client = boto3.client(
|
|
"sts",
|
|
aws_access_key_id=aws_access_key_id, # [OPTIONAL]
|
|
aws_secret_access_key=aws_secret_access_key, # [OPTIONAL]
|
|
)
|
|
|
|
sts_response = sts_client.assume_role(
|
|
RoleArn=aws_role_name, RoleSessionName=aws_session_name
|
|
)
|
|
|
|
return sts_response["Credentials"]
|
|
elif aws_profile_name is not None: ### CHECK SESSION ###
|
|
# uses auth values from AWS profile usually stored in ~/.aws/credentials
|
|
client = boto3.Session(profile_name=aws_profile_name)
|
|
|
|
return client.get_credentials()
|
|
else:
|
|
session = boto3.Session(
|
|
aws_access_key_id=aws_access_key_id,
|
|
aws_secret_access_key=aws_secret_access_key,
|
|
region_name=aws_region_name,
|
|
)
|
|
|
|
return session.get_credentials()
|
|
|
|
def completion(self, *args, **kwargs) -> Union[ModelResponse, CustomStreamWrapper]:
|
|
## get credentials
|
|
## generate signature
|
|
## make request
|
|
return super().completion(*args, **kwargs)
|
|
|
|
def embedding(self, *args, **kwargs):
|
|
return super().embedding(*args, **kwargs)
|