security - Prevent sql injection in /team/update query (#5513)

* fix(team_endpoints.py): replace `.get_data()` usage with prisma interface

Prevent sql injection in `/team/update` query

Fixes https://huntr.com/bounties/a4f6d357-5b44-4e00-9cac-f1cc351211d2

* fix(vertex_ai_non_gemini.py): handle message being a pydantic model
This commit is contained in:
Krish Dholakia 2024-09-04 16:03:02 -07:00 committed by GitHub
parent 7558e49d78
commit 0595d03116
3 changed files with 18 additions and 12 deletions

View file

@ -177,7 +177,11 @@ def _gemini_convert_messages_with_history(
assistant_content = [] assistant_content = []
## MERGE CONSECUTIVE ASSISTANT CONTENT ## ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
assistant_msg = ChatCompletionAssistantMessage(**messages[msg_i]) # type: ignore if isinstance(messages[msg_i], BaseModel):
msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() # type: ignore
else:
msg_dict = messages[msg_i] # type: ignore
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
if assistant_msg.get("content", None) is not None and isinstance( if assistant_msg.get("content", None) is not None and isinstance(
assistant_msg["content"], list assistant_msg["content"], list
): ):

View file

@ -353,9 +353,10 @@ async def update_team(
raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
verbose_proxy_logger.debug("/team/update - %s", data) verbose_proxy_logger.debug("/team/update - %s", data)
existing_team_row = await prisma_client.get_data( existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
team_id=data.team_id, table_name="team", query_type="find_unique" where={"team_id": data.team_id}
) )
if existing_team_row is None: if existing_team_row is None:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,

View file

@ -45,14 +45,14 @@ def get_current_weather(location, unit="fahrenheit"):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model",
[ [
"gpt-3.5-turbo-1106", # "gpt-3.5-turbo-1106",
"mistral/mistral-large-latest", # "mistral/mistral-large-latest",
"claude-3-haiku-20240307", # "claude-3-haiku-20240307",
"gemini/gemini-1.5-pro", "gemini/gemini-1.5-pro",
"anthropic.claude-3-sonnet-20240229-v1:0", # "anthropic.claude-3-sonnet-20240229-v1:0",
], ],
) )
def test_parallel_function_call(model): def test_aaparallel_function_call(model):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
# Step 1: send the conversation and available functions to the model # Step 1: send the conversation and available functions to the model
@ -102,6 +102,7 @@ def test_parallel_function_call(model):
) # this has to call the function for SF, Tokyo and paris ) # this has to call the function for SF, Tokyo and paris
# Step 2: check if the model wanted to call a function # Step 2: check if the model wanted to call a function
print(f"tool_calls: {tool_calls}")
if tool_calls: if tool_calls:
# Step 3: call the function # Step 3: call the function
# Note: the JSON response may not always be valid; be sure to handle errors # Note: the JSON response may not always be valid; be sure to handle errors
@ -142,10 +143,10 @@ def test_parallel_function_call(model):
drop_params=True, drop_params=True,
) # get a new response from the model where it can see the function response ) # get a new response from the model where it can see the function response
print("second response\n", second_response) print("second response\n", second_response)
except litellm.InternalServerError: except litellm.InternalServerError as e:
pass print(e)
except litellm.RateLimitError: except litellm.RateLimitError as e:
pass print(e)
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")