Merge pull request #2704 from BerriAI/litellm_jwt_auth_improvements_3

fix(handle_jwt.py): enable team-based jwt-auth access
This commit is contained in:
Krish Dholakia 2024-03-26 16:06:56 -07:00 committed by GitHub
commit 0ab708e6f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 358 additions and 145 deletions

View file

@ -124,7 +124,7 @@ general_settings:
### Allowed LiteLLM scopes ### Allowed LiteLLM scopes
```python ```python
class LiteLLMProxyRoles(LiteLLMBase): class LiteLLM_JWTAuth(LiteLLMBase):
proxy_admin: str = "litellm_proxy_admin" proxy_admin: str = "litellm_proxy_admin"
proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth. proxy_user: str = "litellm_user" # 👈 Not implemented yet, for JWT-Auth.
``` ```

View file

@ -1,4 +1,5 @@
from pydantic import BaseModel, Extra, Field, root_validator, Json, validator from pydantic import BaseModel, Extra, Field, root_validator, Json, validator
from dataclasses import fields
import enum import enum
from typing import Optional, List, Union, Dict, Literal, Any from typing import Optional, List, Union, Dict, Literal, Any
from datetime import datetime from datetime import datetime
@ -37,9 +38,97 @@ class LiteLLMBase(BaseModel):
protected_namespaces = () protected_namespaces = ()
class LiteLLMProxyRoles(LiteLLMBase): class LiteLLMRoutes(enum.Enum):
proxy_admin: str = "litellm_proxy_admin" openai_routes: List = [ # chat completions
proxy_user: str = "litellm_user" "/openai/deployments/{model}/chat/completions",
"/chat/completions",
"/v1/chat/completions",
# completions
"/openai/deployments/{model}/completions",
"/completions",
"/v1/completions",
# embeddings
"/openai/deployments/{model}/embeddings",
"/embeddings",
"/v1/embeddings",
# image generation
"/images/generations",
"/v1/images/generations",
# audio transcription
"/audio/transcriptions",
"/v1/audio/transcriptions",
# moderations
"/moderations",
"/v1/moderations",
# models
"/models",
"/v1/models",
]
info_routes: List = ["/key/info", "/team/info", "/user/info", "/model/info"]
management_routes: List = [ # key
"/key/generate",
"/key/update",
"/key/delete",
"/key/info",
# user
"/user/new",
"/user/update",
"/user/delete",
"/user/info",
# team
"/team/new",
"/team/update",
"/team/delete",
"/team/info",
# model
"/model/new",
"/model/update",
"/model/delete",
"/model/info",
]
class LiteLLM_JWTAuth(LiteLLMBase):
"""
A class to define the roles and permissions for a LiteLLM Proxy w/ JWT Auth.
Attributes:
- admin_jwt_scope: The JWT scope required for proxy admin roles.
- admin_allowed_routes: list of allowed routes for proxy admin roles.
- team_jwt_scope: The JWT scope required for proxy team roles.
- team_id_jwt_field: The field in the JWT token that stores the team ID. Default - `client_id`.
- team_allowed_routes: list of allowed routes for proxy team roles.
- end_user_id_jwt_field: Default - `sub`. The field in the JWT token that stores the end-user ID. Turn this off by setting to `None`. Enables end-user cost tracking.
See `auth_checks.py` for the specific routes
"""
admin_jwt_scope: str = "litellm_proxy_admin"
admin_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["management_routes"]
team_jwt_scope: str = "litellm_team"
team_id_jwt_field: str = "client_id"
team_allowed_routes: List[
Literal["openai_routes", "info_routes", "management_routes"]
] = ["openai_routes", "info_routes"]
end_user_id_jwt_field: Optional[str] = "sub"
public_key_ttl: float = 600
def __init__(self, **kwargs: Any) -> None:
# get the attribute names for this Pydantic model
allowed_keys = self.__annotations__.keys()
invalid_keys = set(kwargs.keys()) - allowed_keys
if invalid_keys:
raise ValueError(
f"Invalid arguments provided: {', '.join(invalid_keys)}. Allowed arguments are: {', '.join(allowed_keys)}."
)
super().__init__(**kwargs)
class LiteLLMPromptInjectionParams(LiteLLMBase): class LiteLLMPromptInjectionParams(LiteLLMBase):

View file

