feat(assistants/main.py): add assistants api streaming support

This commit is contained in:
Krrish Dholakia 2024-06-04 16:30:35 -07:00
parent 7b474ec267
commit f3d78532f9
9 changed files with 444 additions and 65 deletions

View file

@ -31,6 +31,10 @@ from ..types.llms.openai import (
Thread,
AssistantToolParam,
Run,
AssistantEventHandler,
AsyncAssistantEventHandler,
AsyncAssistantStreamManager,
AssistantStreamManager,
)
@ -1975,27 +1979,67 @@ class AzureAssistantsAPI(BaseLLM):
)
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id="thread_OHLZkEj5xJLxdk0REZ4cl9sP",
assistant_id="asst_nIzr656D1GIVMLHOKD76bN2T",
additional_instructions=None,
instructions=None,
metadata=None,
model=None,
tools=None,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
)
# response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
# thread_id=thread_id,
# assistant_id=assistant_id,
# additional_instructions=additional_instructions,
# instructions=instructions,
# metadata=metadata,
# model=model,
# tools=tools,
# )
return response
async def async_run_thread_stream(
self,
client: AsyncAzureOpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
def run_thread_stream(
self,
client: AzureOpenAI,
thread_id: str,
assistant_id: str,
additional_instructions: Optional[str],
instructions: Optional[str],
metadata: Optional[object],
model: Optional[str],
tools: Optional[Iterable[AssistantToolParam]],
event_handler: Optional[AssistantEventHandler],
) -> AssistantStreamManager[AssistantEventHandler]:
data = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"additional_instructions": additional_instructions,
"instructions": instructions,
"metadata": metadata,
"model": model,
"tools": tools,
}
if event_handler is not None:
data["event_handler"] = event_handler
return client.beta.threads.runs.stream(**data) # type: ignore
# fmt: off
@overload
@ -2062,8 +2106,30 @@ class AzureAssistantsAPI(BaseLLM):
max_retries: Optional[int],
client=None,
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
):
if arun_thread is not None and arun_thread == True:
if stream is not None and stream == True:
azure_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
max_retries=max_retries,
client=client,
)
return self.async_run_thread_stream(
client=azure_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
return self.arun_thread(
thread_id=thread_id,
assistant_id=assistant_id,
@ -2091,6 +2157,19 @@ class AzureAssistantsAPI(BaseLLM):
client=client,
)
if stream is not None and stream == True:
return self.run_thread_stream(
client=openai_client,
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
instructions=instructions,
metadata=metadata,
model=model,
tools=tools,
event_handler=event_handler,
)
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
thread_id=thread_id,
assistant_id=assistant_id,