forked from phoenix/litellm-mirror
Merge pull request #1442 from BerriAI/litellm_bedrock_provisioned_throughput
[Feat] Support Bedrock provisioned throughput LLMs
This commit is contained in:
commit
ad06b08a5e
3 changed files with 98 additions and 12 deletions
|
@ -208,6 +208,32 @@ response = completion(
|
|||
)
|
||||
```
|
||||
|
||||
## Provisioned throughput models
|
||||
To use provisioned throughput Bedrock models pass
|
||||
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
|
||||
- `model_id=provisioned-model-arn`
|
||||
|
||||
Completion
|
||||
```python
|
||||
import litellm
|
||||
response = litellm.completion(
|
||||
model="bedrock/anthropic.claude-instant-v1",
|
||||
model_id="provisioned-model-arn",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}]
|
||||
)
|
||||
```
|
||||
|
||||
Embedding
|
||||
```python
|
||||
import litellm
|
||||
response = litellm.embedding(
|
||||
model="bedrock/amazon.titan-embed-text-v1",
|
||||
model_id="provisioned-model-arn",
|
||||
input=["hi"],
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
## Supported AWS Bedrock Models
|
||||
Here's an example of using a bedrock model with LiteLLM
|
||||
|
||||
|
|
|
@ -436,6 +436,9 @@ def completion(
|
|||
)
|
||||
|
||||
model = model
|
||||
modelId = (
|
||||
optional_params.pop("model_id", None) or model
|
||||
) # default to model if not passed
|
||||
provider = model.split(".")[0]
|
||||
prompt = convert_messages_to_prompt(
|
||||
model, messages, provider, custom_prompt_dict
|
||||
|
@ -510,7 +513,7 @@ def completion(
|
|||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
body={data},
|
||||
modelId={model},
|
||||
modelId={modelId},
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
)
|
||||
|
@ -525,7 +528,7 @@ def completion(
|
|||
)
|
||||
|
||||
response = client.invoke_model(
|
||||
body=data, modelId=model, accept=accept, contentType=contentType
|
||||
body=data, modelId=modelId, accept=accept, contentType=contentType
|
||||
)
|
||||
|
||||
response = response.get("body").read()
|
||||
|
@ -535,7 +538,7 @@ def completion(
|
|||
request_str = f"""
|
||||
response = client.invoke_model_with_response_stream(
|
||||
body={data},
|
||||
modelId={model},
|
||||
modelId={modelId},
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
)
|
||||
|
@ -550,7 +553,7 @@ def completion(
|
|||
)
|
||||
|
||||
response = client.invoke_model_with_response_stream(
|
||||
body=data, modelId=model, accept=accept, contentType=contentType
|
||||
body=data, modelId=modelId, accept=accept, contentType=contentType
|
||||
)
|
||||
response = response.get("body")
|
||||
return response
|
||||
|
@ -559,7 +562,7 @@ def completion(
|
|||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
body={data},
|
||||
modelId={model},
|
||||
modelId={modelId},
|
||||
accept=accept,
|
||||
contentType=contentType
|
||||
)
|
||||
|
@ -573,7 +576,7 @@ def completion(
|
|||
},
|
||||
)
|
||||
response = client.invoke_model(
|
||||
body=data, modelId=model, accept=accept, contentType=contentType
|
||||
body=data, modelId=modelId, accept=accept, contentType=contentType
|
||||
)
|
||||
except client.exceptions.ValidationException as e:
|
||||
if "The provided model identifier is invalid" in str(e):
|
||||
|
@ -664,6 +667,9 @@ def _embedding_func_single(
|
|||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
modelId = (
|
||||
optional_params.pop("model_id", None) or model
|
||||
) # default to model if not passed
|
||||
if provider == "amazon":
|
||||
input = input.replace(os.linesep, " ")
|
||||
data = {"inputText": input, **inference_params}
|
||||
|
@ -678,7 +684,7 @@ def _embedding_func_single(
|
|||
request_str = f"""
|
||||
response = client.invoke_model(
|
||||
body={body},
|
||||
modelId={model},
|
||||
modelId={modelId},
|
||||
accept="*/*",
|
||||
contentType="application/json",
|
||||
)""" # type: ignore
|
||||
|
@ -686,14 +692,14 @@ def _embedding_func_single(
|
|||
input=input,
|
||||
api_key="", # boto3 is used for init.
|
||||
additional_args={
|
||||
"complete_input_dict": {"model": model, "texts": input},
|
||||
"complete_input_dict": {"model": modelId, "texts": input},
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
try:
|
||||
response = client.invoke_model(
|
||||
body=body,
|
||||
modelId=model,
|
||||
modelId=modelId,
|
||||
accept="*/*",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
|
|
@ -63,7 +63,7 @@ def test_completion_bedrock_claude_completion_auth():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
test_completion_bedrock_claude_completion_auth()
|
||||
# test_completion_bedrock_claude_completion_auth()
|
||||
|
||||
|
||||
def test_completion_bedrock_claude_2_1_completion_auth():
|
||||
|
@ -100,7 +100,7 @@ def test_completion_bedrock_claude_2_1_completion_auth():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
test_completion_bedrock_claude_2_1_completion_auth()
|
||||
# test_completion_bedrock_claude_2_1_completion_auth()
|
||||
|
||||
|
||||
def test_completion_bedrock_claude_external_client_auth():
|
||||
|
@ -118,6 +118,8 @@ def test_completion_bedrock_claude_external_client_auth():
|
|||
try:
|
||||
import boto3
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
bedrock = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
region_name=aws_region_name,
|
||||
|
@ -145,4 +147,56 @@ def test_completion_bedrock_claude_external_client_auth():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
test_completion_bedrock_claude_external_client_auth()
|
||||
# test_completion_bedrock_claude_external_client_auth()
|
||||
|
||||
|
||||
def test_provisioned_throughput():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
import botocore, json, io
|
||||
import botocore.session
|
||||
from botocore.stub import Stubber
|
||||
|
||||
bedrock_client = botocore.session.get_session().create_client(
|
||||
"bedrock-runtime", region_name="us-east-1"
|
||||
)
|
||||
|
||||
expected_params = {
|
||||
"accept": "application/json",
|
||||
"body": '{"prompt": "\\n\\nHuman: Hello, how are you?\\n\\nAssistant: ", '
|
||||
'"max_tokens_to_sample": 256}',
|
||||
"contentType": "application/json",
|
||||
"modelId": "provisioned-model-arn",
|
||||
}
|
||||
response_from_bedrock = {
|
||||
"body": io.StringIO(
|
||||
json.dumps(
|
||||
{
|
||||
"completion": " Here is a short poem about the sky:",
|
||||
"stop_reason": "max_tokens",
|
||||
"stop": None,
|
||||
}
|
||||
)
|
||||
),
|
||||
"contentType": "contentType",
|
||||
"ResponseMetadata": {"HTTPStatusCode": 200},
|
||||
}
|
||||
|
||||
with Stubber(bedrock_client) as stubber:
|
||||
stubber.add_response(
|
||||
"invoke_model",
|
||||
service_response=response_from_bedrock,
|
||||
expected_params=expected_params,
|
||||
)
|
||||
response = litellm.completion(
|
||||
model="bedrock/anthropic.claude-instant-v1",
|
||||
model_id="provisioned-model-arn",
|
||||
messages=[{"content": "Hello, how are you?", "role": "user"}],
|
||||
aws_bedrock_client=bedrock_client,
|
||||
)
|
||||
print("response stubbed", response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# test_provisioned_throughput()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue