fix(assistants/main.py): support litellm.create_thread() call

This commit is contained in:
Krrish Dholakia 2024-05-04 19:35:37 -07:00
parent 84c31a5528
commit 681a95e37b
7 changed files with 308 additions and 94 deletions

View file

@ -27,73 +27,7 @@ import aiohttp, requests
import litellm
from .prompt_templates.factory import prompt_factory, custom_prompt
from openai import OpenAI, AsyncOpenAI
from openai.types.beta.threads.message_content import MessageContent
from openai.types.beta.threads.message_create_params import Attachment
from openai.types.beta.threads.message import Message as OpenAIMessage
from openai.types.beta.thread_create_params import (
Message as OpenAICreateThreadParamsMessage,
)
from openai.types.beta.assistant_tool_param import AssistantToolParam
from openai.types.beta.threads.run import Run
from openai.types.beta.assistant import Assistant
from openai.pagination import SyncCursorPage
from typing import TypedDict, List, Optional
class NotGiven:
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
For example:
```py
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response:
...
get(timeout=1) # 1s timeout
get(timeout=None) # No timeout
get() # Default timeout behavior, which may not be statically known at the method definition.
```
"""
def __bool__(self) -> Literal[False]:
return False
@override
def __repr__(self) -> str:
return "NOT_GIVEN"
NOT_GIVEN = NotGiven()
class MessageData(TypedDict):
role: Literal["user", "assistant"]
content: str
attachments: Optional[List[Attachment]]
metadata: Optional[dict]
class Thread(BaseModel):
id: str
"""The identifier, which can be referenced in API endpoints."""
created_at: int
"""The Unix timestamp (in seconds) for when the thread was created."""
metadata: Optional[object] = None
"""Set of 16 key-value pairs that can be attached to an object.
This can be useful for storing additional information about the object in a
structured format. Keys can be a maximum of 64 characters long and values can be
a maxium of 512 characters long.
"""
object: Literal["thread"]
"""The object type, which is always `thread`."""
from ..types.llms.openai import *
class OpenAIError(Exception):
@ -1321,22 +1255,22 @@ class OpenAIAssistantsAPI(BaseLLM):
def get_openai_client(
self,
api_key: str,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: int,
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI] = None,
) -> OpenAI:
received_args = locals()
if client is None:
openai_client = OpenAI(
api_key=api_key,
base_url=api_base,
http_client=litellm.client_session,
timeout=timeout,
max_retries=max_retries,
organization=organization,
)
data = {}
for k, v in received_args.items():
if k == "self" or k == "client":
pass
elif v is not None:
data[k] = v
openai_client = OpenAI(**data) # type: ignore
else:
openai_client = client
@ -1428,16 +1362,14 @@ class OpenAIAssistantsAPI(BaseLLM):
def create_thread(
self,
metadata: dict,
api_key: str,
metadata: Optional[dict],
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: int,
max_retries: Optional[int],
organization: Optional[str],
client: Optional[OpenAI],
messages: Union[
Iterable[OpenAICreateThreadParamsMessage], NotGiven
] = NOT_GIVEN,
messages: Optional[Iterable[OpenAICreateThreadParamsMessage]],
) -> Thread:
"""
Here's an example:
@ -1458,10 +1390,13 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client,
)
message_thread = openai_client.beta.threads.create(
messages=messages, # type: ignore
metadata=metadata,
)
data = {}
if messages is not None:
data["messages"] = messages # type: ignore
if metadata is not None:
data["metadata"] = metadata # type: ignore
message_thread = openai_client.beta.threads.create(**data) # type: ignore
return Thread(**message_thread.dict())