forked from phoenix/litellm-mirror
fix(proxy_server.py): don't pass in user param if not sent
This commit is contained in:
parent
d78b6be8fb
commit
2a4c1a1803
5 changed files with 48 additions and 9 deletions
|
@ -970,7 +970,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
)
|
||||
|
||||
# users can pass in 'user' param to /chat/completions. Don't override it
|
||||
if data.get("user", None) is None:
|
||||
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||
# if users are using user_api_key_auth, set `user` in `data`
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
|
@ -1063,7 +1063,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
"body": copy.copy(data) # use copy instead of deepcopy
|
||||
}
|
||||
|
||||
data["user"] = data.get("user", user_api_key_dict.user_id)
|
||||
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
data["model"] = (
|
||||
general_settings.get("embedding_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
|
|
|
@ -73,3 +73,6 @@ model_list:
|
|||
description: this is a test openai model
|
||||
id: 9b1ef341-322c-410a-8992-903987fef439
|
||||
model_name: test_openai_models
|
||||
- model_name: amazon-embeddings
|
||||
litellm_params:
|
||||
model: "bedrock/amazon.titan-embed-text-v1"
|
||||
|
|
|
@ -310,6 +310,26 @@ async def test_async_custom_handler_embedding_optional_param():
|
|||
|
||||
# asyncio.run(test_async_custom_handler_embedding_optional_param())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_custom_handler_embedding_optional_param_bedrock():
|
||||
"""
|
||||
Tests if the openai optional params for embedding - user + encoding_format,
|
||||
are logged
|
||||
|
||||
but makes sure these are not sent to the non-openai/azure endpoint (raises errors).
|
||||
"""
|
||||
customHandler_optional_params = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler_optional_params]
|
||||
response = await litellm.aembedding(
|
||||
model="azure/azure-embedding-model",
|
||||
input = ["hello world"],
|
||||
user = "John"
|
||||
)
|
||||
await asyncio.sleep(1) # success callback is async
|
||||
assert customHandler_optional_params.user == "John"
|
||||
assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"]
|
||||
|
||||
|
||||
def test_redis_cache_completion_stream():
|
||||
from litellm import Cache
|
||||
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set
|
||||
|
|
|
@ -164,7 +164,7 @@ def test_bedrock_embedding_titan():
|
|||
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_bedrock_embedding_titan()
|
||||
test_bedrock_embedding_titan()
|
||||
|
||||
def test_bedrock_embedding_cohere():
|
||||
try:
|
||||
|
|
|
@ -111,12 +111,26 @@ def test_embedding(client_no_auth):
|
|||
"model": "azure/azure-embedding-model",
|
||||
"input": ["good morning from litellm"],
|
||||
}
|
||||
# print("testing proxy server with Azure embedding")
|
||||
# print(user_custom_auth)
|
||||
# print(id(user_custom_auth))
|
||||
# user_custom_auth = None
|
||||
# print("valu of user_custom_auth", user_custom_auth)
|
||||
# litellm.proxy.proxy_server.user_custom_auth = None
|
||||
|
||||
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
print(len(result["data"][0]["embedding"]))
|
||||
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||
|
||||
def test_bedrock_embedding(client_no_auth):
|
||||
global headers
|
||||
from litellm.proxy.proxy_server import user_custom_auth
|
||||
|
||||
try:
|
||||
test_data = {
|
||||
"model": "amazon-embeddings",
|
||||
"input": ["good morning from litellm"],
|
||||
}
|
||||
|
||||
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue