From 2a4c1a1803c943c4df4ddd6a1172afa179de5ea7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 14 Dec 2023 14:17:33 -0800 Subject: [PATCH] fix(proxy_server.py): don't pass in user param if not sent --- litellm/proxy/proxy_server.py | 6 +++-- .../test_configs/test_config_no_auth.yaml | 3 +++ litellm/tests/test_custom_logger.py | 20 ++++++++++++++ litellm/tests/test_embedding.py | 2 +- litellm/tests/test_proxy_server.py | 26 ++++++++++++++----- 5 files changed, 48 insertions(+), 9 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 2819f9d2a..177edb43f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -970,7 +970,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap ) # users can pass in 'user' param to /chat/completions. Don't override it - if data.get("user", None) is None: + if data.get("user", None) is None and user_api_key_dict.user_id is not None: # if users are using user_api_key_auth, set `user` in `data` data["user"] = user_api_key_dict.user_id @@ -1063,7 +1063,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen "body": copy.copy(data) # use copy instead of deepcopy } - data["user"] = data.get("user", user_api_key_dict.user_id) + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + data["user"] = user_api_key_dict.user_id + data["model"] = ( general_settings.get("embedding_model", None) # server default or user_model # model name passed via cli args diff --git a/litellm/tests/test_configs/test_config_no_auth.yaml b/litellm/tests/test_configs/test_config_no_auth.yaml index 1dd01d619..edf690173 100644 --- a/litellm/tests/test_configs/test_config_no_auth.yaml +++ b/litellm/tests/test_configs/test_config_no_auth.yaml @@ -73,3 +73,6 @@ model_list: description: this is a test openai model id: 9b1ef341-322c-410a-8992-903987fef439 model_name: test_openai_models +- model_name: amazon-embeddings + litellm_params: + model: "bedrock/amazon.titan-embed-text-v1" diff --git a/litellm/tests/test_custom_logger.py b/litellm/tests/test_custom_logger.py index 26dcdf7d4..fc5a63619 100644 --- a/litellm/tests/test_custom_logger.py +++ b/litellm/tests/test_custom_logger.py @@ -310,6 +310,26 @@ async def test_async_custom_handler_embedding_optional_param(): # asyncio.run(test_async_custom_handler_embedding_optional_param()) +@pytest.mark.asyncio +async def test_async_custom_handler_embedding_optional_param_bedrock(): + """ + Tests if the openai optional params for embedding - user + encoding_format, + are logged + + but makes sure these are not sent to the non-openai/azure endpoint (raises errors). + """ + customHandler_optional_params = MyCustomHandler() + litellm.callbacks = [customHandler_optional_params] + response = await litellm.aembedding( + model="azure/azure-embedding-model", + input = ["hello world"], + user = "John" + ) + await asyncio.sleep(1) # success callback is async + assert customHandler_optional_params.user == "John" + assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"] + + def test_redis_cache_completion_stream(): from litellm import Cache # Important Test - This tests if we can add to streaming cache, when custom callbacks are set diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 71e59819f..9a2a5951a 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -164,7 +164,7 @@ def test_bedrock_embedding_titan(): assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats" except Exception as e: pytest.fail(f"Error occurred: {e}") -# test_bedrock_embedding_titan() +test_bedrock_embedding_titan() def test_bedrock_embedding_cohere(): try: diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 25e637c5c..31e18b5ff 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -111,12 +111,26 @@ def test_embedding(client_no_auth): "model": "azure/azure-embedding-model", "input": ["good morning from litellm"], } - # print("testing proxy server with Azure embedding") - # print(user_custom_auth) - # print(id(user_custom_auth)) - # user_custom_auth = None - # print("valu of user_custom_auth", user_custom_auth) - # litellm.proxy.proxy_server.user_custom_auth = None + + response = client_no_auth.post("/v1/embeddings", json=test_data) + + assert response.status_code == 200 + result = response.json() + print(len(result["data"][0]["embedding"])) + assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so + except Exception as e: + pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + +def test_bedrock_embedding(client_no_auth): + global headers + from litellm.proxy.proxy_server import user_custom_auth + + try: + test_data = { + "model": "amazon-embeddings", + "input": ["good morning from litellm"], + } + response = client_no_auth.post("/v1/embeddings", json=test_data) assert response.status_code == 200