feat(proxy_server.py): add assistants api endpoints to proxy server

This commit is contained in:
Krrish Dholakia 2024-05-30 22:44:43 -07:00
parent 3167bee25a
commit e2b34165e7
6 changed files with 1741 additions and 51 deletions

View file

@ -1,11 +1,12 @@
# What is this? # What is this?
## Main file for assistants API logic ## Main file for assistants API logic
from typing import Iterable from typing import Iterable
import os from functools import partial
import os, asyncio, contextvars
import litellm import litellm
from openai import OpenAI from openai import OpenAI, AsyncOpenAI
from litellm import client from litellm import client
from litellm.utils import supports_httpx_timeout from litellm.utils import supports_httpx_timeout, exception_type, get_llm_provider
from ..llms.openai import OpenAIAssistantsAPI from ..llms.openai import OpenAIAssistantsAPI
from ..types.llms.openai import * from ..types.llms.openai import *
from ..types.router import * from ..types.router import *
@ -16,11 +17,49 @@ openai_assistants_api = OpenAIAssistantsAPI()
### ASSISTANTS ### ### ASSISTANTS ###
async def aget_assistants(
custom_llm_provider: Literal["openai"],
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[Assistant]:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["aget_assistants"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(get_assistants, custom_llm_provider, client, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def get_assistants( def get_assistants(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai"],
client: Optional[OpenAI] = None, client: Optional[OpenAI] = None,
**kwargs, **kwargs,
) -> SyncCursorPage[Assistant]: ) -> SyncCursorPage[Assistant]:
aget_assistants = kwargs.pop("aget_assistants", None)
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
@ -67,6 +106,7 @@ def get_assistants(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
organization=organization, organization=organization,
client=client, client=client,
aget_assistants=aget_assistants,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -87,6 +127,39 @@ def get_assistants(
### THREADS ### ### THREADS ###
async def acreate_thread(custom_llm_provider: Literal["openai"], **kwargs) -> Thread:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["acreate_thread"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(create_thread, custom_llm_provider, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def create_thread( def create_thread(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai"],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None, messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None,
@ -117,6 +190,7 @@ def create_thread(
) )
``` ```
""" """
acreate_thread = kwargs.get("acreate_thread", None)
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
@ -165,6 +239,7 @@ def create_thread(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
organization=organization, organization=organization,
client=client, client=client,
acreate_thread=acreate_thread,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -182,6 +257,44 @@ def create_thread(
return response return response
async def aget_thread(
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["aget_thread"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(get_thread, custom_llm_provider, thread_id, client, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def get_thread( def get_thread(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai"],
thread_id: str, thread_id: str,
@ -189,6 +302,7 @@ def get_thread(
**kwargs, **kwargs,
) -> Thread: ) -> Thread:
"""Get the thread object, given a thread_id""" """Get the thread object, given a thread_id"""
aget_thread = kwargs.pop("aget_thread", None)
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
@ -236,6 +350,7 @@ def get_thread(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
organization=organization, organization=organization,
client=client, client=client,
aget_thread=aget_thread,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -256,6 +371,59 @@ def get_thread(
### MESSAGES ### ### MESSAGES ###
async def a_add_message(
custom_llm_provider: Literal["openai"],
thread_id: str,
role: Literal["user", "assistant"],
content: str,
attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> OpenAIMessage:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["a_add_message"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(
add_message,
custom_llm_provider,
thread_id,
role,
content,
attachments,
metadata,
client,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def add_message( def add_message(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai"],
thread_id: str, thread_id: str,
@ -267,6 +435,7 @@ def add_message(
**kwargs, **kwargs,
) -> OpenAIMessage: ) -> OpenAIMessage:
### COMMON OBJECTS ### ### COMMON OBJECTS ###
a_add_message = kwargs.pop("a_add_message", None)
message_data = MessageData( message_data = MessageData(
role=role, content=content, attachments=attachments, metadata=metadata role=role, content=content, attachments=attachments, metadata=metadata
) )
@ -318,6 +487,7 @@ def add_message(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
organization=organization, organization=organization,
client=client, client=client,
a_add_message=a_add_message,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -336,12 +506,58 @@ def add_message(
return response return response
async def aget_messages(
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[OpenAIMessage]:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["aget_messages"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(
get_messages,
custom_llm_provider,
thread_id,
client,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def get_messages( def get_messages(
custom_llm_provider: Literal["openai"], custom_llm_provider: Literal["openai"],
thread_id: str, thread_id: str,
client: Optional[OpenAI] = None, client: Optional[OpenAI] = None,
**kwargs, **kwargs,
) -> SyncCursorPage[OpenAIMessage]: ) -> SyncCursorPage[OpenAIMessage]:
aget_messages = kwargs.pop("aget_messages", None)
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
@ -389,6 +605,7 @@ def get_messages(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
organization=organization, organization=organization,
client=client, client=client,
aget_messages=aget_messages,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(
@ -408,6 +625,63 @@ def get_messages(
### RUNS ### ### RUNS ###
async def arun_thread(
custom_llm_provider: Literal["openai"],
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str] = None,
instructions: Optional[str] = None,
metadata: Optional[dict] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Run:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["arun_thread"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(
run_thread,
custom_llm_provider,
thread_id,
assistant_id,
additional_instructions,
instructions,
metadata,
model,
stream,
tools,
client,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def run_thread( def run_thread(
@ -424,6 +698,7 @@ def run_thread(
**kwargs, **kwargs,
) -> Run: ) -> Run:
"""Run a given thread + assistant.""" """Run a given thread + assistant."""
arun_thread = kwargs.pop("arun_thread", None)
optional_params = GenericLiteLLMParams(**kwargs) optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ### ### TIMEOUT LOGIC ###
@ -478,6 +753,7 @@ def run_thread(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
organization=organization, organization=organization,
client=client, client=client,
arun_thread=arun_thread,
) )
else: else:
raise litellm.exceptions.BadRequestError( raise litellm.exceptions.BadRequestError(

View file

@ -6,7 +6,7 @@ from typing import (
Literal, Literal,
Iterable, Iterable,
) )
from typing_extensions import override from typing_extensions import override, overload
from pydantic import BaseModel from pydantic import BaseModel
import types, time, json, traceback import types, time, json, traceback
import httpx import httpx
@ -1846,8 +1846,71 @@ class OpenAIAssistantsAPI(BaseLLM):
return openai_client return openai_client
def async_get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> AsyncOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = AsyncOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
### ASSISTANTS ### ### ASSISTANTS ###
async def async_get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> AsyncCursorPage[Assistant]:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.assistants.list()
return response
# fmt: off
@overload
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_assistants: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
...
@overload
def get_assistants( def get_assistants(
self, self,
api_key: Optional[str], api_key: Optional[str],
@ -1856,7 +1919,31 @@ class OpenAIAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[OpenAI], client: Optional[OpenAI],
aget_assistants: Optional[Literal[False]],
) -> SyncCursorPage[Assistant]: ) -> SyncCursorPage[Assistant]:
...
# fmt: on
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
aget_assistants=None,
):
if aget_assistants is not None and aget_assistants == True:
return self.async_get_assistants(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client( openai_client = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1872,6 +1959,41 @@ class OpenAIAssistantsAPI(BaseLLM):
### MESSAGES ### ### MESSAGES ###
async def a_add_message(
self,
thread_id: str,
message_data: MessageData,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> OpenAIMessage:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data # type: ignore
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
# fmt: off
@overload
def add_message( def add_message(
self, self,
thread_id: str, thread_id: str,
@ -1881,9 +2003,51 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[OpenAI] = None, client: Optional[AsyncOpenAI],
) -> OpenAIMessage: a_add_message: Literal[True],
) -> Coroutine[None, None, OpenAIMessage]:
...
@overload
def add_message(
self,
thread_id: str,
message_data: MessageData,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
a_add_message: Optional[Literal[False]],
) -> OpenAIMessage:
...
# fmt: on
def add_message(
self,
thread_id: str,
message_data: MessageData,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
a_add_message: Optional[bool] = None,
):
if a_add_message is not None and a_add_message == True:
return self.a_add_message(
thread_id=thread_id,
message_data=message_data,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client( openai_client = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1905,6 +2069,32 @@ class OpenAIAssistantsAPI(BaseLLM):
response_obj = OpenAIMessage(**thread_message.dict()) response_obj = OpenAIMessage(**thread_message.dict())
return response_obj return response_obj
async def async_get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> AsyncCursorPage[OpenAIMessage]:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
# fmt: off
@overload
def get_messages( def get_messages(
self, self,
thread_id: str, thread_id: str,
@ -1913,8 +2103,48 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[OpenAI] = None, client: Optional[AsyncOpenAI],
aget_messages: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
...
@overload
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_messages: Optional[Literal[False]],
) -> SyncCursorPage[OpenAIMessage]: ) -> SyncCursorPage[OpenAIMessage]:
...
# fmt: on
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
aget_messages=None,
):
if aget_messages is not None and aget_messages == True:
return self.async_get_messages(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client( openai_client = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1930,6 +2160,39 @@ class OpenAIAssistantsAPI(BaseLLM):
### THREADS ### ### THREADS ###
async def async_create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
) -> Thread:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
# fmt: off
@overload
def create_thread( def create_thread(
self, self,
metadata: Optional[dict], metadata: Optional[dict],
@ -1938,9 +2201,41 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[OpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AsyncOpenAI],
acreate_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[OpenAI],
acreate_thread: Optional[Literal[False]],
) -> Thread: ) -> Thread:
...
# fmt: on
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client=None,
acreate_thread=None,
):
""" """
Here's an example: Here's an example:
``` ```
@ -1951,6 +2246,17 @@ class OpenAIAssistantsAPI(BaseLLM):
openai_api.create_thread(messages=[message]) openai_api.create_thread(messages=[message])
``` ```
""" """
if acreate_thread is not None and acreate_thread == True:
return self.async_create_thread(
metadata=metadata,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
messages=messages,
)
openai_client = self.get_openai_client( openai_client = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1970,6 +2276,46 @@ class OpenAIAssistantsAPI(BaseLLM):
return Thread(**message_thread.dict()) return Thread(**message_thread.dict())
async def async_get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> Thread:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
# fmt: off
@overload
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def get_thread( def get_thread(
self, self,
thread_id: str, thread_id: str,
@ -1979,7 +2325,33 @@ class OpenAIAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[OpenAI], client: Optional[OpenAI],
aget_thread: Optional[Literal[False]],
) -> Thread: ) -> Thread:
...
# fmt: on
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
aget_thread=None,
):
if aget_thread is not None and aget_thread == True:
return self.async_get_thread(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client( openai_client = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -1998,6 +2370,68 @@ class OpenAIAssistantsAPI(BaseLLM):
### RUNS ### ### RUNS ###
async def arun_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> Run:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response
# fmt: off
@overload
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
arun_thread: Literal[True],
) -> Coroutine[None, None, Run]:
...
@overload
def run_thread( def run_thread(
self, self,
thread_id: str, thread_id: str,
@ -2014,7 +2448,47 @@ class OpenAIAssistantsAPI(BaseLLM):
max_retries: Optional[int], max_retries: Optional[int],
organization: Optional[str], organization: Optional[str],
client: Optional[OpenAI], client: Optional[OpenAI],
arun_thread: Optional[Literal[False]],
) -> Run: ) -> Run:
...
# fmt: on
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
arun_thread=None,
):
if arun_thread is not None and arun_thread == True:
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client( openai_client = self.get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,

View file

@ -5078,6 +5078,765 @@ async def audio_transcriptions(
) )
######################################################################
# /v1/assistant Endpoints
######################################################################
@router.get(
"/v1/assistants",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.get(
"/assistants",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def get_assistants(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
data: Dict = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_assistants(
custom_llm_provider="openai", client=None, **data
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.post(
"/v1/threads",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.post(
"/threads",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def create_threads(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
data: Dict = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if "litellm_metadata" not in data:
data["litellm_metadata"] = {}
data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key
data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["litellm_metadata"]["headers"] = _headers
data["litellm_metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["litellm_metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["litellm_metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["litellm_metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["litellm_metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.acreate_thread(
custom_llm_provider="openai", client=None, **data
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.get(
"/v1/threads/{thread_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.get(
"/threads/{thread_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def get_thread(
request: Request,
thread_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
data: Dict = {}
try:
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_thread(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.post(
"/v1/threads/{thread_id}/messages",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.post(
"/threads/{thread_id}/messages",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def add_messages(
request: Request,
thread_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
data: Dict = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if "litellm_metadata" not in data:
data["litellm_metadata"] = {}
data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key
data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["litellm_metadata"]["headers"] = _headers
data["litellm_metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["litellm_metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["litellm_metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["litellm_metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["litellm_metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.a_add_message(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.get(
"/v1/threads/{thread_id}/messages",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.get(
"/threads/{thread_id}/messages",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def get_messages(
request: Request,
thread_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
data: Dict = {}
try:
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_messages(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.get(
"/v1/threads/{thread_id}/runs",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.get(
"/threads/{thread_id}/runs",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def run_thread(
request: Request,
thread_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
data: Dict = {}
try:
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if "litellm_metadata" not in data:
data["litellm_metadata"] = {}
data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key
data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["litellm_metadata"]["headers"] = _headers
data["litellm_metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["litellm_metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["litellm_metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["litellm_metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["litellm_metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.arun_thread(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
###################################################################### ######################################################################
# /v1/batches Endpoints # /v1/batches Endpoints

View file

@ -53,6 +53,16 @@ from litellm.types.router import (
) )
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.types.llms.openai import (
AsyncCursorPage,
Assistant,
Thread,
Attachment,
OpenAIMessage,
Run,
AssistantToolParam,
)
from typing import Iterable
class Router: class Router:
@ -1646,6 +1656,108 @@ class Router:
self.fail_calls[model_name] += 1 self.fail_calls[model_name] += 1
raise e raise e
#### ASSISTANTS API ####
async def aget_assistants(
self,
custom_llm_provider: Literal["openai"],
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[Assistant]:
return await litellm.aget_assistants(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def acreate_thread(
self,
custom_llm_provider: Literal["openai"],
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
return await litellm.acreate_thread(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def aget_thread(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
return await litellm.aget_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
client=client,
**kwargs,
)
async def a_add_message(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
role: Literal["user", "assistant"],
content: str,
attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> OpenAIMessage:
return await litellm.a_add_message(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
role=role,
content=content,
attachments=attachments,
metadata=metadata,
client=client,
**kwargs,
)
async def aget_messages(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[OpenAIMessage]:
return await litellm.aget_messages(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
client=client,
**kwargs,
)
async def arun_thread(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str] = None,
instructions: Optional[str] = None,
metadata: Optional[dict] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Run:
return await litellm.arun_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
client=client,
**kwargs,
)
#### [END] ASSISTANTS API ####
async def async_function_with_fallbacks(self, *args, **kwargs): async def async_function_with_fallbacks(self, *args, **kwargs):
""" """
Try calling the function_with_retries Try calling the function_with_retries

View file

@ -16,6 +16,7 @@ from litellm.llms.openai import (
MessageData, MessageData,
Thread, Thread,
OpenAIMessage as Message, OpenAIMessage as Message,
AsyncCursorPage,
) )
""" """
@ -26,26 +27,55 @@ V0 Scope:
""" """
def test_create_thread_litellm() -> Thread: @pytest.mark.asyncio
async def test_async_get_assistants():
assistants = await litellm.aget_assistants(custom_llm_provider="openai")
assert isinstance(assistants, AsyncCursorPage)
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_create_thread_litellm(sync_mode) -> Thread:
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
new_thread = create_thread(
custom_llm_provider="openai", if sync_mode:
messages=[message], # type: ignore new_thread = create_thread(
) custom_llm_provider="openai",
messages=[message], # type: ignore
)
else:
new_thread = await litellm.acreate_thread(
custom_llm_provider="openai",
messages=[message], # type: ignore
)
assert isinstance( assert isinstance(
new_thread, Thread new_thread, Thread
), f"type of thread={type(new_thread)}. Expected Thread-type" ), f"type of thread={type(new_thread)}. Expected Thread-type"
return new_thread return new_thread
def test_get_thread_litellm(): @pytest.mark.parametrize("sync_mode", [True, False])
new_thread = test_create_thread_litellm() @pytest.mark.asyncio
async def test_get_thread_litellm(sync_mode):
new_thread = test_create_thread_litellm(sync_mode)
received_thread = get_thread( if asyncio.iscoroutine(new_thread):
custom_llm_provider="openai", _new_thread = await new_thread
thread_id=new_thread.id, else:
) _new_thread = new_thread
if sync_mode:
received_thread = get_thread(
custom_llm_provider="openai",
thread_id=_new_thread.id,
)
else:
received_thread = await litellm.aget_thread(
custom_llm_provider="openai",
thread_id=_new_thread.id,
)
assert isinstance( assert isinstance(
received_thread, Thread received_thread, Thread
@ -53,50 +83,90 @@ def test_get_thread_litellm():
return new_thread return new_thread
def test_add_message_litellm(): @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_add_message_litellm(sync_mode):
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
new_thread = test_create_thread_litellm() new_thread = test_create_thread_litellm(sync_mode)
if asyncio.iscoroutine(new_thread):
_new_thread = await new_thread
else:
_new_thread = new_thread
# add message to thread # add message to thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
added_message = litellm.add_message( if sync_mode:
thread_id=new_thread.id, custom_llm_provider="openai", **message added_message = litellm.add_message(
) thread_id=_new_thread.id, custom_llm_provider="openai", **message
)
else:
added_message = await litellm.a_add_message(
thread_id=_new_thread.id, custom_llm_provider="openai", **message
)
print(f"added message: {added_message}") print(f"added message: {added_message}")
assert isinstance(added_message, Message) assert isinstance(added_message, Message)
def test_run_thread_litellm(): @pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_run_thread_litellm(sync_mode):
""" """
- Get Assistants - Get Assistants
- Create thread - Create thread
- Create run w/ Assistants + Thread - Create run w/ Assistants + Thread
""" """
assistants = litellm.get_assistants(custom_llm_provider="openai") if sync_mode:
assistants = litellm.get_assistants(custom_llm_provider="openai")
else:
assistants = await litellm.aget_assistants(custom_llm_provider="openai")
## get the first assistant ### ## get the first assistant ###
assistant_id = assistants.data[0].id assistant_id = assistants.data[0].id
new_thread = test_create_thread_litellm() new_thread = test_create_thread_litellm(sync_mode=sync_mode)
thread_id = new_thread.id if asyncio.iscoroutine(new_thread):
_new_thread = await new_thread
else:
_new_thread = new_thread
thread_id = _new_thread.id
# add message to thread # add message to thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
added_message = litellm.add_message(
thread_id=new_thread.id, custom_llm_provider="openai", **message
)
run = litellm.run_thread( if sync_mode:
custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id added_message = litellm.add_message(
) thread_id=_new_thread.id, custom_llm_provider="openai", **message
if run.status == "completed":
messages = litellm.get_messages(
thread_id=new_thread.id, custom_llm_provider="openai"
) )
assert isinstance(messages.data[0], Message)
run = litellm.run_thread(
custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id
)
if run.status == "completed":
messages = litellm.get_messages(
thread_id=_new_thread.id, custom_llm_provider="openai"
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")
else: else:
pytest.fail("An unexpected error occurred when running the thread") added_message = await litellm.a_add_message(
thread_id=_new_thread.id, custom_llm_provider="openai", **message
)
run = await litellm.arun_thread(
custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id
)
if run.status == "completed":
messages = await litellm.aget_messages(
thread_id=_new_thread.id, custom_llm_provider="openai"
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")

View file

@ -17,11 +17,10 @@ from openai.types.beta.thread_create_params import (
from openai.types.beta.assistant_tool_param import AssistantToolParam from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.threads.run import Run from openai.types.beta.threads.run import Run
from openai.types.beta.assistant import Assistant from openai.types.beta.assistant import Assistant
from openai.pagination import SyncCursorPage from openai.pagination import SyncCursorPage, AsyncCursorPage
from os import PathLike from os import PathLike
from openai.types import FileObject, Batch from openai.types import FileObject, Batch
from openai._legacy_response import HttpxBinaryResponseContent from openai._legacy_response import HttpxBinaryResponseContent
from typing import TypedDict, List, Optional, Tuple, Mapping, IO from typing import TypedDict, List, Optional, Tuple, Mapping, IO
FileContent = Union[IO[bytes], bytes, PathLike] FileContent = Union[IO[bytes], bytes, PathLike]