support optional params for bedrock amazon

This commit is contained in:
ishaan-jaff 2023-09-16 09:18:57 -07:00
parent e5fff9bada
commit 29e3b4fdd2
3 changed files with 19 additions and 9 deletions

View file

@ -94,14 +94,8 @@ def completion(
else: # amazon titan else: # amazon titan
data = json.dumps({ data = json.dumps({
"inputText": prompt, "inputText": prompt,
"textGenerationConfig":{ "textGenerationConfig": optional_params,
"maxTokenCount":4096,
"stopSequences":[],
"temperature":0,
"topP":0.9
}
}) })
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,

View file

@ -650,14 +650,15 @@ def test_completion_bedrock_titan():
model="bedrock/amazon.titan-tg1-large", model="bedrock/amazon.titan-tg1-large",
messages=messages, messages=messages,
temperature=0.2, temperature=0.2,
max_tokens=20, max_tokens=200,
top_p=0.8,
logger_fn=logger_fn logger_fn=logger_fn
) )
# Add any assertions here to check the response # Add any assertions here to check the response
print(response) print(response)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_completion_bedrock_titan() test_completion_bedrock_titan()
def test_completion_bedrock_ai21(): def test_completion_bedrock_ai21():

View file

@ -931,6 +931,21 @@ def get_optional_params( # use the openai defaults
optional_params["temperature"] = temperature optional_params["temperature"] = temperature
if top_p != 1: if top_p != 1:
optional_params["top_p"] = top_p optional_params["top_p"] = top_p
elif custom_llm_provider == "bedrock":
if "ai21" in model or "anthropic" in model:
pass
elif "amazon" in model: # amazon titan llms
# see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large
if max_tokens != float("inf"):
optional_params["maxTokenCount"] = max_tokens
if temperature != 1:
optional_params["temperature"] = temperature
if stop != None:
optional_params["stopSequences"] = stop
if top_p != 1:
optional_params["topP"] = top_p
elif model in litellm.aleph_alpha_models: elif model in litellm.aleph_alpha_models:
if max_tokens != float("inf"): if max_tokens != float("inf"):
optional_params["maximum_tokens"] = max_tokens optional_params["maximum_tokens"] = max_tokens