diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index f7b39acea..d435d224d 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -94,14 +94,8 @@ def completion( else: # amazon titan data = json.dumps({ "inputText": prompt, - "textGenerationConfig":{ - "maxTokenCount":4096, - "stopSequences":[], - "temperature":0, - "topP":0.9 - } + "textGenerationConfig": optional_params, }) - ## LOGGING logging_obj.pre_call( input=prompt, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index ab69ecfd4..3d34df605 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -650,14 +650,15 @@ def test_completion_bedrock_titan(): model="bedrock/amazon.titan-tg1-large", messages=messages, temperature=0.2, - max_tokens=20, + max_tokens=200, + top_p=0.8, logger_fn=logger_fn ) # Add any assertions here to check the response print(response) except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_completion_bedrock_titan() +test_completion_bedrock_titan() def test_completion_bedrock_ai21(): diff --git a/litellm/utils.py b/litellm/utils.py index 6ee25ac48..e8e0af423 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -931,6 +931,21 @@ def get_optional_params( # use the openai defaults optional_params["temperature"] = temperature if top_p != 1: optional_params["top_p"] = top_p + elif custom_llm_provider == "bedrock": + if "ai21" in model or "anthropic" in model: + pass + + elif "amazon" in model: # amazon titan llms + # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large + if max_tokens != float("inf"): + optional_params["maxTokenCount"] = max_tokens + if temperature != 1: + optional_params["temperature"] = temperature + if stop != None: + optional_params["stopSequences"] = stop + if top_p != 1: + optional_params["topP"] = top_p + elif model in litellm.aleph_alpha_models: if max_tokens != float("inf"): optional_params["maximum_tokens"] = max_tokens