@ -8,15 +8,23 @@ Run checks for:
2. If user is in budget 2. If user is in budget
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
""" """
from litellm.proxy._types import LiteLLM_UserTable, LiteLLM_EndUserTable from litellm.proxy._types import (
LiteLLM_UserTable,
LiteLLM_EndUserTable,
LiteLLM_JWTAuth,
LiteLLM_TeamTable,
LiteLLMRoutes,
)
from typing import Optional, Literal from typing import Optional, Literal
from litellm.proxy.utils import PrismaClient from litellm.proxy.utils import PrismaClient
from litellm.caching import DualCache from litellm.caching import DualCache
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
def common_checks( def common_checks(
request_body: dict, request_body: dict,
user_object: LiteLLM_UserTable, team_object: LiteLLM_TeamTable,
end_user_object: Optional[LiteLLM_EndUserTable], end_user_object: Optional[LiteLLM_EndUserTable],
) -> bool: ) -> bool:
""" """
@ -30,19 +38,20 @@ def common_checks(
# 1. If user can call model # 1. If user can call model
if ( if (
_model is not None _model is not None
and len(user_object.models) > 0 and len(team_object.models) > 0
and _model not in user_object.models and _model not in team_object.models
): ):
raise Exception( raise Exception(
f"User={user_object.user_id} not allowed to call model={_model}. Allowed user models = {user_object.models}" f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
) )
# 2. If user is in budget # 2. If team is in budget
if ( if (
user_object.max_budget is not None team_object.max_budget is not None
and user_object.spend > user_object.max_budget and team_object.spend is not None
and team_object.spend > team_object.max_budget
): ):
raise Exception( raise Exception(
f"User={user_object.user_id} over budget. Spend={user_object.spend}, Budget={user_object.max_budget}" f"Team={team_object.team_id} over budget. Spend={team_object.spend}, Budget={team_object.max_budget}"
) )
# 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget # 3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
if end_user_object is not None and end_user_object.litellm_budget_table is not None: if end_user_object is not None and end_user_object.litellm_budget_table is not None:
@ -54,52 +63,79 @@ def common_checks(
return True return True
def _allowed_routes_check(user_route: str, allowed_routes: list) -> bool:
for allowed_route in allowed_routes:
if (
allowed_route == LiteLLMRoutes.openai_routes.name
and user_route in LiteLLMRoutes.openai_routes.value
):
return True
elif (
allowed_route == LiteLLMRoutes.info_routes.name
and user_route in LiteLLMRoutes.info_routes.value
):
return True
elif (
allowed_route == LiteLLMRoutes.management_routes.name
and user_route in LiteLLMRoutes.management_routes.value
):
return True
elif allowed_route == user_route:
return True
return False
def allowed_routes_check( def allowed_routes_check(
user_role: Literal["proxy_admin", "app_owner"], user_role: Literal["proxy_admin", "team"],
route: str, user_route: str,
allowed_routes: Optional[list] = None, litellm_proxy_roles: LiteLLM_JWTAuth,
) -> bool: ) -> bool:
""" """
Check if user -> not admin - allowed to access these routes Check if user -> not admin - allowed to access these routes
""" """
openai_routes = [
# chat completions
"/openai/deployments/{model}/chat/completions",
"/chat/completions",
"/v1/chat/completions",
# completions
# embeddings
"/openai/deployments/{model}/embeddings",
"/embeddings",
"/v1/embeddings",
# image generation
"/images/generations",
"/v1/images/generations",
# audio transcription
"/audio/transcriptions",
"/v1/audio/transcriptions",
# moderations
"/moderations",
"/v1/moderations",
# models
"/models",
"/v1/models",
]
info_routes = ["/key/info", "/team/info", "/user/info", "/model/info"]
default_routes = openai_routes + info_routes
if user_role == "proxy_admin": if user_role == "proxy_admin":
return True if litellm_proxy_roles.admin_allowed_routes is None:
elif user_role == "app_owner": is_allowed = _allowed_routes_check(
if allowed_routes is None: user_route=user_route, allowed_routes=["management_routes"]
if route in default_routes: # check default routes )
return True return is_allowed
elif route in allowed_routes: elif litellm_proxy_roles.admin_allowed_routes is not None:
return True is_allowed = _allowed_routes_check(
else: user_route=user_route,
return False allowed_routes=litellm_proxy_roles.admin_allowed_routes,
)
return is_allowed
elif user_role == "team":
if litellm_proxy_roles.team_allowed_routes is None:
"""
By default allow a team to call openai + info routes
"""
is_allowed = _allowed_routes_check(
user_route=user_route, allowed_routes=["openai_routes", "info_routes"]
)
return is_allowed
elif litellm_proxy_roles.team_allowed_routes is not None:
is_allowed = _allowed_routes_check(
user_route=user_route,
allowed_routes=litellm_proxy_roles.team_allowed_routes,
)
return is_allowed
return False return False
def get_actual_routes(allowed_routes: list) -> list:
actual_routes: list = []
for route_name in allowed_routes:
try:
route_value = LiteLLMRoutes[route_name].value
actual_routes = actual_routes + route_value
except KeyError:
actual_routes.append(route_name)
return actual_routes
async def get_end_user_object( async def get_end_user_object(
end_user_id: Optional[str], end_user_id: Optional[str],
prisma_client: Optional[PrismaClient], prisma_client: Optional[PrismaClient],
@ -135,3 +171,75 @@ async def get_end_user_object(
return LiteLLM_EndUserTable(**response.dict()) return LiteLLM_EndUserTable(**response.dict())
except Exception as e: # if end-user not in db except Exception as e: # if end-user not in db
return None return None
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
"""
- Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error
"""
if self.prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_UserTable):
return cached_user_obj
# else, check db
try:
response = await self.prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if response is None:
raise Exception
return LiteLLM_UserTable(**response.dict())
except Exception as e:
raise Exception(
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
)
async def get_team_object(
team_id: str,
prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache,
) -> LiteLLM_TeamTable:
"""
- Check if team id in proxy Team Table
- if valid, return LiteLLM_TeamTable object with defined limits
- if not, then raise an error
"""
if prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_team_obj = user_api_key_cache.async_get_cache(key=team_id)
if cached_team_obj is not None:
if isinstance(cached_team_obj, dict):
return LiteLLM_TeamTable(**cached_team_obj)
elif isinstance(cached_team_obj, LiteLLM_TeamTable):
return cached_team_obj
# else, check db
try:
response = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if response is None:
raise Exception
return LiteLLM_TeamTable(**response.dict())
except Exception as e:
raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
)

View file

@ -12,7 +12,7 @@ import json
import os import os
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import LiteLLMProxyRoles, LiteLLM_UserTable from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLM_UserTable
from litellm.proxy.utils import PrismaClient from litellm.proxy.utils import PrismaClient
from typing import Optional from typing import Optional
@ -70,68 +70,43 @@ class JWTHandler:
self, self,
prisma_client: Optional[PrismaClient], prisma_client: Optional[PrismaClient],
user_api_key_cache: DualCache, user_api_key_cache: DualCache,
litellm_proxy_roles: LiteLLMProxyRoles, litellm_jwtauth: LiteLLM_JWTAuth,
) -> None: ) -> None:
self.prisma_client = prisma_client self.prisma_client = prisma_client
self.user_api_key_cache = user_api_key_cache self.user_api_key_cache = user_api_key_cache
self.litellm_proxy_roles = litellm_proxy_roles self.litellm_jwtauth = litellm_jwtauth
def is_jwt(self, token: str): def is_jwt(self, token: str):
parts = token.split(".") parts = token.split(".")
return len(parts) == 3 return len(parts) == 3
def is_admin(self, scopes: list) -> bool: def is_admin(self, scopes: list) -> bool:
if self.litellm_proxy_roles.proxy_admin in scopes: if self.litellm_jwtauth.admin_jwt_scope in scopes:
return True return True
return False return False
def get_user_id(self, token: dict, default_value: str) -> str: def is_team(self, scopes: list) -> bool:
if self.litellm_jwtauth.team_jwt_scope in scopes:
return True
return False
def get_end_user_id(self, token: dict, default_value: Optional[str]) -> str:
try: try:
user_id = token["sub"] if self.litellm_jwtauth.end_user_id_jwt_field is not None:
user_id = token[self.litellm_jwtauth.end_user_id_jwt_field]
else:
user_id = None
except KeyError: except KeyError:
user_id = default_value user_id = default_value
return user_id return user_id
def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]: def get_team_id(self, token: dict, default_value: Optional[str]) -> Optional[str]:
try: try:
team_id = token["client_id"] team_id = token[self.litellm_jwtauth.team_id_jwt_field]
except KeyError: except KeyError:
team_id = default_value team_id = default_value
return team_id return team_id
async def get_user_object(self, user_id: str) -> LiteLLM_UserTable:
"""
- Check if user id in proxy User Table
- if valid, return LiteLLM_UserTable object with defined limits
- if not, then raise an error
"""
if self.prisma_client is None:
raise Exception(
"No DB Connected. See - https://docs.litellm.ai/docs/proxy/virtual_keys"
)
# check if in cache
cached_user_obj = self.user_api_key_cache.async_get_cache(key=user_id)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_UserTable(**cached_user_obj)
elif isinstance(cached_user_obj, LiteLLM_UserTable):
return cached_user_obj
# else, check db
try:
response = await self.prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
if response is None:
raise Exception
return LiteLLM_UserTable(**response.dict())
except Exception as e:
raise Exception(
f"User doesn't exist in db. User={user_id}. Create user via `/user/new` call."
)
def get_scopes(self, token: dict) -> list: def get_scopes(self, token: dict) -> list:
try: try:
if isinstance(token["scope"], str): if isinstance(token["scope"], str):
@ -162,7 +137,9 @@ class JWTHandler:
keys = response.json()["keys"] keys = response.json()["keys"]
await self.user_api_key_cache.async_set_cache( await self.user_api_key_cache.async_set_cache(
key="litellm_jwt_auth_keys", value=keys, ttl=600 # cache for 10 mins key="litellm_jwt_auth_keys",
value=keys,
ttl=self.litellm_jwtauth.public_key_ttl, # cache for 10 mins
) )
else: else:
keys = cached_keys keys = cached_keys

View file

@ -113,7 +113,10 @@ from litellm.proxy.hooks.prompt_injection_detection import (
from litellm.proxy.auth.auth_checks import ( from litellm.proxy.auth.auth_checks import (
common_checks, common_checks,
get_end_user_object, get_end_user_object,
get_team_object,
get_user_object,
allowed_routes_check, allowed_routes_check,
get_actual_routes,
) )
try: try:
@ -359,71 +362,99 @@ async def user_api_key_auth(
scopes = jwt_handler.get_scopes(token=valid_token) scopes = jwt_handler.get_scopes(token=valid_token)
# check if admin # check if admin
is_admin = jwt_handler.is_admin(scopes=scopes) is_admin = jwt_handler.is_admin(scopes=scopes)
# get user id # if admin return
user_id = jwt_handler.get_user_id( if is_admin:
token=valid_token, default_value=litellm_proxy_admin_name # check allowed admin routes
is_allowed = allowed_routes_check(
user_role="proxy_admin",
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed:
return UserAPIKeyAuth()
else:
allowed_routes = (
jwt_handler.litellm_jwtauth.admin_allowed_routes
)
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Admin not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
)
# check if team in scopes
is_team = jwt_handler.is_team(scopes=scopes)
if is_team == False:
raise Exception(
f"Missing both Admin and Team scopes from token. Either is required. Admin Scope={jwt_handler.litellm_jwtauth.admin_jwt_scope}, Team Scope={jwt_handler.litellm_jwtauth.team_jwt_scope}"
)
# get team id
team_id = jwt_handler.get_team_id(token=valid_token, default_value=None)
if team_id is None:
raise Exception(
f"No team id passed in. Field checked in jwt token - '{jwt_handler.litellm_jwtauth.team_id_jwt_field}'"
)
# check allowed team routes
is_allowed = allowed_routes_check(
user_role="team",
user_route=route,
litellm_proxy_roles=jwt_handler.litellm_jwtauth,
)
if is_allowed == False:
allowed_routes = jwt_handler.litellm_jwtauth.team_allowed_routes
actual_routes = get_actual_routes(allowed_routes=allowed_routes)
raise Exception(
f"Team not allowed to access this route. Route={route}, Allowed Routes={actual_routes}"
) )
end_user_object = None # check if team in db
team_object = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# common checks
# allow request
# get the request body # get the request body
request_data = await _read_request_body(request=request) request_data = await _read_request_body(request=request)
# get user obj from cache/db -> run for admin too. Ensures, admin client id in db.
user_object = await jwt_handler.get_user_object(user_id=user_id) end_user_object = None
if ( end_user_id = jwt_handler.get_end_user_id(
request_data.get("user", None) token=valid_token, default_value=None
and request_data["user"] != user_object.user_id )
): if end_user_id is not None:
# get the end-user object # get the end-user object
end_user_object = await get_end_user_object( end_user_object = await get_end_user_object(
end_user_id=request_data["user"], end_user_id=end_user_id,
prisma_client=prisma_client, prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
) )
# save the end-user object to cache # save the end-user object to cache
await user_api_key_cache.async_set_cache( await user_api_key_cache.async_set_cache(
key=request_data["user"], value=end_user_object key=end_user_id, value=end_user_object
) )
# run through common checks # run through common checks
_ = common_checks( _ = common_checks(
request_body=request_data, request_body=request_data,
user_object=user_object, team_object=team_object,
end_user_object=end_user_object, end_user_object=end_user_object,
) )
# save user object in cache # save user object in cache
await user_api_key_cache.async_set_cache( await user_api_key_cache.async_set_cache(
key=user_object.user_id, value=user_object key=team_object.team_id, value=team_object
) )
# if admin return
if is_admin:
return UserAPIKeyAuth(
api_key=api_key,
user_role="proxy_admin",
user_id=user_id,
)
else:
is_allowed = allowed_routes_check(
user_role="app_owner",
route=route,
allowed_routes=general_settings.get("allowed_routes", None),
)
if is_allowed:
# return UserAPIKeyAuth object # return UserAPIKeyAuth object
return UserAPIKeyAuth( return UserAPIKeyAuth(
api_key=None, api_key=None,
user_id=user_object.user_id, team_id=team_object.team_id,
tpm_limit=user_object.tpm_limit, tpm_limit=team_object.tpm_limit,
rpm_limit=user_object.rpm_limit, rpm_limit=team_object.rpm_limit,
models=user_object.models, models=team_object.models,
user_role="app_owner", user_role="app_owner",
) )
else:
raise HTTPException(
status_code=401,
detail={
"error": f"User={user_object.user_id} not allowed to access this route={route}."
},
)
#### ELSE #### #### ELSE ####
if master_key is None: if master_key is None:
if isinstance(api_key, str): if isinstance(api_key, str):
@ -2698,12 +2729,14 @@ async def startup_event():
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
## JWT AUTH ## ## JWT AUTH ##
if general_settings.get("litellm_jwtauth", None) is not None:
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
else:
litellm_jwtauth = LiteLLM_JWTAuth()
jwt_handler.update_environment( jwt_handler.update_environment(
prisma_client=prisma_client, prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
litellm_proxy_roles=LiteLLMProxyRoles( litellm_jwtauth=litellm_jwtauth,
**general_settings.get("litellm_proxy_roles", {})
),
) )
if use_background_health_checks: if use_background_health_checks:

View file

@ -12,7 +12,7 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
from litellm.proxy._types import LiteLLMProxyRoles from litellm.proxy._types import LiteLLM_JWTAuth
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.caching import DualCache from litellm.caching import DualCache
from datetime import datetime, timedelta from datetime import datetime, timedelta
@ -28,17 +28,17 @@ public_key = {
def test_load_config_with_custom_role_names(): def test_load_config_with_custom_role_names():
config = { config = {
"general_settings": { "general_settings": {
"litellm_proxy_roles": {"proxy_admin": "litellm-proxy-admin"} "litellm_proxy_roles": {"admin_jwt_scope": "litellm-proxy-admin"}
} }
} }
proxy_roles = LiteLLMProxyRoles( proxy_roles = LiteLLM_JWTAuth(
**config.get("general_settings", {}).get("litellm_proxy_roles", {}) **config.get("general_settings", {}).get("litellm_proxy_roles", {})
) )
print(f"proxy_roles: {proxy_roles}") print(f"proxy_roles: {proxy_roles}")
assert proxy_roles.proxy_admin == "litellm-proxy-admin" assert proxy_roles.admin_jwt_scope == "litellm-proxy-admin"
# test_load_config_with_custom_role_names() # test_load_config_with_custom_role_names()

View file

@ -500,6 +500,9 @@ class ModelResponse(OpenAIObject):
if choices is not None and isinstance(choices, list): if choices is not None and isinstance(choices, list):
new_choices = [] new_choices = []
for choice in choices: for choice in choices:
if isinstance(choice, StreamingChoices):
_new_choice = choice
elif isinstance(choice, dict):
_new_choice = StreamingChoices(**choice) _new_choice = StreamingChoices(**choice)
new_choices.append(_new_choice) new_choices.append(_new_choice)
choices = new_choices choices = new_choices
@ -513,6 +516,9 @@ class ModelResponse(OpenAIObject):
if choices is not None and isinstance(choices, list): if choices is not None and isinstance(choices, list):
new_choices = [] new_choices = []
for choice in choices: for choice in choices:
if isinstance(choice, Choices):
_new_choice = choice
elif isinstance(choice, dict):
_new_choice = Choices(**choice) _new_choice = Choices(**choice)
new_choices.append(_new_choice) new_choices.append(_new_choice)
choices = new_choices choices = new_choices
@ -7231,7 +7237,7 @@ def exception_type(
exception_mapping_worked = True exception_mapping_worked = True
raise APIError( raise APIError(
status_code=original_exception.status_code, status_code=original_exception.status_code,
message=f"AnthropicException - {original_exception.message}", message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.",
llm_provider="anthropic", llm_provider="anthropic",
model=model, model=model,
request=original_exception.request, request=original_exception.request,