forked from phoenix/litellm-mirror
feat(proxy_server.py): add assistants api endpoints to proxy server
This commit is contained in:
parent
3167bee25a
commit
e2b34165e7
6 changed files with 1741 additions and 51 deletions
|
@ -1,11 +1,12 @@
|
||||||
# What is this?
|
# What is this?
|
||||||
## Main file for assistants API logic
|
## Main file for assistants API logic
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
import os
|
from functools import partial
|
||||||
|
import os, asyncio, contextvars
|
||||||
import litellm
|
import litellm
|
||||||
from openai import OpenAI
|
from openai import OpenAI, AsyncOpenAI
|
||||||
from litellm import client
|
from litellm import client
|
||||||
from litellm.utils import supports_httpx_timeout
|
from litellm.utils import supports_httpx_timeout, exception_type, get_llm_provider
|
||||||
from ..llms.openai import OpenAIAssistantsAPI
|
from ..llms.openai import OpenAIAssistantsAPI
|
||||||
from ..types.llms.openai import *
|
from ..types.llms.openai import *
|
||||||
from ..types.router import *
|
from ..types.router import *
|
||||||
|
@ -16,11 +17,49 @@ openai_assistants_api = OpenAIAssistantsAPI()
|
||||||
### ASSISTANTS ###
|
### ASSISTANTS ###
|
||||||
|
|
||||||
|
|
||||||
|
async def aget_assistants(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncCursorPage[Assistant]:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
### PASS ARGS TO GET ASSISTANTS ###
|
||||||
|
kwargs["aget_assistants"] = True
|
||||||
|
try:
|
||||||
|
# Use a partial function to pass your keyword arguments
|
||||||
|
func = partial(get_assistants, custom_llm_provider, client, **kwargs)
|
||||||
|
|
||||||
|
# Add the context to the function
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
|
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
||||||
|
model="", custom_llm_provider=custom_llm_provider
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
# Await normally
|
||||||
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
if asyncio.iscoroutine(init_response):
|
||||||
|
response = await init_response
|
||||||
|
else:
|
||||||
|
response = init_response
|
||||||
|
return response # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise exception_type(
|
||||||
|
model="",
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
original_exception=e,
|
||||||
|
completion_kwargs={},
|
||||||
|
extra_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_assistants(
|
def get_assistants(
|
||||||
custom_llm_provider: Literal["openai"],
|
custom_llm_provider: Literal["openai"],
|
||||||
client: Optional[OpenAI] = None,
|
client: Optional[OpenAI] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> SyncCursorPage[Assistant]:
|
) -> SyncCursorPage[Assistant]:
|
||||||
|
aget_assistants = kwargs.pop("aget_assistants", None)
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
|
@ -67,6 +106,7 @@ def get_assistants(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
client=client,
|
client=client,
|
||||||
|
aget_assistants=aget_assistants,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -87,6 +127,39 @@ def get_assistants(
|
||||||
### THREADS ###
|
### THREADS ###
|
||||||
|
|
||||||
|
|
||||||
|
async def acreate_thread(custom_llm_provider: Literal["openai"], **kwargs) -> Thread:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
### PASS ARGS TO GET ASSISTANTS ###
|
||||||
|
kwargs["acreate_thread"] = True
|
||||||
|
try:
|
||||||
|
# Use a partial function to pass your keyword arguments
|
||||||
|
func = partial(create_thread, custom_llm_provider, **kwargs)
|
||||||
|
|
||||||
|
# Add the context to the function
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
|
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
||||||
|
model="", custom_llm_provider=custom_llm_provider
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
# Await normally
|
||||||
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
if asyncio.iscoroutine(init_response):
|
||||||
|
response = await init_response
|
||||||
|
else:
|
||||||
|
response = init_response
|
||||||
|
return response # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise exception_type(
|
||||||
|
model="",
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
original_exception=e,
|
||||||
|
completion_kwargs={},
|
||||||
|
extra_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_thread(
|
def create_thread(
|
||||||
custom_llm_provider: Literal["openai"],
|
custom_llm_provider: Literal["openai"],
|
||||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None,
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]] = None,
|
||||||
|
@ -117,6 +190,7 @@ def create_thread(
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
acreate_thread = kwargs.get("acreate_thread", None)
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
|
@ -165,6 +239,7 @@ def create_thread(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
client=client,
|
client=client,
|
||||||
|
acreate_thread=acreate_thread,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -182,6 +257,44 @@ def create_thread(
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def aget_thread(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Thread:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
### PASS ARGS TO GET ASSISTANTS ###
|
||||||
|
kwargs["aget_thread"] = True
|
||||||
|
try:
|
||||||
|
# Use a partial function to pass your keyword arguments
|
||||||
|
func = partial(get_thread, custom_llm_provider, thread_id, client, **kwargs)
|
||||||
|
|
||||||
|
# Add the context to the function
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
|
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
||||||
|
model="", custom_llm_provider=custom_llm_provider
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
# Await normally
|
||||||
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
if asyncio.iscoroutine(init_response):
|
||||||
|
response = await init_response
|
||||||
|
else:
|
||||||
|
response = init_response
|
||||||
|
return response # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise exception_type(
|
||||||
|
model="",
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
original_exception=e,
|
||||||
|
completion_kwargs={},
|
||||||
|
extra_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_thread(
|
def get_thread(
|
||||||
custom_llm_provider: Literal["openai"],
|
custom_llm_provider: Literal["openai"],
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
@ -189,6 +302,7 @@ def get_thread(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Thread:
|
) -> Thread:
|
||||||
"""Get the thread object, given a thread_id"""
|
"""Get the thread object, given a thread_id"""
|
||||||
|
aget_thread = kwargs.pop("aget_thread", None)
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
|
@ -236,6 +350,7 @@ def get_thread(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
client=client,
|
client=client,
|
||||||
|
aget_thread=aget_thread,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -256,6 +371,59 @@ def get_thread(
|
||||||
### MESSAGES ###
|
### MESSAGES ###
|
||||||
|
|
||||||
|
|
||||||
|
async def a_add_message(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
role: Literal["user", "assistant"],
|
||||||
|
content: str,
|
||||||
|
attachments: Optional[List[Attachment]] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> OpenAIMessage:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
### PASS ARGS TO GET ASSISTANTS ###
|
||||||
|
kwargs["a_add_message"] = True
|
||||||
|
try:
|
||||||
|
# Use a partial function to pass your keyword arguments
|
||||||
|
func = partial(
|
||||||
|
add_message,
|
||||||
|
custom_llm_provider,
|
||||||
|
thread_id,
|
||||||
|
role,
|
||||||
|
content,
|
||||||
|
attachments,
|
||||||
|
metadata,
|
||||||
|
client,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the context to the function
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
|
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
||||||
|
model="", custom_llm_provider=custom_llm_provider
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
# Await normally
|
||||||
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
if asyncio.iscoroutine(init_response):
|
||||||
|
response = await init_response
|
||||||
|
else:
|
||||||
|
# Call the synchronous function using run_in_executor
|
||||||
|
response = init_response
|
||||||
|
return response # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise exception_type(
|
||||||
|
model="",
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
original_exception=e,
|
||||||
|
completion_kwargs={},
|
||||||
|
extra_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def add_message(
|
def add_message(
|
||||||
custom_llm_provider: Literal["openai"],
|
custom_llm_provider: Literal["openai"],
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
@ -267,6 +435,7 @@ def add_message(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> OpenAIMessage:
|
) -> OpenAIMessage:
|
||||||
### COMMON OBJECTS ###
|
### COMMON OBJECTS ###
|
||||||
|
a_add_message = kwargs.pop("a_add_message", None)
|
||||||
message_data = MessageData(
|
message_data = MessageData(
|
||||||
role=role, content=content, attachments=attachments, metadata=metadata
|
role=role, content=content, attachments=attachments, metadata=metadata
|
||||||
)
|
)
|
||||||
|
@ -318,6 +487,7 @@ def add_message(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
client=client,
|
client=client,
|
||||||
|
a_add_message=a_add_message,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -336,12 +506,58 @@ def add_message(
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def aget_messages(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncCursorPage[OpenAIMessage]:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
### PASS ARGS TO GET ASSISTANTS ###
|
||||||
|
kwargs["aget_messages"] = True
|
||||||
|
try:
|
||||||
|
# Use a partial function to pass your keyword arguments
|
||||||
|
func = partial(
|
||||||
|
get_messages,
|
||||||
|
custom_llm_provider,
|
||||||
|
thread_id,
|
||||||
|
client,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the context to the function
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
|
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
||||||
|
model="", custom_llm_provider=custom_llm_provider
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
# Await normally
|
||||||
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
if asyncio.iscoroutine(init_response):
|
||||||
|
response = await init_response
|
||||||
|
else:
|
||||||
|
# Call the synchronous function using run_in_executor
|
||||||
|
response = init_response
|
||||||
|
return response # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise exception_type(
|
||||||
|
model="",
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
original_exception=e,
|
||||||
|
completion_kwargs={},
|
||||||
|
extra_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_messages(
|
def get_messages(
|
||||||
custom_llm_provider: Literal["openai"],
|
custom_llm_provider: Literal["openai"],
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
client: Optional[OpenAI] = None,
|
client: Optional[OpenAI] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> SyncCursorPage[OpenAIMessage]:
|
) -> SyncCursorPage[OpenAIMessage]:
|
||||||
|
aget_messages = kwargs.pop("aget_messages", None)
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
|
@ -389,6 +605,7 @@ def get_messages(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
client=client,
|
client=client,
|
||||||
|
aget_messages=aget_messages,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
@ -408,6 +625,63 @@ def get_messages(
|
||||||
|
|
||||||
|
|
||||||
### RUNS ###
|
### RUNS ###
|
||||||
|
async def arun_thread(
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str] = None,
|
||||||
|
instructions: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]] = None,
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Run:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
### PASS ARGS TO GET ASSISTANTS ###
|
||||||
|
kwargs["arun_thread"] = True
|
||||||
|
try:
|
||||||
|
# Use a partial function to pass your keyword arguments
|
||||||
|
func = partial(
|
||||||
|
run_thread,
|
||||||
|
custom_llm_provider,
|
||||||
|
thread_id,
|
||||||
|
assistant_id,
|
||||||
|
additional_instructions,
|
||||||
|
instructions,
|
||||||
|
metadata,
|
||||||
|
model,
|
||||||
|
stream,
|
||||||
|
tools,
|
||||||
|
client,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the context to the function
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
|
_, custom_llm_provider, _, _ = get_llm_provider( # type: ignore
|
||||||
|
model="", custom_llm_provider=custom_llm_provider
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
# Await normally
|
||||||
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
if asyncio.iscoroutine(init_response):
|
||||||
|
response = await init_response
|
||||||
|
else:
|
||||||
|
# Call the synchronous function using run_in_executor
|
||||||
|
response = init_response
|
||||||
|
return response # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
raise exception_type(
|
||||||
|
model="",
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
original_exception=e,
|
||||||
|
completion_kwargs={},
|
||||||
|
extra_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_thread(
|
def run_thread(
|
||||||
|
@ -424,6 +698,7 @@ def run_thread(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Run:
|
) -> Run:
|
||||||
"""Run a given thread + assistant."""
|
"""Run a given thread + assistant."""
|
||||||
|
arun_thread = kwargs.pop("arun_thread", None)
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
|
@ -478,6 +753,7 @@ def run_thread(
|
||||||
max_retries=optional_params.max_retries,
|
max_retries=optional_params.max_retries,
|
||||||
organization=organization,
|
organization=organization,
|
||||||
client=client,
|
client=client,
|
||||||
|
arun_thread=arun_thread,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise litellm.exceptions.BadRequestError(
|
raise litellm.exceptions.BadRequestError(
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import (
|
||||||
Literal,
|
Literal,
|
||||||
Iterable,
|
Iterable,
|
||||||
)
|
)
|
||||||
from typing_extensions import override
|
from typing_extensions import override, overload
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import types, time, json, traceback
|
import types, time, json, traceback
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -1846,8 +1846,85 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
|
|
||||||
return openai_client
|
return openai_client
|
||||||
|
|
||||||
|
def async_get_openai_client(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
) -> AsyncOpenAI:
|
||||||
|
received_args = locals()
|
||||||
|
if client is None:
|
||||||
|
data = {}
|
||||||
|
for k, v in received_args.items():
|
||||||
|
if k == "self" or k == "client":
|
||||||
|
pass
|
||||||
|
elif k == "api_base" and v is not None:
|
||||||
|
data["base_url"] = v
|
||||||
|
elif v is not None:
|
||||||
|
data[k] = v
|
||||||
|
openai_client = AsyncOpenAI(**data) # type: ignore
|
||||||
|
else:
|
||||||
|
openai_client = client
|
||||||
|
|
||||||
|
return openai_client
|
||||||
|
|
||||||
### ASSISTANTS ###
|
### ASSISTANTS ###
|
||||||
|
|
||||||
|
async def async_get_assistants(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AsyncOpenAI],
|
||||||
|
) -> AsyncCursorPage[Assistant]:
|
||||||
|
openai_client = self.async_get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await openai_client.beta.assistants.list()
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_assistants(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AsyncOpenAI],
|
||||||
|
aget_assistants: Literal[True],
|
||||||
|
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_assistants(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
aget_assistants: Optional[Literal[False]],
|
||||||
|
) -> SyncCursorPage[Assistant]:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def get_assistants(
|
def get_assistants(
|
||||||
self,
|
self,
|
||||||
api_key: Optional[str],
|
api_key: Optional[str],
|
||||||
|
@ -1855,8 +1932,18 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client: Optional[OpenAI],
|
client=None,
|
||||||
) -> SyncCursorPage[Assistant]:
|
aget_assistants=None,
|
||||||
|
):
|
||||||
|
if aget_assistants is not None and aget_assistants == True:
|
||||||
|
return self.async_get_assistants(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
openai_client = self.get_openai_client(
|
openai_client = self.get_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1872,6 +1959,72 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
|
|
||||||
### MESSAGES ###
|
### MESSAGES ###
|
||||||
|
|
||||||
|
async def a_add_message(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
message_data: MessageData,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
) -> OpenAIMessage:
|
||||||
|
openai_client = self.async_get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
|
||||||
|
thread_id, **message_data # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
response_obj: Optional[OpenAIMessage] = None
|
||||||
|
if getattr(thread_message, "status", None) is None:
|
||||||
|
thread_message.status = "completed"
|
||||||
|
response_obj = OpenAIMessage(**thread_message.dict())
|
||||||
|
else:
|
||||||
|
response_obj = OpenAIMessage(**thread_message.dict())
|
||||||
|
return response_obj
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def add_message(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
message_data: MessageData,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AsyncOpenAI],
|
||||||
|
a_add_message: Literal[True],
|
||||||
|
) -> Coroutine[None, None, OpenAIMessage]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def add_message(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
message_data: MessageData,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
a_add_message: Optional[Literal[False]],
|
||||||
|
) -> OpenAIMessage:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def add_message(
|
def add_message(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
@ -1881,9 +2034,20 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client: Optional[OpenAI] = None,
|
client=None,
|
||||||
) -> OpenAIMessage:
|
a_add_message: Optional[bool] = None,
|
||||||
|
):
|
||||||
|
if a_add_message is not None and a_add_message == True:
|
||||||
|
return self.a_add_message(
|
||||||
|
thread_id=thread_id,
|
||||||
|
message_data=message_data,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
openai_client = self.get_openai_client(
|
openai_client = self.get_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1905,6 +2069,61 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
response_obj = OpenAIMessage(**thread_message.dict())
|
response_obj = OpenAIMessage(**thread_message.dict())
|
||||||
return response_obj
|
return response_obj
|
||||||
|
|
||||||
|
async def async_get_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
) -> AsyncCursorPage[OpenAIMessage]:
|
||||||
|
openai_client = self.async_get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AsyncOpenAI],
|
||||||
|
aget_messages: Literal[True],
|
||||||
|
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
aget_messages: Optional[Literal[False]],
|
||||||
|
) -> SyncCursorPage[OpenAIMessage]:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def get_messages(
|
def get_messages(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
@ -1913,8 +2132,19 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client: Optional[OpenAI] = None,
|
client=None,
|
||||||
) -> SyncCursorPage[OpenAIMessage]:
|
aget_messages=None,
|
||||||
|
):
|
||||||
|
if aget_messages is not None and aget_messages == True:
|
||||||
|
return self.async_get_messages(
|
||||||
|
thread_id=thread_id,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
openai_client = self.get_openai_client(
|
openai_client = self.get_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1930,6 +2160,70 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
|
|
||||||
### THREADS ###
|
### THREADS ###
|
||||||
|
|
||||||
|
async def async_create_thread(
|
||||||
|
self,
|
||||||
|
metadata: Optional[dict],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AsyncOpenAI],
|
||||||
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
|
) -> Thread:
|
||||||
|
openai_client = self.async_get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
if messages is not None:
|
||||||
|
data["messages"] = messages # type: ignore
|
||||||
|
if metadata is not None:
|
||||||
|
data["metadata"] = metadata # type: ignore
|
||||||
|
|
||||||
|
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
|
||||||
|
|
||||||
|
return Thread(**message_thread.dict())
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def create_thread(
|
||||||
|
self,
|
||||||
|
metadata: Optional[dict],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
|
client: Optional[AsyncOpenAI],
|
||||||
|
acreate_thread: Literal[True],
|
||||||
|
) -> Coroutine[None, None, Thread]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def create_thread(
|
||||||
|
self,
|
||||||
|
metadata: Optional[dict],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
acreate_thread: Optional[Literal[False]],
|
||||||
|
) -> Thread:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def create_thread(
|
def create_thread(
|
||||||
self,
|
self,
|
||||||
metadata: Optional[dict],
|
metadata: Optional[dict],
|
||||||
|
@ -1938,9 +2232,10 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client: Optional[OpenAI],
|
|
||||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||||
) -> Thread:
|
client=None,
|
||||||
|
acreate_thread=None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Here's an example:
|
Here's an example:
|
||||||
```
|
```
|
||||||
|
@ -1951,6 +2246,17 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
openai_api.create_thread(messages=[message])
|
openai_api.create_thread(messages=[message])
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
if acreate_thread is not None and acreate_thread == True:
|
||||||
|
return self.async_create_thread(
|
||||||
|
metadata=metadata,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
openai_client = self.get_openai_client(
|
openai_client = self.get_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1970,6 +2276,61 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
|
|
||||||
return Thread(**message_thread.dict())
|
return Thread(**message_thread.dict())
|
||||||
|
|
||||||
|
async def async_get_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AsyncOpenAI],
|
||||||
|
) -> Thread:
|
||||||
|
openai_client = self.async_get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||||
|
|
||||||
|
return Thread(**response.dict())
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AsyncOpenAI],
|
||||||
|
aget_thread: Literal[True],
|
||||||
|
) -> Coroutine[None, None, Thread]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
aget_thread: Optional[Literal[False]],
|
||||||
|
) -> Thread:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def get_thread(
|
def get_thread(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
@ -1978,8 +2339,19 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client: Optional[OpenAI],
|
client=None,
|
||||||
) -> Thread:
|
aget_thread=None,
|
||||||
|
):
|
||||||
|
if aget_thread is not None and aget_thread == True:
|
||||||
|
return self.async_get_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
openai_client = self.get_openai_client(
|
openai_client = self.get_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -1998,6 +2370,90 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
|
|
||||||
### RUNS ###
|
### RUNS ###
|
||||||
|
|
||||||
|
async def arun_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AsyncOpenAI],
|
||||||
|
) -> Run:
|
||||||
|
openai_client = self.async_get_openai_client(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def run_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[AsyncOpenAI],
|
||||||
|
arun_thread: Literal[True],
|
||||||
|
) -> Coroutine[None, None, Run]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def run_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str],
|
||||||
|
instructions: Optional[str],
|
||||||
|
metadata: Optional[object],
|
||||||
|
model: Optional[str],
|
||||||
|
stream: Optional[bool],
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]],
|
||||||
|
api_key: Optional[str],
|
||||||
|
api_base: Optional[str],
|
||||||
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
max_retries: Optional[int],
|
||||||
|
organization: Optional[str],
|
||||||
|
client: Optional[OpenAI],
|
||||||
|
arun_thread: Optional[Literal[False]],
|
||||||
|
) -> Run:
|
||||||
|
...
|
||||||
|
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def run_thread(
|
def run_thread(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
@ -2013,8 +2469,26 @@ class OpenAIAssistantsAPI(BaseLLM):
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
max_retries: Optional[int],
|
max_retries: Optional[int],
|
||||||
organization: Optional[str],
|
organization: Optional[str],
|
||||||
client: Optional[OpenAI],
|
client=None,
|
||||||
) -> Run:
|
arun_thread=None,
|
||||||
|
):
|
||||||
|
if arun_thread is not None and arun_thread == True:
|
||||||
|
return self.arun_thread(
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
tools=tools,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
timeout=timeout,
|
||||||
|
max_retries=max_retries,
|
||||||
|
organization=organization,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
openai_client = self.get_openai_client(
|
openai_client = self.get_openai_client(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
|
|
@ -5078,6 +5078,765 @@ async def audio_transcriptions(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
|
||||||
|
# /v1/assistant Endpoints
|
||||||
|
|
||||||
|
|
||||||
|
######################################################################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/v1/assistants",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["assistants"],
|
||||||
|
)
|
||||||
|
@router.get(
|
||||||
|
"/assistants",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["assistants"],
|
||||||
|
)
|
||||||
|
async def get_assistants(
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
global proxy_logging_obj
|
||||||
|
data: Dict = {}
|
||||||
|
try:
|
||||||
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
|
body = await request.body()
|
||||||
|
data = orjson.loads(body)
|
||||||
|
|
||||||
|
# Include original request and headers in the data
|
||||||
|
data["proxy_server_request"] = { # type: ignore
|
||||||
|
"url": str(request.url),
|
||||||
|
"method": request.method,
|
||||||
|
"headers": dict(request.headers),
|
||||||
|
"body": copy.copy(data), # use copy instead of deepcopy
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||||
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
|
if "metadata" not in data:
|
||||||
|
data["metadata"] = {}
|
||||||
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
|
_headers = dict(request.headers)
|
||||||
|
_headers.pop(
|
||||||
|
"authorization", None
|
||||||
|
) # do not store the original `sk-..` api key in the db
|
||||||
|
data["metadata"]["headers"] = _headers
|
||||||
|
data["metadata"]["user_api_key_alias"] = getattr(
|
||||||
|
user_api_key_dict, "key_alias", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||||
|
user_api_key_dict, "team_id", None
|
||||||
|
)
|
||||||
|
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||||
|
user_api_key_dict, "team_alias", None
|
||||||
|
)
|
||||||
|
data["metadata"]["endpoint"] = str(request.url)
|
||||||
|
|
||||||
|
### TEAM-SPECIFIC PARAMS ###
|
||||||
|
if user_api_key_dict.team_id is not None:
|
||||||
|
team_config = await proxy_config.load_team_config(
|
||||||
|
team_id=user_api_key_dict.team_id
|
||||||
|
)
|
||||||
|
if len(team_config) == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
team_id = team_config.pop("team_id", None)
|
||||||
|
data["metadata"]["team_id"] = team_id
|
||||||
|
data = {
|
||||||
|
**team_config,
|
||||||
|
**data,
|
||||||
|
} # add the team-specific configs to the completion call
|
||||||
|
|
||||||
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
|
)
|
||||||
|
response = await llm_router.aget_assistants(
|
||||||
|
custom_llm_provider="openai", client=None, **data
|
||||||
|
)
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
|
||||||
|
### RESPONSE HEADERS ###
|
||||||
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
|
cache_key = hidden_params.get("cache_key", None) or ""
|
||||||
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
|
fastapi_response.headers.update(
|
||||||
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
model_id=model_id,
|
||||||
|
cache_key=cache_key,
|
||||||
|
api_base=api_base,
|
||||||
|
version=version,
|
||||||
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", str(e.detail)),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}"
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", error_msg),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", 500),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/threads",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["assistants"],
|
||||||
|
)
|
||||||
|
@router.post(
|
||||||
|
"/threads",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["assistants"],
|
||||||
|
)
|
||||||
|
async def create_threads(
|
||||||
|
request: Request,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
global proxy_logging_obj
|
||||||
|
data: Dict = {}
|
||||||
|
try:
|
||||||
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
|
body = await request.body()
|
||||||
|
data = orjson.loads(body)
|
||||||
|
|
||||||
|
# Include original request and headers in the data
|
||||||
|
data["proxy_server_request"] = { # type: ignore
|
||||||
|
"url": str(request.url),
|
||||||
|
"method": request.method,
|
||||||
|
"headers": dict(request.headers),
|
||||||
|
"body": copy.copy(data), # use copy instead of deepcopy
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||||
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
|
if "litellm_metadata" not in data:
|
||||||
|
data["litellm_metadata"] = {}
|
||||||
|
data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
|
_headers = dict(request.headers)
|
||||||
|
_headers.pop(
|
||||||
|
"authorization", None
|
||||||
|
) # do not store the original `sk-..` api key in the db
|
||||||
|
data["litellm_metadata"]["headers"] = _headers
|
||||||
|
data["litellm_metadata"]["user_api_key_alias"] = getattr(
|
||||||
|
user_api_key_dict, "key_alias", None
|
||||||
|
)
|
||||||
|
data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["litellm_metadata"]["user_api_key_team_id"] = getattr(
|
||||||
|
user_api_key_dict, "team_id", None
|
||||||
|
)
|
||||||
|
data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
|
data["litellm_metadata"]["user_api_key_team_alias"] = getattr(
|
||||||
|
user_api_key_dict, "team_alias", None
|
||||||
|
)
|
||||||
|
data["litellm_metadata"]["endpoint"] = str(request.url)
|
||||||
|
|
||||||
|
### TEAM-SPECIFIC PARAMS ###
|
||||||
|
if user_api_key_dict.team_id is not None:
|
||||||
|
team_config = await proxy_config.load_team_config(
|
||||||
|
team_id=user_api_key_dict.team_id
|
||||||
|
)
|
||||||
|
if len(team_config) == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
team_id = team_config.pop("team_id", None)
|
||||||
|
data["litellm_metadata"]["team_id"] = team_id
|
||||||
|
data = {
|
||||||
|
**team_config,
|
||||||
|
**data,
|
||||||
|
} # add the team-specific configs to the completion call
|
||||||
|
|
||||||
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
|
)
|
||||||
|
response = await llm_router.acreate_thread(
|
||||||
|
custom_llm_provider="openai", client=None, **data
|
||||||
|
)
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
|
||||||
|
### RESPONSE HEADERS ###
|
||||||
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
|
cache_key = hidden_params.get("cache_key", None) or ""
|
||||||
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
|
fastapi_response.headers.update(
|
||||||
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
model_id=model_id,
|
||||||
|
cache_key=cache_key,
|
||||||
|
api_base=api_base,
|
||||||
|
version=version,
|
||||||
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", str(e.detail)),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}"
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", error_msg),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", 500),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/v1/threads/{thread_id}",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["assistants"],
|
||||||
|
)
|
||||||
|
@router.get(
|
||||||
|
"/threads/{thread_id}",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["assistants"],
|
||||||
|
)
|
||||||
|
async def get_thread(
|
||||||
|
request: Request,
|
||||||
|
thread_id: str,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
global proxy_logging_obj
|
||||||
|
data: Dict = {}
|
||||||
|
try:
|
||||||
|
|
||||||
|
# Include original request and headers in the data
|
||||||
|
data["proxy_server_request"] = { # type: ignore
|
||||||
|
"url": str(request.url),
|
||||||
|
"method": request.method,
|
||||||
|
"headers": dict(request.headers),
|
||||||
|
"body": copy.copy(data), # use copy instead of deepcopy
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||||
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
|
if "metadata" not in data:
|
||||||
|
data["metadata"] = {}
|
||||||
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
|
_headers = dict(request.headers)
|
||||||
|
_headers.pop(
|
||||||
|
"authorization", None
|
||||||
|
) # do not store the original `sk-..` api key in the db
|
||||||
|
data["metadata"]["headers"] = _headers
|
||||||
|
data["metadata"]["user_api_key_alias"] = getattr(
|
||||||
|
user_api_key_dict, "key_alias", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||||
|
user_api_key_dict, "team_id", None
|
||||||
|
)
|
||||||
|
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||||
|
user_api_key_dict, "team_alias", None
|
||||||
|
)
|
||||||
|
data["metadata"]["endpoint"] = str(request.url)
|
||||||
|
|
||||||
|
### TEAM-SPECIFIC PARAMS ###
|
||||||
|
if user_api_key_dict.team_id is not None:
|
||||||
|
team_config = await proxy_config.load_team_config(
|
||||||
|
team_id=user_api_key_dict.team_id
|
||||||
|
)
|
||||||
|
if len(team_config) == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
team_id = team_config.pop("team_id", None)
|
||||||
|
data["metadata"]["team_id"] = team_id
|
||||||
|
data = {
|
||||||
|
**team_config,
|
||||||
|
**data,
|
||||||
|
} # add the team-specific configs to the completion call
|
||||||
|
|
||||||
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
|
)
|
||||||
|
response = await llm_router.aget_thread(
|
||||||
|
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||||
|
)
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
|
||||||
|
### RESPONSE HEADERS ###
|
||||||
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
|
cache_key = hidden_params.get("cache_key", None) or ""
|
||||||
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
|
fastapi_response.headers.update(
|
||||||
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
model_id=model_id,
|
||||||
|
cache_key=cache_key,
|
||||||
|
api_base=api_base,
|
||||||
|
version=version,
|
||||||
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", str(e.detail)),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}"
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", error_msg),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", 500),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/threads/{thread_id}/messages",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["assistants"],
|
||||||
|
)
|
||||||
|
@router.post(
|
||||||
|
"/threads/{thread_id}/messages",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["assistants"],
|
||||||
|
)
|
||||||
|
async def add_messages(
|
||||||
|
request: Request,
|
||||||
|
thread_id: str,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
global proxy_logging_obj
|
||||||
|
data: Dict = {}
|
||||||
|
try:
|
||||||
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
|
body = await request.body()
|
||||||
|
data = orjson.loads(body)
|
||||||
|
|
||||||
|
# Include original request and headers in the data
|
||||||
|
data["proxy_server_request"] = { # type: ignore
|
||||||
|
"url": str(request.url),
|
||||||
|
"method": request.method,
|
||||||
|
"headers": dict(request.headers),
|
||||||
|
"body": copy.copy(data), # use copy instead of deepcopy
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||||
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
|
if "litellm_metadata" not in data:
|
||||||
|
data["litellm_metadata"] = {}
|
||||||
|
data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
|
_headers = dict(request.headers)
|
||||||
|
_headers.pop(
|
||||||
|
"authorization", None
|
||||||
|
) # do not store the original `sk-..` api key in the db
|
||||||
|
data["litellm_metadata"]["headers"] = _headers
|
||||||
|
data["litellm_metadata"]["user_api_key_alias"] = getattr(
|
||||||
|
user_api_key_dict, "key_alias", None
|
||||||
|
)
|
||||||
|
data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["litellm_metadata"]["user_api_key_team_id"] = getattr(
|
||||||
|
user_api_key_dict, "team_id", None
|
||||||
|
)
|
||||||
|
data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
|
data["litellm_metadata"]["user_api_key_team_alias"] = getattr(
|
||||||
|
user_api_key_dict, "team_alias", None
|
||||||
|
)
|
||||||
|
data["litellm_metadata"]["endpoint"] = str(request.url)
|
||||||
|
|
||||||
|
### TEAM-SPECIFIC PARAMS ###
|
||||||
|
if user_api_key_dict.team_id is not None:
|
||||||
|
team_config = await proxy_config.load_team_config(
|
||||||
|
team_id=user_api_key_dict.team_id
|
||||||
|
)
|
||||||
|
if len(team_config) == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
team_id = team_config.pop("team_id", None)
|
||||||
|
data["litellm_metadata"]["team_id"] = team_id
|
||||||
|
data = {
|
||||||
|
**team_config,
|
||||||
|
**data,
|
||||||
|
} # add the team-specific configs to the completion call
|
||||||
|
|
||||||
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
|
)
|
||||||
|
response = await llm_router.a_add_message(
|
||||||
|
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||||
|
)
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
|
||||||
|
### RESPONSE HEADERS ###
|
||||||
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
|
cache_key = hidden_params.get("cache_key", None) or ""
|
||||||
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
|
fastapi_response.headers.update(
|
||||||
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
model_id=model_id,
|
||||||
|
cache_key=cache_key,
|
||||||
|
api_base=api_base,
|
||||||
|
version=version,
|
||||||
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", str(e.detail)),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}"
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", error_msg),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", 500),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/v1/threads/{thread_id}/messages",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["assistants"],
|
||||||
|
)
|
||||||
|
@router.get(
|
||||||
|
"/threads/{thread_id}/messages",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["assistants"],
|
||||||
|
)
|
||||||
|
async def get_messages(
|
||||||
|
request: Request,
|
||||||
|
thread_id: str,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
global proxy_logging_obj
|
||||||
|
data: Dict = {}
|
||||||
|
try:
|
||||||
|
# Include original request and headers in the data
|
||||||
|
data["proxy_server_request"] = { # type: ignore
|
||||||
|
"url": str(request.url),
|
||||||
|
"method": request.method,
|
||||||
|
"headers": dict(request.headers),
|
||||||
|
"body": copy.copy(data), # use copy instead of deepcopy
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||||
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
|
if "metadata" not in data:
|
||||||
|
data["metadata"] = {}
|
||||||
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
|
_headers = dict(request.headers)
|
||||||
|
_headers.pop(
|
||||||
|
"authorization", None
|
||||||
|
) # do not store the original `sk-..` api key in the db
|
||||||
|
data["metadata"]["headers"] = _headers
|
||||||
|
data["metadata"]["user_api_key_alias"] = getattr(
|
||||||
|
user_api_key_dict, "key_alias", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||||
|
user_api_key_dict, "team_id", None
|
||||||
|
)
|
||||||
|
data["metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
|
data["metadata"]["user_api_key_team_alias"] = getattr(
|
||||||
|
user_api_key_dict, "team_alias", None
|
||||||
|
)
|
||||||
|
data["metadata"]["endpoint"] = str(request.url)
|
||||||
|
|
||||||
|
### TEAM-SPECIFIC PARAMS ###
|
||||||
|
if user_api_key_dict.team_id is not None:
|
||||||
|
team_config = await proxy_config.load_team_config(
|
||||||
|
team_id=user_api_key_dict.team_id
|
||||||
|
)
|
||||||
|
if len(team_config) == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
team_id = team_config.pop("team_id", None)
|
||||||
|
data["metadata"]["team_id"] = team_id
|
||||||
|
data = {
|
||||||
|
**team_config,
|
||||||
|
**data,
|
||||||
|
} # add the team-specific configs to the completion call
|
||||||
|
|
||||||
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
|
)
|
||||||
|
response = await llm_router.aget_messages(
|
||||||
|
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||||
|
)
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
|
||||||
|
### RESPONSE HEADERS ###
|
||||||
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
|
cache_key = hidden_params.get("cache_key", None) or ""
|
||||||
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
|
fastapi_response.headers.update(
|
||||||
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
model_id=model_id,
|
||||||
|
cache_key=cache_key,
|
||||||
|
api_base=api_base,
|
||||||
|
version=version,
|
||||||
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", str(e.detail)),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}"
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", error_msg),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", 500),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/v1/threads/{thread_id}/runs",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["assistants"],
|
||||||
|
)
|
||||||
|
@router.get(
|
||||||
|
"/threads/{thread_id}/runs",
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
tags=["assistants"],
|
||||||
|
)
|
||||||
|
async def run_thread(
|
||||||
|
request: Request,
|
||||||
|
thread_id: str,
|
||||||
|
fastapi_response: Response,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
global proxy_logging_obj
|
||||||
|
data: Dict = {}
|
||||||
|
try:
|
||||||
|
body = await request.body()
|
||||||
|
data = orjson.loads(body)
|
||||||
|
# Include original request and headers in the data
|
||||||
|
data["proxy_server_request"] = { # type: ignore
|
||||||
|
"url": str(request.url),
|
||||||
|
"method": request.method,
|
||||||
|
"headers": dict(request.headers),
|
||||||
|
"body": copy.copy(data), # use copy instead of deepcopy
|
||||||
|
}
|
||||||
|
|
||||||
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||||
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
|
if "litellm_metadata" not in data:
|
||||||
|
data["litellm_metadata"] = {}
|
||||||
|
data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||||
|
_headers = dict(request.headers)
|
||||||
|
_headers.pop(
|
||||||
|
"authorization", None
|
||||||
|
) # do not store the original `sk-..` api key in the db
|
||||||
|
data["litellm_metadata"]["headers"] = _headers
|
||||||
|
data["litellm_metadata"]["user_api_key_alias"] = getattr(
|
||||||
|
user_api_key_dict, "key_alias", None
|
||||||
|
)
|
||||||
|
data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
|
||||||
|
data["litellm_metadata"]["user_api_key_team_id"] = getattr(
|
||||||
|
user_api_key_dict, "team_id", None
|
||||||
|
)
|
||||||
|
data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get(
|
||||||
|
"global_max_parallel_requests", None
|
||||||
|
)
|
||||||
|
data["litellm_metadata"]["user_api_key_team_alias"] = getattr(
|
||||||
|
user_api_key_dict, "team_alias", None
|
||||||
|
)
|
||||||
|
data["litellm_metadata"]["endpoint"] = str(request.url)
|
||||||
|
|
||||||
|
### TEAM-SPECIFIC PARAMS ###
|
||||||
|
if user_api_key_dict.team_id is not None:
|
||||||
|
team_config = await proxy_config.load_team_config(
|
||||||
|
team_id=user_api_key_dict.team_id
|
||||||
|
)
|
||||||
|
if len(team_config) == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
team_id = team_config.pop("team_id", None)
|
||||||
|
data["litellm_metadata"]["team_id"] = team_id
|
||||||
|
data = {
|
||||||
|
**team_config,
|
||||||
|
**data,
|
||||||
|
} # add the team-specific configs to the completion call
|
||||||
|
|
||||||
|
# for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
|
)
|
||||||
|
response = await llm_router.arun_thread(
|
||||||
|
custom_llm_provider="openai", thread_id=thread_id, client=None, **data
|
||||||
|
)
|
||||||
|
|
||||||
|
### ALERTING ###
|
||||||
|
data["litellm_status"] = "success" # used for alerting
|
||||||
|
|
||||||
|
### RESPONSE HEADERS ###
|
||||||
|
hidden_params = getattr(response, "_hidden_params", {}) or {}
|
||||||
|
model_id = hidden_params.get("model_id", None) or ""
|
||||||
|
cache_key = hidden_params.get("cache_key", None) or ""
|
||||||
|
api_base = hidden_params.get("api_base", None) or ""
|
||||||
|
|
||||||
|
fastapi_response.headers.update(
|
||||||
|
get_custom_headers(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
model_id=model_id,
|
||||||
|
cache_key=cache_key,
|
||||||
|
api_base=api_base,
|
||||||
|
version=version,
|
||||||
|
model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
data["litellm_status"] = "fail" # used for alerting
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
|
||||||
|
)
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", str(e.detail)),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}"
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "message", error_msg),
|
||||||
|
type=getattr(e, "type", "None"),
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", 500),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
# /v1/batches Endpoints
|
# /v1/batches Endpoints
|
||||||
|
|
|
@ -53,6 +53,16 @@ from litellm.types.router import (
|
||||||
)
|
)
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||||
|
from litellm.types.llms.openai import (
|
||||||
|
AsyncCursorPage,
|
||||||
|
Assistant,
|
||||||
|
Thread,
|
||||||
|
Attachment,
|
||||||
|
OpenAIMessage,
|
||||||
|
Run,
|
||||||
|
AssistantToolParam,
|
||||||
|
)
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
|
||||||
class Router:
|
class Router:
|
||||||
|
@ -1646,6 +1656,108 @@ class Router:
|
||||||
self.fail_calls[model_name] += 1
|
self.fail_calls[model_name] += 1
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
#### ASSISTANTS API ####
|
||||||
|
|
||||||
|
async def aget_assistants(
|
||||||
|
self,
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncCursorPage[Assistant]:
|
||||||
|
return await litellm.aget_assistants(
|
||||||
|
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def acreate_thread(
|
||||||
|
self,
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Thread:
|
||||||
|
return await litellm.acreate_thread(
|
||||||
|
custom_llm_provider=custom_llm_provider, client=client, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aget_thread(
|
||||||
|
self,
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Thread:
|
||||||
|
return await litellm.aget_thread(
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
thread_id=thread_id,
|
||||||
|
client=client,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def a_add_message(
|
||||||
|
self,
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
role: Literal["user", "assistant"],
|
||||||
|
content: str,
|
||||||
|
attachments: Optional[List[Attachment]] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> OpenAIMessage:
|
||||||
|
return await litellm.a_add_message(
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
thread_id=thread_id,
|
||||||
|
role=role,
|
||||||
|
content=content,
|
||||||
|
attachments=attachments,
|
||||||
|
metadata=metadata,
|
||||||
|
client=client,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aget_messages(
|
||||||
|
self,
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncCursorPage[OpenAIMessage]:
|
||||||
|
return await litellm.aget_messages(
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
thread_id=thread_id,
|
||||||
|
client=client,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def arun_thread(
|
||||||
|
self,
|
||||||
|
custom_llm_provider: Literal["openai"],
|
||||||
|
thread_id: str,
|
||||||
|
assistant_id: str,
|
||||||
|
additional_instructions: Optional[str] = None,
|
||||||
|
instructions: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
tools: Optional[Iterable[AssistantToolParam]] = None,
|
||||||
|
client: Optional[AsyncOpenAI] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Run:
|
||||||
|
return await litellm.arun_thread(
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
thread_id=thread_id,
|
||||||
|
assistant_id=assistant_id,
|
||||||
|
additional_instructions=additional_instructions,
|
||||||
|
instructions=instructions,
|
||||||
|
metadata=metadata,
|
||||||
|
model=model,
|
||||||
|
stream=stream,
|
||||||
|
tools=tools,
|
||||||
|
client=client,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
#### [END] ASSISTANTS API ####
|
||||||
|
|
||||||
async def async_function_with_fallbacks(self, *args, **kwargs):
|
async def async_function_with_fallbacks(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Try calling the function_with_retries
|
Try calling the function_with_retries
|
||||||
|
|
|
@ -16,6 +16,7 @@ from litellm.llms.openai import (
|
||||||
MessageData,
|
MessageData,
|
||||||
Thread,
|
Thread,
|
||||||
OpenAIMessage as Message,
|
OpenAIMessage as Message,
|
||||||
|
AsyncCursorPage,
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -26,26 +27,55 @@ V0 Scope:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def test_create_thread_litellm() -> Thread:
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_get_assistants():
|
||||||
|
assistants = await litellm.aget_assistants(custom_llm_provider="openai")
|
||||||
|
assert isinstance(assistants, AsyncCursorPage)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_thread_litellm(sync_mode) -> Thread:
|
||||||
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
|
||||||
new_thread = create_thread(
|
|
||||||
custom_llm_provider="openai",
|
if sync_mode:
|
||||||
messages=[message], # type: ignore
|
new_thread = create_thread(
|
||||||
)
|
custom_llm_provider="openai",
|
||||||
|
messages=[message], # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_thread = await litellm.acreate_thread(
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
messages=[message], # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
new_thread, Thread
|
new_thread, Thread
|
||||||
), f"type of thread={type(new_thread)}. Expected Thread-type"
|
), f"type of thread={type(new_thread)}. Expected Thread-type"
|
||||||
|
|
||||||
return new_thread
|
return new_thread
|
||||||
|
|
||||||
|
|
||||||
def test_get_thread_litellm():
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
new_thread = test_create_thread_litellm()
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_thread_litellm(sync_mode):
|
||||||
|
new_thread = test_create_thread_litellm(sync_mode)
|
||||||
|
|
||||||
received_thread = get_thread(
|
if asyncio.iscoroutine(new_thread):
|
||||||
custom_llm_provider="openai",
|
_new_thread = await new_thread
|
||||||
thread_id=new_thread.id,
|
else:
|
||||||
)
|
_new_thread = new_thread
|
||||||
|
|
||||||
|
if sync_mode:
|
||||||
|
received_thread = get_thread(
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
thread_id=_new_thread.id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
received_thread = await litellm.aget_thread(
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
thread_id=_new_thread.id,
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
received_thread, Thread
|
received_thread, Thread
|
||||||
|
@ -53,50 +83,90 @@ def test_get_thread_litellm():
|
||||||
return new_thread
|
return new_thread
|
||||||
|
|
||||||
|
|
||||||
def test_add_message_litellm():
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_message_litellm(sync_mode):
|
||||||
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
|
||||||
new_thread = test_create_thread_litellm()
|
new_thread = test_create_thread_litellm(sync_mode)
|
||||||
|
|
||||||
|
if asyncio.iscoroutine(new_thread):
|
||||||
|
_new_thread = await new_thread
|
||||||
|
else:
|
||||||
|
_new_thread = new_thread
|
||||||
# add message to thread
|
# add message to thread
|
||||||
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
|
||||||
added_message = litellm.add_message(
|
if sync_mode:
|
||||||
thread_id=new_thread.id, custom_llm_provider="openai", **message
|
added_message = litellm.add_message(
|
||||||
)
|
thread_id=_new_thread.id, custom_llm_provider="openai", **message
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
added_message = await litellm.a_add_message(
|
||||||
|
thread_id=_new_thread.id, custom_llm_provider="openai", **message
|
||||||
|
)
|
||||||
|
|
||||||
print(f"added message: {added_message}")
|
print(f"added message: {added_message}")
|
||||||
|
|
||||||
assert isinstance(added_message, Message)
|
assert isinstance(added_message, Message)
|
||||||
|
|
||||||
|
|
||||||
def test_run_thread_litellm():
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_thread_litellm(sync_mode):
|
||||||
"""
|
"""
|
||||||
- Get Assistants
|
- Get Assistants
|
||||||
- Create thread
|
- Create thread
|
||||||
- Create run w/ Assistants + Thread
|
- Create run w/ Assistants + Thread
|
||||||
"""
|
"""
|
||||||
assistants = litellm.get_assistants(custom_llm_provider="openai")
|
if sync_mode:
|
||||||
|
assistants = litellm.get_assistants(custom_llm_provider="openai")
|
||||||
|
else:
|
||||||
|
assistants = await litellm.aget_assistants(custom_llm_provider="openai")
|
||||||
|
|
||||||
## get the first assistant ###
|
## get the first assistant ###
|
||||||
assistant_id = assistants.data[0].id
|
assistant_id = assistants.data[0].id
|
||||||
|
|
||||||
new_thread = test_create_thread_litellm()
|
new_thread = test_create_thread_litellm(sync_mode=sync_mode)
|
||||||
|
|
||||||
thread_id = new_thread.id
|
if asyncio.iscoroutine(new_thread):
|
||||||
|
_new_thread = await new_thread
|
||||||
|
else:
|
||||||
|
_new_thread = new_thread
|
||||||
|
|
||||||
|
thread_id = _new_thread.id
|
||||||
|
|
||||||
# add message to thread
|
# add message to thread
|
||||||
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
|
message: MessageData = {"role": "user", "content": "Hey, how's it going?"} # type: ignore
|
||||||
added_message = litellm.add_message(
|
|
||||||
thread_id=new_thread.id, custom_llm_provider="openai", **message
|
|
||||||
)
|
|
||||||
|
|
||||||
run = litellm.run_thread(
|
if sync_mode:
|
||||||
custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id
|
added_message = litellm.add_message(
|
||||||
)
|
thread_id=_new_thread.id, custom_llm_provider="openai", **message
|
||||||
|
|
||||||
if run.status == "completed":
|
|
||||||
messages = litellm.get_messages(
|
|
||||||
thread_id=new_thread.id, custom_llm_provider="openai"
|
|
||||||
)
|
)
|
||||||
assert isinstance(messages.data[0], Message)
|
|
||||||
|
run = litellm.run_thread(
|
||||||
|
custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if run.status == "completed":
|
||||||
|
messages = litellm.get_messages(
|
||||||
|
thread_id=_new_thread.id, custom_llm_provider="openai"
|
||||||
|
)
|
||||||
|
assert isinstance(messages.data[0], Message)
|
||||||
|
else:
|
||||||
|
pytest.fail("An unexpected error occurred when running the thread")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
pytest.fail("An unexpected error occurred when running the thread")
|
added_message = await litellm.a_add_message(
|
||||||
|
thread_id=_new_thread.id, custom_llm_provider="openai", **message
|
||||||
|
)
|
||||||
|
|
||||||
|
run = await litellm.arun_thread(
|
||||||
|
custom_llm_provider="openai", thread_id=thread_id, assistant_id=assistant_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if run.status == "completed":
|
||||||
|
messages = await litellm.aget_messages(
|
||||||
|
thread_id=_new_thread.id, custom_llm_provider="openai"
|
||||||
|
)
|
||||||
|
assert isinstance(messages.data[0], Message)
|
||||||
|
else:
|
||||||
|
pytest.fail("An unexpected error occurred when running the thread")
|
||||||
|
|
|
@ -17,11 +17,10 @@ from openai.types.beta.thread_create_params import (
|
||||||
from openai.types.beta.assistant_tool_param import AssistantToolParam
|
from openai.types.beta.assistant_tool_param import AssistantToolParam
|
||||||
from openai.types.beta.threads.run import Run
|
from openai.types.beta.threads.run import Run
|
||||||
from openai.types.beta.assistant import Assistant
|
from openai.types.beta.assistant import Assistant
|
||||||
from openai.pagination import SyncCursorPage
|
from openai.pagination import SyncCursorPage, AsyncCursorPage
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from openai.types import FileObject, Batch
|
from openai.types import FileObject, Batch
|
||||||
from openai._legacy_response import HttpxBinaryResponseContent
|
from openai._legacy_response import HttpxBinaryResponseContent
|
||||||
|
|
||||||
from typing import TypedDict, List, Optional, Tuple, Mapping, IO
|
from typing import TypedDict, List, Optional, Tuple, Mapping, IO
|
||||||
|
|
||||||
FileContent = Union[IO[bytes], bytes, PathLike]
|
FileContent = Union[IO[bytes], bytes, PathLike]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue