Merge pull request #4645 from BerriAI/litellm_add_assistants_delete_endpoint

[Feat-Proxy] Add DELETE /assistants
This commit is contained in:
Ishaan Jaff 2024-07-10 11:45:37 -07:00 committed by GitHub
commit e4dbd5abd4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 325 additions and 3 deletions

View file

@ -4,11 +4,12 @@ import asyncio
import contextvars
import os
from functools import partial
from typing import Any, Dict, Iterable, List, Literal, Optional, Union
from typing import Any, Coroutine, Dict, Iterable, List, Literal, Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from openai.types.beta.assistant import Assistant
from openai.types.beta.assistant_deleted import AssistantDeleted
import litellm
from litellm import client
@ -339,6 +340,139 @@ def create_assistants(
return response
async def adelete_assistant(
custom_llm_provider: Literal["openai", "azure"],
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AssistantDeleted:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["async_delete_assistants"] = True
try:
kwargs["client"] = client
# Use a partial function to pass your keyword arguments
func = partial(delete_assistant, custom_llm_provider, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def delete_assistant(
custom_llm_provider: Literal["openai", "azure"],
assistant_id: str,
client: Optional[Any] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
**kwargs,
) -> AssistantDeleted:
optional_params = GenericLiteLLMParams(
api_key=api_key, api_base=api_base, api_version=api_version, **kwargs
)
async_delete_assistants: Optional[bool] = kwargs.pop(
"async_delete_assistants", None
)
if async_delete_assistants is not None and not isinstance(
async_delete_assistants, bool
):
raise ValueError(
"Invalid value passed in for async_delete_assistants. Only bool or None allowed"
)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
timeout = float(timeout) # type: ignore
elif timeout is None:
timeout = 600.0
response: Optional[AssistantDeleted] = None
if custom_llm_provider == "openai":
api_base = (
optional_params.api_base
or litellm.api_base
or os.getenv("OPENAI_API_BASE")
or "https://api.openai.com/v1"
)
organization = (
optional_params.organization
or litellm.organization
or os.getenv("OPENAI_ORGANIZATION", None)
or None
)
# set API KEY
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.openai_key
or os.getenv("OPENAI_API_KEY")
)
response = openai_assistants_api.delete_assistant(
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=optional_params.max_retries,
organization=organization,
assistant_id=assistant_id,
client=client,
async_delete_assistants=async_delete_assistants,
)
else:
raise litellm.exceptions.BadRequestError(
message="LiteLLM doesn't support {} for 'delete_assistant'. Only 'openai' is supported.".format(
custom_llm_provider
),
model="n/a",
llm_provider=custom_llm_provider,
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(
method="delete_assistant", url="https://github.com/BerriAI/litellm"
),
),
)
if response is None:
raise litellm.exceptions.InternalServerError(
message="No response returned from 'delete_assistant'",
model="n/a",
llm_provider=custom_llm_provider,
)
return response
### THREADS ###

View file

@ -17,6 +17,7 @@ from typing import (
import httpx
import openai
from openai import AsyncOpenAI, OpenAI
from openai.types.beta.assistant_deleted import AssistantDeleted
from pydantic import BaseModel
from typing_extensions import overload, override
@ -2440,6 +2441,63 @@ class OpenAIAssistantsAPI(BaseLLM):
response = openai_client.beta.assistants.create(**create_assistant_data)
return response
# Delete Assistant
async def async_delete_assistant(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
assistant_id: str,
) -> AssistantDeleted:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.assistants.delete(assistant_id=assistant_id)
return response
def delete_assistant(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
assistant_id: str,
client=None,
async_delete_assistants=None,
):
if async_delete_assistants is not None and async_delete_assistants == True:
return self.async_delete_assistant(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
assistant_id=assistant_id,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = openai_client.beta.assistants.delete(assistant_id=assistant_id)
return response
### MESSAGES ###
async def a_add_message(

View file

@ -207,6 +207,8 @@ class LiteLLMRoutes(enum.Enum):
# assistants-related routes
"/assistants",
"/v1/assistants",
"/v1/assistants/{assistant_id}",
"/assistants/{assistant_id}",
"/threads",
"/v1/threads",
"/threads/{thread_id}",

View file

@ -4059,6 +4059,101 @@ async def create_assistant(
)
@router.delete(
"/v1/assistants/{assistant_id:path}",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.delete(
"/assistants/{assistant_id:path}",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def delete_assistant(
request: Request,
assistant_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Delete assistant
API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant
"""
global proxy_logging_obj
data: Dict = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
# Include original request and headers in the data
data = await add_litellm_data_to_request(
data=data,
request=request,
general_settings=general_settings,
user_api_key_dict=user_api_key_dict,
version=version,
proxy_config=proxy_config,
)
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.adelete_assistant(assistant_id=assistant_id, **data)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
litellm_call_id=data.get("litellm_call_id", ""), status="success"
)
)
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.delete_assistant(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.post(
"/v1/threads",
dependencies=[Depends(user_api_key_auth)],

View file

@ -43,6 +43,7 @@ from typing_extensions import overload
import litellm
from litellm._logging import verbose_router_logger
from litellm.assistants.main import AssistantDeleted
from litellm.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
@ -1989,6 +1990,25 @@ class Router:
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def adelete_assistant(
self,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AssistantDeleted:
if custom_llm_provider is None:
if self.assistants_config is not None:
custom_llm_provider = self.assistants_config["custom_llm_provider"]
kwargs.update(self.assistants_config["litellm_params"])
else:
raise Exception(
"'custom_llm_provider' must be set. Either via:\n `Router(assistants_config={'custom_llm_provider': ..})` \nor\n `router.arun_thread(custom_llm_provider=..)`"
)
return await litellm.adelete_assistant(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def aget_assistants(
self,
custom_llm_provider: Optional[Literal["openai", "azure"]] = None,

View file

@ -62,10 +62,10 @@ async def test_get_assistants(provider, sync_mode):
@pytest.mark.parametrize("provider", ["openai"])
@pytest.mark.parametrize(
"sync_mode",
[False],
[True, False],
)
@pytest.mark.asyncio
async def test_create_assistants(provider, sync_mode):
async def test_create_delete_assistants(provider, sync_mode):
data = {
"custom_llm_provider": provider,
}
@ -85,6 +85,13 @@ async def test_create_assistants(provider, sync_mode):
== "You are a personal math tutor. When asked a question, write and run Python code to answer the question."
)
assert assistant.id is not None
# delete the created assistant
response = litellm.delete_assistant(
custom_llm_provider="openai", assistant_id=assistant.id
)
print("Response deleting assistant", response)
assert response.id == assistant.id
else:
assistant = await litellm.acreate_assistants(
custom_llm_provider="openai",
@ -101,6 +108,12 @@ async def test_create_assistants(provider, sync_mode):
)
assert assistant.id is not None
response = await litellm.adelete_assistant(
custom_llm_provider="openai", assistant_id=assistant.id
)
print("Response deleting assistant", response)
assert response.id == assistant.id
@pytest.mark.parametrize("provider", ["openai", "azure"])
@pytest.mark.parametrize("sync_mode", [True, False])