feat(proxy_server.py): add assistants api endpoints to proxy server

This commit is contained in:
Krrish Dholakia 2024-05-30 22:44:43 -07:00
parent 3167bee25a
commit e2b34165e7
6 changed files with 1741 additions and 51 deletions

View file

@ -1,11 +1,12 @@
# What is this?
## Main file for assistants API logic
from typing import Iterable
import os
from functools import partial
import os, asyncio, contextvars
import litellm
from openai import OpenAI
from openai import OpenAI, AsyncOpenAI
from litellm import client
from litellm.utils import supports_httpx_timeout
from litellm.utils import supports_httpx_timeout, exception_type, get_llm_provider
from ..llms.openai import OpenAIAssistantsAPI
from ..types.llms.openai import *
from ..types.router import *
@ -16,11 +17,49 @@ openai_assistants_api = OpenAIAssistantsAPI()
### ASSISTANTS ###
async def aget_assistants(
custom_llm_provider: Literal["openai"],
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[Assistant]:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["aget_assistants"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(get_assistants, custom_llm_provider, client, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def get_assistants(
custom_llm_provider: Literal["openai"],
client: Optional[OpenAI] = None,
**kwargs,
) -> SyncCursorPage[Assistant]:
aget_assistants = kwargs.pop("aget_assistants", None)
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
@ -67,6 +106,7 @@ def get_assistants(
max_retries=optional_params.max_retries,
organization=organization,
client=client,
aget_assistants=aget_assistants,
)
else:
raise litellm.exceptions.BadRequestError(
@ -87,6 +127,39 @@ def get_assistants(
### THREADS ###
async def acreate_thread(custom_llm_provider: Literal["openai"], **kwargs) -> Thread:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["acreate_thread"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(create_thread, custom_llm_provider, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def create_thread(
custom_llm_provider: Literal["openai"],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None,
@ -117,6 +190,7 @@ def create_thread(
)
```
"""
acreate_thread = kwargs.get("acreate_thread", None)
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
@ -165,6 +239,7 @@ def create_thread(
max_retries=optional_params.max_retries,
organization=organization,
client=client,
acreate_thread=acreate_thread,
)
else:
raise litellm.exceptions.BadRequestError(
@ -182,6 +257,44 @@ def create_thread(
return response
async def aget_thread(
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["aget_thread"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(get_thread, custom_llm_provider, thread_id, client, **kwargs)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def get_thread(
custom_llm_provider: Literal["openai"],
thread_id: str,
@ -189,6 +302,7 @@ def get_thread(
**kwargs,
) -> Thread:
"""Get the thread object, given a thread_id"""
aget_thread = kwargs.pop("aget_thread", None)
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
@ -236,6 +350,7 @@ def get_thread(
max_retries=optional_params.max_retries,
organization=organization,
client=client,
aget_thread=aget_thread,
)
else:
raise litellm.exceptions.BadRequestError(
@ -256,6 +371,59 @@ def get_thread(
### MESSAGES ###
async def a_add_message(
custom_llm_provider: Literal["openai"],
thread_id: str,
role: Literal["user", "assistant"],
content: str,
attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> OpenAIMessage:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["a_add_message"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(
add_message,
custom_llm_provider,
thread_id,
role,
content,
attachments,
metadata,
client,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def add_message(
custom_llm_provider: Literal["openai"],
thread_id: str,
@ -267,6 +435,7 @@ def add_message(
**kwargs,
) -> OpenAIMessage:
### COMMON OBJECTS ###
a_add_message = kwargs.pop("a_add_message", None)
message_data = MessageData(
role=role, content=content, attachments=attachments, metadata=metadata
)
@ -318,6 +487,7 @@ def add_message(
max_retries=optional_params.max_retries,
organization=organization,
client=client,
a_add_message=a_add_message,
)
else:
raise litellm.exceptions.BadRequestError(
@ -336,12 +506,58 @@ def add_message(
return response
async def aget_messages(
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[OpenAIMessage]:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["aget_messages"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(
get_messages,
custom_llm_provider,
thread_id,
client,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def get_messages(
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[OpenAI] = None,
**kwargs,
) -> SyncCursorPage[OpenAIMessage]:
aget_messages = kwargs.pop("aget_messages", None)
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
@ -389,6 +605,7 @@ def get_messages(
max_retries=optional_params.max_retries,
organization=organization,
client=client,
aget_messages=aget_messages,
)
else:
raise litellm.exceptions.BadRequestError(
@ -408,6 +625,63 @@ def get_messages(
### RUNS ###
async def arun_thread(
custom_llm_provider: Literal["openai"],
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str] = None,
instructions: Optional[str] = None,
metadata: Optional[dict] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Run:
loop = asyncio.get_event_loop()
### PASS ARGS TO GET ASSISTANTS ###
kwargs["arun_thread"] = True
try:
# Use a partial function to pass your keyword arguments
func = partial(
run_thread,
custom_llm_provider,
thread_id,
assistant_id,
additional_instructions,
instructions,
metadata,
model,
stream,
tools,
client,
**kwargs,
)
# Add the context to the function
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
model="", custom_llm_provider=custom_llm_provider
) # type: ignore
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
# Call the synchronous function using run_in_executor
response = init_response
return response # type: ignore
except Exception as e:
raise exception_type(
model="",
custom_llm_provider=custom_llm_provider,
original_exception=e,
completion_kwargs={},
extra_kwargs=kwargs,
)
def run_thread(
@ -424,6 +698,7 @@ def run_thread(
**kwargs,
) -> Run:
"""Run a given thread + assistant."""
arun_thread = kwargs.pop("arun_thread", None)
optional_params = GenericLiteLLMParams(**kwargs)
### TIMEOUT LOGIC ###
@ -478,6 +753,7 @@ def run_thread(
max_retries=optional_params.max_retries,
organization=organization,
client=client,
arun_thread=arun_thread,
)
else:
raise litellm.exceptions.BadRequestError(

View file

@ -6,7 +6,7 @@ from typing import (
Literal,
Iterable,
)
from typing_extensions import override
from typing_extensions import override, overload
from pydantic import BaseModel
import types, time, json, traceback
import httpx
@ -1846,8 +1846,71 @@ class OpenAIAssistantsAPI(BaseLLM):
return openai_client
def async_get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> AsyncOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = AsyncOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
### ASSISTANTS ###
async def async_get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> AsyncCursorPage[Assistant]:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.assistants.list()
return response
# fmt: off
@overload
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_assistants: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
...
@overload
def get_assistants(
self,
api_key: Optional[str],
@ -1856,7 +1919,31 @@ class OpenAIAssistantsAPI(BaseLLM):
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_assistants: Optional[Literal[False]],
) -> SyncCursorPage[Assistant]:
...
# fmt: on
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
aget_assistants=None,
):
if aget_assistants is not None and aget_assistants == True:
return self.async_get_assistants(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@ -1872,6 +1959,41 @@ class OpenAIAssistantsAPI(BaseLLM):
### MESSAGES ###
async def a_add_message(
self,
thread_id: str,
message_data: MessageData,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> OpenAIMessage:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data # type: ignore
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
# fmt: off
@overload
def add_message(
self,
thread_id: str,
@ -1881,9 +2003,51 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAIMessage:
client: Optional[AsyncOpenAI],
a_add_message: Literal[True],
) -> Coroutine[None, None, OpenAIMessage]:
...
@overload
def add_message(
self,
thread_id: str,
message_data: MessageData,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
a_add_message: Optional[Literal[False]],
) -> OpenAIMessage:
...
# fmt: on
def add_message(
self,
thread_id: str,
message_data: MessageData,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
a_add_message: Optional[bool] = None,
):
if a_add_message is not None and a_add_message == True:
return self.a_add_message(
thread_id=thread_id,
message_data=message_data,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@ -1905,6 +2069,32 @@ class OpenAIAssistantsAPI(BaseLLM):
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
async def async_get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> AsyncCursorPage[OpenAIMessage]:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
# fmt: off
@overload
def get_messages(
self,
thread_id: str,
@ -1913,8 +2103,48 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
client: Optional[AsyncOpenAI],
aget_messages: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
...
@overload
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_messages: Optional[Literal[False]],
) -> SyncCursorPage[OpenAIMessage]:
...
# fmt: on
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
aget_messages=None,
):
if aget_messages is not None and aget_messages == True:
return self.async_get_messages(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@ -1930,6 +2160,39 @@ class OpenAIAssistantsAPI(BaseLLM):
### THREADS ###
async def async_create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
) -> Thread:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
# fmt: off
@overload
def create_thread(
self,
metadata: Optional[dict],
@ -1938,9 +2201,41 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AsyncOpenAI],
acreate_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[OpenAI],
acreate_thread: Optional[Literal[False]],
) -> Thread:
...
# fmt: on
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client=None,
acreate_thread=None,
):
"""
Here's an example:
```
@ -1951,6 +2246,17 @@ class OpenAIAssistantsAPI(BaseLLM):
openai_api.create_thread(messages=[message])
```
"""
if acreate_thread is not None and acreate_thread == True:
return self.async_create_thread(
metadata=metadata,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
messages=messages,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@ -1970,6 +2276,46 @@ class OpenAIAssistantsAPI(BaseLLM):
return Thread(**message_thread.dict())
async def async_get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> Thread:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
# fmt: off
@overload
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def get_thread(
self,
thread_id: str,
@ -1979,7 +2325,33 @@ class OpenAIAssistantsAPI(BaseLLM):
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_thread: Optional[Literal[False]],
) -> Thread:
...
# fmt: on
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
aget_thread=None,
):
if aget_thread is not None and aget_thread == True:
return self.async_get_thread(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@ -1998,6 +2370,68 @@ class OpenAIAssistantsAPI(BaseLLM):
### RUNS ###
async def arun_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> Run:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response
# fmt: off
@overload
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
arun_thread: Literal[True],
) -> Coroutine[None, None, Run]:
...
@overload
def run_thread(
self,
thread_id: str,
@ -2014,7 +2448,47 @@ class OpenAIAssistantsAPI(BaseLLM):
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
arun_thread: Optional[Literal[False]],
) -> Run:
...
# fmt: on
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client=None,
arun_thread=None,
):
if arun_thread is not None and arun_thread == True:
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,

View file

@ -5078,6 +5078,765 @@ async def audio_transcriptions(
)
######################################################################
# /v1/assistant Endpoints
######################################################################
@router.get(
"/v1/assistants",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.get(
"/assistants",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def get_assistants(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
data: Dict = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_assistants(
custom_llm_provider="openai", client=None, **data
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.post(
"/v1/threads",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.post(
"/threads",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def create_threads(
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
data: Dict = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if "litellm_metadata" not in data:
data["litellm_metadata"] = {}
data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key
data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["litellm_metadata"]["headers"] = _headers
data["litellm_metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["litellm_metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["litellm_metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["litellm_metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["litellm_metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.acreate_thread(
custom_llm_provider="openai", client=None, **data
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.get(
"/v1/threads/{thread_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.get(
"/threads/{thread_id}",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def get_thread(
request: Request,
thread_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
data: Dict = {}
try:
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_thread(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.post(
"/v1/threads/{thread_id}/messages",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.post(
"/threads/{thread_id}/messages",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def add_messages(
request: Request,
thread_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
data: Dict = {}
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if "litellm_metadata" not in data:
data["litellm_metadata"] = {}
data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key
data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["litellm_metadata"]["headers"] = _headers
data["litellm_metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["litellm_metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["litellm_metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["litellm_metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["litellm_metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.a_add_message(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.get(
"/v1/threads/{thread_id}/messages",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.get(
"/threads/{thread_id}/messages",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def get_messages(
request: Request,
thread_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
data: Dict = {}
try:
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["metadata"]["headers"] = _headers
data["metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.aget_messages(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
@router.get(
"/v1/threads/{thread_id}/runs",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
@router.get(
"/threads/{thread_id}/runs",
dependencies=[Depends(user_api_key_auth)],
tags=["assistants"],
)
async def run_thread(
request: Request,
thread_id: str,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
global proxy_logging_obj
data: Dict = {}
try:
body = await request.body()
data = orjson.loads(body)
# Include original request and headers in the data
data["proxy_server_request"] = { # type: ignore
"url": str(request.url),
"method": request.method,
"headers": dict(request.headers),
"body": copy.copy(data), # use copy instead of deepcopy
}
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
data["user"] = user_api_key_dict.user_id
if "litellm_metadata" not in data:
data["litellm_metadata"] = {}
data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key
data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
_headers = dict(request.headers)
_headers.pop(
"authorization", None
) # do not store the original `sk-..` api key in the db
data["litellm_metadata"]["headers"] = _headers
data["litellm_metadata"]["user_api_key_alias"] = getattr(
user_api_key_dict, "key_alias", None
)
data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
data["litellm_metadata"]["user_api_key_team_id"] = getattr(
user_api_key_dict, "team_id", None
)
data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get(
"global_max_parallel_requests", None
)
data["litellm_metadata"]["user_api_key_team_alias"] = getattr(
user_api_key_dict, "team_alias", None
)
data["litellm_metadata"]["endpoint"] = str(request.url)
### TEAM-SPECIFIC PARAMS ###
if user_api_key_dict.team_id is not None:
team_config = await proxy_config.load_team_config(
team_id=user_api_key_dict.team_id
)
if len(team_config) == 0:
pass
else:
team_id = team_config.pop("team_id", None)
data["litellm_metadata"]["team_id"] = team_id
data = {
**team_config,
**data,
} # add the team-specific configs to the completion call
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
if llm_router is None:
raise HTTPException(
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
)
response = await llm_router.arun_thread(
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
)
### ALERTING ###
data["litellm_status"] = "success" # used for alerting
### RESPONSE HEADERS ###
hidden_params = getattr(response, "_hidden_params", {}) or {}
model_id = hidden_params.get("model_id", None) or ""
cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
fastapi_response.headers.update(
get_custom_headers(
user_api_key_dict=user_api_key_dict,
model_id=model_id,
cache_key=cache_key,
api_base=api_base,
version=version,
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
)
)
return response
except Exception as e:
data["litellm_status"] = "fail" # used for alerting
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
)
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "message", str(e.detail)),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
)
else:
error_traceback = traceback.format_exc()
error_msg = f"{str(e)}"
raise ProxyException(
message=getattr(e, "message", error_msg),
type=getattr(e, "type", "None"),
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", 500),
)
######################################################################
# /v1/batches Endpoints

View file

@ -53,6 +53,16 @@ from litellm.types.router import (
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
from litellm.types.llms.openai import (
AsyncCursorPage,
Assistant,
Thread,
Attachment,
OpenAIMessage,
Run,
AssistantToolParam,
)
from typing import Iterable
class Router:
@ -1646,6 +1656,108 @@ class Router:
self.fail_calls[model_name] += 1
raise e
#### ASSISTANTS API ####
async def aget_assistants(
self,
custom_llm_provider: Literal["openai"],
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[Assistant]:
return await litellm.aget_assistants(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def acreate_thread(
self,
custom_llm_provider: Literal["openai"],
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
return await litellm.acreate_thread(
custom_llm_provider=custom_llm_provider, client=client, **kwargs
)
async def aget_thread(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Thread:
return await litellm.aget_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
client=client,
**kwargs,
)
async def a_add_message(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
role: Literal["user", "assistant"],
content: str,
attachments: Optional[List[Attachment]] = None,
metadata: Optional[dict] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> OpenAIMessage:
return await litellm.a_add_message(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
role=role,
content=content,
attachments=attachments,
metadata=metadata,
client=client,
**kwargs,
)
async def aget_messages(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> AsyncCursorPage[OpenAIMessage]:
return await litellm.aget_messages(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
client=client,
**kwargs,
)
async def arun_thread(
self,
custom_llm_provider: Literal["openai"],
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str] = None,
instructions: Optional[str] = None,
metadata: Optional[dict] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
tools: Optional[Iterable[AssistantToolParam]] = None,
client: Optional[AsyncOpenAI] = None,
**kwargs,
) -> Run:
return await litellm.arun_thread(
custom_llm_provider=custom_llm_provider,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
client=client,
**kwargs,
)
#### [END] ASSISTANTS API ####
async def async_function_with_fallbacks(self, *args, **kwargs):
"""
Try calling the function_with_retries

View file

@ -16,6 +16,7 @@ from litellm.llms.openai import (
MessageData,
Thread,
OpenAIMessage as Message,
AsyncCursorPage,
)
"""
@ -26,25 +27,54 @@ V0 Scope:
"""
def test_create_thread_litellm() -> Thread:
@pytest.mark.asyncio
async def test_async_get_assistants():
assistants = await litellm.aget_assistants(custom_llm_provider="openai")
assert isinstance(assistants, AsyncCursorPage)
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_create_thread_litellm(sync_mode) -> Thread:
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
if sync_mode:
new_thread = create_thread(
custom_llm_provider="openai",
messages=[message], # type: ignore
)
else:
new_thread = await litellm.acreate_thread(
custom_llm_provider="openai",
messages=[message], # type: ignore
)
assert isinstance(
new_thread, Thread
), f"type of thread={type(new_thread)}. Expected Thread-type"
return new_thread
def test_get_thread_litellm():
new_thread = test_create_thread_litellm()
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_get_thread_litellm(sync_mode):
new_thread = test_create_thread_litellm(sync_mode)
if asyncio.iscoroutine(new_thread):
_new_thread = await new_thread
else:
_new_thread = new_thread
if sync_mode:
received_thread = get_thread(
custom_llm_provider="openai",
thread_id=new_thread.id,
thread_id=_new_thread.id,
)
else:
received_thread = await litellm.aget_thread(
custom_llm_provider="openai",
thread_id=_new_thread.id,
)
assert isinstance(
@ -53,14 +83,25 @@ def test_get_thread_litellm():
return new_thread
def test_add_message_litellm():
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_add_message_litellm(sync_mode):
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
new_thread = test_create_thread_litellm()
new_thread = test_create_thread_litellm(sync_mode)
if asyncio.iscoroutine(new_thread):
_new_thread = await new_thread
else:
_new_thread = new_thread
# add message to thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
if sync_mode:
added_message = litellm.add_message(
thread_id=new_thread.id, custom_llm_provider="openai", **message
thread_id=_new_thread.id, custom_llm_provider="openai", **message
)
else:
added_message = await litellm.a_add_message(
thread_id=_new_thread.id, custom_llm_provider="openai", **message
)
print(f"added message: {added_message}")
@ -68,25 +109,37 @@ def test_add_message_litellm():
assert isinstance(added_message, Message)
def test_run_thread_litellm():
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_run_thread_litellm(sync_mode):
"""
- Get Assistants
- Create thread
- Create run w/ Assistants + Thread
"""
if sync_mode:
assistants = litellm.get_assistants(custom_llm_provider="openai")
else:
assistants = await litellm.aget_assistants(custom_llm_provider="openai")
## get the first assistant ###
assistant_id = assistants.data[0].id
new_thread = test_create_thread_litellm()
new_thread = test_create_thread_litellm(sync_mode=sync_mode)
thread_id = new_thread.id
if asyncio.iscoroutine(new_thread):
_new_thread = await new_thread
else:
_new_thread = new_thread
thread_id = _new_thread.id
# add message to thread
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
if sync_mode:
added_message = litellm.add_message(
thread_id=new_thread.id, custom_llm_provider="openai", **message
thread_id=_new_thread.id, custom_llm_provider="openai", **message
)
run = litellm.run_thread(
@ -95,7 +148,24 @@ def test_run_thread_litellm():
if run.status == "completed":
messages = litellm.get_messages(
thread_id=new_thread.id, custom_llm_provider="openai"
thread_id=_new_thread.id, custom_llm_provider="openai"
)
assert isinstance(messages.data[0], Message)
else:
pytest.fail("An unexpected error occurred when running the thread")
else:
added_message = await litellm.a_add_message(
thread_id=_new_thread.id, custom_llm_provider="openai", **message
)
run = await litellm.arun_thread(
custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id
)
if run.status == "completed":
messages = await litellm.aget_messages(
thread_id=_new_thread.id, custom_llm_provider="openai"
)
assert isinstance(messages.data[0], Message)
else:

View file

@ -17,11 +17,10 @@ from openai.types.beta.thread_create_params import (
from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.threads.run import Run
from openai.types.beta.assistant import Assistant
from openai.pagination import SyncCursorPage
from openai.pagination import SyncCursorPage, AsyncCursorPage
from os import PathLike
from openai.types import FileObject, Batch
from openai._legacy_response import HttpxBinaryResponseContent
from typing import TypedDict, List, Optional, Tuple, Mapping, IO
FileContent = Union[IO[bytes], bytes, PathLike]