fix(proxy_server.py): don't pass in user param if not sent

This commit is contained in:
Krrish Dholakia 2023-12-14 14:17:33 -08:00
parent d78b6be8fb
commit 2a4c1a1803
5 changed files with 48 additions and 9 deletions

View file

@ -970,7 +970,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
)
# users can pass in 'user' param to /chat/completions. Don't override it
if data.get("user", None) is None:
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
# if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.user_id
@ -1063,7 +1063,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
"body": copy.copy(data) # use copy instead of deepcopy
}
data["user"] = data.get("user", user_api_key_dict.user_id)
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
data["model"] = (
general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args

View file

@ -73,3 +73,6 @@ model_list:
description: this is a test openai model
id: 9b1ef341-322c-410a-8992-903987fef439
model_name: test_openai_models
- model_name: amazon-embeddings
litellm_params:
model: "bedrock/amazon.titan-embed-text-v1"

View file

@ -310,6 +310,26 @@ async def test_async_custom_handler_embedding_optional_param():
# asyncio.run(test_async_custom_handler_embedding_optional_param())
@pytest.mark.asyncio
async def test_async_custom_handler_embedding_optional_param_bedrock():
"""
Tests if the openai optional params for embedding - user + encoding_format,
are logged
but makes sure these are not sent to the non-openai/azure endpoint (raises errors).
"""
customHandler_optional_params = MyCustomHandler()
litellm.callbacks = [customHandler_optional_params]
response = await litellm.aembedding(
model="azure/azure-embedding-model",
input = ["hello world"],
user = "John"
)
await asyncio.sleep(1) # success callback is async
assert customHandler_optional_params.user == "John"
assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"]
def test_redis_cache_completion_stream():
from litellm import Cache
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set

View file

@ -164,7 +164,7 @@ def test_bedrock_embedding_titan():
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_bedrock_embedding_titan()
test_bedrock_embedding_titan()
def test_bedrock_embedding_cohere():
try:

View file

@ -111,12 +111,26 @@ def test_embedding(client_no_auth):
"model": "azure/azure-embedding-model",
"input": ["good morning from litellm"],
}
# print("testing proxy server with Azure embedding")
# print(user_custom_auth)
# print(id(user_custom_auth))
# user_custom_auth = None
# print("valu of user_custom_auth", user_custom_auth)
# litellm.proxy.proxy_server.user_custom_auth = None
response = client_no_auth.post("/v1/embeddings", json=test_data)
assert response.status_code == 200
result = response.json()
print(len(result["data"][0]["embedding"]))
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
except Exception as e:
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
def test_bedrock_embedding(client_no_auth):
global headers
from litellm.proxy.proxy_server import user_custom_auth
try:
test_data = {
"model": "amazon-embeddings",
"input": ["good morning from litellm"],
}
response = client_no_auth.post("/v1/embeddings", json=test_data)
assert response.status_code == 200