mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(bedrock.py): adding provider-specific configs
This commit is contained in:
parent
5364604ccc
commit
06f279807b
7 changed files with 128 additions and 11 deletions
|
@ -323,7 +323,7 @@ from .llms.aleph_alpha import AlephAlphaConfig
|
|||
from .llms.petals import PetalsConfig
|
||||
from .llms.vertex_ai import VertexAIConfig
|
||||
from .llms.sagemaker import SagemakerConfig
|
||||
from .llms.bedrock import AmazonConfig
|
||||
from .llms.bedrock import AmazonTitanConfig, AmazonAI21Config, AmazonAnthropicConfig, AmazonCohereConfig
|
||||
from .main import * # type: ignore
|
||||
from .integrations import *
|
||||
from .exceptions import (
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -25,7 +25,7 @@ class AnthropicConfig():
|
|||
|
||||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||
"""
|
||||
max_tokens_to_sample: Optional[int]=256 # anthropic requires a default
|
||||
max_tokens_to_sample: Optional[int]=litellm.max_tokens # anthropic requires a default
|
||||
stop_sequences: Optional[list]=None
|
||||
temperature: Optional[int]=None
|
||||
top_p: Optional[int]=None
|
||||
|
|
|
@ -13,7 +13,7 @@ class BedrockError(Exception):
|
|||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class AmazonConfig():
|
||||
class AmazonTitanConfig():
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
||||
|
||||
|
@ -46,6 +46,123 @@ class AmazonConfig():
|
|||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
class AmazonAnthropicConfig():
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
|
||||
Supported Params for the Amazon / Anthropic models:
|
||||
|
||||
- `max_tokens_to_sample` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `top_k` (integer) top k,
|
||||
- `top_p` (integer) top p,
|
||||
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
|
||||
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
"""
|
||||
max_tokens_to_sample: Optional[int]=litellm.max_tokens
|
||||
stop_sequences: Optional[list]=None
|
||||
temperature: Optional[float]=None
|
||||
top_k: Optional[int]=None
|
||||
top_p: Optional[int]=None
|
||||
anthropic_version: Optional[str]=None
|
||||
|
||||
def __init__(self,
|
||||
max_tokens_to_sample: Optional[int]=None,
|
||||
stop_sequences: Optional[list]=None,
|
||||
temperature: Optional[float]=None,
|
||||
top_k: Optional[int]=None,
|
||||
top_p: Optional[int]=None,
|
||||
anthropic_version: Optional[str]=None) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
class AmazonCohereConfig():
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||
|
||||
Supported Params for the Amazon / Cohere models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `return_likelihood` (string) n/a
|
||||
"""
|
||||
max_tokens: Optional[int]=None
|
||||
temperature: Optional[float]=None
|
||||
return_likelihood: Optional[str]=None
|
||||
|
||||
def __init__(self,
|
||||
max_tokens: Optional[int]=None,
|
||||
temperature: Optional[float]=None,
|
||||
return_likelihood: Optional[str]=None) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
class AmazonAI21Config():
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||
|
||||
Supported Params for the Amazon / AI21 models:
|
||||
|
||||
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
||||
|
||||
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
||||
|
||||
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
||||
|
||||
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
||||
|
||||
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
||||
|
||||
- `presencePenalty` (object): Placeholder for presence penalty object.
|
||||
|
||||
- `countPenalty` (object): Placeholder for count penalty object.
|
||||
"""
|
||||
maxTokens: Optional[int]=None
|
||||
temperature: Optional[float]=None
|
||||
topP: Optional[float]=None
|
||||
stopSequences: Optional[list]=None
|
||||
frequencePenalty: Optional[dict]=None
|
||||
presencePenalty: Optional[dict]=None
|
||||
countPenalty: Optional[dict]=None
|
||||
|
||||
def __init__(self,
|
||||
maxTokens: Optional[int]=None,
|
||||
temperature: Optional[float]=None,
|
||||
topP: Optional[float]=None,
|
||||
stopSequences: Optional[list]=None,
|
||||
frequencePenalty: Optional[dict]=None,
|
||||
presencePenalty: Optional[dict]=None,
|
||||
countPenalty: Optional[dict]=None) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != 'self' and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {k: v for k, v in cls.__dict__.items()
|
||||
if not k.startswith('__')
|
||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
class AnthropicConstants(Enum):
|
||||
HUMAN_PROMPT = "\n\nHuman:"
|
||||
AI_PROMPT = "\n\nAssistant:"
|
||||
|
@ -137,7 +254,7 @@ def completion(
|
|||
print(f"bedrock provider: {provider}")
|
||||
if provider == "anthropic":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AnthropicConfig.get_config()
|
||||
config = litellm.AmazonAnthropicConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
@ -147,7 +264,7 @@ def completion(
|
|||
})
|
||||
elif provider == "ai21":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AI21Config.get_config()
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
@ -158,7 +275,7 @@ def completion(
|
|||
})
|
||||
elif provider == "cohere":
|
||||
## LOAD CONFIG
|
||||
config = litellm.CohereConfig.get_config()
|
||||
config = litellm.AmazonCohereConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
@ -168,7 +285,7 @@ def completion(
|
|||
})
|
||||
elif provider == "amazon": # amazon titan
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonConfig.get_config()
|
||||
config = litellm.AmazonTitanConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in inference_params: # completion(top_k=3) > amazon_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
|
|
@ -50,7 +50,7 @@ def test_completion_claude():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_claude()
|
||||
test_completion_claude()
|
||||
|
||||
# def test_completion_oobabooga():
|
||||
# try:
|
||||
|
@ -561,7 +561,7 @@ def test_completion_azure():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_completion_azure()
|
||||
# test_completion_azure()
|
||||
|
||||
# new azure test for using litellm. vars,
|
||||
# use the following vars in this test and make an azure_api_call
|
||||
|
@ -779,7 +779,7 @@ def test_completion_bedrock_claude():
|
|||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_bedrock_claude()
|
||||
test_completion_bedrock_claude()
|
||||
|
||||
def test_completion_bedrock_claude_stream():
|
||||
print("calling claude")
|
||||
|
|
|
@ -368,7 +368,7 @@ def sagemaker_test_completion():
|
|||
|
||||
|
||||
def bedrock_test_completion():
|
||||
litellm.CohereConfig(max_tokens=10)
|
||||
litellm.AmazonCohereConfig(max_tokens=10)
|
||||
# litellm.set_verbose=True
|
||||
try:
|
||||
# OVERRIDE WITH DYNAMIC MAX TOKENS
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue