mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
bedrock.py fixes
This commit is contained in:
parent
7740fe7dda
commit
ccb0b7fc78
1 changed files with 26 additions and 21 deletions
|
@ -15,7 +15,9 @@ class BedrockError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
def init_bedrock_client(boto3, region_name):
|
||||
def init_bedrock_client(region_name):
|
||||
import sys
|
||||
import boto3
|
||||
import subprocess
|
||||
try:
|
||||
client = boto3.client(
|
||||
|
@ -23,7 +25,7 @@ def init_bedrock_client(boto3, region_name):
|
|||
region_name=region_name,
|
||||
endpoint_url=f'https://bedrock.{region_name}.amazonaws.com'
|
||||
)
|
||||
except:
|
||||
except Exception as e:
|
||||
try:
|
||||
command1 = "python3 -m pip install https://github.com/BerriAI/litellm/raw/main/cookbook/bedrock_resources/boto3-1.28.21-py3-none-any.whl"
|
||||
subprocess.run(command1, shell=True, check=True)
|
||||
|
@ -60,21 +62,16 @@ def completion(
|
|||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
import sys
|
||||
if 'boto3' not in sys.modules:
|
||||
try:
|
||||
import boto3
|
||||
except:
|
||||
raise Exception("Please Install boto3 to use bedrock with LiteLLM, run 'pip install boto3'")
|
||||
|
||||
region_name = (
|
||||
get_secret("AWS_REGION_NAME") or
|
||||
"us-west-2" # default to us-west-2
|
||||
)
|
||||
|
||||
client = init_bedrock_client(boto3, region_name)
|
||||
client = init_bedrock_client(region_name)
|
||||
|
||||
model = model
|
||||
provider = model.split(".")[0]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
|
@ -88,18 +85,23 @@ def completion(
|
|||
)
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
|
||||
|
||||
data = json.dumps({
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig":{
|
||||
"maxTokenCount":4096,
|
||||
"stopSequences":[],
|
||||
"temperature":0,
|
||||
"topP":0.9
|
||||
}
|
||||
|
||||
if provider == "ai21":
|
||||
data = json.dumps({
|
||||
"prompt": prompt,
|
||||
})
|
||||
|
||||
else: # amazon titan
|
||||
data = json.dumps({
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig":{
|
||||
"maxTokenCount":4096,
|
||||
"stopSequences":[],
|
||||
"temperature":0,
|
||||
"topP":0.9
|
||||
}
|
||||
})
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -130,8 +132,11 @@ def completion(
|
|||
)
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
## RESPONSE OBJECT
|
||||
outputText = response_body.get('results')[0].get('outputText')
|
||||
print(outputText)
|
||||
outputText = "default"
|
||||
if provider == "ai21":
|
||||
outputText = response_body.get('completions')[0].get('data').get('text')
|
||||
else: # amazon titan
|
||||
outputText = response_body.get('results')[0].get('outputText')
|
||||
if "error" in outputText:
|
||||
raise BedrockError(
|
||||
message=outputText["error"],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue