feat(proxy_server.py): add assistants api endpoints to proxy server

This commit is contained in:
Krrish Dholakia 2024-05-30 22:44:43 -07:00
parent 3167bee25a
commit e2b34165e7
6 changed files with 1741 additions and 51 deletions

View file

@ -6,7 +6,7 @@ from typing import (
Literal,
Iterable,
)
from typing_extensions import override
from typing_extensions import override, overload
from pydantic import BaseModel
import types, time, json, traceback
import httpx
@ -1846,8 +1846,85 @@ class OpenAIAssistantsAPI(BaseLLM):
return openai_client
def async_get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> AsyncOpenAI:
received_args = locals()
if client is None:
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif k == "api_base" and v is not None:
data["base_url"] = v
elif v is not None:
data[k] = v
openai_client = AsyncOpenAI(**data) # type: ignore
else:
openai_client = client
return openai_client
### ASSISTANTS ###
async def async_get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> AsyncCursorPage[Assistant]:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.assistants.list()
return response
# fmt: off
@overload
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_assistants: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[Assistant]]:
...
@overload
def get_assistants(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_assistants: Optional[Literal[False]],
) -> SyncCursorPage[Assistant]:
...
# fmt: on
def get_assistants(
self,
api_key: Optional[str],
@ -1855,8 +1932,18 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
) -> SyncCursorPage[Assistant]:
client=None,
aget_assistants=None,
):
if aget_assistants is not None and aget_assistants == True:
return self.async_get_assistants(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@ -1872,6 +1959,72 @@ class OpenAIAssistantsAPI(BaseLLM):
### MESSAGES ###
async def a_add_message(
self,
thread_id: str,
message_data: MessageData,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> OpenAIMessage:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore
thread_id, **message_data # type: ignore
)
response_obj: Optional[OpenAIMessage] = None
if getattr(thread_message, "status", None) is None:
thread_message.status = "completed"
response_obj = OpenAIMessage(**thread_message.dict())
else:
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
# fmt: off
@overload
def add_message(
self,
thread_id: str,
message_data: MessageData,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
a_add_message: Literal[True],
) -> Coroutine[None, None, OpenAIMessage]:
...
@overload
def add_message(
self,
thread_id: str,
message_data: MessageData,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
a_add_message: Optional[Literal[False]],
) -> OpenAIMessage:
...
# fmt: on
def add_message(
self,
thread_id: str,
@ -1881,9 +2034,20 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAIMessage:
client=None,
a_add_message: Optional[bool] = None,
):
if a_add_message is not None and a_add_message == True:
return self.a_add_message(
thread_id=thread_id,
message_data=message_data,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@ -1905,6 +2069,61 @@ class OpenAIAssistantsAPI(BaseLLM):
response_obj = OpenAIMessage(**thread_message.dict())
return response_obj
async def async_get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI] = None,
) -> AsyncCursorPage[OpenAIMessage]:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.messages.list(thread_id=thread_id)
return response
# fmt: off
@overload
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_messages: Literal[True],
) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]:
...
@overload
def get_messages(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_messages: Optional[Literal[False]],
) -> SyncCursorPage[OpenAIMessage]:
...
# fmt: on
def get_messages(
self,
thread_id: str,
@ -1913,8 +2132,19 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> SyncCursorPage[OpenAIMessage]:
client=None,
aget_messages=None,
):
if aget_messages is not None and aget_messages == True:
return self.async_get_messages(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@ -1930,6 +2160,70 @@ class OpenAIAssistantsAPI(BaseLLM):
### THREADS ###
async def async_create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
) -> Thread:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = await openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())
# fmt: off
@overload
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[AsyncOpenAI],
acreate_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def create_thread(
self,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
client: Optional[OpenAI],
acreate_thread: Optional[Literal[False]],
) -> Thread:
...
# fmt: on
def create_thread(
self,
metadata: Optional[dict],
@ -1938,9 +2232,10 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
) -> Thread:
client=None,
acreate_thread=None,
):
"""
Here's an example:
```
@ -1951,6 +2246,17 @@ class OpenAIAssistantsAPI(BaseLLM):
openai_api.create_thread(messages=[message])
```
"""
if acreate_thread is not None and acreate_thread == True:
return self.async_create_thread(
metadata=metadata,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
messages=messages,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@ -1970,6 +2276,61 @@ class OpenAIAssistantsAPI(BaseLLM):
return Thread(**message_thread.dict())
async def async_get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> Thread:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.retrieve(thread_id=thread_id)
return Thread(**response.dict())
# fmt: off
@overload
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
aget_thread: Literal[True],
) -> Coroutine[None, None, Thread]:
...
@overload
def get_thread(
self,
thread_id: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
aget_thread: Optional[Literal[False]],
) -> Thread:
...
# fmt: on
def get_thread(
self,
thread_id: str,
@ -1978,8 +2339,19 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
) -> Thread:
client=None,
aget_thread=None,
):
if aget_thread is not None and aget_thread == True:
return self.async_get_thread(
thread_id=thread_id,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,
@ -1998,6 +2370,90 @@ class OpenAIAssistantsAPI(BaseLLM):
### RUNS ###
async def arun_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
) -> Run:
openai_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
return response
# fmt: off
@overload
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[AsyncOpenAI],
arun_thread: Literal[True],
) -> Coroutine[None, None, Run]:
...
@overload
def run_thread(
self,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
stream: Optional[bool],
tools: Optional[Iterable[AssistantToolParam]],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
arun_thread: Optional[Literal[False]],
) -> Run:
...
# fmt: on
def run_thread(
self,
thread_id: str,
@ -2013,8 +2469,26 @@ class OpenAIAssistantsAPI(BaseLLM):
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
) -> Run:
client=None,
arun_thread=None,
):
if arun_thread is not None and arun_thread == True:
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
stream=stream,
tools=tools,
api_key=api_key,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
organization=organization,
client=client,
)
openai_client = self.get_openai_client(
api_key=api_key,
api_base=api_base,