LiteLLM Minor Fixes & Improvements (10/08/2024) (#6119)

* refactor(cost_calculator.py): move error line to debug - https://github.com/BerriAI/litellm/issues/5683#issuecomment-2398599498

* fix(migrate-hidden-params-to-read-from-standard-logging-payload): Fixes https://github.com/BerriAI/litellm/issues/5546#issuecomment-2399994026

* fix(types/utils.py): mark weight as a litellm param

Fixes https://github.com/BerriAI/litellm/issues/5781

* feat(internal_user_endpoints.py): fix /user/info + show user max budget as default max budget

Fixes https://github.com/BerriAI/litellm/issues/6117

* feat: support returning team member budget in `/user/info`

Sets user max budget in team as max budget on ui

  Closes https://github.com/BerriAI/litellm/issues/6117

* bug fix for optional parameter passing to replicate (#6067)

Signed-off-by: Mandana Vaziri <mvaziri@us.ibm.com>

* fix(o1_transformation.py): handle o1 temperature=0

o1 doesn't support temp=0, allow admin to drop this param

* test: fix test

---------

Signed-off-by: Mandana Vaziri <mvaziri@us.ibm.com>
Co-authored-by: Mandana Vaziri <mvaziri@us.ibm.com>
This commit is contained in:
Krish Dholakia 2024-10-08 21:57:03 -07:00 committed by GitHub
parent ac6fb0cbef
commit 9695c1af10
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 260 additions and 86 deletions

View file

@ -623,7 +623,7 @@ def completion_cost(
try:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
except Exception as e:
verbose_logger.error(
verbose_logger.debug(
"litellm.cost_calculator.py::completion_cost() - Error inferring custom_llm_provider - {}".format(
str(e)
)

View file

@ -4,6 +4,7 @@ import copy
import inspect
import os
import traceback
from typing import Optional
from packaging.version import Version
from pydantic import BaseModel
@ -12,6 +13,7 @@ import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.redact_messages import redact_user_api_key_info
from litellm.secret_managers.main import str_to_bool
from litellm.types.utils import StandardLoggingPayload
class LangFuseLogger:
@ -502,7 +504,15 @@ class LangFuseLogger:
cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}")
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
clean_metadata["litellm_response_cost"] = cost
if standard_logging_object is not None:
clean_metadata["hidden_params"] = standard_logging_object[
"hidden_params"
]
if (
litellm.langfuse_default_tags is not None

View file

@ -85,7 +85,7 @@ class OpenAIO1Config(OpenAIGPTConfig):
if "temperature" in non_default_params:
temperature_value: Optional[float] = non_default_params.pop("temperature")
if temperature_value is not None:
if temperature_value == 0 or temperature_value == 1:
if temperature_value == 1:
optional_params["temperature"] = temperature_value
else:
## UNSUPPORTED TOOL CHOICE VALUE

View file

@ -1,5 +1,10 @@
model_list:
- model_name: claude-3-5-sonnet-20240620
- model_name: gpt-4o-mini
litellm_params:
model: anthropic/claude-3-5-sonnet-20240620
api_key: os.environ/ANTHROPIC_API_KEY
model: azure/my-gpt-4o-mini
api_key: os.environ/AZURE_API_KEY
api_base: os.environ/AZURE_API_BASE
litellm_settings:
success_callback: ["langfuse"]
max_internal_user_budget: 10

View file

@ -1919,6 +1919,10 @@ class TeamInfoResponseObject(TypedDict):
team_memberships: List[LiteLLM_TeamMembership]
class TeamListResponseObject(LiteLLM_TeamTable):
team_memberships: List[LiteLLM_TeamMembership]
class CurrentItemRateLimit(TypedDict):
current_requests: int
current_tpm: int

View file

@ -79,7 +79,6 @@ def check_prisma_schema_diff_helper(db_url: str) -> Tuple[bool, List[str]]:
print(command) # noqa: T201
return True, sql_commands
else:
print("No changes required.") # noqa: T201
return False, []
except subprocess.CalledProcessError as e:
error_message = f"Failed to generate migration diff. Error: {e.stderr}"

View file

@ -277,8 +277,9 @@ async def ui_get_available_role(
def get_team_from_list(
team_list: Optional[List[LiteLLM_TeamTable]], team_id: str
) -> Optional[LiteLLM_TeamTable]:
team_list: Optional[Union[List[LiteLLM_TeamTable], List[TeamListResponseObject]]],
team_id: str,
) -> Optional[Union[LiteLLM_TeamTable, LiteLLM_TeamMembership]]:
if team_list is None:
return None
@ -292,7 +293,7 @@ def get_team_from_list(
"/user/info",
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
response_model=UserInfoResponse,
# response_model=UserInfoResponse,
)
@management_endpoint_wrapper
async def user_info(
@ -334,8 +335,17 @@ async def user_info(
team_list = []
team_id_list = []
# get all teams user belongs to
teams_1 = await prisma_client.get_data(
user_id=user_id, table_name="team", query_type="find_all"
# teams_1 = await prisma_client.get_data(
# user_id=user_id, table_name="team", query_type="find_all"
# )
from litellm.proxy.management_endpoints.team_endpoints import list_team
teams_1 = await list_team(
http_request=Request(
scope={"type": "http", "path": "/user/info"},
),
user_id=user_id,
user_api_key_dict=user_api_key_dict,
)
if teams_1 is not None and isinstance(teams_1, list):
@ -355,6 +365,7 @@ async def user_info(
if team.team_id not in team_id_list:
team_list.append(team)
team_id_list.append(team.team_id)
elif (
user_api_key_dict.user_id is not None and user_id is None
): # the key querying the endpoint is the one asking for it's teams
@ -436,8 +447,11 @@ async def user_info(
key["team_alias"] = "None"
returned_keys.append(key)
_user_info = (
user_info.model_dump() if isinstance(user_info, BaseModel) else user_info
)
response_data = UserInfoResponse(
user_id=user_id, user_info=user_info, keys=returned_keys, teams=team_list
user_id=user_id, user_info=_user_info, keys=returned_keys, teams=team_list
)
return response_data

View file

@ -31,6 +31,7 @@ from litellm.proxy._types import (
TeamAddMemberResponse,
TeamBase,
TeamInfoResponseObject,
TeamListResponseObject,
TeamMemberAddRequest,
TeamMemberDeleteRequest,
TeamMemberUpdateRequest,
@ -44,6 +45,7 @@ from litellm.proxy.management_helpers.utils import (
add_new_member,
management_endpoint_wrapper,
)
from litellm.proxy.utils import PrismaClient
router = APIRouter()
@ -58,6 +60,27 @@ def _is_user_team_admin(
return False
async def get_all_team_memberships(
prisma_client: PrismaClient, team_id: List[str], user_id: Optional[str] = None
) -> List[LiteLLM_TeamMembership]:
"""Get all team memberships for a given user"""
## GET ALL MEMBERSHIPS ##
team_memberships = await prisma_client.db.litellm_teammembership.find_many(
where=(
{"user_id": user_id, "team_id": {"in": team_id}}
if user_id is not None
else {"team_id": {"in": team_id}}
),
include={"litellm_budget_table": True},
)
returned_tm: List[LiteLLM_TeamMembership] = []
for tm in team_memberships:
returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump()))
return returned_tm
#### TEAM MANAGEMENT ####
@router.post(
"/team/new",
@ -1077,15 +1100,10 @@ async def team_info(
key.pop("token", None)
## GET ALL MEMBERSHIPS ##
team_memberships = await prisma_client.db.litellm_teammembership.find_many(
where={"team_id": team_id},
include={"litellm_budget_table": True},
returned_tm = await get_all_team_memberships(
prisma_client, [team_id], user_id=None
)
returned_tm: List[LiteLLM_TeamMembership] = []
for tm in team_memberships:
returned_tm.append(LiteLLM_TeamMembership(**tm.model_dump()))
if isinstance(team_info, dict):
_team_info = LiteLLM_TeamTable(**team_info)
elif isinstance(team_info, BaseModel):
@ -1188,11 +1206,12 @@ async def unblock_team(
@management_endpoint_wrapper
async def list_team(
http_request: Request,
user_id: Optional[str] = fastapi.Query(
default=None, description="Only return teams which this 'user_id' belongs to"
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
[Admin-only] List all available teams
```
curl --location --request GET 'http://0.0.0.0:4000/team/list' \
--header 'Authorization: Bearer sk-1234'
@ -1208,11 +1227,12 @@ async def list_team(
if (
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
and user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY
and user_api_key_dict.user_id != user_id
):
raise HTTPException(
status_code=401,
detail={
"error": "Admin-only endpoint. Your user role={}".format(
"error": "Only admin users can query all teams/other teams. Your user role={}".format(
user_api_key_dict.user_role
)
},
@ -1226,4 +1246,37 @@ async def list_team(
response = await prisma_client.db.litellm_teamtable.find_many()
return response
filtered_response = []
if user_id:
for team in response:
if team.members_with_roles:
for member in team.members_with_roles:
if (
"user_id" in member
and member["user_id"] is not None
and member["user_id"] == user_id
):
filtered_response.append(team)
else:
filtered_response = response
_team_ids = [team.team_id for team in filtered_response]
returned_tm = await get_all_team_memberships(
prisma_client, _team_ids, user_id=user_id
)
returned_responses: List[TeamListResponseObject] = []
for team in filtered_response:
_team_memberships: List[LiteLLM_TeamMembership] = []
for tm in returned_tm:
if tm.team_id == team.team_id:
_team_memberships.append(tm)
returned_responses.append(
TeamListResponseObject(
**team.model_dump(),
team_memberships=_team_memberships,
)
)
return returned_responses

View file

@ -6029,8 +6029,7 @@ async def end_user_info(
)
user_info = await prisma_client.db.litellm_endusertable.find_first(
where={"user_id": end_user_id},
include={"litellm_budget_table": True}
where={"user_id": end_user_id}, include={"litellm_budget_table": True}
)
if user_info is None:
@ -6235,7 +6234,7 @@ async def delete_end_user(
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
async def list_team(
async def list_end_user(
http_request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):

View file

@ -1657,6 +1657,7 @@ class PrismaClient:
where={
"members": {"has": user_id},
},
include={"litellm_budget_table": True},
)
elif query_type == "find_all" and team_id_list is not None:
response = await self.db.litellm_teamtable.find_many(

View file

@ -1237,6 +1237,7 @@ all_litellm_params = [
"client_secret",
"user_continue_message",
"configurable_clientside_auth_params",
"weight",
]

View file

@ -3197,7 +3197,7 @@ def get_optional_params(
if stream:
optional_params["stream"] = stream
return optional_params
#return optional_params
if max_tokens is not None:
if "vicuna" in model or "flan" in model:
optional_params["max_length"] = max_tokens
@ -7244,34 +7244,6 @@ class CustomStreamWrapper:
except Exception as e:
raise e
def handle_bedrock_stream(self, chunk):
return {
"text": chunk["text"],
"is_finished": chunk["is_finished"],
"finish_reason": chunk["finish_reason"],
}
def handle_sagemaker_stream(self, chunk):
if "data: [DONE]" in chunk:
text = ""
is_finished = True
finish_reason = "stop"
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
elif isinstance(chunk, dict):
if chunk["is_finished"] is True:
finish_reason = "stop"
else:
finish_reason = ""
return {
"text": chunk["text"],
"is_finished": chunk["is_finished"],
"finish_reason": finish_reason,
}
def handle_watsonx_stream(self, chunk):
try:
if isinstance(chunk, dict):
@ -7419,6 +7391,10 @@ class CustomStreamWrapper:
model_response._hidden_params = hidden_params
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
model_response._hidden_params["created_at"] = time.time()
model_response._hidden_params = {
**model_response._hidden_params,
**self._hidden_params,
}
if (
len(model_response.choices) > 0

View file

@ -633,7 +633,7 @@ def test_azure_o1_model_params():
@pytest.mark.parametrize(
"temperature, expected_error",
[(0.2, True), (1, False)],
[(0.2, True), (1, False), (0, True)],
)
@pytest.mark.parametrize("provider", ["openai", "azure"])
def test_o1_model_temperature_params(provider, temperature, expected_error):

View file

@ -1403,6 +1403,37 @@ def test_logging_standard_payload_failure_call():
]["standard_logging_object"]
@pytest.mark.parametrize("stream", [True, False])
def test_logging_standard_payload_llm_headers(stream):
from litellm.types.utils import StandardLoggingPayload
# sync completion
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
with patch.object(
customHandler, "log_success_event", new=MagicMock()
) as mock_client:
resp = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
stream=stream,
)
if stream:
for chunk in resp:
continue
time.sleep(2)
mock_client.assert_called_once()
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
"kwargs"
]["standard_logging_object"]
print(standard_logging_object["hidden_params"]["additional_headers"])
def test_logging_key_masking_gemini():
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]

View file

@ -240,24 +240,24 @@ def test_generate_and_call_with_valid_key(prisma_client, api_route):
await litellm.proxy.proxy_server.prisma_client.connect()
from litellm.proxy.proxy_server import user_api_key_cache
request = NewUserRequest(user_role=LitellmUserRoles.INTERNAL_USER)
key = await new_user(
request,
user_api_key_dict=UserAPIKeyAuth(
user_api_key_dict = UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
api_key="sk-1234",
user_id="1234",
),
)
request = NewUserRequest(user_role=LitellmUserRoles.INTERNAL_USER)
key = await new_user(request, user_api_key_dict=user_api_key_dict)
print(key)
user_id = key.user_id
# check /user/info to verify user_role was set correctly
new_user_info = await user_info(user_id=user_id)
new_user_info = await user_info(
user_id=user_id, user_api_key_dict=user_api_key_dict
)
new_user_info = new_user_info.user_info
print("new_user_info=", new_user_info)
assert new_user_info.user_role == LitellmUserRoles.INTERNAL_USER
assert new_user_info.user_id == user_id
assert new_user_info["user_role"] == LitellmUserRoles.INTERNAL_USER
assert new_user_info["user_id"] == user_id
generated_key = key.key
bearer_token = "Bearer " + generated_key

View file

@ -5,6 +5,35 @@ import asyncio
import aiohttp
import time, uuid
from openai import AsyncOpenAI
from typing import Optional
async def get_user_info(session, get_user, call_user, view_all: Optional[bool] = None):
"""
Make sure only models user has access to are returned
"""
if view_all is True:
url = "http://0.0.0.0:4000/user/info"
else:
url = f"http://0.0.0.0:4000/user/info?user_id={get_user}"
headers = {
"Authorization": f"Bearer {call_user}",
"Content-Type": "application/json",
}
async with session.get(url, headers=headers) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
if call_user != get_user:
return status
else:
print(f"call_user: {call_user}; get_user: {get_user}")
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
async def new_user(
@ -630,3 +659,13 @@ async def test_users_in_team_budget():
print("got exception, this is expected")
print(e)
assert "Budget has been exceeded" in str(e)
## Check user info
user_info = await get_user_info(session, get_user, call_user="sk-1234")
assert (
user_info["teams"][0]["team_memberships"][0]["litellm_budget_table"][
"max_budget"
]
== 0.0000001
)

View file

@ -88,9 +88,15 @@ async def test_user_info():
key_gen = await new_user(session, 0, user_id=get_user)
key = key_gen["key"]
## as admin ##
await get_user_info(session=session, get_user=get_user, call_user="sk-1234")
resp = await get_user_info(
session=session, get_user=get_user, call_user="sk-1234"
)
assert isinstance(resp["user_info"], dict)
assert len(resp["user_info"]) > 0
## as user themself ##
await get_user_info(session=session, get_user=get_user, call_user=key)
resp = await get_user_info(session=session, get_user=get_user, call_user=key)
assert isinstance(resp["user_info"], dict)
assert len(resp["user_info"]) > 0
# as random user #
key_gen = await new_user(session=session, i=0)
random_key = key_gen["key"]

View file

@ -1,18 +1,20 @@
import React, { useState, useEffect } from "react";
import { Select, SelectItem, Text, Title } from "@tremor/react";
import { ProxySettings } from "./user_dashboard";
import { ProxySettings, UserInfo } from "./user_dashboard";
interface DashboardTeamProps {
teams: Object[] | null;
setSelectedTeam: React.Dispatch<React.SetStateAction<any | null>>;
userRole: string | null;
proxySettings: ProxySettings | null;
userInfo: UserInfo | null;
}
type TeamInterface = {
models: any[];
team_id: null;
team_alias: String
team_alias: String;
max_budget: number | null;
}
const DashboardTeam: React.FC<DashboardTeamProps> = ({
@ -20,11 +22,14 @@ const DashboardTeam: React.FC<DashboardTeamProps> = ({
setSelectedTeam,
userRole,
proxySettings,
userInfo,
}) => {
console.log(`userInfo: ${JSON.stringify(userInfo)}`)
const defaultTeam: TeamInterface = {
models: [],
models: userInfo?.models || [],
team_id: null,
team_alias: "Default Team"
team_alias: "Default Team",
max_budget: userInfo?.max_budget || null,
}

View file

@ -23,11 +23,6 @@ if (isLocal != true) {
console.log("isLocal:", isLocal);
const proxyBaseUrl = isLocal ? "http://localhost:4000" : null;
type UserSpendData = {
spend: number;
max_budget?: number | null;
};
export interface ProxySettings {
PROXY_BASE_URL: string | null;
PROXY_LOGOUT_URL: string | null;
@ -35,6 +30,13 @@ export interface ProxySettings {
SSO_ENABLED: boolean;
}
export type UserInfo = {
models: string[];
max_budget?: number | null;
spend: number;
}
function getCookie(name: string) {
console.log("COOKIES", document.cookie)
const cookieValue = document.cookie
@ -74,7 +76,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
setKeys,
premiumUser,
}) => {
const [userSpendData, setUserSpendData] = useState<UserSpendData | null>(
const [userSpendData, setUserSpendData] = useState<UserInfo | null>(
null
);
@ -186,14 +188,9 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
response
)}; team values: ${Object.entries(response.teams)}`
);
// if (userRole == "Admin") {
// const globalSpend = await getTotalSpendCall(accessToken);
// setUserSpendData(globalSpend);
// console.log("globalSpend:", globalSpend);
// } else {
// );
// }
setUserSpendData(response["user_info"]);
console.log(`userSpendData: ${JSON.stringify(userSpendData)}`)
setKeys(response["keys"]); // Assuming this is the correct path to your data
setTeams(response["teams"]);
const teamsArray = [...response["teams"]];
@ -352,6 +349,7 @@ const UserDashboard: React.FC<UserDashboardProps> = ({
setSelectedTeam={setSelectedTeam}
userRole={userRole}
proxySettings={proxySettings}
userInfo={userSpendData}
/>
</Col>
</Grid>

View file

@ -52,9 +52,37 @@ const ViewUserSpend: React.FC<ViewUserSpendProps> = ({ userID, userRole, accessT
if (selectedTeam.team_alias === "Default Team") {
setMaxBudget(userMaxBudget);
} else {
let setMaxBudgetFlag = false;
if (selectedTeam.team_memberships) {
/**
* What 'team_memberships' looks like:
* "team_memberships": [
* {
* "user_id": "2c315de3-e7ce-4269-b73e-b039a06187b1",
* "team_id": "test-team_515e6f42-ded2-4f0d-8919-0a1f43c5a45f",
* "budget_id": "0880769f-716a-4149-ab19-7f7651ad4db5",
* "litellm_budget_table": {
"soft_budget": null,
"max_budget": 20.0,
"max_parallel_requests": null,
"tpm_limit": null,
"rpm_limit": null,
"model_max_budget": null,
"budget_duration": null
}
*/
for (const member of selectedTeam.team_memberships) {
if (member.user_id === userID && "max_budget" in member.litellm_budget_table && member.litellm_budget_table.max_budget !== null) {
setMaxBudget(member.litellm_budget_table.max_budget);
setMaxBudgetFlag = true;
}
}
}
if (!setMaxBudgetFlag) {
setMaxBudget(selectedTeam.max_budget);
}
}
}
}, [selectedTeam, userMaxBudget]);
const [userModels, setUserModels] = useState([]);
useEffect(() => {

View file

@ -73,6 +73,11 @@ const ViewUserTeam: React.FC<ViewUserTeamProps> = ({ userID, userRole, selectedT
<>
<div className="mb-5">
<p className="text-3xl text-tremor-content-strong dark:text-dark-tremor-content-strong font-semibold">{selectedTeam?.team_alias}</p>
{
selectedTeam?.team_id && (
<p className="text-xs text-gray-400 dark:text-gray-400 font-semibold">Team ID: {selectedTeam?.team_id}</p>
)
}
</div>
</>
)