diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index 7902275ab..047be5395 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -40,7 +40,7 @@ This list is constantly being updated. |AI21| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | |VertexAI| ✅ | ✅ | | ✅ | | | | | | | |Bedrock| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | -|Sagemaker| ✅ | ✅ (only `jumpstart llama2`) | | ✅ | | | | | | | +|Sagemaker| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |TogetherAI| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | |AlephAlpha| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |Palm| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | @@ -185,6 +185,25 @@ def completion( - `metadata`: *dict (optional)* - Any additional data you want to be logged when the call is made (sent to logging integrations, eg. promptlayer and accessible via custom callback function) +**CUSTOM MODEL COST** +- `input_cost_per_token`: *float (optional)* - The cost per input token for the completion call + +- `output_cost_per_token`: *float (optional)* - The cost per output token for the completion call + +**CUSTOM PROMPT TEMPLATE** (See [prompt formatting for more info](./prompt_formatting.md#format-prompt-yourself)) +- `initial_prompt_value`: *string (optional)* - Initial string applied at the start of the input messages + +- `roles`: *dict (optional)* - Dictionary specifying how to format the prompt based on the role + message passed in via `messages`. + +- `final_prompt_value`: *string (optional)* - Final string applied at the end of the input messages + +- `bos_token`: *string (optional)* - Initial string applied at the start of a sequence + +- `eos_token`: *string (optional)* - Initial string applied at the end of a sequence + +- `hf_model_name`: *string (optional)* - [Sagemaker Only] The corresponding huggingface name of the model, used to pull the right chat template for the model. + + ## Provider-specific Params Providers might offer params not supported by OpenAI (e.g. top_k). You can pass those in 2 ways: - via completion(): We'll pass the non-openai param, straight to the provider as part of the request body. diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index ca71461cf..2482c5457 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -149,12 +149,15 @@ def completion( additional_args={"complete_input_dict": data, "request_str": request_str}, ) ## COMPLETION CALL - response = client.invoke_endpoint( - EndpointName=model, - ContentType="application/json", - Body=data, - CustomAttributes="accept_eula=true", - ) + try: + response = client.invoke_endpoint( + EndpointName=model, + ContentType="application/json", + Body=data, + CustomAttributes="accept_eula=true", + ) + except Exception as e: + raise SagemakerError(status_code=500, message=f"{str(e)}") response = response["Body"].read().decode("utf8") ## LOGGING logging_obj.post_call( diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index d0cda9335..0f0bff1cc 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1053,10 +1053,11 @@ def test_completion_chat_sagemaker(): print("testing sagemaker") litellm.set_verbose=True response = completion( - model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-f", + model="sagemaker/jumpstart-dft-meta-textgeneration-llama-2-7b-chat", messages=messages, max_tokens=100, stream=True, + n=2, hf_model_name="meta-llama/Llama-2-7b-chat-hf", ) # Add any assertions here to check the response diff --git a/litellm/utils.py b/litellm/utils.py index 814735f88..9e93f6b64 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4065,6 +4065,14 @@ def exception_type( llm_provider="sagemaker", response=original_exception.response ) + elif "Input validation error: `best_of` must be > 0 and <= 2" in error_str: + exception_mapping_worked = True + raise BadRequestError( + message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", + model=model, + llm_provider="sagemaker", + response=original_exception.response + ) elif custom_llm_provider == "vertex_ai": if "Vertex AI API has not been used in project" in error_str or "Unable to find your project" in error_str: exception_mapping_worked = True