mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #3663 from msabramo/msabramo/allow-non-admins-to-use-openai-routes
Allow non-admins to use `/engines/{model}/chat/completions`
This commit is contained in:
commit
ea976d8c30
3 changed files with 58 additions and 3 deletions
|
@ -52,8 +52,18 @@ class LiteLLM_UpperboundKeyGenerateParams(LiteLLMBase):
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMRoutes(enum.Enum):
|
class LiteLLMRoutes(enum.Enum):
|
||||||
|
openai_route_names: List = [
|
||||||
|
"chat_completion",
|
||||||
|
"completion",
|
||||||
|
"embeddings",
|
||||||
|
"image_generation",
|
||||||
|
"audio_transcriptions",
|
||||||
|
"moderations",
|
||||||
|
"model_list", # OpenAI /v1/models route
|
||||||
|
]
|
||||||
openai_routes: List = [
|
openai_routes: List = [
|
||||||
# chat completions
|
# chat completions
|
||||||
|
"/engines/{model}/chat/completions",
|
||||||
"/openai/deployments/{model}/chat/completions",
|
"/openai/deployments/{model}/chat/completions",
|
||||||
"/chat/completions",
|
"/chat/completions",
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
|
|
|
@ -1076,6 +1076,8 @@ async def user_api_key_auth(
|
||||||
if not _is_user_proxy_admin(user_id_information): # if non-admin
|
if not _is_user_proxy_admin(user_id_information): # if non-admin
|
||||||
if route in LiteLLMRoutes.openai_routes.value:
|
if route in LiteLLMRoutes.openai_routes.value:
|
||||||
pass
|
pass
|
||||||
|
elif request['route'].name in LiteLLMRoutes.openai_route_names.value:
|
||||||
|
pass
|
||||||
elif (
|
elif (
|
||||||
route in LiteLLMRoutes.info_routes.value
|
route in LiteLLMRoutes.info_routes.value
|
||||||
): # check if user allowed to call an info route
|
): # check if user allowed to call an info route
|
||||||
|
|
|
@ -23,6 +23,7 @@ import sys, os
|
||||||
import traceback
|
import traceback
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from fastapi.routing import APIRoute
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -51,6 +52,13 @@ from litellm.proxy.proxy_server import (
|
||||||
user_info,
|
user_info,
|
||||||
info_key_fn,
|
info_key_fn,
|
||||||
new_team,
|
new_team,
|
||||||
|
chat_completion,
|
||||||
|
completion,
|
||||||
|
embeddings,
|
||||||
|
image_generation,
|
||||||
|
audio_transcriptions,
|
||||||
|
moderations,
|
||||||
|
model_list,
|
||||||
)
|
)
|
||||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
|
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
@ -146,7 +154,38 @@ async def test_new_user_response(prisma_client):
|
||||||
pytest.fail(f"Got exception {e}")
|
pytest.fail(f"Got exception {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_generate_and_call_with_valid_key(prisma_client):
|
@pytest.mark.parametrize(
|
||||||
|
"api_route", [
|
||||||
|
# chat_completion
|
||||||
|
APIRoute(path="/engines/{model}/chat/completions", endpoint=chat_completion),
|
||||||
|
APIRoute(path="/openai/deployments/{model}/chat/completions", endpoint=chat_completion),
|
||||||
|
APIRoute(path="/chat/completions", endpoint=chat_completion),
|
||||||
|
APIRoute(path="/v1/chat/completions", endpoint=chat_completion),
|
||||||
|
# completion
|
||||||
|
APIRoute(path="/completions", endpoint=completion),
|
||||||
|
APIRoute(path="/v1/completions", endpoint=completion),
|
||||||
|
APIRoute(path="/engines/{model}/completions", endpoint=completion),
|
||||||
|
APIRoute(path="/openai/deployments/{model}/completions", endpoint=completion),
|
||||||
|
# embeddings
|
||||||
|
APIRoute(path="/v1/embeddings", endpoint=embeddings),
|
||||||
|
APIRoute(path="/embeddings", endpoint=embeddings),
|
||||||
|
APIRoute(path="/openai/deployments/{model}/embeddings", endpoint=embeddings),
|
||||||
|
# image generation
|
||||||
|
APIRoute(path="/v1/images/generations", endpoint=image_generation),
|
||||||
|
APIRoute(path="/images/generations", endpoint=image_generation),
|
||||||
|
# audio transcriptions
|
||||||
|
APIRoute(path="/v1/audio/transcriptions", endpoint=audio_transcriptions),
|
||||||
|
APIRoute(path="/audio/transcriptions", endpoint=audio_transcriptions),
|
||||||
|
# moderations
|
||||||
|
APIRoute(path="/v1/moderations", endpoint=moderations),
|
||||||
|
APIRoute(path="/moderations", endpoint=moderations),
|
||||||
|
# model_list
|
||||||
|
APIRoute(path= "/v1/models", endpoint=model_list),
|
||||||
|
APIRoute(path= "/models", endpoint=model_list),
|
||||||
|
],
|
||||||
|
ids=lambda route: str(dict(route=route.endpoint.__name__, path=route.path)),
|
||||||
|
)
|
||||||
|
def test_generate_and_call_with_valid_key(prisma_client, api_route):
|
||||||
# 1. Generate a Key, and use it to make a call
|
# 1. Generate a Key, and use it to make a call
|
||||||
|
|
||||||
print("prisma client=", prisma_client)
|
print("prisma client=", prisma_client)
|
||||||
|
@ -181,8 +220,12 @@ def test_generate_and_call_with_valid_key(prisma_client):
|
||||||
)
|
)
|
||||||
print("token from prisma", value_from_prisma)
|
print("token from prisma", value_from_prisma)
|
||||||
|
|
||||||
request = Request(scope={"type": "http"})
|
request = Request({
|
||||||
request._url = URL(url="/chat/completions")
|
"type": "http",
|
||||||
|
"route": api_route,
|
||||||
|
"path": api_route.path,
|
||||||
|
"headers": [("Authorization", bearer_token)]
|
||||||
|
})
|
||||||
|
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue