mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(assistants/main.py): add assistants api streaming support
This commit is contained in:
parent
7b474ec267
commit
f3d78532f9
9 changed files with 444 additions and 65 deletions
|
@ -31,6 +31,10 @@ from ..types.llms.openai import (
|
|||
Thread,
|
||||
AssistantToolParam,
|
||||
Run,
|
||||
AssistantEventHandler,
|
||||
AsyncAssistantEventHandler,
|
||||
AsyncAssistantStreamManager,
|
||||
AssistantStreamManager,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1975,27 +1979,67 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
)
|
||||
|
||||
response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
thread_id="thread_OHLZkEj5xJLxdk0REZ4cl9sP",
|
||||
assistant_id="asst_nIzr656D1GIVMLHOKD76bN2T",
|
||||
additional_instructions=None,
|
||||
instructions=None,
|
||||
metadata=None,
|
||||
model=None,
|
||||
tools=None,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
# thread_id=thread_id,
|
||||
# assistant_id=assistant_id,
|
||||
# additional_instructions=additional_instructions,
|
||||
# instructions=instructions,
|
||||
# metadata=metadata,
|
||||
# model=model,
|
||||
# tools=tools,
|
||||
# )
|
||||
|
||||
return response
|
||||
|
||||
async def async_run_thread_stream(
|
||||
self,
|
||||
client: AsyncAzureOpenAI,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]:
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"additional_instructions": additional_instructions,
|
||||
"instructions": instructions,
|
||||
"metadata": metadata,
|
||||
"model": model,
|
||||
"tools": tools,
|
||||
}
|
||||
if event_handler is not None:
|
||||
data["event_handler"] = event_handler
|
||||
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||
|
||||
def run_thread_stream(
|
||||
self,
|
||||
client: AzureOpenAI,
|
||||
thread_id: str,
|
||||
assistant_id: str,
|
||||
additional_instructions: Optional[str],
|
||||
instructions: Optional[str],
|
||||
metadata: Optional[object],
|
||||
model: Optional[str],
|
||||
tools: Optional[Iterable[AssistantToolParam]],
|
||||
event_handler: Optional[AssistantEventHandler],
|
||||
) -> AssistantStreamManager[AssistantEventHandler]:
|
||||
data = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"additional_instructions": additional_instructions,
|
||||
"instructions": instructions,
|
||||
"metadata": metadata,
|
||||
"model": model,
|
||||
"tools": tools,
|
||||
}
|
||||
if event_handler is not None:
|
||||
data["event_handler"] = event_handler
|
||||
return client.beta.threads.runs.stream(**data) # type: ignore
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
|
@ -2062,8 +2106,30 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
client=None,
|
||||
arun_thread=None,
|
||||
event_handler: Optional[AssistantEventHandler] = None,
|
||||
):
|
||||
if arun_thread is not None and arun_thread == True:
|
||||
if stream is not None and stream == True:
|
||||
azure_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
return self.async_run_thread_stream(
|
||||
client=azure_client,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
)
|
||||
return self.arun_thread(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
|
@ -2091,6 +2157,19 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
client=client,
|
||||
)
|
||||
|
||||
if stream is not None and stream == True:
|
||||
return self.run_thread_stream(
|
||||
client=openai_client,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
additional_instructions=additional_instructions,
|
||||
instructions=instructions,
|
||||
metadata=metadata,
|
||||
model=model,
|
||||
tools=tools,
|
||||
event_handler=event_handler,
|
||||
)
|
||||
|
||||
response = openai_client.beta.threads.runs.create_and_poll( # type: ignore
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue