Merge pull request #1442 from BerriAI/litellm_bedrock_provisioned_throughput

[Feat] Support Bedrock provisioned throughput LLMs
This commit is contained in:
Ishaan Jaff 2024-01-14 10:51:28 +05:30 committed by GitHub
commit ad06b08a5e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 98 additions and 12 deletions

View file

@ -208,6 +208,32 @@ response = completion(
)
```
## Provisioned throughput models
To use provisioned throughput Bedrock models pass
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
- `model_id=provisioned-model-arn`
Completion
```python
import litellm
response = litellm.completion(
model="bedrock/anthropic.claude-instant-v1",
model_id="provisioned-model-arn",
messages=[{"content": "Hello, how are you?", "role": "user"}]
)
```
Embedding
```python
import litellm
response = litellm.embedding(
model="bedrock/amazon.titan-embed-text-v1",
model_id="provisioned-model-arn",
input=["hi"],
)
```
## Supported AWS Bedrock Models
Here's an example of using a bedrock model with LiteLLM

View file

@ -436,6 +436,9 @@ def completion(
)
model = model
modelId = (
optional_params.pop("model_id", None) or model
) # default to model if not passed
provider = model.split(".")[0]
prompt = convert_messages_to_prompt(
model, messages, provider, custom_prompt_dict
@ -510,7 +513,7 @@ def completion(
request_str = f"""
response = client.invoke_model(
body={data},
modelId={model},
modelId={modelId},
accept=accept,
contentType=contentType
)
@ -525,7 +528,7 @@ def completion(
)
response = client.invoke_model(
body=data, modelId=model, accept=accept, contentType=contentType
body=data, modelId=modelId, accept=accept, contentType=contentType
)
response = response.get("body").read()
@ -535,7 +538,7 @@ def completion(
request_str = f"""
response = client.invoke_model_with_response_stream(
body={data},
modelId={model},
modelId={modelId},
accept=accept,
contentType=contentType
)
@ -550,7 +553,7 @@ def completion(
)
response = client.invoke_model_with_response_stream(
body=data, modelId=model, accept=accept, contentType=contentType
body=data, modelId=modelId, accept=accept, contentType=contentType
)
response = response.get("body")
return response
@ -559,7 +562,7 @@ def completion(
request_str = f"""
response = client.invoke_model(
body={data},
modelId={model},
modelId={modelId},
accept=accept,
contentType=contentType
)
@ -573,7 +576,7 @@ def completion(
},
)
response = client.invoke_model(
body=data, modelId=model, accept=accept, contentType=contentType
body=data, modelId=modelId, accept=accept, contentType=contentType
)
except client.exceptions.ValidationException as e:
if "The provided model identifier is invalid" in str(e):
@ -664,6 +667,9 @@ def _embedding_func_single(
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
modelId = (
optional_params.pop("model_id", None) or model
) # default to model if not passed
if provider == "amazon":
input = input.replace(os.linesep, " ")
data = {"inputText": input, **inference_params}
@ -678,7 +684,7 @@ def _embedding_func_single(
request_str = f"""
response = client.invoke_model(
body={body},
modelId={model},
modelId={modelId},
accept="*/*",
contentType="application/json",
)""" # type: ignore
@ -686,14 +692,14 @@ def _embedding_func_single(
input=input,
api_key="", # boto3 is used for init.
additional_args={
"complete_input_dict": {"model": model, "texts": input},
"complete_input_dict": {"model": modelId, "texts": input},
"request_str": request_str,
},
)
try:
response = client.invoke_model(
body=body,
modelId=model,
modelId=modelId,
accept="*/*",
contentType="application/json",
)

View file

@ -63,7 +63,7 @@ def test_completion_bedrock_claude_completion_auth():
pytest.fail(f"Error occurred: {e}")
test_completion_bedrock_claude_completion_auth()
# test_completion_bedrock_claude_completion_auth()
def test_completion_bedrock_claude_2_1_completion_auth():
@ -100,7 +100,7 @@ def test_completion_bedrock_claude_2_1_completion_auth():
pytest.fail(f"Error occurred: {e}")
test_completion_bedrock_claude_2_1_completion_auth()
# test_completion_bedrock_claude_2_1_completion_auth()
def test_completion_bedrock_claude_external_client_auth():
@ -118,6 +118,8 @@ def test_completion_bedrock_claude_external_client_auth():
try:
import boto3
litellm.set_verbose = True
bedrock = boto3.client(
service_name="bedrock-runtime",
region_name=aws_region_name,
@ -145,4 +147,56 @@ def test_completion_bedrock_claude_external_client_auth():
pytest.fail(f"Error occurred: {e}")
test_completion_bedrock_claude_external_client_auth()
# test_completion_bedrock_claude_external_client_auth()
def test_provisioned_throughput():
try:
litellm.set_verbose = True
import botocore, json, io
import botocore.session
from botocore.stub import Stubber
bedrock_client = botocore.session.get_session().create_client(
"bedrock-runtime", region_name="us-east-1"
)
expected_params = {
"accept": "application/json",
"body": '{"prompt": "\\n\\nHuman: Hello, how are you?\\n\\nAssistant: ", '
'"max_tokens_to_sample": 256}',
"contentType": "application/json",
"modelId": "provisioned-model-arn",
}
response_from_bedrock = {
"body": io.StringIO(
json.dumps(
{
"completion": " Here is a short poem about the sky:",
"stop_reason": "max_tokens",
"stop": None,
}
)
),
"contentType": "contentType",
"ResponseMetadata": {"HTTPStatusCode": 200},
}
with Stubber(bedrock_client) as stubber:
stubber.add_response(
"invoke_model",
service_response=response_from_bedrock,
expected_params=expected_params,
)
response = litellm.completion(
model="bedrock/anthropic.claude-instant-v1",
model_id="provisioned-model-arn",
messages=[{"content": "Hello, how are you?", "role": "user"}],
aws_bedrock_client=bedrock_client,
)
print("response stubbed", response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_provisioned_throughput()