Merge pull request #4627 from BerriAI/litellm_fix_thread_auth

[Fix] Authentication on /thread endpoints on Proxy
This commit is contained in:
Ishaan Jaff 2024-07-09 12:19:19 -07:00 committed by GitHub
commit 6bce7e73a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 493 additions and 25 deletions

View file

@ -1,19 +1,27 @@
# What is this?
## Main file for assistants API logic
from typing import Iterable
import asyncio
import contextvars
import os
from functools import partial
import os, asyncio, contextvars
from typing import Any, Dict, Iterable, List, Literal, Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from openai.types.beta.assistant import Assistant
import litellm
from openai import OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI
from litellm import client
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import (
supports_httpx_timeout,
exception_type,
get_llm_provider,
get_secret,
supports_httpx_timeout,
)
from ..llms.openai import OpenAIAssistantsAPI
from ..llms.azure import AzureAssistantsAPI
from ..llms.openai import OpenAIAssistantsAPI
from ..types.llms.openai import *
from ..types.router import *
from .utils import get_optional_params_add_message
@ -178,6 +186,159 @@ def get_assistants(
return response
async def acreate_assistants(
custom_llm_provider: Literal["openai", "azure"],
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Assistant:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["async_create_assistants"] = True
try:
model = kwargs.pop("model", None)
kwargs["client"] = client
# Use a partial function to pass your keyword arguments
func = partial(create_assistants, custom_llm_provider, model, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model=model, custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model=model,
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def create_assistants(
custom_llm_provider: Literal["openai", "azure"],
model: str,
name: Optional[str] = None,
description: Optional[str] = None,
instructions: Optional[str] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_resources: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, str]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
response_format: Optional[Union[str, Dict[str, str]]] = None,
client: Optional[Any] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
**kwargs,
) -> Assistant:
async_create_assistants: Optional[bool] = kwargs.pop(
"async_create_assistants", None
)
if async_create_assistants is not None and not isinstance(
async_create_assistants, bool
):
raise ValueError(
"Invalid value passed in for async_create_assistants. Only bool or None allowed"
)
optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[Assistant] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
create_assistant_data = {
"model": model,
"name": name,
"description": description,
"instructions": instructions,
"tools": tools,
"tool_resources": tool_resources,
"metadata": metadata,
"temperature": temperature,
"top_p": top_p,
"response_format": response_format,
}
response = openai_assistants_api.create_assistants(
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
create_assistant_data=create_assistant_data,
client=client,
async_create_assistants=async_create_assistants, # type: ignore
) # type: ignore
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'create_assistants'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
if response is None:
raise litellm.exceptions.InternalServerError(
message="No response returned from 'create_assistants'",
model=model,
llm_provider=custom_llm_provider,
)
return response
### THREADS ###

View file

@ -2383,6 +2383,63 @@ class OpenAIAssistantsAPI(BaseLLM):
return response
# Create Assistant
async def async_create_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
create_assistant_data: dict,
) -> Assistant:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.assistants.create(**create_assistant_data)
return response
def create_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
create_assistant_data: dict,
client=None,
async_create_assistants=None,
):
if async_create_assistants is not None and async_create_assistants == True:
return self.async_create_assistants(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
create_assistant_data=create_assistant_data,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.assistants.create(**create_assistant_data)
return response
### MESSAGES ###
async def a_add_message(

View file

@ -175,10 +175,12 @@ class LiteLLMRoutes(enum.Enum):
"/chat/completions",
"/v1/chat/completions",
# completions
"/engines/{model}/completions",
"/openai/deployments/{model}/completions",
"/completions",
"/v1/completions",
# embeddings
"/engines/{model}/embeddings",
"/openai/deployments/{model}/embeddings",
"/embeddings",
"/v1/embeddings",

View file

@ -24,6 +24,7 @@ from litellm.proxy._types import (
LitellmUserRoles,
UserAPIKeyAuth,
)
from litellm.proxy.auth.auth_utils import is_openai_route
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
@ -105,7 +106,7 @@ def common_checks(
general_settings.get("enforce_user_param", None) is not None
and general_settings["enforce_user_param"] == True
):
if route in LiteLLMRoutes.openai_routes.value and "user" not in request_body:
if is_openai_route(route=route) and "user" not in request_body:
raise Exception(
f"'user' param not passed in. 'enforce_user_param'={general_settings['enforce_user_param']}"
)
@ -121,7 +122,7 @@ def common_checks(
+ CommonProxyErrors.not_premium_user.value
)
if route in LiteLLMRoutes.openai_routes.value:
if is_openai_route(route=route):
# loop through each enforced param
# example enforced_params ['user', 'metadata', 'metadata.generation_name']
for enforced_param in general_settings["enforced_params"]:
@ -149,7 +150,7 @@ def common_checks(
and global_proxy_spend is not None
# only run global budget checks for OpenAI routes
# Reason - the Admin UI should continue working if the proxy crosses it's global budget
and route in LiteLLMRoutes.openai_routes.value
and is_openai_route(route=route)
and route != "/v1/models"
and route != "/models"
):

View file

@ -1,4 +1,7 @@
import re
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
def route_in_additonal_public_routes(current_route: str):
@ -41,3 +44,31 @@ def route_in_additonal_public_routes(current_route: str):
except Exception as e:
verbose_proxy_logger.error(f"route_in_additonal_public_routes: {str(e)}")
return False
def is_openai_route(route: str) -> bool:
"""
Helper to checks if provided route is an OpenAI route
Returns:
- True: if route is an OpenAI route
- False: if route is not an OpenAI route
"""
if route in LiteLLMRoutes.openai_routes.value:
return True
# fuzzy match routes like "/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ"
# Check for routes with placeholders
for openai_route in LiteLLMRoutes.openai_routes.value:
# Replace placeholders with regex pattern
# placeholders are written as "/threads/{thread_id}"
if "{" in openai_route:
pattern = re.sub(r"\{[^}]+\}", r"[^/]+", openai_route)
# Anchor the pattern to match the entire string
pattern = f"^{pattern}$"
if re.match(pattern, route):
return True
return False

View file

@ -56,7 +56,10 @@ from litellm.proxy.auth.auth_checks import (
get_user_object,
log_to_opentelemetry,
)
from litellm.proxy.auth.auth_utils import route_in_additonal_public_routes
from litellm.proxy.auth.auth_utils import (
is_openai_route,
route_in_additonal_public_routes,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.utils import _to_ns
@ -933,9 +936,9 @@ async def user_api_key_auth(
_user_role = _get_user_role(user_id_information=user_id_information)
if not _is_user_proxy_admin(user_id_information): # if non-admin
if route in LiteLLMRoutes.openai_routes.value:
if is_openai_route(route=route):
pass
elif request["route"].name in LiteLLMRoutes.openai_route_names.value:
elif is_openai_route(route=request["route"].name):
pass
elif (
route in LiteLLMRoutes.info_routes.value
@ -988,7 +991,7 @@ async def user_api_key_auth(
pass
elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value:
if route in LiteLLMRoutes.openai_routes.value:
if is_openai_route(route=route):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"user not allowed to access this OpenAI routes, role= {_user_role}",

View file

@ -33,11 +33,11 @@ def _get_metadata_variable_name(request: Request) -> str:
"""
Helper to return what the "metadata" field should be called in the request data
For all /thread endpoints we need to call this "litellm_metadata"
For all /thread or /assistant endpoints we need to call this "litellm_metadata"
For ALL other endpoints we call this "metadata
"""
if "thread" in request.url.path:
if "thread" in request.url.path or "assistant" in request.url.path:
return "litellm_metadata"
else:
return "metadata"

View file

@ -35,7 +35,6 @@ general_settings:
LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY"
litellm_settings:
callbacks: ["otel"]
guardrails:
- prompt_injection:
callbacks: [lakera_prompt_injection, hide_secrets]
@ -43,6 +42,9 @@ litellm_settings:
- hide_secrets:
callbacks: [hide_secrets]
default_on: true
assistant_settings:
custom_llm_provider: openai
litellm_params:
api_key: os.environ/OPENAI_API_KEY

View file

@ -3960,6 +3960,101 @@ async def get_assistants(
)
@router.post(
"/v1/assistants",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.post(
"/assistants",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def create_assistant(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create assistant
API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant
"""
global proxy_logging_obj
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.acreate_assistants(**data)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.create_assistant(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.post(
"/v1/threads",
dependencies=[Depends(user_api_key_auth)],

View file

@ -1970,6 +1970,25 @@ class Router:
#### ASSISTANTS API ####
async def acreate_assistants(
self,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Assistant:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.acreate_assistants(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def aget_assistants(
self,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,

View file

@ -1,27 +1,34 @@
# What is this?
## Unit Tests for OpenAI Assistants API
import sys, os, json
import json
import os
import sys
import traceback
from dotenv import load_dotenv
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest, logging, asyncio
import asyncio
import logging
import pytest
from openai.types.beta.assistant import Assistant
from typing_extensions import override
import litellm
from litellm import create_thread, get_thread
from litellm.llms.openai import (
OpenAIAssistantsAPI,
MessageData,
Thread,
OpenAIMessage as Message,
AsyncCursorPage,
SyncCursorPage,
AssistantEventHandler,
AsyncAssistantEventHandler,
AsyncCursorPage,
MessageData,
OpenAIAssistantsAPI,
)
from typing_extensions import override
from litellm.llms.openai import OpenAIMessage as Message
from litellm.llms.openai import SyncCursorPage, Thread
"""
V0 Scope:
@ -52,6 +59,49 @@ async def test_get_assistants(provider, sync_mode):
assert isinstance(assistants, AsyncCursorPage)
@pytest.mark.parametrize("provider", ["openai"])
@pytest.mark.parametrize(
"sync_mode",
[False],
)
@pytest.mark.asyncio
async def test_create_assistants(provider, sync_mode):
data = {
"custom_llm_provider": provider,
}
if sync_mode == True:
assistant = litellm.create_assistants(
custom_llm_provider="openai",
model="gpt-4-turbo",
instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
name="Math Tutor",
tools=[{"type": "code_interpreter"}],
)
print("New assistants", assistant)
assert isinstance(assistant, Assistant)
assert (
assistant.instructions
== "You are a personal math tutor. When asked a question, write and run Python code to answer the question."
)
assert assistant.id is not None
else:
assistant = await litellm.acreate_assistants(
custom_llm_provider="openai",
model="gpt-4-turbo",
instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.",
name="Math Tutor",
tools=[{"type": "code_interpreter"}],
)
print("New assistants", assistant)
assert isinstance(assistant, Assistant)
assert (
assistant.instructions
== "You are a personal math tutor. When asked a question, write and run Python code to answer the question."
)
assert assistant.id is not None
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio

View file

@ -213,6 +213,10 @@ async def test_new_user_response(prisma_client):
# model_list
APIRoute(path="/v1/models", endpoint=model_list),
APIRoute(path="/models", endpoint=model_list),
# threads
APIRoute(
path="/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ", endpoint=model_list
),
],
ids=lambda route: str(dict(route=route.endpoint.__name__, path=route.path)),
)

View file

@ -19,6 +19,7 @@ import pytest
import litellm
from litellm.proxy._types import LiteLLMRoutes
from litellm.proxy.auth.auth_utils import is_openai_route
from litellm.proxy.proxy_server import router
# Configure logging
@ -50,3 +51,45 @@ def test_routes_on_litellm_proxy():
for route in LiteLLMRoutes.openai_routes.value:
assert route in _all_routes
@pytest.mark.parametrize(
"route,expected",
[
# Test exact matches
("/chat/completions", True),
("/v1/chat/completions", True),
("/embeddings", True),
("/v1/models", True),
("/utils/token_counter", True),
# Test routes with placeholders
("/engines/gpt-4/chat/completions", True),
("/openai/deployments/gpt-3.5-turbo/chat/completions", True),
("/threads/thread_49EIN5QF32s4mH20M7GFKdlZ", True),
("/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ", True),
("/threads/thread_49EIN5QF32s4mH20M7GFKdlZ/messages", True),
("/v1/threads/thread_49EIN5QF32s4mH20M7GFKdlZ/runs", True),
("/v1/batches123456", True),
# Test non-OpenAI routes
("/some/random/route", False),
("/v2/chat/completions", False),
("/threads/invalid/format", False),
("/v1/non_existent_endpoint", False),
],
)
def test_is_openai_route(route: str, expected: bool):
assert is_openai_route(route) == expected
# Test case for routes that are similar but should return False
@pytest.mark.parametrize(
"route",
[
"/v1/threads/thread_id/invalid",
"/threads/thread_id/invalid",
"/v1/batches/123/invalid",
"/engines/model/invalid/completions",
],
)
def test_is_openai_route_similar_but_false(route: str):
assert is_openai_route(route) == False