mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge pull request #3936 from BerriAI/litellm_assistants_api_proxy
feat(proxy_server.py): add assistants api endpoints to proxy server
This commit is contained in:
commit
08bae3185a
6 changed files with 1741 additions and 50 deletions
|
@ -6,7 +6,7 @@ from typing import (
|
|||
Literal,
|
||||
Iterable,
|
||||
)
|
||||
from typing_extensions import override
|
||||
from typing_extensions import override, overload
|
||||
from pydantic import BaseModel
|
||||
import types, time, json, traceback
|
||||
import httpx
|
||||
|
@ -1936,8 +1936,85 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
|
||||
return openai_client
|
||||
|
||||
def async_get_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
) -> AsyncOpenAI:
|
||||
received_args = locals()
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["base_url"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
openai_client = AsyncOpenAI(**data) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
### ASSISTANTS ###
|
||||
|
||||
async def async_get_assistants(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI],
|
||||
) -> AsyncCursorPage[Assistant]:
|
||||
openai_client = self.async_get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.assistants.list()
|
||||
|
||||
return response
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def get_assistants(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI],
|
||||
aget_assistants: Literal[True],
|
||||
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_assistants(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
aget_assistants: Optional[Literal[False]],
|
||||
) -> SyncCursorPage[Assistant]:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def get_assistants(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
|
@ -1945,8 +2022,18 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
) -> SyncCursorPage[Assistant]:
|
||||
client=None,
|
||||
aget_assistants=None,
|
||||
):
|
||||
if aget_assistants is not None and aget_assistants == True:
|
||||
return self.async_get_assistants(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
openai_client = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -1962,6 +2049,72 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
|
||||
### MESSAGES ###
|
||||
|
||||
async def a_add_message(
|
||||
self,
|
||||
thread_id: str,
|
||||
message_data: MessageData,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
) -> OpenAIMessage:
|
||||
openai_client = self.async_get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
|
||||
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
|
||||
thread_id, **message_data # type: ignore
|
||||
)
|
||||
|
||||
response_obj: Optional[OpenAIMessage] = None
|
||||
if getattr(thread_message, "status", None) is None:
|
||||
thread_message.status = "completed"
|
||||
response_obj = OpenAIMessage(**thread_message.dict())
|
||||
else:
|
||||
response_obj = OpenAIMessage(**thread_message.dict())
|
||||
return response_obj
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def add_message(
|
||||
self,
|
||||
thread_id: str,
|
||||
message_data: MessageData,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI],
|
||||
a_add_message: Literal[True],
|
||||
) -> Coroutine[None, None, OpenAIMessage]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def add_message(
|
||||
self,
|
||||
thread_id: str,
|
||||
message_data: MessageData,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
a_add_message: Optional[Literal[False]],
|
||||
) -> OpenAIMessage:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
thread_id: str,
|
||||
|
@ -1971,9 +2124,20 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI] = None,
|
||||
) -> OpenAIMessage:
|
||||
|
||||
client=None,
|
||||
a_add_message: Optional[bool] = None,
|
||||
):
|
||||
if a_add_message is not None and a_add_message == True:
|
||||
return self.a_add_message(
|
||||
thread_id=thread_id,
|
||||
message_data=message_data,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
openai_client = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -1995,6 +2159,61 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
response_obj = OpenAIMessage(**thread_message.dict())
|
||||
return response_obj
|
||||
|
||||
async def async_get_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI] = None,
|
||||
) -> AsyncCursorPage[OpenAIMessage]:
|
||||
openai_client = self.async_get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
|
||||
|
||||
return response
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def get_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI],
|
||||
aget_messages: Literal[True],
|
||||
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
aget_messages: Optional[Literal[False]],
|
||||
) -> SyncCursorPage[OpenAIMessage]:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
|
@ -2003,8 +2222,19 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI] = None,
|
||||
) -> SyncCursorPage[OpenAIMessage]:
|
||||
client=None,
|
||||
aget_messages=None,
|
||||
):
|
||||
if aget_messages is not None and aget_messages == True:
|
||||
return self.async_get_messages(
|
||||
thread_id=thread_id,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
openai_client = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -2020,6 +2250,70 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
|
||||
### THREADS ###
|
||||
|
||||
async def async_create_thread(
|
||||
self,
|
||||
metadata: Optional[dict],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI],
|
||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
) -> Thread:
|
||||
openai_client = self.async_get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
|
||||
data = {}
|
||||
if messages is not None:
|
||||
data["messages"] = messages # type: ignore
|
||||
if metadata is not None:
|
||||
data["metadata"] = metadata # type: ignore
|
||||
|
||||
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
|
||||
|
||||
return Thread(**message_thread.dict())
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def create_thread(
|
||||
self,
|
||||
metadata: Optional[dict],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
client: Optional[AsyncOpenAI],
|
||||
acreate_thread: Literal[True],
|
||||
) -> Coroutine[None, None, Thread]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def create_thread(
|
||||
self,
|
||||
metadata: Optional[dict],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
client: Optional[OpenAI],
|
||||
acreate_thread: Optional[Literal[False]],
|
||||
) -> Thread:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def create_thread(
|
||||
self,
|
||||
metadata: Optional[dict],
|
||||
|
@ -2028,9 +2322,10 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
|
||||
) -> Thread:
|
||||
client=None,
|
||||
acreate_thread=None,
|
||||
):
|
||||
"""
|
||||
Here's an example:
|
||||
```
|
||||
|
@ -2041,6 +2336,17 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
openai_api.create_thread(messages=[message])
|
||||
```
|
||||
"""
|
||||
if acreate_thread is not None and acreate_thread == True:
|
||||
return self.async_create_thread(
|
||||
metadata=metadata,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
messages=messages,
|
||||
)
|
||||
openai_client = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -2060,6 +2366,61 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
|
||||
return Thread(**message_thread.dict())
|
||||
|
||||
async def async_get_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI],
|
||||
) -> Thread:
|
||||
openai_client = self.async_get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
|
||||
|
||||
return Thread(**response.dict())
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def get_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI],
|
||||
aget_thread: Literal[True],
|
||||
) -> Coroutine[None, None, Thread]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
aget_thread: Optional[Literal[False]],
|
||||
) -> Thread:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def get_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
|
@ -2068,8 +2429,19 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
) -> Thread:
|
||||
client=None,
|
||||
aget_thread=None,
|
||||
):
|
||||
if aget_thread is not None and aget_thread == True:
|
||||
return self.async_get_thread(
|
||||
thread_id=thread_id,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
openai_client = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -2088,6 +2460,90 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
|
||||
### RUNS ###
|
||||
|
||||
async def arun_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI],
|
||||
) -> Run:
|
||||
openai_client = self.async_get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
|
||||
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def run_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI],
|
||||
arun_thread: Literal[True],
|
||||
) -> Coroutine[None, None, Run]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def run_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
arun_thread: Optional[Literal[False]],
|
||||
) -> Run:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def run_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
|
@ -2103,8 +2559,26 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[OpenAI],
|
||||
) -> Run:
|
||||
client=None,
|
||||
arun_thread=None,
|
||||
):
|
||||
if arun_thread is not None and arun_thread == True:
|
||||
return self.arun_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
stream=stream,
|
||||
tools=tools,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
)
|
||||
openai_client = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue