Add AmazonMistralConfig

This commit is contained in:
Tim Xia 2024-03-01 23:14:00 -05:00
parent 8a0385a51b
commit 78a93e40ed
2 changed files with 51 additions and 1 deletions

View file

@ -282,6 +282,55 @@ class AmazonLlamaConfig:
}
class AmazonMistralConfig:
"""
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
Supported Params for the Amazon / Mistral models:
- `max_tokens` (integer) max tokens,
- `temperature` (float) temperature for model,
- `top_p` (float) top p for model
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
- `top_k` (float) top k for model
"""
max_tokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
topK: Optional[float] = None
stop: Optional[list[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
topK: Optional[float] = None,
stop: Optional[list[str]] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
class AmazonStabilityConfig:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
@ -627,7 +676,7 @@ def completion(
)
elif provider == "mistral":
## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config()
config = litellm.AmazonMistralConfig.get_config()
for k, v in config.items():
if (
k not in inference_params