fix(bedrock.py): add claude 3 support

This commit is contained in:
Krrish Dholakia 2024-03-04 16:22:44 -08:00
parent 5e93bad4af
commit 0ac652a771
4 changed files with 90 additions and 16 deletions

View file

@ -591,10 +591,11 @@ from .llms.bedrock import (
AmazonTitanConfig,
AmazonAI21Config,
AmazonAnthropicConfig,
AmazonAnthropicClaude3Config,
AmazonCohereConfig,
AmazonLlamaConfig,
AmazonStabilityConfig,
AmazonMistralConfig
AmazonMistralConfig,
)
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError

View file

@ -70,6 +70,48 @@ class AmazonTitanConfig:
}
class AmazonAnthropicClaude3Config:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
Supported Params for the Amazon / Anthropic Claude 3 models:
- `max_tokens` (integer) max tokens,
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
"""
max_tokens: Optional[int] = litellm.max_tokens
anthropic_version: Optional[str] = None
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonAnthropicConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
@ -330,7 +372,8 @@ class AmazonMistralConfig:
)
and v is not None
}
class AmazonStabilityConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
@ -542,7 +585,9 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "mistral":
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider="bedrock")
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
else:
prompt = ""
for message in messages:
@ -619,14 +664,24 @@ def completion(
inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", False)
if provider == "anthropic":
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
if model == "anthropic.claude-3":
## LOAD CONFIG
config = litellm.AmazonAnthropicClaude3Config.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
else:
## LOAD CONFIG
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config()
@ -646,9 +701,9 @@ def completion(
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if optional_params.get("stream", False) == True:
inference_params[
"stream"
] = True # cohere requires stream = True in inference params
inference_params["stream"] = (
True # cohere requires stream = True in inference params
)
data = json.dumps({"prompt": prompt, **inference_params})
elif provider == "meta":
## LOAD CONFIG
@ -674,7 +729,7 @@ def completion(
"textGenerationConfig": inference_params,
}
)
elif provider == "mistral":
elif provider == "mistral":
## LOAD CONFIG
config = litellm.AmazonMistralConfig.get_config()
for k, v in config.items():
@ -1118,4 +1173,4 @@ def image_generation(
image_dict = {"url": artifact["base64"]}
model_response.data = image_dict
return model_response
return model_response

View file

@ -1266,6 +1266,15 @@
"litellm_provider": "bedrock",
"mode": "completion"
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 200000,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat"
},
"anthropic.claude-v1": {
"max_tokens": 100000,
"max_output_tokens": 8191,

View file

@ -1266,6 +1266,15 @@
"litellm_provider": "bedrock",
"mode": "completion"
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"max_tokens": 200000,
"max_input_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "bedrock",
"mode": "chat"
},
"anthropic.claude-v1": {
"max_tokens": 100000,
"max_output_tokens": 8191,