diff --git a/litellm/tests/test_configs/test_config_no_auth.yaml b/litellm/tests/test_configs/test_config_no_auth.yaml index edf690173..2fd9ef203 100644 --- a/litellm/tests/test_configs/test_config_no_auth.yaml +++ b/litellm/tests/test_configs/test_config_no_auth.yaml @@ -76,3 +76,6 @@ model_list: - model_name: amazon-embeddings litellm_params: model: "bedrock/amazon.titan-embed-text-v1" +- model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)" + litellm_params: + model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16" \ No newline at end of file diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 31e18b5ff..9de25c298 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -140,6 +140,25 @@ def test_bedrock_embedding(client_no_auth): except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") +def test_sagemaker_embedding(client_no_auth): + global headers + from litellm.proxy.proxy_server import user_custom_auth + + try: + test_data = { + "model": "GPT-J 6B - Sagemaker Text Embedding (Internal)", + "input": ["good morning from litellm"], + } + + response = client_no_auth.post("/v1/embeddings", json=test_data) + + assert response.status_code == 200 + result = response.json() + print(len(result["data"][0]["embedding"])) + assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so + except Exception as e: + pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + # Run the test # test_embedding()