fix(bedrock.py): adding provider-specific configs

This commit is contained in:
Krrish Dholakia 2023-10-05 23:49:20 -07:00
parent 5364604ccc
commit 06f279807b
7 changed files with 128 additions and 11 deletions

View file

@ -323,7 +323,7 @@ from .llms.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig
from .llms.vertex_ai import VertexAIConfig
from .llms.sagemaker import SagemakerConfig
from .llms.bedrock import AmazonConfig
from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig
from .main import * # type: ignore
from .integrations import *
from .exceptions import (

View file

@ -25,7 +25,7 @@ class AnthropicConfig():
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens_to_sample: Optional[int]=256 # anthropic requires a default
max_tokens_to_sample: Optional[int]=litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list]=None
temperature: Optional[int]=None
top_p: Optional[int]=None

View file

@ -13,7 +13,7 @@ class BedrockError(Exception):
self.message
) # Call the base class constructor with the parameters it needs
class AmazonConfig():
class AmazonTitanConfig():
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
@ -46,6 +46,123 @@ class AmazonConfig():
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
class AmazonAnthropicConfig():
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
Supported Params for the Amazon / Anthropic models:
- `max_tokens_to_sample` (integer) max tokens,
- `temperature` (float) model temperature,
- `top_k` (integer) top k,
- `top_p` (integer) top p,
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
"""
max_tokens_to_sample: Optional[int]=litellm.max_tokens
stop_sequences: Optional[list]=None
temperature: Optional[float]=None
top_k: Optional[int]=None
top_p: Optional[int]=None
anthropic_version: Optional[str]=None
def __init__(self,
max_tokens_to_sample: Optional[int]=None,
stop_sequences: Optional[list]=None,
temperature: Optional[float]=None,
top_k: Optional[int]=None,
top_p: Optional[int]=None,
anthropic_version: Optional[str]=None) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
class AmazonCohereConfig():
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
Supported Params for the Amazon / Cohere models:
- `max_tokens` (integer) max tokens,
- `temperature` (float) model temperature,
- `return_likelihood` (string) n/a
"""
max_tokens: Optional[int]=None
temperature: Optional[float]=None
return_likelihood: Optional[str]=None
def __init__(self,
max_tokens: Optional[int]=None,
temperature: Optional[float]=None,
return_likelihood: Optional[str]=None) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
class AmazonAI21Config():
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
Supported Params for the Amazon / AI21 models:
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
- `presencePenalty` (object): Placeholder for presence penalty object.
- `countPenalty` (object): Placeholder for count penalty object.
"""
maxTokens: Optional[int]=None
temperature: Optional[float]=None
topP: Optional[float]=None
stopSequences: Optional[list]=None
frequencePenalty: Optional[dict]=None
presencePenalty: Optional[dict]=None
countPenalty: Optional[dict]=None
def __init__(self,
maxTokens: Optional[int]=None,
temperature: Optional[float]=None,
topP: Optional[float]=None,
stopSequences: Optional[list]=None,
frequencePenalty: Optional[dict]=None,
presencePenalty: Optional[dict]=None,
countPenalty: Optional[dict]=None) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}
class AnthropicConstants(Enum):
HUMAN_PROMPT = "\n\nHuman:"
AI_PROMPT = "\n\nAssistant:"
@ -137,7 +254,7 @@ def completion(
print(f"bedrock provider: {provider}")
if provider == "anthropic":
## LOAD CONFIG
config = litellm.AnthropicConfig.get_config()
config = litellm.AmazonAnthropicConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
@ -147,7 +264,7 @@ def completion(
})
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AI21Config.get_config()
config = litellm.AmazonAI21Config.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
@ -158,7 +275,7 @@ def completion(
})
elif provider == "cohere":
## LOAD CONFIG
config = litellm.CohereConfig.get_config()
config = litellm.AmazonCohereConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
@ -168,7 +285,7 @@ def completion(
})
elif provider == "amazon": # amazon titan
## LOAD CONFIG
config = litellm.AmazonConfig.get_config()
config = litellm.AmazonTitanConfig.get_config()
for k, v in config.items():
if k not in inference_params: # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v

View file

@ -50,7 +50,7 @@ def test_completion_claude():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_claude()
test_completion_claude()
# def test_completion_oobabooga():
# try:
@ -561,7 +561,7 @@ def test_completion_azure():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_azure()
# test_completion_azure()
# new azure test for using litellm. vars,
# use the following vars in this test and make an azure_api_call
@ -779,7 +779,7 @@ def test_completion_bedrock_claude():
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_claude()
test_completion_bedrock_claude()
def test_completion_bedrock_claude_stream():
print("calling claude")

View file

@ -368,7 +368,7 @@ def sagemaker_test_completion():
def bedrock_test_completion():
litellm.CohereConfig(max_tokens=10)
litellm.AmazonCohereConfig(max_tokens=10)
# litellm.set_verbose=True
try:
# OVERRIDE WITH DYNAMIC MAX TOKENS